Compare commits

...

7 Commits

24 changed files with 1051 additions and 343 deletions

97
Cargo.lock generated
View File

@@ -115,6 +115,12 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "atomic-waker"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
[[package]] [[package]]
name = "autocfg" name = "autocfg"
version = "1.5.0" version = "1.5.0"
@@ -282,6 +288,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"cargo-husky", "cargo-husky",
"claims", "claims",
"clap",
"miette", "miette",
"rcgen", "rcgen",
"rustls", "rustls",
@@ -496,6 +503,86 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "http"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a"
dependencies = [
"bytes",
"itoa",
]
[[package]]
name = "http-body"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
dependencies = [
"bytes",
"http",
]
[[package]]
name = "http-body-util"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
dependencies = [
"bytes",
"futures-core",
"http",
"http-body",
"pin-project-lite",
]
[[package]]
name = "httparse"
version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
[[package]]
name = "httpdate"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "hyper"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11"
dependencies = [
"atomic-waker",
"bytes",
"futures-channel",
"futures-core",
"http",
"http-body",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"pin-utils",
"smallvec",
"tokio",
]
[[package]]
name = "hyper-util"
version = "0.1.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0"
dependencies = [
"bytes",
"http",
"http-body",
"hyper",
"pin-project-lite",
"tokio",
]
[[package]] [[package]]
name = "indexmap" name = "indexmap"
version = "2.13.0" version = "2.13.0"
@@ -776,6 +863,12 @@ version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b"
[[package]]
name = "pin-utils"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]] [[package]]
name = "powerfmt" name = "powerfmt"
version = "0.2.0" version = "0.2.0"
@@ -1012,9 +1105,13 @@ name = "server"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"base64", "base64",
"bytes",
"claims", "claims",
"clap", "clap",
"common", "common",
"http-body-util",
"hyper",
"hyper-util",
"miette", "miette",
"rustls", "rustls",
"thiserror", "thiserror",

View File

