feat(runner): use per-benchmark TLS mode instead of global mode

This commit is contained in:
2026-02-25 16:18:20 +02:00
parent 07cae6df55
commit ea2a07d5aa
3 changed files with 22 additions and 57 deletions

View File

@@ -12,7 +12,7 @@ use std::{fs::read_to_string, net::SocketAddr, path::Path};
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
pub struct BenchmarkConfig { pub struct BenchmarkConfig {
pub mode: String, pub mode: KeyExchangeMode,
pub payload: u32, pub payload: u32,
pub iters: u32, pub iters: u32,
pub warmup: u32, pub warmup: u32,
@@ -61,7 +61,7 @@ pub fn load_from_file(path: &Path) -> error::Result<Config> {
pub fn load_from_cli(args: &Args) -> error::Result<Config> { pub fn load_from_cli(args: &Args) -> error::Result<Config> {
Ok(Config { Ok(Config {
benchmarks: vec![BenchmarkConfig { benchmarks: vec![BenchmarkConfig {
mode: args.mode.to_string(), mode: args.mode,
payload: args.payload_bytes, payload: args.payload_bytes,
iters: args.iters, iters: args.iters,
warmup: args.warmup, warmup: args.warmup,
@@ -73,21 +73,10 @@ pub fn load_from_cli(args: &Args) -> error::Result<Config> {
}) })
} }
impl Config {
/// Get the key exchange mode from the first benchmark configuration.
#[must_use]
pub fn server_mode(&self) -> KeyExchangeMode {
self.benchmarks
.first()
.and_then(|b| b.mode.parse().ok())
.unwrap_or(KeyExchangeMode::X25519)
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use claims::assert_ok; use claims::{assert_err, assert_ok, assert_some};
const VALID_CONFIG: &str = r#" const VALID_CONFIG: &str = r#"
[[benchmarks]] [[benchmarks]]
@@ -124,7 +113,7 @@ server = "127.0.0.1:4433"
"#; "#;
let config = get_config_from_str(toml); let config = get_config_from_str(toml);
assert_eq!(config.benchmarks.len(), 1); assert_eq!(config.benchmarks.len(), 1);
assert_eq!(config.benchmarks[0].mode, "x25519"); assert_eq!(config.benchmarks[0].mode, KeyExchangeMode::X25519);
assert_eq!(config.benchmarks[0].payload, 1024); assert_eq!(config.benchmarks[0].payload, 1024);
} }
@@ -132,8 +121,8 @@ server = "127.0.0.1:4433"
fn valid_multiple_benchmarks() { fn valid_multiple_benchmarks() {
let config = get_config_from_str(VALID_CONFIG); let config = get_config_from_str(VALID_CONFIG);
assert_eq!(config.benchmarks.len(), 2); assert_eq!(config.benchmarks.len(), 2);
assert_eq!(config.benchmarks[0].mode, "x25519"); assert_eq!(config.benchmarks[0].mode, KeyExchangeMode::X25519);
assert_eq!(config.benchmarks[1].mode, "x25519mlkem768"); assert_eq!(config.benchmarks[1].mode, KeyExchangeMode::X25519Mlkem768);
} }
#[test] #[test]
@@ -147,8 +136,7 @@ warmup = 10
concurrency = 1 concurrency = 1
server = "127.0.0.1:4433" server = "127.0.0.1:4433"
"#; "#;
let config = get_config_from_str(toml); assert_err!(toml::from_str::<Config>(toml));
assert!(config.server_mode() == KeyExchangeMode::X25519); // fallback
} }
#[test] #[test]
@@ -163,8 +151,7 @@ concurrency = 1
server = "127.0.0.1:4433" server = "127.0.0.1:4433"
"#; "#;
let config = get_config_from_str(toml); let config = get_config_from_str(toml);
let result = validate_config(&config, toml, std::path::Path::new("test.toml")); assert_err!(validate_config(&config, toml, Path::new("test.toml")));
assert!(result.is_err());
} }
#[test] #[test]
@@ -179,8 +166,7 @@ concurrency = 1
server = "127.0.0.1:4433" server = "127.0.0.1:4433"
"#; "#;
let config = get_config_from_str(toml); let config = get_config_from_str(toml);
let result = validate_config(&config, toml, std::path::Path::new("test.toml")); assert_err!(validate_config(&config, toml, Path::new("test.toml")));
assert!(result.is_err());
} }
#[test] #[test]
@@ -195,8 +181,7 @@ concurrency = 0
server = "127.0.0.1:4433" server = "127.0.0.1:4433"
"#; "#;
let config = get_config_from_str(toml); let config = get_config_from_str(toml);
let result = validate_config(&config, toml, std::path::Path::new("test.toml")); assert_err!(validate_config(&config, toml, Path::new("test.toml")));
assert!(result.is_err());
} }
#[test] #[test]
@@ -218,7 +203,8 @@ concurrency = 1
server = "127.0.0.1:4433" server = "127.0.0.1:4433"
"#; "#;
let config = get_config_from_str(toml); let config = get_config_from_str(toml);
assert_eq!(config.server_mode(), KeyExchangeMode::X25519); let benchmark = assert_some!(config.benchmarks.first());
assert_eq!(benchmark.mode, KeyExchangeMode::X25519);
} }
#[test] #[test]
@@ -233,6 +219,7 @@ concurrency = 1
server = "127.0.0.1:4433" server = "127.0.0.1:4433"
"#; "#;
let config = get_config_from_str(toml); let config = get_config_from_str(toml);
assert_eq!(config.server_mode(), KeyExchangeMode::X25519Mlkem768); let benchmark = assert_some!(config.benchmarks.first());
assert_eq!(benchmark.mode, KeyExchangeMode::X25519Mlkem768);
} }
} }

