feat(runner): implement concurrent benchmark execution

- Add tokio::Semaphore for bounded parallelism
- Split benchmark into sequential warmup and parallel measured phases
- Refactor to spawn_benchmark_tasks() and spawn_single_iteration()
- Add write_results() to aggregate and output records
- Use u32 for CLI arguments (iters, warmup, concurrency, payload_bytes)
This commit is contained in:
2026-02-04 10:27:37 +02:00
parent 2cd2edc61d
commit 0da83f231b

View File

@@ -30,7 +30,7 @@ use std::{
sync::Arc, sync::Arc,
time::Instant, time::Instant,
}; };
use tokio::net::TcpStream; use tokio::{net::TcpStream, sync::Semaphore, task::JoinHandle};
use tokio_rustls::TlsConnector; use tokio_rustls::TlsConnector;
/// TLS benchmark runner. /// TLS benchmark runner.
@@ -47,19 +47,19 @@ struct Args {
/// Payload size in bytes to request from server. /// Payload size in bytes to request from server.
#[arg(long, default_value = "1024")] #[arg(long, default_value = "1024")]
payload_bytes: u64, payload_bytes: u32,
/// Number of benchmark iterations (excluding warmup). /// Number of benchmark iterations (excluding warmup).
#[arg(long, default_value = "100")] #[arg(long, default_value = "100")]
iters: u64, iters: u32,
/// Number of warmup iterations (not recorded). /// Number of warmup iterations (not recorded).
#[arg(long, default_value = "10")] #[arg(long, default_value = "10")]
warmup: u64, warmup: u32,
/// Number of concurrent connections. /// Number of concurrent connections.
#[arg(long, default_value = "1")] #[arg(long, default_value = "1")]
concurrency: u64, concurrency: u32,
/// Output file for NDJSON records (stdout if not specified). /// Output file for NDJSON records (stdout if not specified).
#[arg(long)] #[arg(long)]
@@ -146,7 +146,7 @@ fn build_tls_config(mode: KeyExchangeMode) -> miette::Result<Arc<ClientConfig>>
#[allow(clippy::cast_possible_truncation)] // nanoseconds won't overflow u64 for reasonable durations #[allow(clippy::cast_possible_truncation)] // nanoseconds won't overflow u64 for reasonable durations
async fn run_iteration( async fn run_iteration(
server: SocketAddr, server: SocketAddr,
payload_bytes: u64, payload_bytes: u32,
tls_connector: &TlsConnector, tls_connector: &TlsConnector,
server_name: &ServerName<'static>, server_name: &ServerName<'static>,
) -> miette::Result<IterationResult> { ) -> miette::Result<IterationResult> {
@@ -165,15 +165,15 @@ async fn run_iteration(
let handshake_ns = start.elapsed().as_nanos() as u64; let handshake_ns = start.elapsed().as_nanos() as u64;
// Send request // Send request
write_request(&mut tls_stream, payload_bytes) write_request(&mut tls_stream, u64::from(payload_bytes))
.await .await
.map_err(|e| miette!("write request failed: {e}"))?; .map_err(|e| miette!("write request failed: {e}"))?;
// Read response // Read response
read_payload(&mut tls_stream, payload_bytes) read_payload(&mut tls_stream, u64::from(payload_bytes))
.await .await
.map_err(|e| miette!("read payload failed: {e}"))?; .map_err(|e| miette!("read payload failed: {e}"))?;
let ttlb_ns = start.elapsed().as_nanos() as u64; let ttlb_ns = start.elapsed().as_nanos() as u64;
@@ -188,8 +188,6 @@ async fn run_benchmark(
tls_connector: TlsConnector, tls_connector: TlsConnector,
server_name: ServerName<'static>, server_name: ServerName<'static>,
) -> miette::Result<()> { ) -> miette::Result<()> {
let total_iters = args.warmup + args.iters;
// Open output file or use stdout // Open output file or use stdout
let mut output: Box<dyn Write + Send> = match &args.out { let mut output: Box<dyn Write + Send> = match &args.out {
Some(path) => { Some(path) => {
@@ -206,34 +204,22 @@ async fn run_benchmark(
); );
eprintln!(); eprintln!();
// TODO: Implement concurrency for _ in 0..args.warmup {
for i in 0..total_iters { run_iteration(
let is_warmup = i < args.warmup;
let result = run_iteration(
args.server, args.server,
args.payload_bytes, args.payload_bytes,
&tls_connector, &tls_connector,
&server_name, &server_name,
) )
.await?; .await?;
if !is_warmup {
let record = BenchRecord {
iteration: i - args.warmup,
mode: args.mode,
payload_bytes: args.payload_bytes,
handshake_ns: result.handshake_ns,
ttlb_ns: result.ttlb_ns,
};
writeln!(output, "{record}").map_err(|e| miette!("failed to write record: {e}"))?;
}
if is_warmup && i == args.warmup.saturating_sub(1) {
eprintln!("Warmup complete.");
}
} }
eprintln!("Warmup complete.");
#[allow(clippy::cast_possible_truncation)] // concurrency is limited to reasonable values
let semaphore = Arc::new(Semaphore::new(args.concurrency as usize));
let tasks = spawn_benchmark_tasks(&args, &semaphore, &tls_connector, &server_name);
write_results(&mut output, tasks).await?;
output output
.flush() .flush()
@@ -243,6 +229,78 @@ async fn run_benchmark(
Ok(()) Ok(())
} }
type ReturnHandle = JoinHandle<(IterationResult, Option<BenchRecord>)>;
fn spawn_benchmark_tasks(
args: &Args,
semaphore: &Arc<Semaphore>,
tls_connector: &TlsConnector,
server_name: &ServerName<'static>,
) -> Vec<ReturnHandle> {
let server = args.server;
let payload_bytes = args.payload_bytes;
let mode = args.mode;
(0..args.iters)
.map(|i| {
spawn_single_iteration(
i,
payload_bytes,
mode,
server,
semaphore.clone(),
tls_connector.clone(),
server_name.clone(),
)
})
.collect()
}
#[allow(clippy::too_many_arguments)]
fn spawn_single_iteration(
i: u32,
payload_bytes: u32,
mode: KeyExchangeMode,
server: SocketAddr,
semaphore: Arc<Semaphore>,
tls_connector: TlsConnector,
server_name: ServerName<'static>,
) -> ReturnHandle {
tokio::spawn(async move {
let _permit = semaphore
.acquire()
.await
.expect("semaphore should not be closed");
let result = run_iteration(server, payload_bytes, &tls_connector, &server_name)
.await
.expect("iteration should not fail");
let record = BenchRecord {
iteration: u64::from(i),
mode,
payload_bytes: u64::from(payload_bytes),
handshake_ns: result.handshake_ns,
ttlb_ns: result.ttlb_ns,
};
(result, Some(record))
})
}
async fn write_results(
output: &mut Box<dyn Write + Send>,
tasks: Vec<ReturnHandle>,
) -> miette::Result<()> {
for task in tasks {
let (_result, record) = task.await.expect("task should not panic");
if let Some(record) = record {
writeln!(output, "{record}").map_err(|e| miette!("failed to write record: {e}"))?;
}
}
Ok(())
}
#[tokio::main] #[tokio::main]
async fn main() -> miette::Result<()> { async fn main() -> miette::Result<()> {
let args = Args::parse(); let args = Args::parse();