@@ -13,8 +13,12 @@ common = { path = "common" }
aws-lc-rs = "1" aws-lc-rs = "1"
base64 = "0.22" base64 = "0.22"
bytes = "1.11"
clap = { version = "4.5", features = ["derive"] } clap = { version = "4.5", features = ["derive"] }
futures = "0.3" futures = "0.3"
http-body-util = "0.1"
hyper = { version = "1.8", features = ["http1"] }
hyper-util = { version = "0.1", features = ["tokio"] }
miette = { version = "7", features = ["fancy"] } miette = { version = "7", features = ["fancy"] }
rcgen = "0.14" rcgen = "0.14"
rustls = { version = "0.23", default-features = false, features = [ rustls = { version = "0.23", default-features = false, features = [

View File

@@ -43,22 +43,23 @@ cargo build --release
Terminal 1 - Start server: Terminal 1 - Start server:
```bash ```bash
./target/release/server --mode x25519 --listen 127.0.0.1:4433 ./target/release/server --mode x25519 --proto raw --listen 127.0.0.1:4433
``` ```
Terminal 2 - Run benchmark: Terminal 2 - Run benchmark:
```bash ```bash
./target/release/runner --server 127.0.0.1:4433 --mode x25519 --iters 100 --warmup 10 ./target/release/runner --server 127.0.0.1:4433 --proto raw --mode x25519 --iters 100 --warmup 10
``` ```
### Run Matrix Benchmarks ### Run Matrix Benchmarks
Create a config file (`matrix.toml`): Create a config file (`benchmarks.toml`):
```toml ```toml
[[benchmarks]] [[benchmarks]]
server = "127.0.0.1:4433" server = "127.0.0.1:4433"
proto = "raw"
mode = "x25519" mode = "x25519"
payload = 1024 payload = 1024
iters = 100 iters = 100
@@ -67,6 +68,7 @@ concurrency = 1
[[benchmarks]] [[benchmarks]]
server = "127.0.0.1:4433" server = "127.0.0.1:4433"
proto = "http1"
mode = "x25519mlkem768" mode = "x25519mlkem768"
payload = 1024 payload = 1024
iters = 100 iters = 100

View File

@@ -5,6 +5,7 @@ authors.workspace = true
edition.workspace = true edition.workspace = true
[dependencies] [dependencies]
clap.workspace = true
miette.workspace = true miette.workspace = true
rcgen.workspace = true rcgen.workspace = true
rustls.workspace = true rustls.workspace = true

View File

@@ -1,8 +1,3 @@
//! Self-signed certificate generation for local testing.
//!
//! Generates a CA certificate and server certificate for TLS benchmarking.
//! These certificates are NOT suitable for production use.
use rcgen::{BasicConstraints, CertificateParams, DnType, IsCa, Issuer, KeyPair, SanType}; use rcgen::{BasicConstraints, CertificateParams, DnType, IsCa, Issuer, KeyPair, SanType};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};

View File

@@ -1,7 +1,7 @@
use miette::Diagnostic; use miette::Diagnostic;
use thiserror::Error; use thiserror::Error;
/// Result type using the common's custom error type. /// Result type using the `common`'s custom error type.
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Error, Diagnostic)] #[derive(Debug, Error, Diagnostic)]

View File

@@ -1,30 +1,50 @@
//! Common types and utilities for the TLS benchmark harness
pub mod cert; pub mod cert;
pub mod error; pub mod error;
pub mod prelude;
pub mod protocol; pub mod protocol;
use clap::ValueEnum;
pub use error::Error; pub use error::Error;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt; use std::fmt;
use strum::{Display, EnumString}; use strum::Display;
/// TLS key exchange mode /// TLS 1.3 key exchange mode used for benchmark runs
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, EnumString, Display)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Display, ValueEnum)]
#[strum(serialize_all = "lowercase")] #[strum(serialize_all = "lowercase")]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum KeyExchangeMode { pub enum KeyExchangeMode {
/// Classical X25519 ECDH. /// Classical X25519 ECDH.
X25519, X25519,
#[value(name = "x25519mlkem768")]
/// Hybrid post-quantum: X25519 + ML-KEM-768. /// Hybrid post-quantum: X25519 + ML-KEM-768.
X25519Mlkem768, X25519Mlkem768,
} }
/// Application protocol carried over TLS in benchmark runs.
///
/// `Raw` is a minimal custom framing protocol (8-byte LE length request, then N payload bytes)
/// used for low-overhead microbenchmarks.
///
/// `Http1` is an HTTP/1.1 request/response mode (`GET /bytes/{n}`) used for realism-oriented
/// comparisons where HTTP parsing and headers are part of measured overhead.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Display, ValueEnum)]
#[strum(serialize_all = "lowercase")]
#[serde(rename_all = "lowercase")]
pub enum ProtocolMode {
/// Minimal custom framing protocol for primary microbenchmarks.
Raw,
/// HTTP/1.1 mode for realism-oriented comparisons.
Http1,
}
/// A single benchmark measurement record, output as NDJSON /// A single benchmark measurement record, output as NDJSON
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchRecord { pub struct BenchRecord {
/// Iteration number (0-indexed, excludes warmup) /// Iteration number (0-indexed, excludes warmup)
pub iteration: u64, pub iteration: u64,
/// Protocol carrier mode
pub proto: ProtocolMode,
/// Key exchange mode used /// Key exchange mode used
pub mode: KeyExchangeMode, pub mode: KeyExchangeMode,
/// Payload size in bytes /// Payload size in bytes
@@ -61,12 +81,12 @@ mod tests {
use super::*; use super::*;
use claims::{assert_err, assert_ok}; use claims::{assert_err, assert_ok};
use serde_json::Value; use serde_json::Value;
use std::str::FromStr;
#[test] #[test]
fn bench_record_serializes_to_ndjson() { fn bench_record_serializes_to_ndjson() {
let record = BenchRecord { let record = BenchRecord {
iteration: 0, iteration: 0,
proto: ProtocolMode::Raw,
mode: KeyExchangeMode::X25519, mode: KeyExchangeMode::X25519,
payload_bytes: 1024, payload_bytes: 1024,
tcp_ns: 500_000, tcp_ns: 500_000,
@@ -75,6 +95,7 @@ mod tests {
}; };
let json = assert_ok!(record.to_ndjson()); let json = assert_ok!(record.to_ndjson());
assert!(json.contains(r#""iteration":0"#)); assert!(json.contains(r#""iteration":0"#));
assert!(json.contains(r#""proto":"raw""#));
assert!(json.contains(r#""mode":"x25519""#)); assert!(json.contains(r#""mode":"x25519""#));
assert!(json.contains(r#""payload_bytes":1024"#)); assert!(json.contains(r#""payload_bytes":1024"#));
} }
@@ -83,6 +104,7 @@ mod tests {
fn bench_record_roundtrip() { fn bench_record_roundtrip() {
let original = BenchRecord { let original = BenchRecord {
iteration: 42, iteration: 42,
proto: ProtocolMode::Http1,
mode: KeyExchangeMode::X25519Mlkem768, mode: KeyExchangeMode::X25519Mlkem768,
payload_bytes: 4096, payload_bytes: 4096,
tcp_ns: 1_000_000, tcp_ns: 1_000_000,
@@ -93,26 +115,18 @@ mod tests {
let deserialized = assert_ok!(serde_json::from_str::<BenchRecord>(&json)); let deserialized = assert_ok!(serde_json::from_str::<BenchRecord>(&json));
assert_eq!(original.iteration, deserialized.iteration); assert_eq!(original.iteration, deserialized.iteration);
assert_eq!(original.proto, deserialized.proto);
assert_eq!(original.mode, deserialized.mode); assert_eq!(original.mode, deserialized.mode);
assert_eq!(original.payload_bytes, deserialized.payload_bytes); assert_eq!(original.payload_bytes, deserialized.payload_bytes);
assert_eq!(original.handshake_ns, deserialized.handshake_ns); assert_eq!(original.handshake_ns, deserialized.handshake_ns);
assert_eq!(original.ttlb_ns, deserialized.ttlb_ns); assert_eq!(original.ttlb_ns, deserialized.ttlb_ns);
} }
#[test]
fn key_exchange_mode_from_str() {
let mode = assert_ok!(KeyExchangeMode::from_str("x25519"));
assert_eq!(mode, KeyExchangeMode::X25519);
let mode = assert_ok!(KeyExchangeMode::from_str("x25519mlkem768"));
assert_eq!(mode, KeyExchangeMode::X25519Mlkem768);
}
#[test] #[test]
fn key_exchange_mode_parse_error() { fn key_exchange_mode_parse_error() {
assert_err!(KeyExchangeMode::from_str("invalid")); assert_err!(KeyExchangeMode::from_str("invalid", true));
assert_err!(KeyExchangeMode::from_str("x25519invalid")); assert_err!(KeyExchangeMode::from_str("x25519invalid", true));
assert_err!(KeyExchangeMode::from_str("")); assert_err!(KeyExchangeMode::from_str("", true));
} }
#[test] #[test]
@@ -135,4 +149,38 @@ mod tests {
)); ));
assert_eq!(mode_mlkem_lower, KeyExchangeMode::X25519Mlkem768); assert_eq!(mode_mlkem_lower, KeyExchangeMode::X25519Mlkem768);
} }
#[test]
fn key_protocol_mod_from_str() {
let proto = assert_ok!(ProtocolMode::from_str("raw", true));
assert_eq!(proto, ProtocolMode::Raw);
let proto = assert_ok!(ProtocolMode::from_str("http1", true));
assert_eq!(proto, ProtocolMode::Http1);
}
#[test]
fn key_protocol_mode_parse_error() {
assert_err!(ProtocolMode::from_str("invalid", true));
assert_err!(ProtocolMode::from_str("", true));
}
#[test]
fn key_exchange_mode_from_str() {
let mode = assert_ok!(KeyExchangeMode::from_str("x25519", true));
assert_eq!(mode, KeyExchangeMode::X25519);
let mode = assert_ok!(KeyExchangeMode::from_str("x25519mlkem768", true));
assert_eq!(mode, KeyExchangeMode::X25519Mlkem768);
}
#[test]
fn key_protocol_mode_serde() {
let json = r#"{"proto":"http1"}"#;
let value = assert_ok!(serde_json::from_str::<Value>(json));
let proto = assert_ok!(serde_json::from_value::<ProtocolMode>(
value["proto"].clone()
));
assert_eq!(proto, ProtocolMode::Http1);
}
} }

7
common/src/prelude.rs Normal file
View File

@@ -0,0 +1,7 @@
pub use crate::{
BenchRecord, KeyExchangeMode, ProtocolMode,
protocol::{
MAX_PAYLOAD_SIZE, generate_payload, read_payload, read_request, write_payload,
write_request,
},
};

View File

@@ -1,11 +1,3 @@
//! Benchmark protocol implementation.
//!
//! Protocol specification:
//! 1. Client sends 8-byte little-endian u64: requested payload size N
//! 2. Server responds with exactly N bytes (deterministic pattern)
//!
//! The deterministic pattern is a repeating sequence of bytes 0x00..0xFF.
// Casts are intentional: MAX_PAYLOAD_SIZE (16 MiB) fits in usize on 64-bit, // Casts are intentional: MAX_PAYLOAD_SIZE (16 MiB) fits in usize on 64-bit,
// and byte patterns are explicitly masked to 0xFF before casting. // and byte patterns are explicitly masked to 0xFF before casting.
#![allow(clippy::cast_possible_truncation)] #![allow(clippy::cast_possible_truncation)]
@@ -50,6 +42,7 @@ pub async fn write_request<W: AsyncWriteExt + Unpin>(writer: &mut W, size: u64)
/// Generate deterministic payload of the given size. /// Generate deterministic payload of the given size.
/// ///
/// The pattern is a repeating sequence: 0x00, 0x01, ..., 0xFF, 0x00, ... /// The pattern is a repeating sequence: 0x00, 0x01, ..., 0xFF, 0x00, ...
#[inline]
#[must_use] #[must_use]
pub fn generate_payload(size: u64) -> Vec<u8> { pub fn generate_payload(size: u64) -> Vec<u8> {
(0..size).map(|i| (i & 0xFF) as u8).collect() (0..size).map(|i| (i & 0xFF) as u8).collect()

View File

@@ -61,8 +61,8 @@ setup:
# Run server (default: x25519 on localhost:4433) # Run server (default: x25519 on localhost:4433)
[group("run")] [group("run")]
server mode="x25519" listen="127.0.0.1:4433": server mode="x25519" proto="raw" listen="127.0.0.1:4433":
cargo run --release --bin server -- --mode {{mode}} --listen {{listen}} cargo run --release --bin server -- --mode {{mode}} --proto {{proto}} --listen {{listen}}
# Run benchmark runner # Run benchmark runner
[group("run")] [group("run")]

View File

@@ -1,40 +1,44 @@
use clap::Parser; use clap::Parser;
use common::KeyExchangeMode; use common::prelude::*;
use std::{net::SocketAddr, path::PathBuf}; use std::{net::SocketAddr, path::PathBuf};
/// TLS benchmark runner. /// TLS benchmark runner.
#[derive(Debug, Parser)] #[derive(Debug, Parser)]
#[command(name = "runner", version, about)] #[command(name = "runner", version, about)]
pub struct Args { pub struct Args {
/// Key exchange mode. /// Protocol carrier mode
#[arg(long, default_value = "raw")]
pub proto: ProtocolMode,
/// Key exchange mode
#[arg(long, default_value = "x25519")] #[arg(long, default_value = "x25519")]
pub mode: KeyExchangeMode, pub mode: KeyExchangeMode,
/// Server address to connect to. /// Server address to connect to
#[arg(long, required_unless_present = "config")] #[arg(long, required_unless_present = "config")]
pub server: Option<SocketAddr>, pub server: Option<SocketAddr>,
/// Payload size in bytes to request from server. /// Payload size in bytes to request from server
#[arg(long, default_value = "1024")] #[arg(long, default_value = "1024")]
pub payload_bytes: u32, pub payload_bytes: u32,
/// Number of benchmark iterations (excluding warmup). /// Number of benchmark iterations (excluding warmup)
#[arg(long, default_value = "100")] #[arg(long, default_value = "100")]
pub iters: u32, pub iters: u32,
/// Number of warmup iterations (not recorded). /// Number of warmup iterations (not recorded)
#[arg(long, default_value = "10")] #[arg(long, default_value = "10")]
pub warmup: u32, pub warmup: u32,
/// Number of concurrent connections. /// Number of concurrent connections
#[arg(long, default_value = "1")] #[arg(long, default_value = "1")]
pub concurrency: u32, pub concurrency: u32,
/// Output file for NDJSON records (stdout if not specified). /// Output file for NDJSON records (stdout if not specified)
#[arg(long)] #[arg(long)]
pub out: Option<PathBuf>, pub out: Option<PathBuf>,
/// Config file for matrix benchmarks (TOML). /// Config file for matrix benchmarks (TOML)
#[arg(long)] #[arg(long, short)]
pub config: Option<PathBuf>, pub config: Option<PathBuf>,
} }

View File

@@ -1,7 +1,5 @@
use common::{ use crate::config::BenchmarkConfig;
BenchRecord, KeyExchangeMode, use common::prelude::*;
protocol::{read_payload, write_request},
};
use futures::{StreamExt, stream::FuturesUnordered}; use futures::{StreamExt, stream::FuturesUnordered};
use miette::{Context, IntoDiagnostic}; use miette::{Context, IntoDiagnostic};
use rustls::pki_types::ServerName; use rustls::pki_types::ServerName;
@@ -10,12 +8,13 @@ use std::{
net::SocketAddr, net::SocketAddr,
time::Instant, time::Instant,
}; };
use tokio::net::TcpStream; use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::TcpStream,
};
use tokio_rustls::TlsConnector; use tokio_rustls::TlsConnector;
use tracing::info; use tracing::info;
use crate::config::BenchmarkConfig;
/// Result of a single benchmark iteration. /// Result of a single benchmark iteration.
struct IterationResult { struct IterationResult {
tcp: u128, tcp: u128,
@@ -23,9 +22,103 @@ struct IterationResult {
ttlb: u128, ttlb: u128,
} }
pub async fn run_benchmark(
config: &BenchmarkConfig,
tls_connector: &TlsConnector,
server_name: &ServerName<'static>,
) -> miette::Result<()> {
let server = config.server;
info!(
warmup = config.warmup,
iters = config.iters,
concurrency = config.concurrency,
"running benchmark iterations"
);
for _ in 0..config.warmup {
run_iteration(
server,
config.proto,
config.payload,
tls_connector,
server_name,
)
.await?;
}
info!("warmup complete");
let mut output = stdout();
run_and_write(config, tls_connector, server_name, &mut output).await?;
output
.flush()
.into_diagnostic()
.context("failed to flush output")?;
info!("benchmark complete");
Ok(())
}
async fn run_and_write<W: Write + Send>(
config: &BenchmarkConfig,
tls_connector: &TlsConnector,
server_name: &ServerName<'static>,
output: &mut W,
) -> miette::Result<()> {
let mut in_flight = FuturesUnordered::new();
let mut issued = 0;
loop {
while issued < config.iters && in_flight.len() < config.concurrency as usize {
in_flight.push(run_single_iteration(
issued,
config.payload,
config.proto,
config.mode,
config.server,
tls_connector.clone(),
server_name.clone(),
));
issued += 1;
}
match in_flight.next().await {
Some(record) => writeln!(output, "{}", record?)
.into_diagnostic()
.context("failed to write record")?,
None => break,
}
}
Ok(())
}
async fn run_single_iteration(
i: u32,
payload_bytes: u32,
proto: ProtocolMode,
mode: KeyExchangeMode,
server: SocketAddr,
tls_connector: TlsConnector,
server_name: ServerName<'static>,
) -> miette::Result<BenchRecord> {
let result = run_iteration(server, proto, payload_bytes, &tls_connector, &server_name).await?;
Ok(BenchRecord {
iteration: u64::from(i),
proto,
mode,
payload_bytes: u64::from(payload_bytes),
tcp_ns: result.tcp,
handshake_ns: result.handshake,
ttlb_ns: result.ttlb,
})
}
/// Run a single benchmark iteration over TLS. /// Run a single benchmark iteration over TLS.
async fn run_iteration( async fn run_iteration(
server: SocketAddr, server: SocketAddr,
proto: ProtocolMode,
payload_bytes: u32, payload_bytes: u32,
tls_connector: &TlsConnector, tls_connector: &TlsConnector,
server_name: &ServerName<'static>, server_name: &ServerName<'static>,
@@ -49,17 +142,9 @@ async fn run_iteration(
let handshake_ns = hs_start.elapsed().as_nanos(); let handshake_ns = hs_start.elapsed().as_nanos();
let ttlb_start = Instant::now(); let ttlb_start = Instant::now();
write_request(&mut tls_stream, u64::from(payload_bytes))
.await
.into_diagnostic()
.context("write request failed")?;
read_payload(&mut tls_stream, u64::from(payload_bytes))
.await
.into_diagnostic()
.context("read payload failed")?;
let ttlb_ns = tcp_ns + handshake_ns + ttlb_start.elapsed().as_nanos(); let ttlb_ns = tcp_ns + handshake_ns + ttlb_start.elapsed().as_nanos();
run_exchange(&mut tls_stream, proto, payload_bytes).await?;
Ok(IterationResult { Ok(IterationResult {
tcp: tcp_ns, tcp: tcp_ns,
@@ -68,86 +153,212 @@ async fn run_iteration(
}) })
} }
pub async fn run_benchmark( async fn run_exchange<S>(
config: &BenchmarkConfig, tls_stream: &mut S,
tls_connector: &TlsConnector, proto: ProtocolMode,
server_name: &ServerName<'static>,
) -> miette::Result<()> {
let server = config.server;
info!(
warmup = config.warmup,
iters = config.iters,
concurrency = config.concurrency,
"running benchmark iterations"
);
for _ in 0..config.warmup {
run_iteration(server, config.payload, tls_connector, server_name).await?;
}
info!("warmup complete");
#[allow(clippy::cast_possible_truncation)] // concurrency is limited to reasonable values
let mut output = stdout();
run_and_write(config, tls_connector, server_name, &mut output).await?;
output
.flush()
.into_diagnostic()
.context("failed to flush output")?;
info!("benchmark complete");
Ok(())
}
async fn run_single_iteration(
i: u32,
payload_bytes: u32, payload_bytes: u32,
mode: KeyExchangeMode, ) -> miette::Result<()>
server: SocketAddr, where
tls_connector: TlsConnector, S: AsyncRead + AsyncWrite + Unpin,
server_name: ServerName<'static>, {
) -> miette::Result<BenchRecord> { match proto {
let result = run_iteration(server, payload_bytes, &tls_connector, &server_name).await?; ProtocolMode::Raw => run_raw_exchange(tls_stream, payload_bytes).await,
ProtocolMode::Http1 => run_http1_exchange(tls_stream, payload_bytes).await,
Ok(BenchRecord { }
iteration: u64::from(i),
mode,
payload_bytes: u64::from(payload_bytes),
tcp_ns: result.tcp,
handshake_ns: result.handshake,
ttlb_ns: result.ttlb,
})
} }
async fn run_and_write<W: Write + Send>( async fn run_raw_exchange<S>(tls_stream: &mut S, payload_bytes: u32) -> miette::Result<()>
config: &BenchmarkConfig, where
tls_connector: &TlsConnector, S: AsyncRead + AsyncWrite + Unpin,
server_name: &ServerName<'static>, {
output: &mut W, write_request(tls_stream, u64::from(payload_bytes))
) -> miette::Result<()> { .await
let mut in_flight = FuturesUnordered::new();
let mut issued = 0;
loop {
while issued < config.iters && in_flight.len() < config.concurrency as usize {
in_flight.push(run_single_iteration(
issued,
config.payload,
config.mode,
config.server,
tls_connector.clone(),
server_name.clone(),
));
issued += 1;
}
match in_flight.next().await {
Some(record) => writeln!(output, "{}", record?)
.into_diagnostic() .into_diagnostic()
.context("failed to write record")?, .context("write request failed")?;
None => break,
read_payload(tls_stream, u64::from(payload_bytes))
.await
.into_diagnostic()
.context("read payload failed")?;
Ok(())
}
async fn run_http1_exchange<S>(tls_stream: &mut S, payload_bytes: u32) -> miette::Result<()>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let request = build_http1_request(payload_bytes);
tls_stream
.write_all(&request)
.await
.into_diagnostic()
.context("write http1 request failed")?;
tls_stream
.flush()
.await
.into_diagnostic()
.context("flush http1 request failed")?;
let mut response_buf = Vec::with_capacity(1024);
let mut chunk = [0; 1024];
let (content_length, body_start) = loop {
let n = tls_stream
.read(&mut chunk)
.await
.into_diagnostic()
.context("read http1 response failed")?;
if n == 0 {
return Err(common::Error::protocol("unexpected EOF before http1 headers").into());
} }
response_buf.extend_from_slice(&chunk[..n]);
if let Some(pos) = find_headers_end(&response_buf) {
let headers = str::from_utf8(&response_buf[..pos])
.into_diagnostic()
.context("http1 headers are not valid UTF-8")?;
let content_length = parse_content_length(headers)?;
break (content_length, pos + 4);
}
};
let body_already_read = response_buf.len() - body_start;
if body_already_read > content_length {
return Err(common::Error::protocol("http1 body exceeded content-lenght").into());
}
let mut remaining = content_length - body_already_read;
let mut body_buf = vec![0; 64 * 1024];
while remaining > 0 {
let to_read = remaining.min(body_buf.len());
tls_stream
.read_exact(&mut body_buf[..to_read])
.await
.into_diagnostic()
.context("read http1 body failed")?;
remaining -= to_read;
} }
Ok(()) Ok(())
} }
fn build_http1_request(payload_bytes: u32) -> Vec<u8> {
format!("GET /bytes/{payload_bytes} HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
.into_bytes()
}
fn parse_content_length(headers: &str) -> miette::Result<usize> {
let mut lines = headers.lines();
let status_line = lines
.next()
.ok_or_else(|| common::Error::protocol("missing http1 status line"))?;
let mut parts = status_line.split_whitespace();
let version = parts
.next()
.ok_or_else(|| common::Error::protocol("missing http1 version"))?;
let status = parts
.next()
.ok_or_else(|| common::Error::protocol("missing http1 status"))?;
if version != "HTTP/1.1" {
return Err(common::Error::protocol(format!("unsupported http version: {version}")).into());
}
if status != "200" {
return Err(common::Error::protocol(format!("unsupported http status: {status}")).into());
}
for line in lines {
if let Some((name, value)) = line.split_once(':')
&& name.trim().eq_ignore_ascii_case("content-length")
{
return value
.trim()
.parse::<usize>()
.into_diagnostic()
.context("invalid content-length header");
}
}
Err(common::Error::protocol("missing content-length header").into())
}
fn find_headers_end(buf: &[u8]) -> Option<usize> {
buf.windows(4).position(|window| window == b"\r\n\r\n")
}
#[cfg(test)]
mod tests {
use super::*;
use claims::{assert_err, assert_none, assert_ok, assert_some};
#[test]
fn build_http1_request_formats_get_requests() {
let request = build_http1_request(16);
let request_string = String::from_utf8(request).expect("valid string");
assert_eq!(
request_string,
"GET /bytes/16 HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
);
}
#[test]
fn parse_content_length_accepts_200() {
let headers = "HTTP/1.1 200 OK\r\nContent-Length: 16\r\nConnection: close\r\n";
let len = assert_ok!(parse_content_length(headers));
assert_eq!(len, 16);
}
#[test]
fn parse_content_length_rejects_missing_header() {
let headers = "HTTP/1.1 200 OK\r\nConnection: close\r\n";
assert_err!(parse_content_length(headers));
}
#[test]
fn parse_content_length_accepts_mixed_case_header_name() {
let headers = "HTTP/1.1 200 OK\r\nContent-Length: 8\r\nConnection: close\r\n";
let len = assert_ok!(parse_content_length(headers));
assert_eq!(len, 8);
let headers = "HTTP/1.1 200 OK\r\ncontent-length: 9\r\nConnection: close\r\n";
let len = assert_ok!(parse_content_length(headers));
assert_eq!(len, 9);
}
#[test]
fn parse_content_length_rejects_non_200_status() {
let headers = "HTTP/1.1 404 Not Found\r\nContent-Length: 3\r\nConnection: close\r\n";
assert_err!(parse_content_length(headers));
}
#[test]
fn parse_content_length_rejects_unsupported_http_version() {
let headers = "HTTP/1.0 200 OK\r\nContent-Length: 3\r\nConnection: close\r\n";
assert_err!(parse_content_length(headers));
}
#[test]
fn parse_content_length_rejects_invalid_numeric_value() {
let headers = "HTTP/1.1 200 OK\r\nContent-Length: nope\r\nConnection: close\r\n";
assert_err!(parse_content_length(headers));
}
#[test]
fn find_headers_end_returns_none_when_separator_missing() {
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 4\r\n";
assert_none!(find_headers_end(response));
}
#[test]
fn find_headers_end_returns_separator_offset() {
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\nbody";
let pos = assert_some!(find_headers_end(response));
assert_eq!(pos, 34);
}
}

View File

@@ -5,13 +5,14 @@ use crate::{
config::utils::validate_config, config::utils::validate_config,
error::{self, ConfigError}, error::{self, ConfigError},
}; };
use common::{self, KeyExchangeMode}; use common::prelude::*;
use miette::{NamedSource, SourceSpan}; use miette::{NamedSource, SourceSpan};
use serde::Deserialize; use serde::Deserialize;
use std::{fs::read_to_string, net::SocketAddr, path::Path}; use std::{fs::read_to_string, net::SocketAddr, path::Path};
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
pub struct BenchmarkConfig { pub struct BenchmarkConfig {
pub proto: ProtocolMode,
pub mode: KeyExchangeMode, pub mode: KeyExchangeMode,
pub payload: u32, pub payload: u32,
pub iters: u32, pub iters: u32,
@@ -29,7 +30,10 @@ pub struct Config {
/// ///
/// # Errors /// # Errors
/// Returns an error if the file cannot be read or parsed. /// Returns an error if the file cannot be read or parsed.
pub fn load_from_file(path: &Path) -> error::Result<Config> { impl TryFrom<&Path> for Config {
type Error = error::Error;
fn try_from(path: &Path) -> Result<Self, Self::Error> {
let content = read_to_string(path).map_err(|source| ConfigError::ReadError { let content = read_to_string(path).map_err(|source| ConfigError::ReadError {
source, source,
path: path.to_owned(), path: path.to_owned(),
@@ -37,7 +41,7 @@ pub fn load_from_file(path: &Path) -> error::Result<Config> {
let src = NamedSource::new(path.display().to_string(), content.clone()); let src = NamedSource::new(path.display().to_string(), content.clone());
let config = toml::from_str::<Config>(&content).map_err(|source| { let config = toml::from_str::<Self>(&content).map_err(|source| {
let span = source let span = source
.span() .span()
.map(|s| SourceSpan::new(s.start.into(), s.end - s.start)); .map(|s| SourceSpan::new(s.start.into(), s.end - s.start));
@@ -52,15 +56,20 @@ pub fn load_from_file(path: &Path) -> error::Result<Config> {
validate_config(&config, &content, path)?; validate_config(&config, &content, path)?;
Ok(config) Ok(config)
}
} }
/// Create benchmark configuration from CLI arguments. /// Create benchmark configuration from CLI arguments.
/// ///
/// # Errors /// # Errors
/// Never returns an error, but returns Result for consistency. /// Returns an error if `--server` was not provided.
pub fn load_from_cli(args: &Args) -> error::Result<Config> { impl TryFrom<Args> for Config {
Ok(Config { type Error = error::Error;
fn try_from(args: Args) -> Result<Self, Self::Error> {
Ok(Self {
benchmarks: vec![BenchmarkConfig { benchmarks: vec![BenchmarkConfig {
proto: args.proto,
mode: args.mode, mode: args.mode,
payload: args.payload_bytes, payload: args.payload_bytes,
iters: args.iters, iters: args.iters,
@@ -68,18 +77,20 @@ pub fn load_from_cli(args: &Args) -> error::Result<Config> {
concurrency: args.concurrency, concurrency: args.concurrency,
server: args server: args
.server .server
.ok_or_else(|| common::Error::config("--server ir required"))?, .ok_or_else(|| common::Error::config("--server is required"))?,
}], }],
}) })
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use claims::{assert_err, assert_ok, assert_some}; use claims::{assert_err, assert_ok};
const VALID_CONFIG: &str = r#" const VALID_CONFIG: &str = r#"
[[benchmarks]] [[benchmarks]]
proto = "raw"
mode = "x25519" mode = "x25519"
payload = 1024 payload = 1024
iters = 100 iters = 100
@@ -88,6 +99,7 @@ concurrency = 1
server = "127.0.0.1:4433" server = "127.0.0.1:4433"
[[benchmarks]] [[benchmarks]]
proto = "http1"
mode = "x25519mlkem768" mode = "x25519mlkem768"
payload = 4096 payload = 4096
iters = 50 iters = 50
@@ -104,6 +116,7 @@ server = "127.0.0.1:4433"
fn valid_single_benchmark() { fn valid_single_benchmark() {
let toml = r#" let toml = r#"
[[benchmarks]] [[benchmarks]]
proto = "raw"
mode = "x25519" mode = "x25519"
payload = 1024 payload = 1024
iters = 100 iters = 100
@@ -121,14 +134,35 @@ server = "127.0.0.1:4433"
fn valid_multiple_benchmarks() { fn valid_multiple_benchmarks() {
let config = get_config_from_str(VALID_CONFIG); let config = get_config_from_str(VALID_CONFIG);
assert_eq!(config.benchmarks.len(), 2); assert_eq!(config.benchmarks.len(), 2);
assert_eq!(config.benchmarks[0].mode, KeyExchangeMode::X25519); let bench_0 = config.benchmarks[0].clone();
assert_eq!(config.benchmarks[1].mode, KeyExchangeMode::X25519Mlkem768); let bench_1 = config.benchmarks[1].clone();
assert_eq!(bench_0.mode, KeyExchangeMode::X25519);
assert_eq!(bench_0.proto, ProtocolMode::Raw);
assert_eq!(bench_1.mode, KeyExchangeMode::X25519Mlkem768);
assert_eq!(bench_1.proto, ProtocolMode::Http1);
}
#[test]
fn invalid_proto() {
let toml = r#"
[[benchmarks]]
proto = "invalid_proto"
mode = "x25519"
payload = 1024
iters = 100
warmup = 10
concurrency = 1
server = "127.0.0.1:4433"
"#;
assert_err!(toml::from_str::<Config>(toml));
} }
#[test] #[test]
fn invalid_mode() { fn invalid_mode() {
let toml = r#" let toml = r#"
[[benchmarks]] [[benchmarks]]
proto = "raw"
mode = "invalid_mode" mode = "invalid_mode"
payload = 1024 payload = 1024
iters = 100 iters = 100
@@ -143,6 +177,7 @@ server = "127.0.0.1:4433"
fn payload_zero_validation() { fn payload_zero_validation() {
let toml = r#" let toml = r#"
[[benchmarks]] [[benchmarks]]
proto = "raw"
mode = "x25519" mode = "x25519"
payload = 0 payload = 0
iters = 100 iters = 100
@@ -158,6 +193,7 @@ server = "127.0.0.1:4433"
fn iters_zero_validation() { fn iters_zero_validation() {
let toml = r#" let toml = r#"
[[benchmarks]] [[benchmarks]]
proto = "raw"
mode = "x25519" mode = "x25519"
payload = 1024 payload = 1024
iters = 0 iters = 0
@@ -173,6 +209,7 @@ server = "127.0.0.1:4433"
fn concurrency_zero_validation() { fn concurrency_zero_validation() {
let toml = r#" let toml = r#"
[[benchmarks]] [[benchmarks]]
proto = "raw"
mode = "x25519" mode = "x25519"
payload = 1024 payload = 1024
iters = 100 iters = 100
@@ -190,36 +227,4 @@ server = "127.0.0.1:4433"
let config = get_config_from_str(toml); let config = get_config_from_str(toml);
assert!(config.benchmarks.is_empty()); assert!(config.benchmarks.is_empty());
} }
#[test]
fn server_mode_fallback() {
let toml = r#"
[[benchmarks]]
mode = "x25519"
payload = 1024
iters = 100
warmup = 10
concurrency = 1
server = "127.0.0.1:4433"
"#;
let config = get_config_from_str(toml);
let benchmark = assert_some!(config.benchmarks.first());
assert_eq!(benchmark.mode, KeyExchangeMode::X25519);
}
#[test]
fn server_mode_mlkem() {
let toml = r#"
[[benchmarks]]
mode = "x25519mlkem768"
payload = 1024
iters = 100
warmup = 10
concurrency = 1
server = "127.0.0.1:4433"
"#;
let config = get_config_from_str(toml);
let benchmark = assert_some!(config.benchmarks.first());
assert_eq!(benchmark.mode, KeyExchangeMode::X25519Mlkem768);
}
} }

View File

@@ -5,7 +5,7 @@ use miette::{Diagnostic, NamedSource, SourceSpan};
use std::path::PathBuf; use std::path::PathBuf;
use thiserror::Error; use thiserror::Error;
/// Result type using the runner's custom error type. /// Result type using the `runner`'s custom error type.
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
/// Errors that can occur during benchmark execution. /// Errors that can occur during benchmark execution.

View File

@@ -12,6 +12,7 @@ mod config;
mod error; mod error;
mod tls; mod tls;
use crate::{args::Args, bench::run_benchmark, config::Config, tls::build_tls_config};
use clap::Parser; use clap::Parser;
use miette::{Context, IntoDiagnostic}; use miette::{Context, IntoDiagnostic};
use rustls::pki_types::ServerName; use rustls::pki_types::ServerName;
@@ -21,13 +22,6 @@ use tracing::info;
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
use uuid::Uuid; use uuid::Uuid;
use crate::{
args::Args,
bench::run_benchmark,
config::{load_from_cli, load_from_file},
tls::build_tls_config,
};
#[tokio::main] #[tokio::main]
async fn main() -> miette::Result<()> { async fn main() -> miette::Result<()> {
let run_id = Uuid::new_v4(); let run_id = Uuid::new_v4();
@@ -48,12 +42,12 @@ async fn main() -> miette::Result<()> {
let args = Args::parse(); let args = Args::parse();
let config = if let Some(config_path) = &args.config { let config: Config = if let Some(config_path) = &args.config {
info!(config_file = %config_path.display(), "loading config from file"); info!(config_file = %config_path.display(), "loading config from file");
load_from_file(config_path)? config_path.as_path().try_into()?
} else { } else {
info!("using CLI arguments"); info!("using CLI arguments");
load_from_cli(&args)? args.try_into()?
}; };
let server_name = ServerName::try_from("localhost".to_string()) let server_name = ServerName::try_from("localhost".to_string())
@@ -62,6 +56,7 @@ async fn main() -> miette::Result<()> {
for benchmark in &config.benchmarks { for benchmark in &config.benchmarks {
info!( info!(
proto = %benchmark.proto,
mode = %benchmark.mode, mode = %benchmark.mode,
payload = benchmark.payload, payload = benchmark.payload,
iters = benchmark.iters, iters = benchmark.iters,

View File

@@ -1,4 +1,4 @@
use common::KeyExchangeMode; use common::prelude::*;
use miette::{Context, IntoDiagnostic}; use miette::{Context, IntoDiagnostic};
use rustls::{ use rustls::{
ClientConfig, DigitallySignedStruct, SignatureScheme, ClientConfig, DigitallySignedStruct, SignatureScheme,

View File

@@ -6,8 +6,12 @@ edition.workspace = true
[dependencies] [dependencies]
base64.workspace = true base64.workspace = true
bytes.workspace = true
clap.workspace = true clap.workspace = true
common.workspace = true common.workspace = true
http-body-util.workspace = true
hyper-util = { workspace = true, features = ["server"] }
hyper = { workspace = true, features = ["server"] }
miette.workspace = true miette.workspace = true
rustls.workspace = true rustls.workspace = true
thiserror.workspace = true thiserror.workspace = true

View File

@@ -1,7 +1,7 @@
use miette::Diagnostic; use miette::Diagnostic;
use thiserror::Error; use thiserror::Error;
/// Result type using the servers's custom error type. /// Result type using the `servers`'s custom error type.
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Error, Diagnostic)] #[derive(Debug, Error, Diagnostic)]

View File

@@ -1 +0,0 @@
pub mod error;

View File

@@ -5,146 +5,34 @@
//! - Responds with exactly N bytes (deterministic pattern) //! - Responds with exactly N bytes (deterministic pattern)
mod error; mod error;
mod server;
mod tls;
use crate::{server::run_server, tls::build_tls_config};
use base64::prelude::*; use base64::prelude::*;
use clap::Parser; use clap::Parser;
use common::{ use common::{cert::CaCertificate, prelude::*};
KeyExchangeMode, use std::{env, net::SocketAddr};
cert::{CaCertificate, ServerCertificate}, use tracing::{error, info};
protocol::{read_request, write_payload},
};
use rustls::{
ServerConfig,
crypto::aws_lc_rs::{
self,
kx_group::{X25519, X25519MLKEM768},
},
pki_types::{CertificateDer, PrivateKeyDer},
server::Acceptor,
version::TLS13,
};
use std::{env, io::ErrorKind, net::SocketAddr, sync::Arc};
use tokio::{
io::AsyncWriteExt,
net::{TcpListener, TcpStream},
};
use tokio_rustls::LazyConfigAcceptor;
use tracing::{debug, error, info, warn};
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
/// TLS benchmark server. /// TLS benchmark server.
#[derive(Debug, Parser)] #[derive(Debug, Parser)]
#[command(name = "server", version, about)] #[command(name = "server", version, about)]
struct Args { struct Args {
/// Key exchange mode. /// Key exchange mode
#[arg(long, default_value = "x25519")] #[arg(long, default_value = "x25519")]
mode: KeyExchangeMode, mode: KeyExchangeMode,
/// Address to listen on. /// Protocol carrier mode
#[arg(long, default_value = "raw")]
proto: ProtocolMode,
/// Address to listen on
#[arg(long, default_value = "127.0.0.1:4433")] #[arg(long, default_value = "127.0.0.1:4433")]
listen: SocketAddr, listen: SocketAddr,
} }
/// Build TLS server config for the given key exchange mode.
fn build_tls_config(
mode: KeyExchangeMode,
server_cert: &ServerCertificate,
) -> error::Result<Arc<ServerConfig>> {
let mut provider = aws_lc_rs::default_provider();
provider.kx_groups = match mode {
KeyExchangeMode::X25519 => vec![X25519],
KeyExchangeMode::X25519Mlkem768 => vec![X25519MLKEM768],
};
let certs = server_cert
.cert_chain_der
.iter()
.map(|der| CertificateDer::from(der.clone()))
.collect::<Vec<_>>();
let key = PrivateKeyDer::try_from(server_cert.private_key_der.clone())
.map_err(|e| error::Error::invalid_cert(format!("invalid private_key: {e}")))?;
let config = ServerConfig::builder_with_provider(Arc::new(provider))
.with_protocol_versions(&[&TLS13])
.map_err(common::Error::Tls)?
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(common::Error::Tls)?;
Ok(Arc::new(config))
}
async fn handle_connection(stream: TcpStream, peer: SocketAddr, tls_config: Arc<ServerConfig>) {
let acceptor = LazyConfigAcceptor::new(Acceptor::default(), stream);
let start_handshake = match acceptor.await {
Ok(sh) => sh,
Err(e) => {
return warn!(peer = %peer, error = %e, "TLS accept error");
}
};
let mut tls_stream = match start_handshake.into_stream(tls_config).await {
Ok(s) => s,
Err(e) => {
return warn!(peer = %peer, error = %e, "TLS handshake error");
}
};
let (_, conn) = tls_stream.get_ref();
info!(
cipher = ?conn.negotiated_cipher_suite(),
version = ?conn.protocol_version(),
"connection established"
);
loop {
let payload_size = match read_request(&mut tls_stream).await {
Ok(size) => size,
Err(e) if e.kind() == ErrorKind::UnexpectedEof => {
debug!(peer = %peer, "client disconnected");
break;
}
Err(e) => {
warn!(peer = %peer, error = %e, "connection error");
break;
}
};
if let Err(e) = write_payload(&mut tls_stream, payload_size).await {
warn!(peer = %peer, error = %e, "write error");
break;
}
// Flush to ensure data is sent
if let Err(e) = tls_stream.flush().await {
warn!(peer = %peer, error = %e, "flush error");
break;
}
}
}
async fn run_server(args: Args, tls_config: Arc<ServerConfig>) -> error::Result<()> {
let listener = TcpListener::bind(args.listen)
.await
.map_err(|e| error::Error::network(format!("failed to bind to {}: {e}", args.listen)))?;
info!(listen = %args.listen, mode = %args.mode, "listening");
loop {
let (stream, peer) = match listener.accept().await {
Ok(conn) => conn,
Err(e) => {
error!(error = %e, "accept error");
continue;
}
};
let config = tls_config.clone();
tokio::spawn(handle_connection(stream, peer, config));
}
}
#[tokio::main] #[tokio::main]
async fn main() -> miette::Result<()> { async fn main() -> miette::Result<()> {
tracing_subscriber::fmt() tracing_subscriber::fmt()
@@ -161,6 +49,7 @@ async fn main() -> miette::Result<()> {
command = env::args().collect::<Vec<_>>().join(" "), command = env::args().collect::<Vec<_>>().join(" "),
listen = %args.listen, listen = %args.listen,
mode = %args.mode, mode = %args.mode,
proto = %args.proto,
"server started" "server started"
); );
@@ -181,19 +70,20 @@ async fn main() -> miette::Result<()> {
"CA cert (truncated)" "CA cert (truncated)"
); );
Ok(run_server(args, tls_config).await?) Ok(run_server(&args, tls_config).await?)
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use claims::assert_ok; use claims::assert_ok;
use common::cert::CaCertificate; use std::sync::Arc;
#[test] #[test]
fn default_args() { fn default_args() {
let args = Args::parse_from(["server"]); let args = Args::parse_from(["server"]);
assert_eq!(args.mode, KeyExchangeMode::X25519); assert_eq!(args.mode, KeyExchangeMode::X25519);
assert_eq!(args.proto, ProtocolMode::Raw);
assert_eq!(args.listen.to_string(), "127.0.0.1:4433"); assert_eq!(args.listen.to_string(), "127.0.0.1:4433");
} }

307
server/src/server/http1.rs Normal file
View File

@@ -0,0 +1,307 @@
use bytes::Bytes;
use common::prelude::*;
use http_body_util::Full;
use hyper::{
Method, Request, Response, StatusCode,
header::{ALLOW, CONNECTION, CONTENT_LENGTH, CONTENT_TYPE, HeaderValue},
server::conn::http1::Builder,
service::service_fn,
};
use hyper_util::rt::TokioIo;
use rustls::{ServerConfig, server::Acceptor};
use std::{convert::Infallible, net::SocketAddr, sync::Arc};
use tokio::net::TcpStream;
use tokio_rustls::LazyConfigAcceptor;
use tracing::{info, warn};
type RespBody = Full<Bytes>;
pub async fn handle_http1_connection(
stream: TcpStream,
peer: SocketAddr,
tls_config: Arc<ServerConfig>,
) {
let acceptor = LazyConfigAcceptor::new(Acceptor::default(), stream);
let start_handshake = match acceptor.await {
Ok(sh) => sh,
Err(e) => {
return warn!(peer = %peer, error = %e, "TLS accept error");
}
};
let tls_stream = match start_handshake.into_stream(tls_config).await {
Ok(s) => s,
Err(e) => {
return warn!(peer = %peer, error = %e, "TLS handshake error");
}
};
let (_, conn) = tls_stream.get_ref();
info!(
cipher = ?conn.negotiated_cipher_suite(),
version = ?conn.protocol_version(),
"connection established"
);
let service = service_fn(move |req| async move { Ok::<_, Infallible>(handle_request(&req)) });
let io = TokioIo::new(tls_stream);
if let Err(e) = Builder::new()
.keep_alive(false)
.serve_connection(io, service)
.await
{
warn!(peer = %peer, error = %e, "http1 serve error");
}
}
fn handle_request<B>(req: &Request<B>) -> Response<RespBody> {
if req.method() != Method::GET {
let mut response = text_response(StatusCode::METHOD_NOT_ALLOWED, "method not allowed");
response
.headers_mut()
.insert(ALLOW, HeaderValue::from_static("GET"));
return response;
}
let n = match parse_bytes_path(req.uri().path()) {
Ok(n) => n,
Err(status) => {
let msg = match status {
StatusCode::NOT_FOUND => "not found",
StatusCode::PAYLOAD_TOO_LARGE => "payload too large",
_ => "bad request",
};
return text_response(status, msg);
}
};
let payload = generate_payload(n);
let mut response = Response::new(Full::new(Bytes::from(payload)));
*response.status_mut() = StatusCode::OK;
let headers = response.headers_mut();
headers.insert(
CONTENT_TYPE,
HeaderValue::from_static("application/octet-stream"),
);
headers.insert(CONNECTION, HeaderValue::from_static("close"));
#[allow(clippy::option_if_let_else)]
match HeaderValue::from_str(&n.to_string()) {
Ok(v) => {
headers.insert(CONTENT_LENGTH, v);
response
}
Err(_) => text_response(StatusCode::INTERNAL_SERVER_ERROR, "internal server error"),
}
}
fn parse_bytes_path(path: &str) -> Result<u64, StatusCode> {
let Some(rest) = path.strip_prefix("/bytes/") else {
return Err(StatusCode::NOT_FOUND);
};
if rest.is_empty() || rest.contains('/') {
return Err(StatusCode::BAD_REQUEST);
}
let n = rest.parse::<u64>().map_err(|_| StatusCode::BAD_REQUEST)?;
if n > MAX_PAYLOAD_SIZE {
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}
Ok(n)
}
fn text_response(status: StatusCode, msg: &'static str) -> Response<RespBody> {
let mut response = Response::new(Full::new(Bytes::from_static(msg.as_bytes())));
*response.status_mut() = status;
response.headers_mut().insert(
CONTENT_TYPE,
HeaderValue::from_static("text/plain; charset=utf-8"),
);
response
.headers_mut()
.insert(CONNECTION, HeaderValue::from_static("close"));
response
}
#[cfg(test)]
mod tests {
use super::*;
use claims::{assert_err, assert_none, assert_ok, assert_some};
use http_body_util::BodyExt;
fn make_get_request(uri: &str) -> Request<()> {
assert_ok!(Request::builder().method(Method::GET).uri(uri).body(()))
}
#[test]
fn parse_bytes_path_accepts_valid_numeric_size() {
let min_n = assert_ok!(parse_bytes_path("/bytes/0"));
assert_eq!(min_n, 0);
let n = assert_ok!(parse_bytes_path("/bytes/1024"));
assert_eq!(n, 1024);
let max_n = assert_ok!(parse_bytes_path(&format!("/bytes/{MAX_PAYLOAD_SIZE}")));
assert_eq!(max_n, MAX_PAYLOAD_SIZE);
}
#[test]
fn parse_bytes_path_rejects_non_bytes_prefix() {
let status = assert_err!(parse_bytes_path("/foo/1024"));
assert_eq!(status, StatusCode::NOT_FOUND);
}
#[test]
fn parse_bytes_path_rejects_empty_size_segment() {
let status = assert_err!(parse_bytes_path("/bytes/"));
assert_eq!(status, StatusCode::BAD_REQUEST);
}
#[test]
fn parse_bytes_path_rejects_non_numeric_size() {
let status = assert_err!(parse_bytes_path("/bytes/foo"));
assert_eq!(status, StatusCode::BAD_REQUEST);
}
#[test]
fn parse_bytes_path_rejects_nested_segments() {
let status = assert_err!(parse_bytes_path("/bytes/16/extra"));
assert_eq!(status, StatusCode::BAD_REQUEST);
}
#[test]
fn parse_bytes_path_rejects_payload_above_max() {
let status = assert_err!(parse_bytes_path(&format!(
"/bytes/{}",
MAX_PAYLOAD_SIZE + 1
)));
assert_eq!(status, StatusCode::PAYLOAD_TOO_LARGE);
}
#[test]
fn handle_request_get_bytes_returns_200_with_expected_headers() {
let req = make_get_request("/bytes/16");
let resp = handle_request(&req);
assert_eq!(resp.status(), StatusCode::OK);
let content_type = assert_some!(resp.headers().get("content-type"));
assert_eq!(content_type, "application/octet-stream");
let content_length = assert_some!(resp.headers().get("content-length"));
assert_eq!(content_length, "16");
let connection = assert_some!(resp.headers().get("connection"));
assert_eq!(connection, "close");
assert_none!(resp.headers().get("allow"));
}
#[tokio::test]
async fn handle_request_get_bytes_returns_body_with_requested_length() {
let req = make_get_request("/bytes/32");
let resp = handle_request(&req);
assert_eq!(resp.status(), StatusCode::OK);
let content_length = assert_some!(resp.headers().get("content-length"));
assert_eq!(content_length, "32");
let body = assert_ok!(resp.into_body().collect().await).to_bytes();
assert_eq!(body.len(), 32);
assert_eq!(body[0], 0x00);
assert_eq!(body[1], 0x01);
assert_eq!(body[31], 0x1F);
}
#[test]
fn handle_request_get_unknown_path_returns_404() {
let req = make_get_request("/unknown");
let resp = handle_request(&req);
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
let content_type = assert_some!(resp.headers().get("content-type"));
assert_eq!(content_type, "text/plain; charset=utf-8");
let connection = assert_some!(resp.headers().get("connection"));
assert_eq!(connection, "close");
}
#[test]
fn handle_request_post_bytes_returns_405_and_allow_get() {
let req = Request::builder()
.method(Method::POST)
.uri("/bytes/32")
.body(())
.expect("post request");
let resp = handle_request(&req);
assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
let allow = assert_some!(resp.headers().get("allow"));
assert_eq!(allow, "GET");
let connection = assert_some!(resp.headers().get("connection"));
assert_eq!(connection, "close");
}
#[test]
fn handle_request_get_bytes_without_size_segment_returns_404() {
let req = make_get_request("/bytes");
let resp = handle_request(&req);
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
assert_none!(resp.headers().get("content-length"));
let connection = assert_some!(resp.headers().get("connection"));
assert_eq!(connection, "close");
}
#[test]
fn handle_request_get_bytes_with_non_numeric_size_returns_400() {
let req = make_get_request("/bytes/foo");
let resp = handle_request(&req);
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
assert_none!(resp.headers().get("content-length"));
let connection = assert_some!(resp.headers().get("connection"));
assert_eq!(connection, "close");
}
#[test]
fn handle_request_get_bytes_with_nested_path_returns_400() {
let req = make_get_request("/bytes/16/extra");
let resp = handle_request(&req);
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
assert_none!(resp.headers().get("content-length"));
let connection = assert_some!(resp.headers().get("connection"));
assert_eq!(connection, "close");
}
#[test]
fn handle_request_get_bytes_exceeding_max_payload_returns_413() {
let req = make_get_request(&format!("/bytes/{}", MAX_PAYLOAD_SIZE + 1));
let resp = handle_request(&req);
assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE);
assert_none!(resp.headers().get("content-length"));
let connection = assert_some!(resp.headers().get("connection"));
assert_eq!(connection, "close");
}
}

45
server/src/server/mod.rs Normal file
View File

@@ -0,0 +1,45 @@
mod http1;
mod raw;
use crate::{
Args,
error::{Error as ServerError, Result as ServerResult},
server::{http1::handle_http1_connection, raw::handle_raw_connection},
};
use common::prelude::*;
use rustls::ServerConfig;
use std::sync::Arc;
use tokio::net::TcpListener;
use tracing::{error, info};
pub async fn run_server(args: &Args, tls_config: Arc<ServerConfig>) -> ServerResult<()> {
let listener = TcpListener::bind(args.listen)
.await
.map_err(|e| ServerError::network(format!("failed to bind to {}: {e}", args.listen)))?;
info!(
listen = %args.listen,
mode = %args.mode,
proto = %args.proto,
"listening"
);
loop {
let (stream, peer) = match listener.accept().await {
Ok(conn) => conn,
Err(e) => {
error!(error = %e, "accept error");
continue;
}
};
let config = tls_config.clone();
let proto = args.proto;
tokio::spawn(async move {
match proto {
ProtocolMode::Raw => handle_raw_connection(stream, peer, config).await,
ProtocolMode::Http1 => handle_http1_connection(stream, peer, config).await,
}
});
}
}

59
server/src/server/raw.rs Normal file
View File

@@ -0,0 +1,59 @@
use common::prelude::*;
use rustls::{ServerConfig, server::Acceptor};
use std::{io::ErrorKind, net::SocketAddr, sync::Arc};
use tokio::{io::AsyncWriteExt, net::TcpStream};
use tokio_rustls::LazyConfigAcceptor;
use tracing::{debug, info, warn};
pub async fn handle_raw_connection(
stream: TcpStream,
peer: SocketAddr,
tls_config: Arc<ServerConfig>,
) {
let acceptor = LazyConfigAcceptor::new(Acceptor::default(), stream);
let start_handshake = match acceptor.await {
Ok(sh) => sh,
Err(e) => {
return warn!(peer = %peer, error = %e, "TLS accept error");
}
};
let mut tls_stream = match start_handshake.into_stream(tls_config).await {
Ok(s) => s,
Err(e) => {
return warn!(peer = %peer, error = %e, "TLS handshake error");
}
};
let (_, conn) = tls_stream.get_ref();
info!(
cipher = ?conn.negotiated_cipher_suite(),
version = ?conn.protocol_version(),
"connection established"
);
loop {
let payload_size = match read_request(&mut tls_stream).await {
Ok(size) => size,
Err(e) if e.kind() == ErrorKind::UnexpectedEof => {
debug!(peer = %peer, "client disconnected");
break;
}
Err(e) => {
warn!(peer = %peer, error = %e, "connection error");
break;
}
};
if let Err(e) = write_payload(&mut tls_stream, payload_size).await {
warn!(peer = %peer, error = %e, "write error");
break;
}
// Flush to ensure data is sent
if let Err(e) = tls_stream.flush().await {
warn!(peer = %peer, error = %e, "flush error");
break;
}
}
}

42
server/src/tls.rs Normal file
View File

@@ -0,0 +1,42 @@
use crate::error;
use common::{cert::ServerCertificate, prelude::*};
use rustls::{
ServerConfig,
crypto::aws_lc_rs::{
self,
kx_group::{X25519, X25519MLKEM768},
},
pki_types::{CertificateDer, PrivateKeyDer},
version::TLS13,
};
use std::sync::Arc;
/// Build TLS server config for the given key exchange mode.
pub fn build_tls_config(
mode: KeyExchangeMode,
server_cert: &ServerCertificate,
) -> error::Result<Arc<ServerConfig>> {
let mut provider = aws_lc_rs::default_provider();
provider.kx_groups = match mode {
KeyExchangeMode::X25519 => vec![X25519],
KeyExchangeMode::X25519Mlkem768 => vec![X25519MLKEM768],
};
let certs = server_cert
.cert_chain_der
.iter()
.map(|der| CertificateDer::from(der.clone()))
.collect::<Vec<_>>();
let key = PrivateKeyDer::try_from(server_cert.private_key_der.clone())
.map_err(|e| error::Error::invalid_cert(format!("invalid private_key: {e}")))?;
let config = ServerConfig::builder_with_provider(Arc::new(provider))
.with_protocol_versions(&[&TLS13])
.map_err(common::Error::Tls)?
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(common::Error::Tls)?;
Ok(Arc::new(config))
}