mirror of
https://github.com/kristoferssolo/tls-pq-bench.git
synced 2026-03-22 00:36:21 +00:00
feat(runner): use per-benchmark TLS mode instead of global mode
This commit is contained in:
@@ -12,7 +12,7 @@ use std::{fs::read_to_string, net::SocketAddr, path::Path};
|
|||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
pub struct BenchmarkConfig {
|
pub struct BenchmarkConfig {
|
||||||
pub mode: String,
|
pub mode: KeyExchangeMode,
|
||||||
pub payload: u32,
|
pub payload: u32,
|
||||||
pub iters: u32,
|
pub iters: u32,
|
||||||
pub warmup: u32,
|
pub warmup: u32,
|
||||||
@@ -61,7 +61,7 @@ pub fn load_from_file(path: &Path) -> error::Result<Config> {
|
|||||||
pub fn load_from_cli(args: &Args) -> error::Result<Config> {
|
pub fn load_from_cli(args: &Args) -> error::Result<Config> {
|
||||||
Ok(Config {
|
Ok(Config {
|
||||||
benchmarks: vec![BenchmarkConfig {
|
benchmarks: vec![BenchmarkConfig {
|
||||||
mode: args.mode.to_string(),
|
mode: args.mode,
|
||||||
payload: args.payload_bytes,
|
payload: args.payload_bytes,
|
||||||
iters: args.iters,
|
iters: args.iters,
|
||||||
warmup: args.warmup,
|
warmup: args.warmup,
|
||||||
@@ -73,21 +73,10 @@ pub fn load_from_cli(args: &Args) -> error::Result<Config> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
|
||||||
/// Get the key exchange mode from the first benchmark configuration.
|
|
||||||
#[must_use]
|
|
||||||
pub fn server_mode(&self) -> KeyExchangeMode {
|
|
||||||
self.benchmarks
|
|
||||||
.first()
|
|
||||||
.and_then(|b| b.mode.parse().ok())
|
|
||||||
.unwrap_or(KeyExchangeMode::X25519)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use claims::assert_ok;
|
use claims::{assert_err, assert_ok, assert_some};
|
||||||
|
|
||||||
const VALID_CONFIG: &str = r#"
|
const VALID_CONFIG: &str = r#"
|
||||||
[[benchmarks]]
|
[[benchmarks]]
|
||||||
@@ -124,7 +113,7 @@ server = "127.0.0.1:4433"
|
|||||||
"#;
|
"#;
|
||||||
let config = get_config_from_str(toml);
|
let config = get_config_from_str(toml);
|
||||||
assert_eq!(config.benchmarks.len(), 1);
|
assert_eq!(config.benchmarks.len(), 1);
|
||||||
assert_eq!(config.benchmarks[0].mode, "x25519");
|
assert_eq!(config.benchmarks[0].mode, KeyExchangeMode::X25519);
|
||||||
assert_eq!(config.benchmarks[0].payload, 1024);
|
assert_eq!(config.benchmarks[0].payload, 1024);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -132,8 +121,8 @@ server = "127.0.0.1:4433"
|
|||||||
fn valid_multiple_benchmarks() {
|
fn valid_multiple_benchmarks() {
|
||||||
let config = get_config_from_str(VALID_CONFIG);
|
let config = get_config_from_str(VALID_CONFIG);
|
||||||
assert_eq!(config.benchmarks.len(), 2);
|
assert_eq!(config.benchmarks.len(), 2);
|
||||||
assert_eq!(config.benchmarks[0].mode, "x25519");
|
assert_eq!(config.benchmarks[0].mode, KeyExchangeMode::X25519);
|
||||||
assert_eq!(config.benchmarks[1].mode, "x25519mlkem768");
|
assert_eq!(config.benchmarks[1].mode, KeyExchangeMode::X25519Mlkem768);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -147,8 +136,7 @@ warmup = 10
|
|||||||
concurrency = 1
|
concurrency = 1
|
||||||
server = "127.0.0.1:4433"
|
server = "127.0.0.1:4433"
|
||||||
"#;
|
"#;
|
||||||
let config = get_config_from_str(toml);
|
assert_err!(toml::from_str::<Config>(toml));
|
||||||
assert!(config.server_mode() == KeyExchangeMode::X25519); // fallback
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -163,8 +151,7 @@ concurrency = 1
|
|||||||
server = "127.0.0.1:4433"
|
server = "127.0.0.1:4433"
|
||||||
"#;
|
"#;
|
||||||
let config = get_config_from_str(toml);
|
let config = get_config_from_str(toml);
|
||||||
let result = validate_config(&config, toml, std::path::Path::new("test.toml"));
|
assert_err!(validate_config(&config, toml, Path::new("test.toml")));
|
||||||
assert!(result.is_err());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -179,8 +166,7 @@ concurrency = 1
|
|||||||
server = "127.0.0.1:4433"
|
server = "127.0.0.1:4433"
|
||||||
"#;
|
"#;
|
||||||
let config = get_config_from_str(toml);
|
let config = get_config_from_str(toml);
|
||||||
let result = validate_config(&config, toml, std::path::Path::new("test.toml"));
|
assert_err!(validate_config(&config, toml, Path::new("test.toml")));
|
||||||
assert!(result.is_err());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -195,8 +181,7 @@ concurrency = 0
|
|||||||
server = "127.0.0.1:4433"
|
server = "127.0.0.1:4433"
|
||||||
"#;
|
"#;
|
||||||
let config = get_config_from_str(toml);
|
let config = get_config_from_str(toml);
|
||||||
let result = validate_config(&config, toml, std::path::Path::new("test.toml"));
|
assert_err!(validate_config(&config, toml, Path::new("test.toml")));
|
||||||
assert!(result.is_err());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -218,7 +203,8 @@ concurrency = 1
|
|||||||
server = "127.0.0.1:4433"
|
server = "127.0.0.1:4433"
|
||||||
"#;
|
"#;
|
||||||
let config = get_config_from_str(toml);
|
let config = get_config_from_str(toml);
|
||||||
assert_eq!(config.server_mode(), KeyExchangeMode::X25519);
|
let benchmark = assert_some!(config.benchmarks.first());
|
||||||
|
assert_eq!(benchmark.mode, KeyExchangeMode::X25519);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -233,6 +219,7 @@ concurrency = 1
|
|||||||
server = "127.0.0.1:4433"
|
server = "127.0.0.1:4433"
|
||||||
"#;
|
"#;
|
||||||
let config = get_config_from_str(toml);
|
let config = get_config_from_str(toml);
|
||||||
assert_eq!(config.server_mode(), KeyExchangeMode::X25519Mlkem768);
|
let benchmark = assert_some!(config.benchmarks.first());
|
||||||
|
assert_eq!(benchmark.mode, KeyExchangeMode::X25519Mlkem768);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ use crate::{
|
|||||||
config::{BenchmarkConfig, Config},
|
config::{BenchmarkConfig, Config},
|
||||||
error::{self, ConfigError},
|
error::{self, ConfigError},
|
||||||
};
|
};
|
||||||
use common::{self, KeyExchangeMode};
|
|
||||||
use miette::{NamedSource, SourceSpan};
|
use miette::{NamedSource, SourceSpan};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
@@ -31,21 +30,6 @@ fn validate_benchmark(
|
|||||||
) -> error::Result<()> {
|
) -> error::Result<()> {
|
||||||
let src = NamedSource::new(path.display().to_string(), content.to_string());
|
let src = NamedSource::new(path.display().to_string(), content.to_string());
|
||||||
|
|
||||||
// Validate mode
|
|
||||||
if benchmark.mode.parse::<KeyExchangeMode>().is_err() {
|
|
||||||
return Err(ConfigError::ValidationError {
|
|
||||||
src,
|
|
||||||
span: find_field_span(content, idx, "mode"),
|
|
||||||
field: "mode".into(),
|
|
||||||
idx,
|
|
||||||
message: format!(
|
|
||||||
"Invalid key exchange mode '{}'. Valid values: 'x25519', 'x25519mlkem768'",
|
|
||||||
benchmark.mode
|
|
||||||
),
|
|
||||||
}
|
|
||||||
.into());
|
|
||||||
}
|
|
||||||
|
|
||||||
validate_positive_field(src.clone(), content, idx, "payload", benchmark.payload)?;
|
validate_positive_field(src.clone(), content, idx, "payload", benchmark.payload)?;
|
||||||
validate_positive_field(src.clone(), content, idx, "iters", benchmark.iters)?;
|
validate_positive_field(src.clone(), content, idx, "iters", benchmark.iters)?;
|
||||||
validate_positive_field(src, content, idx, "concurrency", benchmark.concurrency)?;
|
validate_positive_field(src, content, idx, "concurrency", benchmark.concurrency)?;
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ use common::{
|
|||||||
use miette::{Context, IntoDiagnostic};
|
use miette::{Context, IntoDiagnostic};
|
||||||
use runner::{
|
use runner::{
|
||||||
args::Args,
|
args::Args,
|
||||||
config::{load_from_cli, load_from_file},
|
config::{BenchmarkConfig, load_from_cli, load_from_file},
|
||||||
};
|
};
|
||||||
use rustls::{
|
use rustls::{
|
||||||
ClientConfig, DigitallySignedStruct, SignatureScheme,
|
ClientConfig, DigitallySignedStruct, SignatureScheme,
|
||||||
@@ -97,7 +97,7 @@ impl ServerCertVerifier for NoVerifier {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Build TLS client config for the given key exchange mode.
|
/// Build TLS client config for the given key exchange mode.
|
||||||
fn build_tls_config(mode: KeyExchangeMode) -> miette::Result<Arc<ClientConfig>> {
|
fn build_tls_config(mode: KeyExchangeMode) -> miette::Result<ClientConfig> {
|
||||||
let mut provider = aws_lc_rs::default_provider();
|
let mut provider = aws_lc_rs::default_provider();
|
||||||
provider.kx_groups = match mode {
|
provider.kx_groups = match mode {
|
||||||
KeyExchangeMode::X25519 => vec![X25519],
|
KeyExchangeMode::X25519 => vec![X25519],
|
||||||
@@ -112,7 +112,7 @@ fn build_tls_config(mode: KeyExchangeMode) -> miette::Result<Arc<ClientConfig>>
|
|||||||
.with_custom_certificate_verifier(Arc::new(NoVerifier))
|
.with_custom_certificate_verifier(Arc::new(NoVerifier))
|
||||||
.with_no_client_auth();
|
.with_no_client_auth();
|
||||||
|
|
||||||
Ok(Arc::new(config))
|
Ok(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Run a single benchmark iteration over TLS.
|
/// Run a single benchmark iteration over TLS.
|
||||||
@@ -156,9 +156,8 @@ async fn run_iteration(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::future_not_send)] // References held across await points
|
|
||||||
async fn run_benchmark(
|
async fn run_benchmark(
|
||||||
config: &runner::config::BenchmarkConfig,
|
config: &BenchmarkConfig,
|
||||||
tls_connector: &TlsConnector,
|
tls_connector: &TlsConnector,
|
||||||
server_name: &ServerName<'static>,
|
server_name: &ServerName<'static>,
|
||||||
) -> miette::Result<()> {
|
) -> miette::Result<()> {
|
||||||
@@ -195,7 +194,6 @@ async fn run_benchmark(
|
|||||||
let semaphore = Arc::new(Semaphore::new(config.concurrency as usize));
|
let semaphore = Arc::new(Semaphore::new(config.concurrency as usize));
|
||||||
let tasks = spawn_benchmark_tasks(config, &semaphore, tls_connector, server_name);
|
let tasks = spawn_benchmark_tasks(config, &semaphore, tls_connector, server_name);
|
||||||
|
|
||||||
// Output to stdout for now
|
|
||||||
{
|
{
|
||||||
let mut output = stdout();
|
let mut output = stdout();
|
||||||
write_results(&mut output, tasks).await?;
|
write_results(&mut output, tasks).await?;
|
||||||
@@ -219,17 +217,13 @@ fn spawn_benchmark_tasks(
|
|||||||
) -> Vec<ReturnHandle> {
|
) -> Vec<ReturnHandle> {
|
||||||
let server = config.server;
|
let server = config.server;
|
||||||
let payload_bytes = config.payload;
|
let payload_bytes = config.payload;
|
||||||
let mode = config
|
|
||||||
.mode
|
|
||||||
.parse::<KeyExchangeMode>()
|
|
||||||
.expect("mode should be valid");
|
|
||||||
|
|
||||||
(0..config.iters)
|
(0..config.iters)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
spawn_single_iteration(
|
spawn_single_iteration(
|
||||||
i,
|
i,
|
||||||
payload_bytes,
|
payload_bytes,
|
||||||
mode,
|
config.mode,
|
||||||
server,
|
server,
|
||||||
semaphore.clone(),
|
semaphore.clone(),
|
||||||
tls_connector.clone(),
|
tls_connector.clone(),
|
||||||
@@ -314,9 +308,6 @@ async fn main() -> miette::Result<()> {
|
|||||||
load_from_cli(&args)?
|
load_from_cli(&args)?
|
||||||
};
|
};
|
||||||
|
|
||||||
let tls_config = build_tls_config(config.server_mode())?;
|
|
||||||
let tls_connector = TlsConnector::from(tls_config);
|
|
||||||
|
|
||||||
let server_name = ServerName::try_from("localhost".to_string())
|
let server_name = ServerName::try_from("localhost".to_string())
|
||||||
.into_diagnostic()
|
.into_diagnostic()
|
||||||
.context("invalid server name")?;
|
.context("invalid server name")?;
|
||||||
@@ -331,6 +322,9 @@ async fn main() -> miette::Result<()> {
|
|||||||
"running benchmark"
|
"running benchmark"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let tls_config = build_tls_config(benchmark.mode)?;
|
||||||
|
let tls_connector = TlsConnector::from(Arc::new(tls_config));
|
||||||
|
|
||||||
run_benchmark(benchmark, &tls_connector, &server_name).await?;
|
run_benchmark(benchmark, &tls_connector, &server_name).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user