mirror of
https://github.com/kristoferssolo/tls-pq-bench.git
synced 2026-03-22 00:36:21 +00:00
feat(runner): add TOML config file support for matrix benchmarks
- Add toml and serde dependencies - Create config module with Config and BenchmarkConfig structs - Add --config CLI option for matrix benchmarks - Refactor `run_benchmark()` to accept BenchmarkConfig
This commit is contained in:
@@ -24,11 +24,8 @@ use rustls::{
|
||||
};
|
||||
use std::{
|
||||
env,
|
||||
fmt::Debug,
|
||||
fs::File,
|
||||
io::{BufWriter, Write, stdout},
|
||||
io::{Write, stdout},
|
||||
net::SocketAddr,
|
||||
path::PathBuf,
|
||||
sync::Arc,
|
||||
time::Instant,
|
||||
};
|
||||
@@ -38,38 +35,8 @@ use tracing::info;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// TLS benchmark runner.
|
||||
#[derive(Debug, Parser)]
|
||||
#[command(name = "runner", version, about)]
|
||||
struct Args {
|
||||
/// Key exchange mode.
|
||||
#[arg(long, default_value = "x25519")]
|
||||
mode: KeyExchangeMode,
|
||||
|
||||
/// Server address to connect to.
|
||||
#[arg(long)]
|
||||
server: SocketAddr,
|
||||
|
||||
/// Payload size in bytes to request from server.
|
||||
#[arg(long, default_value = "1024")]
|
||||
payload_bytes: u32,
|
||||
|
||||
/// Number of benchmark iterations (excluding warmup).
|
||||
#[arg(long, default_value = "100")]
|
||||
iters: u32,
|
||||
|
||||
/// Number of warmup iterations (not recorded).
|
||||
#[arg(long, default_value = "10")]
|
||||
warmup: u32,
|
||||
|
||||
/// Number of concurrent connections.
|
||||
#[arg(long, default_value = "1")]
|
||||
concurrency: u32,
|
||||
|
||||
/// Output file for NDJSON records (stdout if not specified).
|
||||
#[arg(long)]
|
||||
out: Option<PathBuf>,
|
||||
}
|
||||
use runner::args::Args;
|
||||
use runner::config::{load_from_cli, load_from_file};
|
||||
|
||||
/// Result of a single benchmark iteration.
|
||||
struct IterationResult {
|
||||
@@ -183,44 +150,32 @@ async fn run_iteration(
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(clippy::future_not_send)] // References held across await points
|
||||
async fn run_benchmark(
|
||||
args: Args,
|
||||
tls_connector: TlsConnector,
|
||||
server_name: ServerName<'static>,
|
||||
config: &runner::config::BenchmarkConfig,
|
||||
tls_connector: &TlsConnector,
|
||||
server_name: &ServerName<'static>,
|
||||
) -> miette::Result<()> {
|
||||
let mut output: Box<dyn Write + Send> = match &args.out {
|
||||
Some(path) => {
|
||||
let file =
|
||||
File::create(path).map_err(|e| miette!("failed to create output file: {e}"))?;
|
||||
Box::new(BufWriter::new(file))
|
||||
}
|
||||
None => Box::new(stdout()),
|
||||
};
|
||||
let server = config.server;
|
||||
|
||||
info!(
|
||||
warmup = args.warmup,
|
||||
iters = args.iters,
|
||||
concurrency = args.concurrency,
|
||||
"runnning benchmark iterations"
|
||||
warmup = config.warmup,
|
||||
iters = config.iters,
|
||||
concurrency = config.concurrency,
|
||||
"running benchmark iterations"
|
||||
);
|
||||
|
||||
for _ in 0..args.warmup {
|
||||
run_iteration(
|
||||
args.server,
|
||||
args.payload_bytes,
|
||||
&tls_connector,
|
||||
&server_name,
|
||||
)
|
||||
.await?;
|
||||
for _ in 0..config.warmup {
|
||||
run_iteration(server, config.payload, tls_connector, server_name).await?;
|
||||
}
|
||||
info!("warmup complete");
|
||||
|
||||
let test_conn = tls_connector
|
||||
.connect(
|
||||
server_name.clone(),
|
||||
TcpStream::connect(args.server)
|
||||
TcpStream::connect(server)
|
||||
.await
|
||||
.map_err(|e| miette!("failed to connect to server {}: {e}", args.server))?,
|
||||
.map_err(|e| miette!("failed to connect to server {}: {e}", server))?,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| miette!("TLS handshake failed: {e}"))?;
|
||||
@@ -229,14 +184,17 @@ async fn run_benchmark(
|
||||
info!(cipher = ?cipher, "TLS handshake complete");
|
||||
|
||||
#[allow(clippy::cast_possible_truncation)] // concurrency is limited to reasonable values
|
||||
let semaphore = Arc::new(Semaphore::new(args.concurrency as usize));
|
||||
let tasks = spawn_benchmark_tasks(&args, &semaphore, &tls_connector, &server_name);
|
||||
let semaphore = Arc::new(Semaphore::new(config.concurrency as usize));
|
||||
let tasks = spawn_benchmark_tasks(config, &semaphore, tls_connector, server_name);
|
||||
|
||||
write_results(&mut output, tasks).await?;
|
||||
|
||||
output
|
||||
.flush()
|
||||
.map_err(|e| miette!("failed to flush output: {e}"))?;
|
||||
// Output to stdout for now
|
||||
{
|
||||
let mut output = stdout();
|
||||
write_results(&mut output, tasks).await?;
|
||||
output
|
||||
.flush()
|
||||
.map_err(|e| miette!("failed to flush output: {e}"))?;
|
||||
}
|
||||
|
||||
info!("benchmark complete");
|
||||
Ok(())
|
||||
@@ -245,16 +203,19 @@ async fn run_benchmark(
|
||||
type ReturnHandle = JoinHandle<(IterationResult, Option<BenchRecord>)>;
|
||||
|
||||
fn spawn_benchmark_tasks(
|
||||
args: &Args,
|
||||
config: &runner::config::BenchmarkConfig,
|
||||
semaphore: &Arc<Semaphore>,
|
||||
tls_connector: &TlsConnector,
|
||||
server_name: &ServerName<'static>,
|
||||
) -> Vec<ReturnHandle> {
|
||||
let server = args.server;
|
||||
let payload_bytes = args.payload_bytes;
|
||||
let mode = args.mode;
|
||||
let server = config.server;
|
||||
let payload_bytes = config.payload;
|
||||
let mode = config
|
||||
.mode
|
||||
.parse::<KeyExchangeMode>()
|
||||
.expect("mode should be valid");
|
||||
|
||||
(0..args.iters)
|
||||
(0..config.iters)
|
||||
.map(|i| {
|
||||
spawn_single_iteration(
|
||||
i,
|
||||
@@ -300,10 +261,8 @@ fn spawn_single_iteration(
|
||||
})
|
||||
}
|
||||
|
||||
async fn write_results(
|
||||
output: &mut Box<dyn Write + Send>,
|
||||
tasks: Vec<ReturnHandle>,
|
||||
) -> miette::Result<()> {
|
||||
#[allow(clippy::future_not_send)] // dyn Write is not Send
|
||||
async fn write_results(output: &mut dyn Write, tasks: Vec<ReturnHandle>) -> miette::Result<()> {
|
||||
for task in tasks {
|
||||
let (_result, record) = task.await.expect("task should not panic");
|
||||
if let Some(record) = record {
|
||||
@@ -332,22 +291,33 @@ async fn main() -> miette::Result<()> {
|
||||
);
|
||||
|
||||
let args = Args::parse();
|
||||
info!(
|
||||
mode=%args.mode,
|
||||
server=%args.server,
|
||||
payload_bytes=%args.payload_bytes,
|
||||
iters=%args.iters,
|
||||
warmup=%args.warmup,
|
||||
concurrency=%args.concurrency,
|
||||
out=%args.out.as_ref().map_or("stdout", |p| p.to_str().unwrap_or("invalid")),
|
||||
"runner configuration"
|
||||
);
|
||||
|
||||
let tls_config = build_tls_config(args.mode)?;
|
||||
let config = if let Some(config_path) = &args.config {
|
||||
info!(config_file = %config_path.display(), "loading config from file");
|
||||
load_from_file(config_path)?
|
||||
} else {
|
||||
info!("using CLI arguments");
|
||||
load_from_cli(&args)?
|
||||
};
|
||||
|
||||
let tls_config = build_tls_config(config.server_mode())?;
|
||||
let tls_connector = TlsConnector::from(tls_config);
|
||||
|
||||
let server_name = ServerName::try_from("localhost".to_string())
|
||||
.map_err(|e| miette!("invalid server name: {e}"))?;
|
||||
|
||||
run_benchmark(args, tls_connector, server_name).await
|
||||
for benchmark in &config.benchmarks {
|
||||
info!(
|
||||
mode = %benchmark.mode,
|
||||
payload = benchmark.payload,
|
||||
iters = benchmark.iters,
|
||||
warmup = benchmark.warmup,
|
||||
concurrency = benchmark.concurrency,
|
||||
"running benchmark"
|
||||
);
|
||||
|
||||
run_benchmark(benchmark, &tls_connector, &server_name).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user