mirror of
https://github.com/kristoferssolo/tls-pq-bench.git
synced 2026-03-22 00:36:21 +00:00
feat(runner): implement concurrent benchmark execution
- Add tokio::Semaphore for bounded parallelism - Split benchmark into sequential warmup and parallel measured phases - Refactor to spawn_benchmark_tasks() and spawn_single_iteration() - Add write_results() to aggregate and output records - Use u32 for CLI arguments (iters, warmup, concurrency, payload_bytes)
This commit is contained in:
@@ -30,7 +30,7 @@ use std::{
|
||||
sync::Arc,
|
||||
time::Instant,
|
||||
};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::{net::TcpStream, sync::Semaphore, task::JoinHandle};
|
||||
use tokio_rustls::TlsConnector;
|
||||
|
||||
/// TLS benchmark runner.
|
||||
@@ -47,19 +47,19 @@ struct Args {
|
||||
|
||||
/// Payload size in bytes to request from server.
|
||||
#[arg(long, default_value = "1024")]
|
||||
payload_bytes: u64,
|
||||
payload_bytes: u32,
|
||||
|
||||
/// Number of benchmark iterations (excluding warmup).
|
||||
#[arg(long, default_value = "100")]
|
||||
iters: u64,
|
||||
iters: u32,
|
||||
|
||||
/// Number of warmup iterations (not recorded).
|
||||
#[arg(long, default_value = "10")]
|
||||
warmup: u64,
|
||||
warmup: u32,
|
||||
|
||||
/// Number of concurrent connections.
|
||||
#[arg(long, default_value = "1")]
|
||||
concurrency: u64,
|
||||
concurrency: u32,
|
||||
|
||||
/// Output file for NDJSON records (stdout if not specified).
|
||||
#[arg(long)]
|
||||
@@ -146,7 +146,7 @@ fn build_tls_config(mode: KeyExchangeMode) -> miette::Result<Arc<ClientConfig>>
|
||||
#[allow(clippy::cast_possible_truncation)] // nanoseconds won't overflow u64 for reasonable durations
|
||||
async fn run_iteration(
|
||||
server: SocketAddr,
|
||||
payload_bytes: u64,
|
||||
payload_bytes: u32,
|
||||
tls_connector: &TlsConnector,
|
||||
server_name: &ServerName<'static>,
|
||||
) -> miette::Result<IterationResult> {
|
||||
@@ -165,15 +165,15 @@ async fn run_iteration(
|
||||
|
||||
let handshake_ns = start.elapsed().as_nanos() as u64;
|
||||
|
||||
// Send request
|
||||
write_request(&mut tls_stream, payload_bytes)
|
||||
.await
|
||||
.map_err(|e| miette!("write request failed: {e}"))?;
|
||||
// Send request
|
||||
write_request(&mut tls_stream, u64::from(payload_bytes))
|
||||
.await
|
||||
.map_err(|e| miette!("write request failed: {e}"))?;
|
||||
|
||||
// Read response
|
||||
read_payload(&mut tls_stream, payload_bytes)
|
||||
.await
|
||||
.map_err(|e| miette!("read payload failed: {e}"))?;
|
||||
// Read response
|
||||
read_payload(&mut tls_stream, u64::from(payload_bytes))
|
||||
.await
|
||||
.map_err(|e| miette!("read payload failed: {e}"))?;
|
||||
|
||||
let ttlb_ns = start.elapsed().as_nanos() as u64;
|
||||
|
||||
@@ -188,8 +188,6 @@ async fn run_benchmark(
|
||||
tls_connector: TlsConnector,
|
||||
server_name: ServerName<'static>,
|
||||
) -> miette::Result<()> {
|
||||
let total_iters = args.warmup + args.iters;
|
||||
|
||||
// Open output file or use stdout
|
||||
let mut output: Box<dyn Write + Send> = match &args.out {
|
||||
Some(path) => {
|
||||
@@ -206,34 +204,22 @@ async fn run_benchmark(
|
||||
);
|
||||
eprintln!();
|
||||
|
||||
// TODO: Implement concurrency
|
||||
for i in 0..total_iters {
|
||||
let is_warmup = i < args.warmup;
|
||||
|
||||
let result = run_iteration(
|
||||
for _ in 0..args.warmup {
|
||||
run_iteration(
|
||||
args.server,
|
||||
args.payload_bytes,
|
||||
&tls_connector,
|
||||
&server_name,
|
||||
)
|
||||
.await?;
|
||||
|
||||
if !is_warmup {
|
||||
let record = BenchRecord {
|
||||
iteration: i - args.warmup,
|
||||
mode: args.mode,
|
||||
payload_bytes: args.payload_bytes,
|
||||
handshake_ns: result.handshake_ns,
|
||||
ttlb_ns: result.ttlb_ns,
|
||||
};
|
||||
|
||||
writeln!(output, "{record}").map_err(|e| miette!("failed to write record: {e}"))?;
|
||||
}
|
||||
|
||||
if is_warmup && i == args.warmup.saturating_sub(1) {
|
||||
eprintln!("Warmup complete.");
|
||||
}
|
||||
}
|
||||
eprintln!("Warmup complete.");
|
||||
|
||||
#[allow(clippy::cast_possible_truncation)] // concurrency is limited to reasonable values
|
||||
let semaphore = Arc::new(Semaphore::new(args.concurrency as usize));
|
||||
let tasks = spawn_benchmark_tasks(&args, &semaphore, &tls_connector, &server_name);
|
||||
|
||||
write_results(&mut output, tasks).await?;
|
||||
|
||||
output
|
||||
.flush()
|
||||
@@ -243,6 +229,78 @@ async fn run_benchmark(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
type ReturnHandle = JoinHandle<(IterationResult, Option<BenchRecord>)>;
|
||||
|
||||
fn spawn_benchmark_tasks(
|
||||
args: &Args,
|
||||
semaphore: &Arc<Semaphore>,
|
||||
tls_connector: &TlsConnector,
|
||||
server_name: &ServerName<'static>,
|
||||
) -> Vec<ReturnHandle> {
|
||||
let server = args.server;
|
||||
let payload_bytes = args.payload_bytes;
|
||||
let mode = args.mode;
|
||||
|
||||
(0..args.iters)
|
||||
.map(|i| {
|
||||
spawn_single_iteration(
|
||||
i,
|
||||
payload_bytes,
|
||||
mode,
|
||||
server,
|
||||
semaphore.clone(),
|
||||
tls_connector.clone(),
|
||||
server_name.clone(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn spawn_single_iteration(
|
||||
i: u32,
|
||||
payload_bytes: u32,
|
||||
mode: KeyExchangeMode,
|
||||
server: SocketAddr,
|
||||
semaphore: Arc<Semaphore>,
|
||||
tls_connector: TlsConnector,
|
||||
server_name: ServerName<'static>,
|
||||
) -> ReturnHandle {
|
||||
tokio::spawn(async move {
|
||||
let _permit = semaphore
|
||||
.acquire()
|
||||
.await
|
||||
.expect("semaphore should not be closed");
|
||||
|
||||
let result = run_iteration(server, payload_bytes, &tls_connector, &server_name)
|
||||
.await
|
||||
.expect("iteration should not fail");
|
||||
|
||||
let record = BenchRecord {
|
||||
iteration: u64::from(i),
|
||||
mode,
|
||||
payload_bytes: u64::from(payload_bytes),
|
||||
handshake_ns: result.handshake_ns,
|
||||
ttlb_ns: result.ttlb_ns,
|
||||
};
|
||||
|
||||
(result, Some(record))
|
||||
})
|
||||
}
|
||||
|
||||
async fn write_results(
|
||||
output: &mut Box<dyn Write + Send>,
|
||||
tasks: Vec<ReturnHandle>,
|
||||
) -> miette::Result<()> {
|
||||
for task in tasks {
|
||||
let (_result, record) = task.await.expect("task should not panic");
|
||||
if let Some(record) = record {
|
||||
writeln!(output, "{record}").map_err(|e| miette!("failed to write record: {e}"))?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> miette::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
Reference in New Issue
Block a user