From 0da83f231bc316e2a71f44c7abd4ee0baa993d6e Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Wed, 4 Feb 2026 10:27:37 +0200 Subject: [PATCH] feat(runner): implement concurrent benchmark execution - Add tokio::Semaphore for bounded parallelism - Split benchmark into sequential warmup and parallel measured phases - Refactor to spawn_benchmark_tasks() and spawn_single_iteration() - Add write_results() to aggregate and output records - Use u32 for CLI arguments (iters, warmup, concurrency, payload_bytes) --- runner/src/main.rs | 132 ++++++++++++++++++++++++++++++++------------- 1 file changed, 95 insertions(+), 37 deletions(-) diff --git a/runner/src/main.rs b/runner/src/main.rs index 083aea2..c0660d3 100644 --- a/runner/src/main.rs +++ b/runner/src/main.rs @@ -30,7 +30,7 @@ use std::{ sync::Arc, time::Instant, }; -use tokio::net::TcpStream; +use tokio::{net::TcpStream, sync::Semaphore, task::JoinHandle}; use tokio_rustls::TlsConnector; /// TLS benchmark runner. @@ -47,19 +47,19 @@ struct Args { /// Payload size in bytes to request from server. #[arg(long, default_value = "1024")] - payload_bytes: u64, + payload_bytes: u32, /// Number of benchmark iterations (excluding warmup). #[arg(long, default_value = "100")] - iters: u64, + iters: u32, /// Number of warmup iterations (not recorded). #[arg(long, default_value = "10")] - warmup: u64, + warmup: u32, /// Number of concurrent connections. #[arg(long, default_value = "1")] - concurrency: u64, + concurrency: u32, /// Output file for NDJSON records (stdout if not specified). #[arg(long)] @@ -146,7 +146,7 @@ fn build_tls_config(mode: KeyExchangeMode) -> miette::Result> #[allow(clippy::cast_possible_truncation)] // nanoseconds won't overflow u64 for reasonable durations async fn run_iteration( server: SocketAddr, - payload_bytes: u64, + payload_bytes: u32, tls_connector: &TlsConnector, server_name: &ServerName<'static>, ) -> miette::Result { @@ -165,15 +165,15 @@ async fn run_iteration( let handshake_ns = start.elapsed().as_nanos() as u64; - // Send request - write_request(&mut tls_stream, payload_bytes) - .await - .map_err(|e| miette!("write request failed: {e}"))?; + // Send request + write_request(&mut tls_stream, u64::from(payload_bytes)) + .await + .map_err(|e| miette!("write request failed: {e}"))?; - // Read response - read_payload(&mut tls_stream, payload_bytes) - .await - .map_err(|e| miette!("read payload failed: {e}"))?; + // Read response + read_payload(&mut tls_stream, u64::from(payload_bytes)) + .await + .map_err(|e| miette!("read payload failed: {e}"))?; let ttlb_ns = start.elapsed().as_nanos() as u64; @@ -188,8 +188,6 @@ async fn run_benchmark( tls_connector: TlsConnector, server_name: ServerName<'static>, ) -> miette::Result<()> { - let total_iters = args.warmup + args.iters; - // Open output file or use stdout let mut output: Box = match &args.out { Some(path) => { @@ -206,34 +204,22 @@ async fn run_benchmark( ); eprintln!(); - // TODO: Implement concurrency - for i in 0..total_iters { - let is_warmup = i < args.warmup; - - let result = run_iteration( + for _ in 0..args.warmup { + run_iteration( args.server, args.payload_bytes, &tls_connector, &server_name, ) .await?; - - if !is_warmup { - let record = BenchRecord { - iteration: i - args.warmup, - mode: args.mode, - payload_bytes: args.payload_bytes, - handshake_ns: result.handshake_ns, - ttlb_ns: result.ttlb_ns, - }; - - writeln!(output, "{record}").map_err(|e| miette!("failed to write record: {e}"))?; - } - - if is_warmup && i == args.warmup.saturating_sub(1) { - eprintln!("Warmup complete."); - } } + eprintln!("Warmup complete."); + + #[allow(clippy::cast_possible_truncation)] // concurrency is limited to reasonable values + let semaphore = Arc::new(Semaphore::new(args.concurrency as usize)); + let tasks = spawn_benchmark_tasks(&args, &semaphore, &tls_connector, &server_name); + + write_results(&mut output, tasks).await?; output .flush() @@ -243,6 +229,78 @@ async fn run_benchmark( Ok(()) } +type ReturnHandle = JoinHandle<(IterationResult, Option)>; + +fn spawn_benchmark_tasks( + args: &Args, + semaphore: &Arc, + tls_connector: &TlsConnector, + server_name: &ServerName<'static>, +) -> Vec { + let server = args.server; + let payload_bytes = args.payload_bytes; + let mode = args.mode; + + (0..args.iters) + .map(|i| { + spawn_single_iteration( + i, + payload_bytes, + mode, + server, + semaphore.clone(), + tls_connector.clone(), + server_name.clone(), + ) + }) + .collect() +} + +#[allow(clippy::too_many_arguments)] +fn spawn_single_iteration( + i: u32, + payload_bytes: u32, + mode: KeyExchangeMode, + server: SocketAddr, + semaphore: Arc, + tls_connector: TlsConnector, + server_name: ServerName<'static>, +) -> ReturnHandle { + tokio::spawn(async move { + let _permit = semaphore + .acquire() + .await + .expect("semaphore should not be closed"); + + let result = run_iteration(server, payload_bytes, &tls_connector, &server_name) + .await + .expect("iteration should not fail"); + + let record = BenchRecord { + iteration: u64::from(i), + mode, + payload_bytes: u64::from(payload_bytes), + handshake_ns: result.handshake_ns, + ttlb_ns: result.ttlb_ns, + }; + + (result, Some(record)) + }) +} + +async fn write_results( + output: &mut Box, + tasks: Vec, +) -> miette::Result<()> { + for task in tasks { + let (_result, record) = task.await.expect("task should not panic"); + if let Some(record) = record { + writeln!(output, "{record}").map_err(|e| miette!("failed to write record: {e}"))?; + } + } + Ok(()) +} + #[tokio::main] async fn main() -> miette::Result<()> { let args = Args::parse();