View File

@@ -2,7 +2,6 @@ use crate::{
config::{BenchmarkConfig, Config}, config::{BenchmarkConfig, Config},
error::{self, ConfigError}, error::{self, ConfigError},
}; };
use common::{self, KeyExchangeMode};
use miette::{NamedSource, SourceSpan}; use miette::{NamedSource, SourceSpan};
use std::path::Path; use std::path::Path;
@@ -31,21 +30,6 @@ fn validate_benchmark(
) -> error::Result<()> { ) -> error::Result<()> {
let src = NamedSource::new(path.display().to_string(), content.to_string()); let src = NamedSource::new(path.display().to_string(), content.to_string());
// Validate mode
if benchmark.mode.parse::<KeyExchangeMode>().is_err() {
return Err(ConfigError::ValidationError {
src,
span: find_field_span(content, idx, "mode"),
field: "mode".into(),
idx,
message: format!(
"Invalid key exchange mode '{}'. Valid values: 'x25519', 'x25519mlkem768'",
benchmark.mode
),
}
.into());
}
validate_positive_field(src.clone(), content, idx, "payload", benchmark.payload)?; validate_positive_field(src.clone(), content, idx, "payload", benchmark.payload)?;
validate_positive_field(src.clone(), content, idx, "iters", benchmark.iters)?; validate_positive_field(src.clone(), content, idx, "iters", benchmark.iters)?;
validate_positive_field(src, content, idx, "concurrency", benchmark.concurrency)?; validate_positive_field(src, content, idx, "concurrency", benchmark.concurrency)?;

View File

@@ -14,7 +14,7 @@ use common::{
use miette::{Context, IntoDiagnostic}; use miette::{Context, IntoDiagnostic};
use runner::{ use runner::{
args::Args, args::Args,
config::{load_from_cli, load_from_file}, config::{BenchmarkConfig, load_from_cli, load_from_file},
}; };
use rustls::{ use rustls::{
ClientConfig, DigitallySignedStruct, SignatureScheme, ClientConfig, DigitallySignedStruct, SignatureScheme,
@@ -97,7 +97,7 @@ impl ServerCertVerifier for NoVerifier {
} }
/// Build TLS client config for the given key exchange mode. /// Build TLS client config for the given key exchange mode.
fn build_tls_config(mode: KeyExchangeMode) -> miette::Result<Arc<ClientConfig>> { fn build_tls_config(mode: KeyExchangeMode) -> miette::Result<ClientConfig> {
let mut provider = aws_lc_rs::default_provider(); let mut provider = aws_lc_rs::default_provider();
provider.kx_groups = match mode { provider.kx_groups = match mode {
KeyExchangeMode::X25519 => vec![X25519], KeyExchangeMode::X25519 => vec![X25519],
@@ -112,7 +112,7 @@ fn build_tls_config(mode: KeyExchangeMode) -> miette::Result<Arc<ClientConfig>>
.with_custom_certificate_verifier(Arc::new(NoVerifier)) .with_custom_certificate_verifier(Arc::new(NoVerifier))
.with_no_client_auth(); .with_no_client_auth();
Ok(Arc::new(config)) Ok(config)
} }
/// Run a single benchmark iteration over TLS. /// Run a single benchmark iteration over TLS.
@@ -156,9 +156,8 @@ async fn run_iteration(
}) })
} }
#[allow(clippy::future_not_send)] // References held across await points
async fn run_benchmark( async fn run_benchmark(
config: &runner::config::BenchmarkConfig, config: &BenchmarkConfig,
tls_connector: &TlsConnector, tls_connector: &TlsConnector,
server_name: &ServerName<'static>, server_name: &ServerName<'static>,
) -> miette::Result<()> { ) -> miette::Result<()> {
@@ -195,7 +194,6 @@ async fn run_benchmark(
let semaphore = Arc::new(Semaphore::new(config.concurrency as usize)); let semaphore = Arc::new(Semaphore::new(config.concurrency as usize));
let tasks = spawn_benchmark_tasks(config, &semaphore, tls_connector, server_name); let tasks = spawn_benchmark_tasks(config, &semaphore, tls_connector, server_name);
// Output to stdout for now
{ {
let mut output = stdout(); let mut output = stdout();
write_results(&mut output, tasks).await?; write_results(&mut output, tasks).await?;
@@ -219,17 +217,13 @@ fn spawn_benchmark_tasks(
) -> Vec<ReturnHandle> { ) -> Vec<ReturnHandle> {
let server = config.server; let server = config.server;
let payload_bytes = config.payload; let payload_bytes = config.payload;
let mode = config
.mode
.parse::<KeyExchangeMode>()
.expect("mode should be valid");
(0..config.iters) (0..config.iters)
.map(|i| { .map(|i| {
spawn_single_iteration( spawn_single_iteration(
i, i,
payload_bytes, payload_bytes,
mode, config.mode,
server, server,
semaphore.clone(), semaphore.clone(),
tls_connector.clone(), tls_connector.clone(),
@@ -314,9 +308,6 @@ async fn main() -> miette::Result<()> {
load_from_cli(&args)? load_from_cli(&args)?
}; };
let tls_config = build_tls_config(config.server_mode())?;
let tls_connector = TlsConnector::from(tls_config);
let server_name = ServerName::try_from("localhost".to_string()) let server_name = ServerName::try_from("localhost".to_string())
.into_diagnostic() .into_diagnostic()
.context("invalid server name")?; .context("invalid server name")?;
@@ -331,6 +322,9 @@ async fn main() -> miette::Result<()> {
"running benchmark" "running benchmark"
); );
let tls_config = build_tls_config(benchmark.mode)?;
let tls_connector = TlsConnector::from(Arc::new(tls_config));
run_benchmark(benchmark, &tls_connector, &server_name).await?; run_benchmark(benchmark, &tls_connector, &server_name).await?;
} }