diff --git a/Cargo.lock b/Cargo.lock index 521ad55..f081297 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -150,6 +150,7 @@ dependencies = [ "serde", "serde_json", "strum", + "tokio", ] [[package]] diff --git a/bench-common/Cargo.toml b/bench-common/Cargo.toml index 2d04ad8..229f124 100644 --- a/bench-common/Cargo.toml +++ b/bench-common/Cargo.toml @@ -9,6 +9,7 @@ rcgen.workspace = true serde.workspace = true serde_json.workspace = true strum.workspace = true +tokio.workspace = true [lints] workspace = true diff --git a/bench-common/src/lib.rs b/bench-common/src/lib.rs index 98cfc62..389f263 100644 --- a/bench-common/src/lib.rs +++ b/bench-common/src/lib.rs @@ -1,6 +1,7 @@ //! Common types and utilities for the TLS benchmark harness. pub mod cert; +pub mod protocol; use serde::{Deserialize, Serialize}; use std::fmt; diff --git a/bench-common/src/protocol.rs b/bench-common/src/protocol.rs new file mode 100644 index 0000000..e0468cf --- /dev/null +++ b/bench-common/src/protocol.rs @@ -0,0 +1,145 @@ +//! Benchmark protocol implementation. +//! +//! Protocol specification: +//! 1. Client sends 8-byte little-endian u64: requested payload size N +//! 2. Server responds with exactly N bytes (deterministic pattern) +//! +//! The deterministic pattern is a repeating sequence of bytes 0x00..0xFF. + +// Casts are intentional: MAX_PAYLOAD_SIZE (16 MiB) fits in usize on 64-bit, +// and byte patterns are explicitly masked to 0xFF before casting. +#![allow(clippy::cast_possible_truncation)] + +use std::io; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +/// Size of the request header (u64 payload size). +pub const REQUEST_SIZE: usize = 8; + +/// Maximum allowed payload size (16 MiB). +pub const MAX_PAYLOAD_SIZE: u64 = 16 * 1024 * 1024; + +/// Read the payload size request from a stream. +/// +/// # Errors +/// Returns an error if reading fails or payload size exceeds maximum. +pub async fn read_request(reader: &mut R) -> io::Result { + let mut buf = [0u8; REQUEST_SIZE]; + reader.read_exact(&mut buf).await?; + let size = u64::from_le_bytes(buf); + + if size > MAX_PAYLOAD_SIZE { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("payload size {size} exceeds maximum {MAX_PAYLOAD_SIZE}"), + )); + } + + Ok(size) +} + +/// Write a payload size request to a stream. +/// +/// # Errors +/// Returns an error if writing fails. +pub async fn write_request(writer: &mut W, size: u64) -> io::Result<()> { + let buf = size.to_le_bytes(); + writer.write_all(&buf).await +} + +/// Generate deterministic payload of the given size. +/// +/// The pattern is a repeating sequence: 0x00, 0x01, ..., 0xFF, 0x00, ... +#[must_use] +pub fn generate_payload(size: u64) -> Vec { + let size = size as usize; + let mut payload = Vec::with_capacity(size); + for i in 0..size { + payload.push((i & 0xFF) as u8); + } + payload +} + +/// Write deterministic payload to a stream. +/// +/// Writes in chunks to avoid allocating large buffers. +/// +/// # Errors +/// Returns an error if writing fails. +pub async fn write_payload(writer: &mut W, size: u64) -> io::Result<()> { + const CHUNK_SIZE: usize = 64 * 1024; + let mut remaining = size as usize; + let mut offset = 0usize; + + while remaining > 0 { + let chunk_len = remaining.min(CHUNK_SIZE); + let chunk: Vec = (0..chunk_len) + .map(|i| ((offset + i) & 0xFF) as u8) + .collect(); + writer.write_all(&chunk).await?; + remaining -= chunk_len; + offset += chunk_len; + } + + Ok(()) +} + +/// Read and discard payload from a stream, returning the number of bytes read. +/// +/// # Errors +/// Returns an error if reading fails. +pub async fn read_payload( + reader: &mut R, + expected_size: u64, +) -> io::Result { + const CHUNK_SIZE: usize = 64 * 1024; + let mut buf = vec![0u8; CHUNK_SIZE]; + let mut total_read = 0u64; + + while total_read < expected_size { + let to_read = ((expected_size - total_read) as usize).min(CHUNK_SIZE); + reader.read_exact(&mut buf[..to_read]).await?; + total_read += to_read as u64; + } + + Ok(total_read) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Cursor; + + #[test] + fn generate_payload_pattern() { + let payload = generate_payload(300); + assert_eq!(payload.len(), 300); + assert_eq!(payload[0], 0x00); + assert_eq!(payload[255], 0xFF); + assert_eq!(payload[256], 0x00); + assert_eq!(payload[299], 43); + } + + #[tokio::test] + async fn roundtrip_request() { + let mut buf = Vec::new(); + write_request(&mut buf, 12345) + .await + .expect("write should succeed"); + assert_eq!(buf.len(), REQUEST_SIZE); + + let mut cursor = Cursor::new(buf); + let size = read_request(&mut cursor) + .await + .expect("read should succeed"); + assert_eq!(size, 12345); + } + + #[tokio::test] + async fn reject_oversized_request() { + let buf = (MAX_PAYLOAD_SIZE + 1).to_le_bytes(); + let mut cursor = Cursor::new(buf); + let result = read_request(&mut cursor).await; + assert!(result.is_err()); + } +} diff --git a/bench-runner/src/main.rs b/bench-runner/src/main.rs index fedbd46..ec539dd 100644 --- a/bench-runner/src/main.rs +++ b/bench-runner/src/main.rs @@ -6,10 +6,16 @@ //! //! Outputs NDJSON records to stdout or a file. -use bench_common::KeyExchangeMode; +use bench_common::protocol::{read_payload, write_request}; +use bench_common::{BenchRecord, KeyExchangeMode}; use clap::Parser; +use miette::miette; +use std::fs::File; +use std::io::{BufWriter, Write, stdout}; use std::net::SocketAddr; use std::path::PathBuf; +use std::time::Instant; +use tokio::net::TcpStream; /// TLS benchmark runner. #[derive(Debug, Parser)] @@ -44,6 +50,92 @@ struct Args { out: Option, } +/// Result of a single benchmark iteration. +struct IterationResult { + handshake_ns: u64, + ttlb_ns: u64, +} + +/// Run a single benchmark iteration over plain TCP. +#[allow(clippy::cast_possible_truncation)] // nanoseconds won't overflow u64 for reasonable durations +async fn run_iteration(server: SocketAddr, payload_bytes: u64) -> miette::Result { + let start = Instant::now(); + + // Connect (this is the "handshake" for plain TCP) + let mut stream = TcpStream::connect(server) + .await + .map_err(|e| miette!("connection failed: {e}"))?; + + let handshake_ns = start.elapsed().as_nanos() as u64; + + // Send request + write_request(&mut stream, payload_bytes) + .await + .map_err(|e| miette!("write request failed: {e}"))?; + + // Read response + read_payload(&mut stream, payload_bytes) + .await + .map_err(|e| miette!("read payload failed: {e}"))?; + + let ttlb_ns = start.elapsed().as_nanos() as u64; + + Ok(IterationResult { + handshake_ns, + ttlb_ns, + }) +} + +async fn run_benchmark(args: Args) -> miette::Result<()> { + let total_iters = args.warmup + args.iters; + + // Open output file or use stdout + let mut output: Box = match &args.out { + Some(path) => { + let file = + File::create(path).map_err(|e| miette!("failed to create output file: {e}"))?; + Box::new(BufWriter::new(file)) + } + None => Box::new(stdout()), + }; + + eprintln!( + "Running {} warmup + {} measured iterations (concurrency: {}, TLS disabled)", + args.warmup, args.iters, args.concurrency + ); + eprintln!(); + + // TODO: Implement concurrency + for i in 0..total_iters { + let is_warmup = i < args.warmup; + + let result = run_iteration(args.server, args.payload_bytes).await?; + + if !is_warmup { + let record = BenchRecord { + iteration: i - args.warmup, + mode: args.mode, + payload_bytes: args.payload_bytes, + handshake_ns: result.handshake_ns, + ttlb_ns: result.ttlb_ns, + }; + + writeln!(output, "{record}").map_err(|e| miette!("failed to write record: {e}"))?; + } + + if is_warmup && i == args.warmup.saturating_sub(1) { + eprintln!("Warmup complete."); + } + } + + output + .flush() + .map_err(|e| miette!("failed to flush output: {e}"))?; + + eprintln!("Benchmark complete."); + Ok(()) +} + #[tokio::main] async fn main() -> miette::Result<()> { let args = Args::parse(); @@ -61,9 +153,7 @@ async fn main() -> miette::Result<()> { .as_ref() .map_or_else(|| "stdout".to_string(), |p| p.display().to_string()) ); + eprintln!(); - // TODO: Implement TLS client and benchmark loop - eprintln!("\nRunner not yet implemented."); - - Ok(()) + run_benchmark(args).await } diff --git a/bench-server/src/main.rs b/bench-server/src/main.rs index a1a09da..ccf3f14 100644 --- a/bench-server/src/main.rs +++ b/bench-server/src/main.rs @@ -1,12 +1,14 @@ //! TLS benchmark server. //! -//! Listens for TLS connections and serves the benchmark protocol: +//! Listens for connections and serves the benchmark protocol: //! - Reads 8-byte little-endian u64 (requested payload size N) //! - Responds with exactly N bytes (deterministic pattern) +use bench_common::protocol::{read_request, write_payload}; use bench_common::KeyExchangeMode; use clap::Parser; use std::net::SocketAddr; +use tokio::net::{TcpListener, TcpStream}; /// TLS benchmark server. #[derive(Debug, Parser)] @@ -21,6 +23,48 @@ struct Args { listen: SocketAddr, } +async fn handle_connection(mut stream: TcpStream, peer: SocketAddr) { + loop { + let payload_size = match read_request(&mut stream).await { + Ok(size) => size, + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + // Client closed connection + break; + } + Err(e) => { + eprintln!("[{peer}] read error: {e}"); + break; + } + }; + + if let Err(e) = write_payload(&mut stream, payload_size).await { + eprintln!("[{peer}] write error: {e}"); + break; + } + } +} + +async fn run_server(args: Args) -> miette::Result<()> { + let listener = TcpListener::bind(args.listen) + .await + .map_err(|e| miette::miette!("failed to bind to {}: {e}", args.listen))?; + + eprintln!("Listening on {} (TCP, TLS disabled)", args.listen); + eprintln!("Mode: {} (not yet implemented)", args.mode); + + loop { + let (stream, peer) = match listener.accept().await { + Ok(conn) => conn, + Err(e) => { + eprintln!("accept error: {e}"); + continue; + } + }; + + tokio::spawn(handle_connection(stream, peer)); + } +} + #[tokio::main] async fn main() -> miette::Result<()> { let args = Args::parse(); @@ -28,9 +72,7 @@ async fn main() -> miette::Result<()> { eprintln!("bench-server configuration:"); eprintln!(" mode: {}", args.mode); eprintln!(" listen: {}", args.listen); + eprintln!(); - // TODO: Implement TLS server - eprintln!("\nServer not yet implemented."); - - Ok(()) + run_server(args).await }