From 1c6625a04c65d270ccf80cdc3c14d0038db93a22 Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Fri, 6 Feb 2026 17:43:46 +0200 Subject: [PATCH] feat(runner): add TOML config file support for matrix benchmarks - Add toml and serde dependencies - Create config module with Config and BenchmarkConfig structs - Add --config CLI option for matrix benchmarks - Refactor `run_benchmark()` to accept BenchmarkConfig --- Cargo.lock | 78 +++++++++++++++++++++++ Cargo.toml | 1 + runner/Cargo.toml | 2 + runner/src/args.rs | 40 ++++++++++++ runner/src/config.rs | 64 +++++++++++++++++++ runner/src/lib.rs | 2 + runner/src/main.rs | 148 +++++++++++++++++-------------------------- 7 files changed, 246 insertions(+), 89 deletions(-) create mode 100644 runner/src/args.rs create mode 100644 runner/src/config.rs create mode 100644 runner/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 16ed252..ac19e9e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -328,6 +328,12 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + [[package]] name = "errno" version = "0.3.14" @@ -379,12 +385,28 @@ version = "0.32.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e629b9b98ef3dd8afe6ca2bd0f89306cec16d43d907889945bc5d6687f2f13c7" +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + [[package]] name = "heck" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown", +] + [[package]] name = "is_ci" version = "1.2.0" @@ -747,8 +769,10 @@ dependencies = [ "common", "miette", "rustls", + "serde", "tokio", "tokio-rustls", + "toml", "tracing", "tracing-subscriber", "uuid", @@ -872,6 +896,15 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_spanned" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8bbf91e5a4d6315eee45e704372590b30e260ee83af6639d64557f51b067776" +dependencies = [ + "serde_core", +] + [[package]] name = "server" version = "0.1.0" @@ -1123,6 +1156,45 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml" +version = "0.9.11+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3afc9a848309fe1aaffaed6e1546a7a14de1f935dc9d89d32afd9a44bab7c46" +dependencies = [ + "indexmap", + "serde_core", + "serde_spanned", + "toml_datetime", + "toml_parser", + "toml_writer", + "winnow", +] + +[[package]] +name = "toml_datetime" +version = "0.7.5+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92e1cfed4a3038bc5a127e35a2d360f145e1f4b971b551a2ba5fd7aedf7e1347" +dependencies = [ + "serde_core", +] + +[[package]] +name = "toml_parser" +version = "1.0.6+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3198b4b0a8e11f09dd03e133c0280504d0801269e9afa46362ffde1cbeebf44" +dependencies = [ + "winnow", +] + +[[package]] +name = "toml_writer" +version = "1.0.6+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab16f14aed21ee8bfd8ec22513f7287cd4a91aa92e44edfe2c17ddd004e92607" + [[package]] name = "tracing" version = "0.1.44" @@ -1459,6 +1531,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" +[[package]] +name = "winnow" +version = "0.7.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829" + [[package]] name = "wit-bindgen" version = "0.51.0" diff --git a/Cargo.toml b/Cargo.toml index 863bd31..696d2e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ serde = { version = "1", features = ["derive"] } serde_json = "1" strum = { version = "0.27", features = ["derive"] } thiserror = "2" +toml = "0.9" tokio = { version = "1", features = ["full"] } tokio-rustls = { version = "0.26", default-features = false, features = [ "tls12", diff --git a/runner/Cargo.toml b/runner/Cargo.toml index df32b2b..fba7cb6 100644 --- a/runner/Cargo.toml +++ b/runner/Cargo.toml @@ -14,6 +14,8 @@ tokio.workspace = true tracing.workspace = true tracing-subscriber.workspace = true uuid.workspace = true +serde.workspace = true +toml.workspace = true [lints] workspace = true diff --git a/runner/src/args.rs b/runner/src/args.rs new file mode 100644 index 0000000..f641f65 --- /dev/null +++ b/runner/src/args.rs @@ -0,0 +1,40 @@ +use clap::Parser; +use common::KeyExchangeMode; +use std::{net::SocketAddr, path::PathBuf}; + +/// TLS benchmark runner. +#[derive(Debug, Parser)] +#[command(name = "runner", version, about)] +pub struct Args { + /// Key exchange mode. + #[arg(long, default_value = "x25519")] + pub mode: KeyExchangeMode, + + /// Server address to connect to. + #[arg(long)] + pub server: SocketAddr, + + /// Payload size in bytes to request from server. + #[arg(long, default_value = "1024")] + pub payload_bytes: u32, + + /// Number of benchmark iterations (excluding warmup). + #[arg(long, default_value = "100")] + pub iters: u32, + + /// Number of warmup iterations (not recorded). + #[arg(long, default_value = "10")] + pub warmup: u32, + + /// Number of concurrent connections. + #[arg(long, default_value = "1")] + pub concurrency: u32, + + /// Output file for NDJSON records (stdout if not specified). + #[arg(long)] + pub out: Option, + + /// Config file for matrix benchmarks (TOML). + #[arg(long)] + pub config: Option, +} diff --git a/runner/src/config.rs b/runner/src/config.rs new file mode 100644 index 0000000..35acc62 --- /dev/null +++ b/runner/src/config.rs @@ -0,0 +1,64 @@ +use miette::{Context, IntoDiagnostic}; +use serde::Deserialize; +use std::{fs::read_to_string, net::SocketAddr, path::PathBuf}; + +#[derive(Debug, Clone, Deserialize)] +pub struct BenchmarkConfig { + pub mode: String, + pub payload: u32, + pub iters: u32, + pub warmup: u32, + pub concurrency: u32, + pub server: SocketAddr, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct Config { + pub benchmarks: Vec, +} + +/// Load benchmark configuration from a TOML file. +/// +/// # Errors +/// Returns an error if the file cannot be read or parsed. +pub fn load_from_file(path: &PathBuf) -> miette::Result { + let content = read_to_string(path) + .into_diagnostic() + .context(format!("failed to read config file: {}", path.display()))?; + let config: Config = toml::from_str(&content).into_diagnostic().context(format!( + "failed to parse TOML config from file {}", + path.display() + ))?; + Ok(config) +} + +/// Create benchmark configuration from CLI arguments. +/// +/// # Errors +/// Never returns an error, but returns Result for consistency. +pub fn load_from_cli(args: &crate::args::Args) -> miette::Result { + let mode = args.mode.to_string(); + Ok(Config { + benchmarks: vec![BenchmarkConfig { + mode, + payload: args.payload_bytes, + iters: args.iters, + warmup: args.warmup, + concurrency: args.concurrency, + server: args.server, + }], + }) +} + +impl Config { + /// Get the key exchange mode from the first benchmark configuration. + #[must_use] + pub fn server_mode(&self) -> KeyExchangeMode { + self.benchmarks + .first() + .and_then(|b| b.mode.parse().ok()) + .unwrap_or(KeyExchangeMode::X25519) + } +} + +use common::KeyExchangeMode; diff --git a/runner/src/lib.rs b/runner/src/lib.rs new file mode 100644 index 0000000..459d432 --- /dev/null +++ b/runner/src/lib.rs @@ -0,0 +1,2 @@ +pub mod args; +pub mod config; diff --git a/runner/src/main.rs b/runner/src/main.rs index c57fd49..7ef1cae 100644 --- a/runner/src/main.rs +++ b/runner/src/main.rs @@ -24,11 +24,8 @@ use rustls::{ }; use std::{ env, - fmt::Debug, - fs::File, - io::{BufWriter, Write, stdout}, + io::{Write, stdout}, net::SocketAddr, - path::PathBuf, sync::Arc, time::Instant, }; @@ -38,38 +35,8 @@ use tracing::info; use tracing_subscriber::EnvFilter; use uuid::Uuid; -/// TLS benchmark runner. -#[derive(Debug, Parser)] -#[command(name = "runner", version, about)] -struct Args { - /// Key exchange mode. - #[arg(long, default_value = "x25519")] - mode: KeyExchangeMode, - - /// Server address to connect to. - #[arg(long)] - server: SocketAddr, - - /// Payload size in bytes to request from server. - #[arg(long, default_value = "1024")] - payload_bytes: u32, - - /// Number of benchmark iterations (excluding warmup). - #[arg(long, default_value = "100")] - iters: u32, - - /// Number of warmup iterations (not recorded). - #[arg(long, default_value = "10")] - warmup: u32, - - /// Number of concurrent connections. - #[arg(long, default_value = "1")] - concurrency: u32, - - /// Output file for NDJSON records (stdout if not specified). - #[arg(long)] - out: Option, -} +use runner::args::Args; +use runner::config::{load_from_cli, load_from_file}; /// Result of a single benchmark iteration. struct IterationResult { @@ -183,44 +150,32 @@ async fn run_iteration( }) } +#[allow(clippy::future_not_send)] // References held across await points async fn run_benchmark( - args: Args, - tls_connector: TlsConnector, - server_name: ServerName<'static>, + config: &runner::config::BenchmarkConfig, + tls_connector: &TlsConnector, + server_name: &ServerName<'static>, ) -> miette::Result<()> { - let mut output: Box = match &args.out { - Some(path) => { - let file = - File::create(path).map_err(|e| miette!("failed to create output file: {e}"))?; - Box::new(BufWriter::new(file)) - } - None => Box::new(stdout()), - }; + let server = config.server; info!( - warmup = args.warmup, - iters = args.iters, - concurrency = args.concurrency, - "runnning benchmark iterations" + warmup = config.warmup, + iters = config.iters, + concurrency = config.concurrency, + "running benchmark iterations" ); - for _ in 0..args.warmup { - run_iteration( - args.server, - args.payload_bytes, - &tls_connector, - &server_name, - ) - .await?; + for _ in 0..config.warmup { + run_iteration(server, config.payload, tls_connector, server_name).await?; } info!("warmup complete"); let test_conn = tls_connector .connect( server_name.clone(), - TcpStream::connect(args.server) + TcpStream::connect(server) .await - .map_err(|e| miette!("failed to connect to server {}: {e}", args.server))?, + .map_err(|e| miette!("failed to connect to server {}: {e}", server))?, ) .await .map_err(|e| miette!("TLS handshake failed: {e}"))?; @@ -229,14 +184,17 @@ async fn run_benchmark( info!(cipher = ?cipher, "TLS handshake complete"); #[allow(clippy::cast_possible_truncation)] // concurrency is limited to reasonable values - let semaphore = Arc::new(Semaphore::new(args.concurrency as usize)); - let tasks = spawn_benchmark_tasks(&args, &semaphore, &tls_connector, &server_name); + let semaphore = Arc::new(Semaphore::new(config.concurrency as usize)); + let tasks = spawn_benchmark_tasks(config, &semaphore, tls_connector, server_name); - write_results(&mut output, tasks).await?; - - output - .flush() - .map_err(|e| miette!("failed to flush output: {e}"))?; + // Output to stdout for now + { + let mut output = stdout(); + write_results(&mut output, tasks).await?; + output + .flush() + .map_err(|e| miette!("failed to flush output: {e}"))?; + } info!("benchmark complete"); Ok(()) @@ -245,16 +203,19 @@ async fn run_benchmark( type ReturnHandle = JoinHandle<(IterationResult, Option)>; fn spawn_benchmark_tasks( - args: &Args, + config: &runner::config::BenchmarkConfig, semaphore: &Arc, tls_connector: &TlsConnector, server_name: &ServerName<'static>, ) -> Vec { - let server = args.server; - let payload_bytes = args.payload_bytes; - let mode = args.mode; + let server = config.server; + let payload_bytes = config.payload; + let mode = config + .mode + .parse::() + .expect("mode should be valid"); - (0..args.iters) + (0..config.iters) .map(|i| { spawn_single_iteration( i, @@ -300,10 +261,8 @@ fn spawn_single_iteration( }) } -async fn write_results( - output: &mut Box, - tasks: Vec, -) -> miette::Result<()> { +#[allow(clippy::future_not_send)] // dyn Write is not Send +async fn write_results(output: &mut dyn Write, tasks: Vec) -> miette::Result<()> { for task in tasks { let (_result, record) = task.await.expect("task should not panic"); if let Some(record) = record { @@ -332,22 +291,33 @@ async fn main() -> miette::Result<()> { ); let args = Args::parse(); - info!( - mode=%args.mode, - server=%args.server, - payload_bytes=%args.payload_bytes, - iters=%args.iters, - warmup=%args.warmup, - concurrency=%args.concurrency, - out=%args.out.as_ref().map_or("stdout", |p| p.to_str().unwrap_or("invalid")), - "runner configuration" - ); - let tls_config = build_tls_config(args.mode)?; + let config = if let Some(config_path) = &args.config { + info!(config_file = %config_path.display(), "loading config from file"); + load_from_file(config_path)? + } else { + info!("using CLI arguments"); + load_from_cli(&args)? + }; + + let tls_config = build_tls_config(config.server_mode())?; let tls_connector = TlsConnector::from(tls_config); let server_name = ServerName::try_from("localhost".to_string()) .map_err(|e| miette!("invalid server name: {e}"))?; - run_benchmark(args, tls_connector, server_name).await + for benchmark in &config.benchmarks { + info!( + mode = %benchmark.mode, + payload = benchmark.payload, + iters = benchmark.iters, + warmup = benchmark.warmup, + concurrency = benchmark.concurrency, + "running benchmark" + ); + + run_benchmark(benchmark, &tls_connector, &server_name).await?; + } + + Ok(()) }