diff --git a/common/src/lib.rs b/common/src/lib.rs index 0861211..5e97e8e 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -43,6 +43,8 @@ pub enum ProtocolMode { pub struct BenchRecord { /// Iteration number (0-indexed, excludes warmup) pub iteration: u64, + /// Protocol carrier mode + pub proto: ProtocolMode, /// Key exchange mode used pub mode: KeyExchangeMode, /// Payload size in bytes @@ -84,6 +86,7 @@ mod tests { fn bench_record_serializes_to_ndjson() { let record = BenchRecord { iteration: 0, + proto: ProtocolMode::Raw, mode: KeyExchangeMode::X25519, payload_bytes: 1024, tcp_ns: 500_000, @@ -92,6 +95,7 @@ mod tests { }; let json = assert_ok!(record.to_ndjson()); assert!(json.contains(r#""iteration":0"#)); + assert!(json.contains(r#""proto":"raw""#)); assert!(json.contains(r#""mode":"x25519""#)); assert!(json.contains(r#""payload_bytes":1024"#)); } @@ -100,6 +104,7 @@ mod tests { fn bench_record_roundtrip() { let original = BenchRecord { iteration: 42, + proto: ProtocolMode::Http1, mode: KeyExchangeMode::X25519Mlkem768, payload_bytes: 4096, tcp_ns: 1_000_000, @@ -110,21 +115,13 @@ mod tests { let deserialized = assert_ok!(serde_json::from_str::(&json)); assert_eq!(original.iteration, deserialized.iteration); + assert_eq!(original.proto, deserialized.proto); assert_eq!(original.mode, deserialized.mode); assert_eq!(original.payload_bytes, deserialized.payload_bytes); assert_eq!(original.handshake_ns, deserialized.handshake_ns); assert_eq!(original.ttlb_ns, deserialized.ttlb_ns); } - #[test] - fn key_exchange_mode_from_str() { - let mode = assert_ok!(KeyExchangeMode::from_str("x25519", true)); - assert_eq!(mode, KeyExchangeMode::X25519); - - let mode = assert_ok!(KeyExchangeMode::from_str("x25519mlkem768", true)); - assert_eq!(mode, KeyExchangeMode::X25519Mlkem768); - } - #[test] fn key_exchange_mode_parse_error() { assert_err!(KeyExchangeMode::from_str("invalid", true)); @@ -152,4 +149,38 @@ mod tests { )); assert_eq!(mode_mlkem_lower, KeyExchangeMode::X25519Mlkem768); } + + #[test] + fn key_protocol_mod_from_str() { + let proto = assert_ok!(ProtocolMode::from_str("raw", true)); + assert_eq!(proto, ProtocolMode::Raw); + + let proto = assert_ok!(ProtocolMode::from_str("http1", true)); + assert_eq!(proto, ProtocolMode::Http1); + } + + #[test] + fn key_protocol_mode_parse_error() { + assert_err!(ProtocolMode::from_str("invalid", true)); + assert_err!(ProtocolMode::from_str("", true)); + } + + #[test] + fn key_exchange_mode_from_str() { + let mode = assert_ok!(KeyExchangeMode::from_str("x25519", true)); + assert_eq!(mode, KeyExchangeMode::X25519); + + let mode = assert_ok!(KeyExchangeMode::from_str("x25519mlkem768", true)); + assert_eq!(mode, KeyExchangeMode::X25519Mlkem768); + } + + #[test] + fn key_protocol_mode_serde() { + let json = r#"{"proto":"http1"}"#; + let value = assert_ok!(serde_json::from_str::(json)); + let proto = assert_ok!(serde_json::from_value::( + value["proto"].clone() + )); + assert_eq!(proto, ProtocolMode::Http1); + } } diff --git a/runner/src/bench.rs b/runner/src/bench.rs index 2b42f6c..0a0ca83 100644 --- a/runner/src/bench.rs +++ b/runner/src/bench.rs @@ -1,3 +1,4 @@ +use crate::config::BenchmarkConfig; use common::prelude::*; use futures::{StreamExt, stream::FuturesUnordered}; use miette::{Context, IntoDiagnostic}; @@ -7,12 +8,13 @@ use std::{ net::SocketAddr, time::Instant, }; -use tokio::net::TcpStream; +use tokio::{ + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, + net::TcpStream, +}; use tokio_rustls::TlsConnector; use tracing::info; -use crate::config::BenchmarkConfig; - /// Result of a single benchmark iteration. struct IterationResult { tcp: u128, @@ -20,9 +22,103 @@ struct IterationResult { ttlb: u128, } +pub async fn run_benchmark( + config: &BenchmarkConfig, + tls_connector: &TlsConnector, + server_name: &ServerName<'static>, +) -> miette::Result<()> { + let server = config.server; + + info!( + warmup = config.warmup, + iters = config.iters, + concurrency = config.concurrency, + "running benchmark iterations" + ); + + for _ in 0..config.warmup { + run_iteration( + server, + config.proto, + config.payload, + tls_connector, + server_name, + ) + .await?; + } + info!("warmup complete"); + + let mut output = stdout(); + run_and_write(config, tls_connector, server_name, &mut output).await?; + output + .flush() + .into_diagnostic() + .context("failed to flush output")?; + + info!("benchmark complete"); + Ok(()) +} + +async fn run_and_write( + config: &BenchmarkConfig, + tls_connector: &TlsConnector, + server_name: &ServerName<'static>, + output: &mut W, +) -> miette::Result<()> { + let mut in_flight = FuturesUnordered::new(); + let mut issued = 0; + + loop { + while issued < config.iters && in_flight.len() < config.concurrency as usize { + in_flight.push(run_single_iteration( + issued, + config.payload, + config.proto, + config.mode, + config.server, + tls_connector.clone(), + server_name.clone(), + )); + issued += 1; + } + + match in_flight.next().await { + Some(record) => writeln!(output, "{}", record?) + .into_diagnostic() + .context("failed to write record")?, + None => break, + } + } + + Ok(()) +} + +async fn run_single_iteration( + i: u32, + payload_bytes: u32, + proto: ProtocolMode, + mode: KeyExchangeMode, + server: SocketAddr, + tls_connector: TlsConnector, + server_name: ServerName<'static>, +) -> miette::Result { + let result = run_iteration(server, proto, payload_bytes, &tls_connector, &server_name).await?; + + Ok(BenchRecord { + iteration: u64::from(i), + proto, + mode, + payload_bytes: u64::from(payload_bytes), + tcp_ns: result.tcp, + handshake_ns: result.handshake, + ttlb_ns: result.ttlb, + }) +} + /// Run a single benchmark iteration over TLS. async fn run_iteration( server: SocketAddr, + proto: ProtocolMode, payload_bytes: u32, tls_connector: &TlsConnector, server_name: &ServerName<'static>, @@ -46,17 +142,9 @@ async fn run_iteration( let handshake_ns = hs_start.elapsed().as_nanos(); let ttlb_start = Instant::now(); - write_request(&mut tls_stream, u64::from(payload_bytes)) - .await - .into_diagnostic() - .context("write request failed")?; - - read_payload(&mut tls_stream, u64::from(payload_bytes)) - .await - .into_diagnostic() - .context("read payload failed")?; let ttlb_ns = tcp_ns + handshake_ns + ttlb_start.elapsed().as_nanos(); + run_exchange(&mut tls_stream, proto, payload_bytes).await?; Ok(IterationResult { tcp: tcp_ns, @@ -65,86 +153,212 @@ async fn run_iteration( }) } -pub async fn run_benchmark( - config: &BenchmarkConfig, - tls_connector: &TlsConnector, - server_name: &ServerName<'static>, -) -> miette::Result<()> { - let server = config.server; - - info!( - warmup = config.warmup, - iters = config.iters, - concurrency = config.concurrency, - "running benchmark iterations" - ); - - for _ in 0..config.warmup { - run_iteration(server, config.payload, tls_connector, server_name).await?; - } - info!("warmup complete"); - - #[allow(clippy::cast_possible_truncation)] // concurrency is limited to reasonable values - let mut output = stdout(); - run_and_write(config, tls_connector, server_name, &mut output).await?; - output - .flush() - .into_diagnostic() - .context("failed to flush output")?; - - info!("benchmark complete"); - Ok(()) -} - -async fn run_single_iteration( - i: u32, +async fn run_exchange( + tls_stream: &mut S, + proto: ProtocolMode, payload_bytes: u32, - mode: KeyExchangeMode, - server: SocketAddr, - tls_connector: TlsConnector, - server_name: ServerName<'static>, -) -> miette::Result { - let result = run_iteration(server, payload_bytes, &tls_connector, &server_name).await?; - - Ok(BenchRecord { - iteration: u64::from(i), - mode, - payload_bytes: u64::from(payload_bytes), - tcp_ns: result.tcp, - handshake_ns: result.handshake, - ttlb_ns: result.ttlb, - }) +) -> miette::Result<()> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + match proto { + ProtocolMode::Raw => run_raw_exchange(tls_stream, payload_bytes).await, + ProtocolMode::Http1 => run_http1_exchange(tls_stream, payload_bytes).await, + } } -async fn run_and_write( - config: &BenchmarkConfig, - tls_connector: &TlsConnector, - server_name: &ServerName<'static>, - output: &mut W, -) -> miette::Result<()> { - let mut in_flight = FuturesUnordered::new(); - let mut issued = 0; +async fn run_raw_exchange(tls_stream: &mut S, payload_bytes: u32) -> miette::Result<()> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + write_request(tls_stream, u64::from(payload_bytes)) + .await + .into_diagnostic() + .context("write request failed")?; - loop { - while issued < config.iters && in_flight.len() < config.concurrency as usize { - in_flight.push(run_single_iteration( - issued, - config.payload, - config.mode, - config.server, - tls_connector.clone(), - server_name.clone(), - )); - issued += 1; + read_payload(tls_stream, u64::from(payload_bytes)) + .await + .into_diagnostic() + .context("read payload failed")?; + Ok(()) +} + +async fn run_http1_exchange(tls_stream: &mut S, payload_bytes: u32) -> miette::Result<()> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + let request = build_http1_request(payload_bytes); + + tls_stream + .write_all(&request) + .await + .into_diagnostic() + .context("write http1 request failed")?; + + tls_stream + .flush() + .await + .into_diagnostic() + .context("flush http1 request failed")?; + + let mut response_buf = Vec::with_capacity(1024); + let mut chunk = [0; 1024]; + + let (content_length, body_start) = loop { + let n = tls_stream + .read(&mut chunk) + .await + .into_diagnostic() + .context("read http1 response failed")?; + + if n == 0 { + return Err(common::Error::protocol("unexpected EOF before http1 headers").into()); } - match in_flight.next().await { - Some(record) => writeln!(output, "{}", record?) + response_buf.extend_from_slice(&chunk[..n]); + + if let Some(pos) = find_headers_end(&response_buf) { + let headers = str::from_utf8(&response_buf[..pos]) .into_diagnostic() - .context("failed to write record")?, - None => break, + .context("http1 headers are not valid UTF-8")?; + let content_length = parse_content_length(headers)?; + break (content_length, pos + 4); } + }; + + let body_already_read = response_buf.len() - body_start; + if body_already_read > content_length { + return Err(common::Error::protocol("http1 body exceeded content-lenght").into()); + } + + let mut remaining = content_length - body_already_read; + let mut body_buf = vec![0; 64 * 1024]; + + while remaining > 0 { + let to_read = remaining.min(body_buf.len()); + tls_stream + .read_exact(&mut body_buf[..to_read]) + .await + .into_diagnostic() + .context("read http1 body failed")?; + remaining -= to_read; } Ok(()) } + +fn build_http1_request(payload_bytes: u32) -> Vec { + format!("GET /bytes/{payload_bytes} HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n") + .into_bytes() +} + +fn parse_content_length(headers: &str) -> miette::Result { + let mut lines = headers.lines(); + + let status_line = lines + .next() + .ok_or_else(|| common::Error::protocol("missing http1 status line"))?; + + let mut parts = status_line.split_whitespace(); + let version = parts + .next() + .ok_or_else(|| common::Error::protocol("missing http1 version"))?; + let status = parts + .next() + .ok_or_else(|| common::Error::protocol("missing http1 status"))?; + + if version != "HTTP/1.1" { + return Err(common::Error::protocol(format!("unsupported http version: {version}")).into()); + } + if status != "200" { + return Err(common::Error::protocol(format!("unsupported http status: {status}")).into()); + } + + for line in lines { + if let Some((name, value)) = line.split_once(':') + && name.trim().eq_ignore_ascii_case("content-length") + { + return value + .trim() + .parse::() + .into_diagnostic() + .context("invalid content-length header"); + } + } + Err(common::Error::protocol("missing content-length header").into()) +} + +fn find_headers_end(buf: &[u8]) -> Option { + buf.windows(4).position(|window| window == b"\r\n\r\n") +} + +#[cfg(test)] +mod tests { + use super::*; + use claims::{assert_err, assert_none, assert_ok, assert_some}; + + #[test] + fn build_http1_request_formats_get_requests() { + let request = build_http1_request(16); + let request_string = String::from_utf8(request).expect("valid string"); + assert_eq!( + request_string, + "GET /bytes/16 HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n" + ); + } + + #[test] + fn parse_content_length_accepts_200() { + let headers = "HTTP/1.1 200 OK\r\nContent-Length: 16\r\nConnection: close\r\n"; + let len = assert_ok!(parse_content_length(headers)); + assert_eq!(len, 16); + } + + #[test] + fn parse_content_length_rejects_missing_header() { + let headers = "HTTP/1.1 200 OK\r\nConnection: close\r\n"; + assert_err!(parse_content_length(headers)); + } + + #[test] + fn parse_content_length_accepts_mixed_case_header_name() { + let headers = "HTTP/1.1 200 OK\r\nContent-Length: 8\r\nConnection: close\r\n"; + let len = assert_ok!(parse_content_length(headers)); + assert_eq!(len, 8); + + let headers = "HTTP/1.1 200 OK\r\ncontent-length: 9\r\nConnection: close\r\n"; + let len = assert_ok!(parse_content_length(headers)); + assert_eq!(len, 9); + } + + #[test] + fn parse_content_length_rejects_non_200_status() { + let headers = "HTTP/1.1 404 Not Found\r\nContent-Length: 3\r\nConnection: close\r\n"; + assert_err!(parse_content_length(headers)); + } + + #[test] + fn parse_content_length_rejects_unsupported_http_version() { + let headers = "HTTP/1.0 200 OK\r\nContent-Length: 3\r\nConnection: close\r\n"; + assert_err!(parse_content_length(headers)); + } + + #[test] + fn parse_content_length_rejects_invalid_numeric_value() { + let headers = "HTTP/1.1 200 OK\r\nContent-Length: nope\r\nConnection: close\r\n"; + assert_err!(parse_content_length(headers)); + } + + #[test] + fn find_headers_end_returns_none_when_separator_missing() { + let response = b"HTTP/1.1 200 OK\r\nContent-Length: 4\r\n"; + assert_none!(find_headers_end(response)); + } + + #[test] + fn find_headers_end_returns_separator_offset() { + let response = b"HTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\nbody"; + let pos = assert_some!(find_headers_end(response)); + assert_eq!(pos, 34); + } +} diff --git a/runner/src/config/mod.rs b/runner/src/config/mod.rs index 1451fd5..76a1f2f 100644 --- a/runner/src/config/mod.rs +++ b/runner/src/config/mod.rs @@ -86,7 +86,7 @@ impl TryFrom for Config { #[cfg(test)] mod tests { use super::*; - use claims::{assert_err, assert_ok, assert_some}; + use claims::{assert_err, assert_ok}; const VALID_CONFIG: &str = r#" [[benchmarks]] @@ -134,8 +134,13 @@ server = "127.0.0.1:4433" fn valid_multiple_benchmarks() { let config = get_config_from_str(VALID_CONFIG); assert_eq!(config.benchmarks.len(), 2); - assert_eq!(config.benchmarks[0].mode, KeyExchangeMode::X25519); - assert_eq!(config.benchmarks[1].mode, KeyExchangeMode::X25519Mlkem768); + let bench_0 = config.benchmarks[0].clone(); + let bench_1 = config.benchmarks[1].clone(); + + assert_eq!(bench_0.mode, KeyExchangeMode::X25519); + assert_eq!(bench_0.proto, ProtocolMode::Raw); + assert_eq!(bench_1.mode, KeyExchangeMode::X25519Mlkem768); + assert_eq!(bench_1.proto, ProtocolMode::Http1); } #[test] @@ -222,38 +227,4 @@ server = "127.0.0.1:4433" let config = get_config_from_str(toml); assert!(config.benchmarks.is_empty()); } - - #[test] - fn server_mode_fallback() { - let toml = r#" -[[benchmarks]] -proto = "raw" -mode = "x25519" -payload = 1024 -iters = 100 -warmup = 10 -concurrency = 1 -server = "127.0.0.1:4433" -"#; - let config = get_config_from_str(toml); - let benchmark = assert_some!(config.benchmarks.first()); - assert_eq!(benchmark.mode, KeyExchangeMode::X25519); - } - - #[test] - fn server_mode_mlkem() { - let toml = r#" -[[benchmarks]] -proto = "raw" -mode = "x25519mlkem768" -payload = 1024 -iters = 100 -warmup = 10 -concurrency = 1 -server = "127.0.0.1:4433" -"#; - let config = get_config_from_str(toml); - let benchmark = assert_some!(config.benchmarks.first()); - assert_eq!(benchmark.mode, KeyExchangeMode::X25519Mlkem768); - } }