diff --git a/runner/src/config/mod.rs b/runner/src/config/mod.rs index 0e625b7..84ae57a 100644 --- a/runner/src/config/mod.rs +++ b/runner/src/config/mod.rs @@ -12,7 +12,7 @@ use std::{fs::read_to_string, net::SocketAddr, path::Path}; #[derive(Debug, Clone, Deserialize)] pub struct BenchmarkConfig { - pub mode: String, + pub mode: KeyExchangeMode, pub payload: u32, pub iters: u32, pub warmup: u32, @@ -61,7 +61,7 @@ pub fn load_from_file(path: &Path) -> error::Result { pub fn load_from_cli(args: &Args) -> error::Result { Ok(Config { benchmarks: vec![BenchmarkConfig { - mode: args.mode.to_string(), + mode: args.mode, payload: args.payload_bytes, iters: args.iters, warmup: args.warmup, @@ -73,21 +73,10 @@ pub fn load_from_cli(args: &Args) -> error::Result { }) } -impl Config { - /// Get the key exchange mode from the first benchmark configuration. - #[must_use] - pub fn server_mode(&self) -> KeyExchangeMode { - self.benchmarks - .first() - .and_then(|b| b.mode.parse().ok()) - .unwrap_or(KeyExchangeMode::X25519) - } -} - #[cfg(test)] mod tests { use super::*; - use claims::assert_ok; + use claims::{assert_err, assert_ok, assert_some}; const VALID_CONFIG: &str = r#" [[benchmarks]] @@ -124,7 +113,7 @@ server = "127.0.0.1:4433" "#; let config = get_config_from_str(toml); assert_eq!(config.benchmarks.len(), 1); - assert_eq!(config.benchmarks[0].mode, "x25519"); + assert_eq!(config.benchmarks[0].mode, KeyExchangeMode::X25519); assert_eq!(config.benchmarks[0].payload, 1024); } @@ -132,8 +121,8 @@ server = "127.0.0.1:4433" fn valid_multiple_benchmarks() { let config = get_config_from_str(VALID_CONFIG); assert_eq!(config.benchmarks.len(), 2); - assert_eq!(config.benchmarks[0].mode, "x25519"); - assert_eq!(config.benchmarks[1].mode, "x25519mlkem768"); + assert_eq!(config.benchmarks[0].mode, KeyExchangeMode::X25519); + assert_eq!(config.benchmarks[1].mode, KeyExchangeMode::X25519Mlkem768); } #[test] @@ -147,8 +136,7 @@ warmup = 10 concurrency = 1 server = "127.0.0.1:4433" "#; - let config = get_config_from_str(toml); - assert!(config.server_mode() == KeyExchangeMode::X25519); // fallback + assert_err!(toml::from_str::(toml)); } #[test] @@ -163,8 +151,7 @@ concurrency = 1 server = "127.0.0.1:4433" "#; let config = get_config_from_str(toml); - let result = validate_config(&config, toml, std::path::Path::new("test.toml")); - assert!(result.is_err()); + assert_err!(validate_config(&config, toml, Path::new("test.toml"))); } #[test] @@ -179,8 +166,7 @@ concurrency = 1 server = "127.0.0.1:4433" "#; let config = get_config_from_str(toml); - let result = validate_config(&config, toml, std::path::Path::new("test.toml")); - assert!(result.is_err()); + assert_err!(validate_config(&config, toml, Path::new("test.toml"))); } #[test] @@ -195,8 +181,7 @@ concurrency = 0 server = "127.0.0.1:4433" "#; let config = get_config_from_str(toml); - let result = validate_config(&config, toml, std::path::Path::new("test.toml")); - assert!(result.is_err()); + assert_err!(validate_config(&config, toml, Path::new("test.toml"))); } #[test] @@ -218,7 +203,8 @@ concurrency = 1 server = "127.0.0.1:4433" "#; let config = get_config_from_str(toml); - assert_eq!(config.server_mode(), KeyExchangeMode::X25519); + let benchmark = assert_some!(config.benchmarks.first()); + assert_eq!(benchmark.mode, KeyExchangeMode::X25519); } #[test] @@ -233,6 +219,7 @@ concurrency = 1 server = "127.0.0.1:4433" "#; let config = get_config_from_str(toml); - assert_eq!(config.server_mode(), KeyExchangeMode::X25519Mlkem768); + let benchmark = assert_some!(config.benchmarks.first()); + assert_eq!(benchmark.mode, KeyExchangeMode::X25519Mlkem768); } } diff --git a/runner/src/config/utils.rs b/runner/src/config/utils.rs index 08698fa..7ab50f4 100644 --- a/runner/src/config/utils.rs +++ b/runner/src/config/utils.rs @@ -2,7 +2,6 @@ use crate::{ config::{BenchmarkConfig, Config}, error::{self, ConfigError}, }; -use common::{self, KeyExchangeMode}; use miette::{NamedSource, SourceSpan}; use std::path::Path; @@ -31,21 +30,6 @@ fn validate_benchmark( ) -> error::Result<()> { let src = NamedSource::new(path.display().to_string(), content.to_string()); - // Validate mode - if benchmark.mode.parse::().is_err() { - return Err(ConfigError::ValidationError { - src, - span: find_field_span(content, idx, "mode"), - field: "mode".into(), - idx, - message: format!( - "Invalid key exchange mode '{}'. Valid values: 'x25519', 'x25519mlkem768'", - benchmark.mode - ), - } - .into()); - } - validate_positive_field(src.clone(), content, idx, "payload", benchmark.payload)?; validate_positive_field(src.clone(), content, idx, "iters", benchmark.iters)?; validate_positive_field(src, content, idx, "concurrency", benchmark.concurrency)?; diff --git a/runner/src/main.rs b/runner/src/main.rs index 713ebf3..7ba88d0 100644 --- a/runner/src/main.rs +++ b/runner/src/main.rs @@ -14,7 +14,7 @@ use common::{ use miette::{Context, IntoDiagnostic}; use runner::{ args::Args, - config::{load_from_cli, load_from_file}, + config::{BenchmarkConfig, load_from_cli, load_from_file}, }; use rustls::{ ClientConfig, DigitallySignedStruct, SignatureScheme, @@ -97,7 +97,7 @@ impl ServerCertVerifier for NoVerifier { } /// Build TLS client config for the given key exchange mode. -fn build_tls_config(mode: KeyExchangeMode) -> miette::Result> { +fn build_tls_config(mode: KeyExchangeMode) -> miette::Result { let mut provider = aws_lc_rs::default_provider(); provider.kx_groups = match mode { KeyExchangeMode::X25519 => vec![X25519], @@ -112,7 +112,7 @@ fn build_tls_config(mode: KeyExchangeMode) -> miette::Result> .with_custom_certificate_verifier(Arc::new(NoVerifier)) .with_no_client_auth(); - Ok(Arc::new(config)) + Ok(config) } /// Run a single benchmark iteration over TLS. @@ -156,9 +156,8 @@ async fn run_iteration( }) } -#[allow(clippy::future_not_send)] // References held across await points async fn run_benchmark( - config: &runner::config::BenchmarkConfig, + config: &BenchmarkConfig, tls_connector: &TlsConnector, server_name: &ServerName<'static>, ) -> miette::Result<()> { @@ -195,7 +194,6 @@ async fn run_benchmark( let semaphore = Arc::new(Semaphore::new(config.concurrency as usize)); let tasks = spawn_benchmark_tasks(config, &semaphore, tls_connector, server_name); - // Output to stdout for now { let mut output = stdout(); write_results(&mut output, tasks).await?; @@ -219,17 +217,13 @@ fn spawn_benchmark_tasks( ) -> Vec { let server = config.server; let payload_bytes = config.payload; - let mode = config - .mode - .parse::() - .expect("mode should be valid"); (0..config.iters) .map(|i| { spawn_single_iteration( i, payload_bytes, - mode, + config.mode, server, semaphore.clone(), tls_connector.clone(), @@ -314,9 +308,6 @@ async fn main() -> miette::Result<()> { load_from_cli(&args)? }; - let tls_config = build_tls_config(config.server_mode())?; - let tls_connector = TlsConnector::from(tls_config); - let server_name = ServerName::try_from("localhost".to_string()) .into_diagnostic() .context("invalid server name")?; @@ -331,6 +322,9 @@ async fn main() -> miette::Result<()> { "running benchmark" ); + let tls_config = build_tls_config(benchmark.mode)?; + let tls_connector = TlsConnector::from(Arc::new(tls_config)); + run_benchmark(benchmark, &tls_connector, &server_name).await?; }