mirror of
https://github.com/kristoferssolo/tls-pq-bench.git
synced 2026-03-22 00:36:21 +00:00
Compare commits
14 Commits
07cae6df55
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 589e5217f5 | |||
| 7a3a5bc3ae | |||
| 2cfe85ad82 | |||
| 4b2324d131 | |||
| 6198e3ab2e | |||
| 0ea39e7663 | |||
| 8f0d2a6efc | |||
| 99f2e0bb72 | |||
| 31f5fc5b44 | |||
| b519c4f059 | |||
| 082add6be0 | |||
| c0886a454d | |||
| a5e166e0b0 | |||
| ea2a07d5aa |
192
Cargo.lock
generated
192
Cargo.lock
generated
@@ -115,6 +115,12 @@ dependencies = [
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic-waker"
|
||||
version = "1.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.5.0"
|
||||
@@ -282,6 +288,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"cargo-husky",
|
||||
"claims",
|
||||
"clap",
|
||||
"miette",
|
||||
"rcgen",
|
||||
"rustls",
|
||||
@@ -367,6 +374,94 @@ version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
|
||||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-executor",
|
||||
"futures-io",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"futures-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-channel"
|
||||
version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-core"
|
||||
version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d"
|
||||
|
||||
[[package]]
|
||||
name = "futures-executor"
|
||||
version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-task",
|
||||
"futures-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-io"
|
||||
version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
|
||||
|
||||
[[package]]
|
||||
name = "futures-macro"
|
||||
version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-sink"
|
||||
version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893"
|
||||
|
||||
[[package]]
|
||||
name = "futures-task"
|
||||
version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393"
|
||||
|
||||
[[package]]
|
||||
name = "futures-util"
|
||||
version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"futures-macro",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"memchr",
|
||||
"pin-project-lite",
|
||||
"slab",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.2.17"
|
||||
@@ -408,6 +503,86 @@ version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"itoa",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-body"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"http",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-body-util"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"http",
|
||||
"http-body",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "httparse"
|
||||
version = "1.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
|
||||
|
||||
[[package]]
|
||||
name = "httpdate"
|
||||
version = "1.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "1.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11"
|
||||
dependencies = [
|
||||
"atomic-waker",
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"http",
|
||||
"http-body",
|
||||
"httparse",
|
||||
"httpdate",
|
||||
"itoa",
|
||||
"pin-project-lite",
|
||||
"pin-utils",
|
||||
"smallvec",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-util"
|
||||
version = "0.1.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"http",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.13.0"
|
||||
@@ -688,6 +863,12 @@ version = "0.2.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b"
|
||||
|
||||
[[package]]
|
||||
name = "pin-utils"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
|
||||
|
||||
[[package]]
|
||||
name = "powerfmt"
|
||||
version = "0.2.0"
|
||||
@@ -779,6 +960,7 @@ dependencies = [
|
||||
"claims",
|
||||
"clap",
|
||||
"common",
|
||||
"futures",
|
||||
"miette",
|
||||
"rustls",
|
||||
"serde",
|
||||
@@ -923,9 +1105,13 @@ name = "server"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"bytes",
|
||||
"claims",
|
||||
"clap",
|
||||
"common",
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
"hyper-util",
|
||||
"miette",
|
||||
"rustls",
|
||||
"thiserror",
|
||||
@@ -961,6 +1147,12 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "slab"
|
||||
version = "0.4.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5"
|
||||
|
||||
[[package]]
|
||||
name = "smallvec"
|
||||
version = "1.15.1"
|
||||
|
||||
@@ -13,7 +13,12 @@ common = { path = "common" }
|
||||
|
||||
aws-lc-rs = "1"
|
||||
base64 = "0.22"
|
||||
bytes = "1.11"
|
||||
clap = { version = "4.5", features = ["derive"] }
|
||||
futures = "0.3"
|
||||
http-body-util = "0.1"
|
||||
hyper = { version = "1.8", features = ["http1"] }
|
||||
hyper-util = { version = "0.1", features = ["tokio"] }
|
||||
miette = { version = "7", features = ["fancy"] }
|
||||
rcgen = "0.14"
|
||||
rustls = { version = "0.23", default-features = false, features = [
|
||||
|
||||
@@ -43,22 +43,23 @@ cargo build --release
|
||||
Terminal 1 - Start server:
|
||||
|
||||
```bash
|
||||
./target/release/server --mode x25519 --listen 127.0.0.1:4433
|
||||
./target/release/server --mode x25519 --proto raw --listen 127.0.0.1:4433
|
||||
```
|
||||
|
||||
Terminal 2 - Run benchmark:
|
||||
|
||||
```bash
|
||||
./target/release/runner --server 127.0.0.1:4433 --mode x25519 --iters 100 --warmup 10
|
||||
./target/release/runner --server 127.0.0.1:4433 --proto raw --mode x25519 --iters 100 --warmup 10
|
||||
```
|
||||
|
||||
### Run Matrix Benchmarks
|
||||
|
||||
Create a config file (`matrix.toml`):
|
||||
Create a config file (`benchmarks.toml`):
|
||||
|
||||
```toml
|
||||
[[benchmarks]]
|
||||
server = "127.0.0.1:4433"
|
||||
proto = "raw"
|
||||
mode = "x25519"
|
||||
payload = 1024
|
||||
iters = 100
|
||||
@@ -67,6 +68,7 @@ concurrency = 1
|
||||
|
||||
[[benchmarks]]
|
||||
server = "127.0.0.1:4433"
|
||||
proto = "http1"
|
||||
mode = "x25519mlkem768"
|
||||
payload = 1024
|
||||
iters = 100
|
||||
|
||||
@@ -5,6 +5,7 @@ authors.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
clap.workspace = true
|
||||
miette.workspace = true
|
||||
rcgen.workspace = true
|
||||
rustls.workspace = true
|
||||
|
||||
@@ -1,8 +1,3 @@
|
||||
//! Self-signed certificate generation for local testing.
|
||||
//!
|
||||
//! Generates a CA certificate and server certificate for TLS benchmarking.
|
||||
//! These certificates are NOT suitable for production use.
|
||||
|
||||
use rcgen::{BasicConstraints, CertificateParams, DnType, IsCa, Issuer, KeyPair, SanType};
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use miette::Diagnostic;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Result type using the common's custom error type.
|
||||
/// Result type using the `common`'s custom error type.
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Debug, Error, Diagnostic)]
|
||||
|
||||
@@ -1,38 +1,60 @@
|
||||
//! Common types and utilities for the TLS benchmark harness
|
||||
|
||||
pub mod cert;
|
||||
pub mod error;
|
||||
pub mod prelude;
|
||||
pub mod protocol;
|
||||
|
||||
use clap::ValueEnum;
|
||||
pub use error::Error;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
use strum::{Display, EnumString};
|
||||
use strum::Display;
|
||||
|
||||
/// TLS key exchange mode
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, EnumString, Display)]
|
||||
/// TLS 1.3 key exchange mode used for benchmark runs
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Display, ValueEnum)]
|
||||
#[strum(serialize_all = "lowercase")]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum KeyExchangeMode {
|
||||
/// Classical X25519 ECDH.
|
||||
X25519,
|
||||
#[value(name = "x25519mlkem768")]
|
||||
/// Hybrid post-quantum: X25519 + ML-KEM-768.
|
||||
X25519Mlkem768,
|
||||
}
|
||||
|
||||
/// Application protocol carried over TLS in benchmark runs.
|
||||
///
|
||||
/// `Raw` is a minimal custom framing protocol (8-byte LE length request, then N payload bytes)
|
||||
/// used for low-overhead microbenchmarks.
|
||||
///
|
||||
/// `Http1` is an HTTP/1.1 request/response mode (`GET /bytes/{n}`) used for realism-oriented
|
||||
/// comparisons where HTTP parsing and headers are part of measured overhead.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Display, ValueEnum)]
|
||||
#[strum(serialize_all = "lowercase")]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ProtocolMode {
|
||||
/// Minimal custom framing protocol for primary microbenchmarks.
|
||||
Raw,
|
||||
/// HTTP/1.1 mode for realism-oriented comparisons.
|
||||
Http1,
|
||||
}
|
||||
|
||||
/// A single benchmark measurement record, output as NDJSON
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BenchRecord {
|
||||
/// Iteration number (0-indexed, excludes warmup)
|
||||
pub iteration: u64,
|
||||
/// Protocol carrier mode
|
||||
pub proto: ProtocolMode,
|
||||
/// Key exchange mode used
|
||||
pub mode: KeyExchangeMode,
|
||||
/// Payload size in bytes
|
||||
pub payload_bytes: u64,
|
||||
/// TCP connection latency in nanoseconds
|
||||
pub tcp_ns: u128,
|
||||
/// Handshake latency in nanoseconds
|
||||
pub handshake_ns: u64,
|
||||
pub handshake_ns: u128,
|
||||
/// Time-to-last-byte in nanoseconds (from connection start)
|
||||
pub ttlb_ns: u64,
|
||||
pub ttlb_ns: u128,
|
||||
}
|
||||
|
||||
impl BenchRecord {
|
||||
@@ -56,23 +78,24 @@ impl fmt::Display for BenchRecord {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use claims::{assert_err, assert_ok};
|
||||
use serde_json::Value;
|
||||
|
||||
use super::*;
|
||||
use std::str::FromStr;
|
||||
|
||||
#[test]
|
||||
fn bench_record_serializes_to_ndjson() {
|
||||
let record = BenchRecord {
|
||||
iteration: 0,
|
||||
proto: ProtocolMode::Raw,
|
||||
mode: KeyExchangeMode::X25519,
|
||||
payload_bytes: 1024,
|
||||
tcp_ns: 500_000,
|
||||
handshake_ns: 1_000_000,
|
||||
ttlb_ns: 2_000_000,
|
||||
};
|
||||
let json = assert_ok!(record.to_ndjson());
|
||||
assert!(json.contains(r#""iteration":0"#));
|
||||
assert!(json.contains(r#""proto":"raw""#));
|
||||
assert!(json.contains(r#""mode":"x25519""#));
|
||||
assert!(json.contains(r#""payload_bytes":1024"#));
|
||||
}
|
||||
@@ -81,8 +104,10 @@ mod tests {
|
||||
fn bench_record_roundtrip() {
|
||||
let original = BenchRecord {
|
||||
iteration: 42,
|
||||
proto: ProtocolMode::Http1,
|
||||
mode: KeyExchangeMode::X25519Mlkem768,
|
||||
payload_bytes: 4096,
|
||||
tcp_ns: 1_000_000,
|
||||
handshake_ns: 5_000_000,
|
||||
ttlb_ns: 10_000_000,
|
||||
};
|
||||
@@ -90,26 +115,18 @@ mod tests {
|
||||
let deserialized = assert_ok!(serde_json::from_str::<BenchRecord>(&json));
|
||||
|
||||
assert_eq!(original.iteration, deserialized.iteration);
|
||||
assert_eq!(original.proto, deserialized.proto);
|
||||
assert_eq!(original.mode, deserialized.mode);
|
||||
assert_eq!(original.payload_bytes, deserialized.payload_bytes);
|
||||
assert_eq!(original.handshake_ns, deserialized.handshake_ns);
|
||||
assert_eq!(original.ttlb_ns, deserialized.ttlb_ns);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn key_exchange_mode_from_str() {
|
||||
let mode = assert_ok!(KeyExchangeMode::from_str("x25519"));
|
||||
assert_eq!(mode, KeyExchangeMode::X25519);
|
||||
|
||||
let mode = assert_ok!(KeyExchangeMode::from_str("x25519mlkem768"));
|
||||
assert_eq!(mode, KeyExchangeMode::X25519Mlkem768);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn key_exchange_mode_parse_error() {
|
||||
assert_err!(KeyExchangeMode::from_str("invalid"));
|
||||
assert_err!(KeyExchangeMode::from_str("x25519invalid"));
|
||||
assert_err!(KeyExchangeMode::from_str(""));
|
||||
assert_err!(KeyExchangeMode::from_str("invalid", true));
|
||||
assert_err!(KeyExchangeMode::from_str("x25519invalid", true));
|
||||
assert_err!(KeyExchangeMode::from_str("", true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -132,4 +149,38 @@ mod tests {
|
||||
));
|
||||
assert_eq!(mode_mlkem_lower, KeyExchangeMode::X25519Mlkem768);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn key_protocol_mod_from_str() {
|
||||
let proto = assert_ok!(ProtocolMode::from_str("raw", true));
|
||||
assert_eq!(proto, ProtocolMode::Raw);
|
||||
|
||||
let proto = assert_ok!(ProtocolMode::from_str("http1", true));
|
||||
assert_eq!(proto, ProtocolMode::Http1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn key_protocol_mode_parse_error() {
|
||||
assert_err!(ProtocolMode::from_str("invalid", true));
|
||||
assert_err!(ProtocolMode::from_str("", true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn key_exchange_mode_from_str() {
|
||||
let mode = assert_ok!(KeyExchangeMode::from_str("x25519", true));
|
||||
assert_eq!(mode, KeyExchangeMode::X25519);
|
||||
|
||||
let mode = assert_ok!(KeyExchangeMode::from_str("x25519mlkem768", true));
|
||||
assert_eq!(mode, KeyExchangeMode::X25519Mlkem768);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn key_protocol_mode_serde() {
|
||||
let json = r#"{"proto":"http1"}"#;
|
||||
let value = assert_ok!(serde_json::from_str::<Value>(json));
|
||||
let proto = assert_ok!(serde_json::from_value::<ProtocolMode>(
|
||||
value["proto"].clone()
|
||||
));
|
||||
assert_eq!(proto, ProtocolMode::Http1);
|
||||
}
|
||||
}
|
||||
|
||||
7
common/src/prelude.rs
Normal file
7
common/src/prelude.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
pub use crate::{
|
||||
BenchRecord, KeyExchangeMode, ProtocolMode,
|
||||
protocol::{
|
||||
MAX_PAYLOAD_SIZE, generate_payload, read_payload, read_request, write_payload,
|
||||
write_request,
|
||||
},
|
||||
};
|
||||
@@ -1,11 +1,3 @@
|
||||
//! Benchmark protocol implementation.
|
||||
//!
|
||||
//! Protocol specification:
|
||||
//! 1. Client sends 8-byte little-endian u64: requested payload size N
|
||||
//! 2. Server responds with exactly N bytes (deterministic pattern)
|
||||
//!
|
||||
//! The deterministic pattern is a repeating sequence of bytes 0x00..0xFF.
|
||||
|
||||
// Casts are intentional: MAX_PAYLOAD_SIZE (16 MiB) fits in usize on 64-bit,
|
||||
// and byte patterns are explicitly masked to 0xFF before casting.
|
||||
#![allow(clippy::cast_possible_truncation)]
|
||||
@@ -50,6 +42,7 @@ pub async fn write_request<W: AsyncWriteExt + Unpin>(writer: &mut W, size: u64)
|
||||
/// Generate deterministic payload of the given size.
|
||||
///
|
||||
/// The pattern is a repeating sequence: 0x00, 0x01, ..., 0xFF, 0x00, ...
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn generate_payload(size: u64) -> Vec<u8> {
|
||||
(0..size).map(|i| (i & 0xFF) as u8).collect()
|
||||
|
||||
5
justfile
5
justfile
@@ -12,6 +12,7 @@ alias c := check
|
||||
alias d := docs
|
||||
alias f := fmt
|
||||
alias t := test
|
||||
alias bench := benchmark
|
||||
|
||||
# Run all checks (fmt, clippy, docs, test)
|
||||
[group("dev")]
|
||||
@@ -60,8 +61,8 @@ setup:
|
||||
|
||||
# Run server (default: x25519 on localhost:4433)
|
||||
[group("run")]
|
||||
server mode="x25519" listen="127.0.0.1:4433":
|
||||
cargo run --release --bin server -- --mode {{mode}} --listen {{listen}}
|
||||
server mode="x25519" proto="raw" listen="127.0.0.1:4433":
|
||||
cargo run --release --bin server -- --mode {{mode}} --proto {{proto}} --listen {{listen}}
|
||||
|
||||
# Run benchmark runner
|
||||
[group("run")]
|
||||
|
||||
@@ -7,6 +7,7 @@ edition.workspace = true
|
||||
[dependencies]
|
||||
clap.workspace = true
|
||||
common.workspace = true
|
||||
futures.workspace = true
|
||||
miette.workspace = true
|
||||
rustls.workspace = true
|
||||
serde.workspace = true
|
||||
|
||||
@@ -1,40 +1,44 @@
|
||||
use clap::Parser;
|
||||
use common::KeyExchangeMode;
|
||||
use common::prelude::*;
|
||||
use std::{net::SocketAddr, path::PathBuf};
|
||||
|
||||
/// TLS benchmark runner.
|
||||
#[derive(Debug, Parser)]
|
||||
#[command(name = "runner", version, about)]
|
||||
pub struct Args {
|
||||
/// Key exchange mode.
|
||||
/// Protocol carrier mode
|
||||
#[arg(long, default_value = "raw")]
|
||||
pub proto: ProtocolMode,
|
||||
|
||||
/// Key exchange mode
|
||||
#[arg(long, default_value = "x25519")]
|
||||
pub mode: KeyExchangeMode,
|
||||
|
||||
/// Server address to connect to.
|
||||
/// Server address to connect to
|
||||
#[arg(long, required_unless_present = "config")]
|
||||
pub server: Option<SocketAddr>,
|
||||
|
||||
/// Payload size in bytes to request from server.
|
||||
/// Payload size in bytes to request from server
|
||||
#[arg(long, default_value = "1024")]
|
||||
pub payload_bytes: u32,
|
||||
|
||||
/// Number of benchmark iterations (excluding warmup).
|
||||
/// Number of benchmark iterations (excluding warmup)
|
||||
#[arg(long, default_value = "100")]
|
||||
pub iters: u32,
|
||||
|
||||
/// Number of warmup iterations (not recorded).
|
||||
/// Number of warmup iterations (not recorded)
|
||||
#[arg(long, default_value = "10")]
|
||||
pub warmup: u32,
|
||||
|
||||
/// Number of concurrent connections.
|
||||
/// Number of concurrent connections
|
||||
#[arg(long, default_value = "1")]
|
||||
pub concurrency: u32,
|
||||
|
||||
/// Output file for NDJSON records (stdout if not specified).
|
||||
/// Output file for NDJSON records (stdout if not specified)
|
||||
#[arg(long)]
|
||||
pub out: Option<PathBuf>,
|
||||
|
||||
/// Config file for matrix benchmarks (TOML).
|
||||
#[arg(long)]
|
||||
/// Config file for matrix benchmarks (TOML)
|
||||
#[arg(long, short)]
|
||||
pub config: Option<PathBuf>,
|
||||
}
|
||||
|
||||
364
runner/src/bench.rs
Normal file
364
runner/src/bench.rs
Normal file
@@ -0,0 +1,364 @@
|
||||
use crate::config::BenchmarkConfig;
|
||||
use common::prelude::*;
|
||||
use futures::{StreamExt, stream::FuturesUnordered};
|
||||
use miette::{Context, IntoDiagnostic};
|
||||
use rustls::pki_types::ServerName;
|
||||
use std::{
|
||||
io::{Write, stdout},
|
||||
net::SocketAddr,
|
||||
time::Instant,
|
||||
};
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
|
||||
net::TcpStream,
|
||||
};
|
||||
use tokio_rustls::TlsConnector;
|
||||
use tracing::info;
|
||||
|
||||
/// Result of a single benchmark iteration.
|
||||
struct IterationResult {
|
||||
tcp: u128,
|
||||
handshake: u128,
|
||||
ttlb: u128,
|
||||
}
|
||||
|
||||
pub async fn run_benchmark(
|
||||
config: &BenchmarkConfig,
|
||||
tls_connector: &TlsConnector,
|
||||
server_name: &ServerName<'static>,
|
||||
) -> miette::Result<()> {
|
||||
let server = config.server;
|
||||
|
||||
info!(
|
||||
warmup = config.warmup,
|
||||
iters = config.iters,
|
||||
concurrency = config.concurrency,
|
||||
"running benchmark iterations"
|
||||
);
|
||||
|
||||
for _ in 0..config.warmup {
|
||||
run_iteration(
|
||||
server,
|
||||
config.proto,
|
||||
config.payload,
|
||||
tls_connector,
|
||||
server_name,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
info!("warmup complete");
|
||||
|
||||
let mut output = stdout();
|
||||
run_and_write(config, tls_connector, server_name, &mut output).await?;
|
||||
output
|
||||
.flush()
|
||||
.into_diagnostic()
|
||||
.context("failed to flush output")?;
|
||||
|
||||
info!("benchmark complete");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_and_write<W: Write + Send>(
|
||||
config: &BenchmarkConfig,
|
||||
tls_connector: &TlsConnector,
|
||||
server_name: &ServerName<'static>,
|
||||
output: &mut W,
|
||||
) -> miette::Result<()> {
|
||||
let mut in_flight = FuturesUnordered::new();
|
||||
let mut issued = 0;
|
||||
|
||||
loop {
|
||||
while issued < config.iters && in_flight.len() < config.concurrency as usize {
|
||||
in_flight.push(run_single_iteration(
|
||||
issued,
|
||||
config.payload,
|
||||
config.proto,
|
||||
config.mode,
|
||||
config.server,
|
||||
tls_connector.clone(),
|
||||
server_name.clone(),
|
||||
));
|
||||
issued += 1;
|
||||
}
|
||||
|
||||
match in_flight.next().await {
|
||||
Some(record) => writeln!(output, "{}", record?)
|
||||
.into_diagnostic()
|
||||
.context("failed to write record")?,
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_single_iteration(
|
||||
i: u32,
|
||||
payload_bytes: u32,
|
||||
proto: ProtocolMode,
|
||||
mode: KeyExchangeMode,
|
||||
server: SocketAddr,
|
||||
tls_connector: TlsConnector,
|
||||
server_name: ServerName<'static>,
|
||||
) -> miette::Result<BenchRecord> {
|
||||
let result = run_iteration(server, proto, payload_bytes, &tls_connector, &server_name).await?;
|
||||
|
||||
Ok(BenchRecord {
|
||||
iteration: u64::from(i),
|
||||
proto,
|
||||
mode,
|
||||
payload_bytes: u64::from(payload_bytes),
|
||||
tcp_ns: result.tcp,
|
||||
handshake_ns: result.handshake,
|
||||
ttlb_ns: result.ttlb,
|
||||
})
|
||||
}
|
||||
|
||||
/// Run a single benchmark iteration over TLS.
|
||||
async fn run_iteration(
|
||||
server: SocketAddr,
|
||||
proto: ProtocolMode,
|
||||
payload_bytes: u32,
|
||||
tls_connector: &TlsConnector,
|
||||
server_name: &ServerName<'static>,
|
||||
) -> miette::Result<IterationResult> {
|
||||
let tcp_start = Instant::now();
|
||||
|
||||
let stream = TcpStream::connect(server)
|
||||
.await
|
||||
.into_diagnostic()
|
||||
.context("TCP connection failed")?;
|
||||
|
||||
let tcp_ns = tcp_start.elapsed().as_nanos();
|
||||
|
||||
let hs_start = Instant::now();
|
||||
let mut tls_stream = tls_connector
|
||||
.connect(server_name.clone(), stream)
|
||||
.await
|
||||
.into_diagnostic()
|
||||
.context("TLS handshake failed")?;
|
||||
|
||||
let handshake_ns = hs_start.elapsed().as_nanos();
|
||||
|
||||
let ttlb_start = Instant::now();
|
||||
|
||||
let ttlb_ns = tcp_ns + handshake_ns + ttlb_start.elapsed().as_nanos();
|
||||
run_exchange(&mut tls_stream, proto, payload_bytes).await?;
|
||||
|
||||
Ok(IterationResult {
|
||||
tcp: tcp_ns,
|
||||
handshake: handshake_ns,
|
||||
ttlb: ttlb_ns,
|
||||
})
|
||||
}
|
||||
|
||||
async fn run_exchange<S>(
|
||||
tls_stream: &mut S,
|
||||
proto: ProtocolMode,
|
||||
payload_bytes: u32,
|
||||
) -> miette::Result<()>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
match proto {
|
||||
ProtocolMode::Raw => run_raw_exchange(tls_stream, payload_bytes).await,
|
||||
ProtocolMode::Http1 => run_http1_exchange(tls_stream, payload_bytes).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_raw_exchange<S>(tls_stream: &mut S, payload_bytes: u32) -> miette::Result<()>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
write_request(tls_stream, u64::from(payload_bytes))
|
||||
.await
|
||||
.into_diagnostic()
|
||||
.context("write request failed")?;
|
||||
|
||||
read_payload(tls_stream, u64::from(payload_bytes))
|
||||
.await
|
||||
.into_diagnostic()
|
||||
.context("read payload failed")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_http1_exchange<S>(tls_stream: &mut S, payload_bytes: u32) -> miette::Result<()>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let request = build_http1_request(payload_bytes);
|
||||
|
||||
tls_stream
|
||||
.write_all(&request)
|
||||
.await
|
||||
.into_diagnostic()
|
||||
.context("write http1 request failed")?;
|
||||
|
||||
tls_stream
|
||||
.flush()
|
||||
.await
|
||||
.into_diagnostic()
|
||||
.context("flush http1 request failed")?;
|
||||
|
||||
let mut response_buf = Vec::with_capacity(1024);
|
||||
let mut chunk = [0; 1024];
|
||||
|
||||
let (content_length, body_start) = loop {
|
||||
let n = tls_stream
|
||||
.read(&mut chunk)
|
||||
.await
|
||||
.into_diagnostic()
|
||||
.context("read http1 response failed")?;
|
||||
|
||||
if n == 0 {
|
||||
return Err(common::Error::protocol("unexpected EOF before http1 headers").into());
|
||||
}
|
||||
|
||||
response_buf.extend_from_slice(&chunk[..n]);
|
||||
|
||||
if let Some(pos) = find_headers_end(&response_buf) {
|
||||
let headers = str::from_utf8(&response_buf[..pos])
|
||||
.into_diagnostic()
|
||||
.context("http1 headers are not valid UTF-8")?;
|
||||
let content_length = parse_content_length(headers)?;
|
||||
break (content_length, pos + 4);
|
||||
}
|
||||
};
|
||||
|
||||
let body_already_read = response_buf.len() - body_start;
|
||||
if body_already_read > content_length {
|
||||
return Err(common::Error::protocol("http1 body exceeded content-lenght").into());
|
||||
}
|
||||
|
||||
let mut remaining = content_length - body_already_read;
|
||||
let mut body_buf = vec![0; 64 * 1024];
|
||||
|
||||
while remaining > 0 {
|
||||
let to_read = remaining.min(body_buf.len());
|
||||
tls_stream
|
||||
.read_exact(&mut body_buf[..to_read])
|
||||
.await
|
||||
.into_diagnostic()
|
||||
.context("read http1 body failed")?;
|
||||
remaining -= to_read;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_http1_request(payload_bytes: u32) -> Vec<u8> {
|
||||
format!("GET /bytes/{payload_bytes} HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
|
||||
.into_bytes()
|
||||
}
|
||||
|
||||
fn parse_content_length(headers: &str) -> miette::Result<usize> {
|
||||
let mut lines = headers.lines();
|
||||
|
||||
let status_line = lines
|
||||
.next()
|
||||
.ok_or_else(|| common::Error::protocol("missing http1 status line"))?;
|
||||
|
||||
let mut parts = status_line.split_whitespace();
|
||||
let version = parts
|
||||
.next()
|
||||
.ok_or_else(|| common::Error::protocol("missing http1 version"))?;
|
||||
let status = parts
|
||||
.next()
|
||||
.ok_or_else(|| common::Error::protocol("missing http1 status"))?;
|
||||
|
||||
if version != "HTTP/1.1" {
|
||||
return Err(common::Error::protocol(format!("unsupported http version: {version}")).into());
|
||||
}
|
||||
if status != "200" {
|
||||
return Err(common::Error::protocol(format!("unsupported http status: {status}")).into());
|
||||
}
|
||||
|
||||
for line in lines {
|
||||
if let Some((name, value)) = line.split_once(':')
|
||||
&& name.trim().eq_ignore_ascii_case("content-length")
|
||||
{
|
||||
return value
|
||||
.trim()
|
||||
.parse::<usize>()
|
||||
.into_diagnostic()
|
||||
.context("invalid content-length header");
|
||||
}
|
||||
}
|
||||
Err(common::Error::protocol("missing content-length header").into())
|
||||
}
|
||||
|
||||
fn find_headers_end(buf: &[u8]) -> Option<usize> {
|
||||
buf.windows(4).position(|window| window == b"\r\n\r\n")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use claims::{assert_err, assert_none, assert_ok, assert_some};
|
||||
|
||||
#[test]
|
||||
fn build_http1_request_formats_get_requests() {
|
||||
let request = build_http1_request(16);
|
||||
let request_string = String::from_utf8(request).expect("valid string");
|
||||
assert_eq!(
|
||||
request_string,
|
||||
"GET /bytes/16 HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_content_length_accepts_200() {
|
||||
let headers = "HTTP/1.1 200 OK\r\nContent-Length: 16\r\nConnection: close\r\n";
|
||||
let len = assert_ok!(parse_content_length(headers));
|
||||
assert_eq!(len, 16);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_content_length_rejects_missing_header() {
|
||||
let headers = "HTTP/1.1 200 OK\r\nConnection: close\r\n";
|
||||
assert_err!(parse_content_length(headers));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_content_length_accepts_mixed_case_header_name() {
|
||||
let headers = "HTTP/1.1 200 OK\r\nContent-Length: 8\r\nConnection: close\r\n";
|
||||
let len = assert_ok!(parse_content_length(headers));
|
||||
assert_eq!(len, 8);
|
||||
|
||||
let headers = "HTTP/1.1 200 OK\r\ncontent-length: 9\r\nConnection: close\r\n";
|
||||
let len = assert_ok!(parse_content_length(headers));
|
||||
assert_eq!(len, 9);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_content_length_rejects_non_200_status() {
|
||||
let headers = "HTTP/1.1 404 Not Found\r\nContent-Length: 3\r\nConnection: close\r\n";
|
||||
assert_err!(parse_content_length(headers));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_content_length_rejects_unsupported_http_version() {
|
||||
let headers = "HTTP/1.0 200 OK\r\nContent-Length: 3\r\nConnection: close\r\n";
|
||||
assert_err!(parse_content_length(headers));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_content_length_rejects_invalid_numeric_value() {
|
||||
let headers = "HTTP/1.1 200 OK\r\nContent-Length: nope\r\nConnection: close\r\n";
|
||||
assert_err!(parse_content_length(headers));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_headers_end_returns_none_when_separator_missing() {
|
||||
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 4\r\n";
|
||||
assert_none!(find_headers_end(response));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_headers_end_returns_separator_offset() {
|
||||
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\nbody";
|
||||
let pos = assert_some!(find_headers_end(response));
|
||||
assert_eq!(pos, 34);
|
||||
}
|
||||
}
|
||||
@@ -5,14 +5,15 @@ use crate::{
|
||||
config::utils::validate_config,
|
||||
error::{self, ConfigError},
|
||||
};
|
||||
use common::{self, KeyExchangeMode};
|
||||
use common::prelude::*;
|
||||
use miette::{NamedSource, SourceSpan};
|
||||
use serde::Deserialize;
|
||||
use std::{fs::read_to_string, net::SocketAddr, path::Path};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct BenchmarkConfig {
|
||||
pub mode: String,
|
||||
pub proto: ProtocolMode,
|
||||
pub mode: KeyExchangeMode,
|
||||
pub payload: u32,
|
||||
pub iters: u32,
|
||||
pub warmup: u32,
|
||||
@@ -29,68 +30,67 @@ pub struct Config {
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the file cannot be read or parsed.
|
||||
pub fn load_from_file(path: &Path) -> error::Result<Config> {
|
||||
let content = read_to_string(path).map_err(|source| ConfigError::ReadError {
|
||||
source,
|
||||
path: path.to_owned(),
|
||||
})?;
|
||||
impl TryFrom<&Path> for Config {
|
||||
type Error = error::Error;
|
||||
|
||||
let src = NamedSource::new(path.display().to_string(), content.clone());
|
||||
|
||||
let config = toml::from_str::<Config>(&content).map_err(|source| {
|
||||
let span = source
|
||||
.span()
|
||||
.map(|s| SourceSpan::new(s.start.into(), s.end - s.start));
|
||||
|
||||
ConfigError::TomlParseError {
|
||||
src: src.clone(),
|
||||
span,
|
||||
fn try_from(path: &Path) -> Result<Self, Self::Error> {
|
||||
let content = read_to_string(path).map_err(|source| ConfigError::ReadError {
|
||||
source,
|
||||
}
|
||||
})?;
|
||||
path: path.to_owned(),
|
||||
})?;
|
||||
|
||||
validate_config(&config, &content, path)?;
|
||||
let src = NamedSource::new(path.display().to_string(), content.clone());
|
||||
|
||||
Ok(config)
|
||||
let config = toml::from_str::<Self>(&content).map_err(|source| {
|
||||
let span = source
|
||||
.span()
|
||||
.map(|s| SourceSpan::new(s.start.into(), s.end - s.start));
|
||||
|
||||
ConfigError::TomlParseError {
|
||||
src: src.clone(),
|
||||
span,
|
||||
source,
|
||||
}
|
||||
})?;
|
||||
|
||||
validate_config(&config, &content, path)?;
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
|
||||
/// Create benchmark configuration from CLI arguments.
|
||||
///
|
||||
/// # Errors
|
||||
/// Never returns an error, but returns Result for consistency.
|
||||
pub fn load_from_cli(args: &Args) -> error::Result<Config> {
|
||||
Ok(Config {
|
||||
benchmarks: vec![BenchmarkConfig {
|
||||
mode: args.mode.to_string(),
|
||||
payload: args.payload_bytes,
|
||||
iters: args.iters,
|
||||
warmup: args.warmup,
|
||||
concurrency: args.concurrency,
|
||||
server: args
|
||||
.server
|
||||
.ok_or_else(|| common::Error::config("--server ir required"))?,
|
||||
}],
|
||||
})
|
||||
}
|
||||
/// Returns an error if `--server` was not provided.
|
||||
impl TryFrom<Args> for Config {
|
||||
type Error = error::Error;
|
||||
|
||||
impl Config {
|
||||
/// Get the key exchange mode from the first benchmark configuration.
|
||||
#[must_use]
|
||||
pub fn server_mode(&self) -> KeyExchangeMode {
|
||||
self.benchmarks
|
||||
.first()
|
||||
.and_then(|b| b.mode.parse().ok())
|
||||
.unwrap_or(KeyExchangeMode::X25519)
|
||||
fn try_from(args: Args) -> Result<Self, Self::Error> {
|
||||
Ok(Self {
|
||||
benchmarks: vec![BenchmarkConfig {
|
||||
proto: args.proto,
|
||||
mode: args.mode,
|
||||
payload: args.payload_bytes,
|
||||
iters: args.iters,
|
||||
warmup: args.warmup,
|
||||
concurrency: args.concurrency,
|
||||
server: args
|
||||
.server
|
||||
.ok_or_else(|| common::Error::config("--server is required"))?,
|
||||
}],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use claims::assert_ok;
|
||||
use claims::{assert_err, assert_ok};
|
||||
|
||||
const VALID_CONFIG: &str = r#"
|
||||
[[benchmarks]]
|
||||
proto = "raw"
|
||||
mode = "x25519"
|
||||
payload = 1024
|
||||
iters = 100
|
||||
@@ -99,6 +99,7 @@ concurrency = 1
|
||||
server = "127.0.0.1:4433"
|
||||
|
||||
[[benchmarks]]
|
||||
proto = "http1"
|
||||
mode = "x25519mlkem768"
|
||||
payload = 4096
|
||||
iters = 50
|
||||
@@ -115,6 +116,7 @@ server = "127.0.0.1:4433"
|
||||
fn valid_single_benchmark() {
|
||||
let toml = r#"
|
||||
[[benchmarks]]
|
||||
proto = "raw"
|
||||
mode = "x25519"
|
||||
payload = 1024
|
||||
iters = 100
|
||||
@@ -124,7 +126,7 @@ server = "127.0.0.1:4433"
|
||||
"#;
|
||||
let config = get_config_from_str(toml);
|
||||
assert_eq!(config.benchmarks.len(), 1);
|
||||
assert_eq!(config.benchmarks[0].mode, "x25519");
|
||||
assert_eq!(config.benchmarks[0].mode, KeyExchangeMode::X25519);
|
||||
assert_eq!(config.benchmarks[0].payload, 1024);
|
||||
}
|
||||
|
||||
@@ -132,14 +134,35 @@ server = "127.0.0.1:4433"
|
||||
fn valid_multiple_benchmarks() {
|
||||
let config = get_config_from_str(VALID_CONFIG);
|
||||
assert_eq!(config.benchmarks.len(), 2);
|
||||
assert_eq!(config.benchmarks[0].mode, "x25519");
|
||||
assert_eq!(config.benchmarks[1].mode, "x25519mlkem768");
|
||||
let bench_0 = config.benchmarks[0].clone();
|
||||
let bench_1 = config.benchmarks[1].clone();
|
||||
|
||||
assert_eq!(bench_0.mode, KeyExchangeMode::X25519);
|
||||
assert_eq!(bench_0.proto, ProtocolMode::Raw);
|
||||
assert_eq!(bench_1.mode, KeyExchangeMode::X25519Mlkem768);
|
||||
assert_eq!(bench_1.proto, ProtocolMode::Http1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_proto() {
|
||||
let toml = r#"
|
||||
[[benchmarks]]
|
||||
proto = "invalid_proto"
|
||||
mode = "x25519"
|
||||
payload = 1024
|
||||
iters = 100
|
||||
warmup = 10
|
||||
concurrency = 1
|
||||
server = "127.0.0.1:4433"
|
||||
"#;
|
||||
assert_err!(toml::from_str::<Config>(toml));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_mode() {
|
||||
let toml = r#"
|
||||
[[benchmarks]]
|
||||
proto = "raw"
|
||||
mode = "invalid_mode"
|
||||
payload = 1024
|
||||
iters = 100
|
||||
@@ -147,14 +170,14 @@ warmup = 10
|
||||
concurrency = 1
|
||||
server = "127.0.0.1:4433"
|
||||
"#;
|
||||
let config = get_config_from_str(toml);
|
||||
assert!(config.server_mode() == KeyExchangeMode::X25519); // fallback
|
||||
assert_err!(toml::from_str::<Config>(toml));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn payload_zero_validation() {
|
||||
let toml = r#"
|
||||
[[benchmarks]]
|
||||
proto = "raw"
|
||||
mode = "x25519"
|
||||
payload = 0
|
||||
iters = 100
|
||||
@@ -163,14 +186,14 @@ concurrency = 1
|
||||
server = "127.0.0.1:4433"
|
||||
"#;
|
||||
let config = get_config_from_str(toml);
|
||||
let result = validate_config(&config, toml, std::path::Path::new("test.toml"));
|
||||
assert!(result.is_err());
|
||||
assert_err!(validate_config(&config, toml, Path::new("test.toml")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn iters_zero_validation() {
|
||||
let toml = r#"
|
||||
[[benchmarks]]
|
||||
proto = "raw"
|
||||
mode = "x25519"
|
||||
payload = 1024
|
||||
iters = 0
|
||||
@@ -179,14 +202,14 @@ concurrency = 1
|
||||
server = "127.0.0.1:4433"
|
||||
"#;
|
||||
let config = get_config_from_str(toml);
|
||||
let result = validate_config(&config, toml, std::path::Path::new("test.toml"));
|
||||
assert!(result.is_err());
|
||||
assert_err!(validate_config(&config, toml, Path::new("test.toml")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn concurrency_zero_validation() {
|
||||
let toml = r#"
|
||||
[[benchmarks]]
|
||||
proto = "raw"
|
||||
mode = "x25519"
|
||||
payload = 1024
|
||||
iters = 100
|
||||
@@ -195,8 +218,7 @@ concurrency = 0
|
||||
server = "127.0.0.1:4433"
|
||||
"#;
|
||||
let config = get_config_from_str(toml);
|
||||
let result = validate_config(&config, toml, std::path::Path::new("test.toml"));
|
||||
assert!(result.is_err());
|
||||
assert_err!(validate_config(&config, toml, Path::new("test.toml")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -205,34 +227,4 @@ server = "127.0.0.1:4433"
|
||||
let config = get_config_from_str(toml);
|
||||
assert!(config.benchmarks.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn server_mode_fallback() {
|
||||
let toml = r#"
|
||||
[[benchmarks]]
|
||||
mode = "x25519"
|
||||
payload = 1024
|
||||
iters = 100
|
||||
warmup = 10
|
||||
concurrency = 1
|
||||
server = "127.0.0.1:4433"
|
||||
"#;
|
||||
let config = get_config_from_str(toml);
|
||||
assert_eq!(config.server_mode(), KeyExchangeMode::X25519);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn server_mode_mlkem() {
|
||||
let toml = r#"
|
||||
[[benchmarks]]
|
||||
mode = "x25519mlkem768"
|
||||
payload = 1024
|
||||
iters = 100
|
||||
warmup = 10
|
||||
concurrency = 1
|
||||
server = "127.0.0.1:4433"
|
||||
"#;
|
||||
let config = get_config_from_str(toml);
|
||||
assert_eq!(config.server_mode(), KeyExchangeMode::X25519Mlkem768);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ use crate::{
|
||||
config::{BenchmarkConfig, Config},
|
||||
error::{self, ConfigError},
|
||||
};
|
||||
use common::{self, KeyExchangeMode};
|
||||
use miette::{NamedSource, SourceSpan};
|
||||
use std::path::Path;
|
||||
|
||||
@@ -31,21 +30,6 @@ fn validate_benchmark(
|
||||
) -> error::Result<()> {
|
||||
let src = NamedSource::new(path.display().to_string(), content.to_string());
|
||||
|
||||
// Validate mode
|
||||
if benchmark.mode.parse::<KeyExchangeMode>().is_err() {
|
||||
return Err(ConfigError::ValidationError {
|
||||
src,
|
||||
span: find_field_span(content, idx, "mode"),
|
||||
field: "mode".into(),
|
||||
idx,
|
||||
message: format!(
|
||||
"Invalid key exchange mode '{}'. Valid values: 'x25519', 'x25519mlkem768'",
|
||||
benchmark.mode
|
||||
),
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
validate_positive_field(src.clone(), content, idx, "payload", benchmark.payload)?;
|
||||
validate_positive_field(src.clone(), content, idx, "iters", benchmark.iters)?;
|
||||
validate_positive_field(src, content, idx, "concurrency", benchmark.concurrency)?;
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
#![allow(unused_assignments)]
|
||||
#![allow(unused)]
|
||||
//! Error types for the benchmark runner.
|
||||
|
||||
use miette::{Diagnostic, NamedSource, SourceSpan};
|
||||
use std::path::PathBuf;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Result type using the runner's custom error type.
|
||||
/// Result type using the `runner`'s custom error type.
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
/// Errors that can occur during benchmark execution.
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
pub mod args;
|
||||
pub mod config;
|
||||
pub mod error;
|
||||
@@ -6,286 +6,22 @@
|
||||
//!
|
||||
//! Outputs NDJSON records to stdout or a file.
|
||||
|
||||
mod args;
|
||||
mod bench;
|
||||
mod config;
|
||||
mod error;
|
||||
mod tls;
|
||||
|
||||
use crate::{args::Args, bench::run_benchmark, config::Config, tls::build_tls_config};
|
||||
use clap::Parser;
|
||||
use common::{
|
||||
BenchRecord, KeyExchangeMode,
|
||||
protocol::{read_payload, write_request},
|
||||
};
|
||||
use miette::{Context, IntoDiagnostic};
|
||||
use runner::{
|
||||
args::Args,
|
||||
config::{load_from_cli, load_from_file},
|
||||
};
|
||||
use rustls::{
|
||||
ClientConfig, DigitallySignedStruct, SignatureScheme,
|
||||
client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
|
||||
crypto::aws_lc_rs::{
|
||||
self,
|
||||
kx_group::{X25519, X25519MLKEM768},
|
||||
},
|
||||
pki_types::{CertificateDer, ServerName, UnixTime},
|
||||
version::TLS13,
|
||||
};
|
||||
use std::{
|
||||
env,
|
||||
io::{Write, stdout},
|
||||
net::SocketAddr,
|
||||
sync::Arc,
|
||||
time::Instant,
|
||||
};
|
||||
use tokio::{net::TcpStream, sync::Semaphore, task::JoinHandle};
|
||||
use rustls::pki_types::ServerName;
|
||||
use std::{env, sync::Arc};
|
||||
use tokio_rustls::TlsConnector;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Result of a single benchmark iteration.
|
||||
struct IterationResult {
|
||||
handshake_ns: u64,
|
||||
ttlb_ns: u64,
|
||||
}
|
||||
|
||||
/// Certificate verifier that accepts any certificate.
|
||||
/// Used for benchmarking where we don't need to verify the server's identity.
|
||||
#[derive(Debug)]
|
||||
struct NoVerifier;
|
||||
|
||||
impl ServerCertVerifier for NoVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &CertificateDer<'_>,
|
||||
_intermediates: &[CertificateDer<'_>],
|
||||
_server_name: &ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: UnixTime,
|
||||
) -> Result<ServerCertVerified, rustls::Error> {
|
||||
Ok(ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_dss: &DigitallySignedStruct,
|
||||
) -> Result<HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_dss: &DigitallySignedStruct,
|
||||
) -> Result<HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
|
||||
vec![
|
||||
SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
SignatureScheme::ECDSA_NISTP521_SHA512,
|
||||
SignatureScheme::ED25519,
|
||||
SignatureScheme::RSA_PSS_SHA256,
|
||||
SignatureScheme::RSA_PSS_SHA384,
|
||||
SignatureScheme::RSA_PSS_SHA512,
|
||||
SignatureScheme::RSA_PKCS1_SHA256,
|
||||
SignatureScheme::RSA_PKCS1_SHA384,
|
||||
SignatureScheme::RSA_PKCS1_SHA512,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
/// Build TLS client config for the given key exchange mode.
|
||||
fn build_tls_config(mode: KeyExchangeMode) -> miette::Result<Arc<ClientConfig>> {
|
||||
let mut provider = aws_lc_rs::default_provider();
|
||||
provider.kx_groups = match mode {
|
||||
KeyExchangeMode::X25519 => vec![X25519],
|
||||
KeyExchangeMode::X25519Mlkem768 => vec![X25519MLKEM768],
|
||||
};
|
||||
|
||||
let config = ClientConfig::builder_with_provider(Arc::new(provider))
|
||||
.with_protocol_versions(&[&TLS13])
|
||||
.into_diagnostic()
|
||||
.context("failed to set TLS versions")?
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(NoVerifier))
|
||||
.with_no_client_auth();
|
||||
|
||||
Ok(Arc::new(config))
|
||||
}
|
||||
|
||||
/// Run a single benchmark iteration over TLS.
|
||||
#[allow(clippy::cast_possible_truncation)] // nanoseconds won't overflow u64 for reasonable durations
|
||||
async fn run_iteration(
|
||||
server: SocketAddr,
|
||||
payload_bytes: u32,
|
||||
tls_connector: &TlsConnector,
|
||||
server_name: &ServerName<'static>,
|
||||
) -> miette::Result<IterationResult> {
|
||||
let start = Instant::now();
|
||||
|
||||
let stream = TcpStream::connect(server)
|
||||
.await
|
||||
.into_diagnostic()
|
||||
.context("TCP connection failed")?;
|
||||
|
||||
let mut tls_stream = tls_connector
|
||||
.connect(server_name.clone(), stream)
|
||||
.await
|
||||
.into_diagnostic()
|
||||
.context("TLS handshake failed")?;
|
||||
|
||||
let handshake_ns = start.elapsed().as_nanos() as u64;
|
||||
|
||||
write_request(&mut tls_stream, u64::from(payload_bytes))
|
||||
.await
|
||||
.into_diagnostic()
|
||||
.context("write request failed")?;
|
||||
|
||||
read_payload(&mut tls_stream, u64::from(payload_bytes))
|
||||
.await
|
||||
.into_diagnostic()
|
||||
.context("read payload failed")?;
|
||||
|
||||
let ttlb_ns = start.elapsed().as_nanos() as u64;
|
||||
|
||||
Ok(IterationResult {
|
||||
handshake_ns,
|
||||
ttlb_ns,
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(clippy::future_not_send)] // References held across await points
|
||||
async fn run_benchmark(
|
||||
config: &runner::config::BenchmarkConfig,
|
||||
tls_connector: &TlsConnector,
|
||||
server_name: &ServerName<'static>,
|
||||
) -> miette::Result<()> {
|
||||
let server = config.server;
|
||||
|
||||
info!(
|
||||
warmup = config.warmup,
|
||||
iters = config.iters,
|
||||
concurrency = config.concurrency,
|
||||
"running benchmark iterations"
|
||||
);
|
||||
|
||||
for _ in 0..config.warmup {
|
||||
run_iteration(server, config.payload, tls_connector, server_name).await?;
|
||||
}
|
||||
info!("warmup complete");
|
||||
|
||||
let test_conn = tls_connector
|
||||
.connect(
|
||||
server_name.clone(),
|
||||
TcpStream::connect(server)
|
||||
.await
|
||||
.into_diagnostic()
|
||||
.context(format!("failed to connect to server {server}"))?,
|
||||
)
|
||||
.await
|
||||
.into_diagnostic()
|
||||
.context("TLS handshake failed")?;
|
||||
|
||||
let cipher = test_conn.get_ref().1.negotiated_cipher_suite();
|
||||
info!(cipher = ?cipher, "TLS handshake complete");
|
||||
|
||||
#[allow(clippy::cast_possible_truncation)] // concurrency is limited to reasonable values
|
||||
let semaphore = Arc::new(Semaphore::new(config.concurrency as usize));
|
||||
let tasks = spawn_benchmark_tasks(config, &semaphore, tls_connector, server_name);
|
||||
|
||||
// Output to stdout for now
|
||||
{
|
||||
let mut output = stdout();
|
||||
write_results(&mut output, tasks).await?;
|
||||
output
|
||||
.flush()
|
||||
.into_diagnostic()
|
||||
.context("failed to flush output")?;
|
||||
}
|
||||
|
||||
info!("benchmark complete");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
type ReturnHandle = JoinHandle<(IterationResult, Option<BenchRecord>)>;
|
||||
|
||||
fn spawn_benchmark_tasks(
|
||||
config: &runner::config::BenchmarkConfig,
|
||||
semaphore: &Arc<Semaphore>,
|
||||
tls_connector: &TlsConnector,
|
||||
server_name: &ServerName<'static>,
|
||||
) -> Vec<ReturnHandle> {
|
||||
let server = config.server;
|
||||
let payload_bytes = config.payload;
|
||||
let mode = config
|
||||
.mode
|
||||
.parse::<KeyExchangeMode>()
|
||||
.expect("mode should be valid");
|
||||
|
||||
(0..config.iters)
|
||||
.map(|i| {
|
||||
spawn_single_iteration(
|
||||
i,
|
||||
payload_bytes,
|
||||
mode,
|
||||
server,
|
||||
semaphore.clone(),
|
||||
tls_connector.clone(),
|
||||
server_name.clone(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn spawn_single_iteration(
|
||||
i: u32,
|
||||
payload_bytes: u32,
|
||||
mode: KeyExchangeMode,
|
||||
server: SocketAddr,
|
||||
semaphore: Arc<Semaphore>,
|
||||
tls_connector: TlsConnector,
|
||||
server_name: ServerName<'static>,
|
||||
) -> ReturnHandle {
|
||||
tokio::spawn(async move {
|
||||
let _permit = semaphore
|
||||
.acquire()
|
||||
.await
|
||||
.expect("semaphore should not be closed");
|
||||
|
||||
let result = run_iteration(server, payload_bytes, &tls_connector, &server_name)
|
||||
.await
|
||||
.expect("iteration should not fail");
|
||||
|
||||
let record = BenchRecord {
|
||||
iteration: u64::from(i),
|
||||
mode,
|
||||
payload_bytes: u64::from(payload_bytes),
|
||||
handshake_ns: result.handshake_ns,
|
||||
ttlb_ns: result.ttlb_ns,
|
||||
};
|
||||
|
||||
(result, Some(record))
|
||||
})
|
||||
}
|
||||
|
||||
// #[allow(clippy::future_not_send)] // dyn Write is not Send
|
||||
async fn write_results<W: Write + Send>(
|
||||
output: &mut W,
|
||||
tasks: Vec<ReturnHandle>,
|
||||
) -> miette::Result<()> {
|
||||
for task in tasks {
|
||||
let (_result, record) = task.await.expect("task should not panic");
|
||||
if let Some(record) = record {
|
||||
writeln!(output, "{record}")
|
||||
.into_diagnostic()
|
||||
.context("failed to write record")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> miette::Result<()> {
|
||||
let run_id = Uuid::new_v4();
|
||||
@@ -306,23 +42,21 @@ async fn main() -> miette::Result<()> {
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let config = if let Some(config_path) = &args.config {
|
||||
let config: Config = if let Some(config_path) = &args.config {
|
||||
info!(config_file = %config_path.display(), "loading config from file");
|
||||
load_from_file(config_path)?
|
||||
config_path.as_path().try_into()?
|
||||
} else {
|
||||
info!("using CLI arguments");
|
||||
load_from_cli(&args)?
|
||||
args.try_into()?
|
||||
};
|
||||
|
||||
let tls_config = build_tls_config(config.server_mode())?;
|
||||
let tls_connector = TlsConnector::from(tls_config);
|
||||
|
||||
let server_name = ServerName::try_from("localhost".to_string())
|
||||
.into_diagnostic()
|
||||
.context("invalid server name")?;
|
||||
|
||||
for benchmark in &config.benchmarks {
|
||||
info!(
|
||||
proto = %benchmark.proto,
|
||||
mode = %benchmark.mode,
|
||||
payload = benchmark.payload,
|
||||
iters = benchmark.iters,
|
||||
@@ -331,6 +65,9 @@ async fn main() -> miette::Result<()> {
|
||||
"running benchmark"
|
||||
);
|
||||
|
||||
let tls_config = build_tls_config(benchmark.mode)?;
|
||||
let tls_connector = TlsConnector::from(Arc::new(tls_config));
|
||||
|
||||
run_benchmark(benchmark, &tls_connector, &server_name).await?;
|
||||
}
|
||||
|
||||
|
||||
86
runner/src/tls.rs
Normal file
86
runner/src/tls.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use common::prelude::*;
|
||||
use miette::{Context, IntoDiagnostic};
|
||||
use rustls::{
|
||||
ClientConfig, DigitallySignedStruct, SignatureScheme,
|
||||
client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
|
||||
compress::CompressionCache,
|
||||
crypto::aws_lc_rs::{
|
||||
self,
|
||||
kx_group::{X25519, X25519MLKEM768},
|
||||
},
|
||||
pki_types::{CertificateDer, ServerName, UnixTime},
|
||||
version::TLS13,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Certificate verifier that accepts any certificate.
|
||||
/// Used for benchmarking where we don't need to verify the server's identity.
|
||||
#[derive(Debug)]
|
||||
pub struct NoVerifier;
|
||||
|
||||
impl ServerCertVerifier for NoVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &CertificateDer<'_>,
|
||||
_intermediates: &[CertificateDer<'_>],
|
||||
_server_name: &ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: UnixTime,
|
||||
) -> Result<ServerCertVerified, rustls::Error> {
|
||||
Ok(ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_dss: &DigitallySignedStruct,
|
||||
) -> Result<HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_dss: &DigitallySignedStruct,
|
||||
) -> Result<HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
|
||||
vec![
|
||||
SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
SignatureScheme::ECDSA_NISTP521_SHA512,
|
||||
SignatureScheme::ED25519,
|
||||
SignatureScheme::RSA_PSS_SHA256,
|
||||
SignatureScheme::RSA_PSS_SHA384,
|
||||
SignatureScheme::RSA_PSS_SHA512,
|
||||
SignatureScheme::RSA_PKCS1_SHA256,
|
||||
SignatureScheme::RSA_PKCS1_SHA384,
|
||||
SignatureScheme::RSA_PKCS1_SHA512,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
/// Build TLS client config for the given key exchange mode.
|
||||
pub fn build_tls_config(mode: KeyExchangeMode) -> miette::Result<ClientConfig> {
|
||||
let mut provider = aws_lc_rs::default_provider();
|
||||
provider.kx_groups = match mode {
|
||||
KeyExchangeMode::X25519 => vec![X25519],
|
||||
KeyExchangeMode::X25519Mlkem768 => vec![X25519MLKEM768],
|
||||
};
|
||||
|
||||
let mut config = ClientConfig::builder_with_provider(Arc::new(provider))
|
||||
.with_protocol_versions(&[&TLS13])
|
||||
.into_diagnostic()
|
||||
.context("failed to set TLS versions")?
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(NoVerifier))
|
||||
.with_no_client_auth();
|
||||
|
||||
config.cert_compression_cache = Arc::new(CompressionCache::Disabled);
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
@@ -6,8 +6,12 @@ edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
base64.workspace = true
|
||||
bytes.workspace = true
|
||||
clap.workspace = true
|
||||
common.workspace = true
|
||||
http-body-util.workspace = true
|
||||
hyper-util = { workspace = true, features = ["server"] }
|
||||
hyper = { workspace = true, features = ["server"] }
|
||||
miette.workspace = true
|
||||
rustls.workspace = true
|
||||
thiserror.workspace = true
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use miette::Diagnostic;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Result type using the servers's custom error type.
|
||||
/// Result type using the `servers`'s custom error type.
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Debug, Error, Diagnostic)]
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
pub mod error;
|
||||
@@ -5,146 +5,34 @@
|
||||
//! - Responds with exactly N bytes (deterministic pattern)
|
||||
|
||||
mod error;
|
||||
mod server;
|
||||
mod tls;
|
||||
|
||||
use crate::{server::run_server, tls::build_tls_config};
|
||||
use base64::prelude::*;
|
||||
use clap::Parser;
|
||||
use common::{
|
||||
KeyExchangeMode,
|
||||
cert::{CaCertificate, ServerCertificate},
|
||||
protocol::{read_request, write_payload},
|
||||
};
|
||||
use rustls::{
|
||||
ServerConfig,
|
||||
crypto::aws_lc_rs::{
|
||||
self,
|
||||
kx_group::{X25519, X25519MLKEM768},
|
||||
},
|
||||
pki_types::{CertificateDer, PrivateKeyDer},
|
||||
server::Acceptor,
|
||||
version::TLS13,
|
||||
};
|
||||
use std::{env, io::ErrorKind, net::SocketAddr, sync::Arc};
|
||||
use tokio::{
|
||||
io::AsyncWriteExt,
|
||||
net::{TcpListener, TcpStream},
|
||||
};
|
||||
use tokio_rustls::LazyConfigAcceptor;
|
||||
use tracing::{debug, error, info, warn};
|
||||
use common::{cert::CaCertificate, prelude::*};
|
||||
use std::{env, net::SocketAddr};
|
||||
use tracing::{error, info};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
/// TLS benchmark server.
|
||||
#[derive(Debug, Parser)]
|
||||
#[command(name = "server", version, about)]
|
||||
struct Args {
|
||||
/// Key exchange mode.
|
||||
/// Key exchange mode
|
||||
#[arg(long, default_value = "x25519")]
|
||||
mode: KeyExchangeMode,
|
||||
|
||||
/// Address to listen on.
|
||||
/// Protocol carrier mode
|
||||
#[arg(long, default_value = "raw")]
|
||||
proto: ProtocolMode,
|
||||
|
||||
/// Address to listen on
|
||||
#[arg(long, default_value = "127.0.0.1:4433")]
|
||||
listen: SocketAddr,
|
||||
}
|
||||
|
||||
/// Build TLS server config for the given key exchange mode.
|
||||
fn build_tls_config(
|
||||
mode: KeyExchangeMode,
|
||||
server_cert: &ServerCertificate,
|
||||
) -> error::Result<Arc<ServerConfig>> {
|
||||
let mut provider = aws_lc_rs::default_provider();
|
||||
provider.kx_groups = match mode {
|
||||
KeyExchangeMode::X25519 => vec![X25519],
|
||||
KeyExchangeMode::X25519Mlkem768 => vec![X25519MLKEM768],
|
||||
};
|
||||
|
||||
let certs = server_cert
|
||||
.cert_chain_der
|
||||
.iter()
|
||||
.map(|der| CertificateDer::from(der.clone()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let key = PrivateKeyDer::try_from(server_cert.private_key_der.clone())
|
||||
.map_err(|e| error::Error::invalid_cert(format!("invalid private_key: {e}")))?;
|
||||
|
||||
let config = ServerConfig::builder_with_provider(Arc::new(provider))
|
||||
.with_protocol_versions(&[&TLS13])
|
||||
.map_err(common::Error::Tls)?
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)
|
||||
.map_err(common::Error::Tls)?;
|
||||
|
||||
Ok(Arc::new(config))
|
||||
}
|
||||
|
||||
async fn handle_connection(stream: TcpStream, peer: SocketAddr, tls_config: Arc<ServerConfig>) {
|
||||
let acceptor = LazyConfigAcceptor::new(Acceptor::default(), stream);
|
||||
let start_handshake = match acceptor.await {
|
||||
Ok(sh) => sh,
|
||||
Err(e) => {
|
||||
return warn!(peer = %peer, error = %e, "TLS accept error");
|
||||
}
|
||||
};
|
||||
|
||||
let mut tls_stream = match start_handshake.into_stream(tls_config).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return warn!(peer = %peer, error = %e, "TLS handshake error");
|
||||
}
|
||||
};
|
||||
|
||||
let (_, conn) = tls_stream.get_ref();
|
||||
info!(
|
||||
cipher = ?conn.negotiated_cipher_suite(),
|
||||
version = ?conn.protocol_version(),
|
||||
"connection established"
|
||||
);
|
||||
|
||||
loop {
|
||||
let payload_size = match read_request(&mut tls_stream).await {
|
||||
Ok(size) => size,
|
||||
Err(e) if e.kind() == ErrorKind::UnexpectedEof => {
|
||||
debug!(peer = %peer, "client disconnected");
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(peer = %peer, error = %e, "connection error");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = write_payload(&mut tls_stream, payload_size).await {
|
||||
warn!(peer = %peer, error = %e, "write error");
|
||||
break;
|
||||
}
|
||||
|
||||
// Flush to ensure data is sent
|
||||
if let Err(e) = tls_stream.flush().await {
|
||||
warn!(peer = %peer, error = %e, "flush error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_server(args: Args, tls_config: Arc<ServerConfig>) -> error::Result<()> {
|
||||
let listener = TcpListener::bind(args.listen)
|
||||
.await
|
||||
.map_err(|e| error::Error::network(format!("failed to bind to {}: {e}", args.listen)))?;
|
||||
|
||||
info!(listen = %args.listen, mode = %args.mode, "listening");
|
||||
|
||||
loop {
|
||||
let (stream, peer) = match listener.accept().await {
|
||||
Ok(conn) => conn,
|
||||
Err(e) => {
|
||||
error!(error = %e, "accept error");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let config = tls_config.clone();
|
||||
tokio::spawn(handle_connection(stream, peer, config));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> miette::Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
@@ -161,6 +49,7 @@ async fn main() -> miette::Result<()> {
|
||||
command = env::args().collect::<Vec<_>>().join(" "),
|
||||
listen = %args.listen,
|
||||
mode = %args.mode,
|
||||
proto = %args.proto,
|
||||
"server started"
|
||||
);
|
||||
|
||||
@@ -181,19 +70,20 @@ async fn main() -> miette::Result<()> {
|
||||
"CA cert (truncated)"
|
||||
);
|
||||
|
||||
Ok(run_server(args, tls_config).await?)
|
||||
Ok(run_server(&args, tls_config).await?)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use claims::assert_ok;
|
||||
use common::cert::CaCertificate;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[test]
|
||||
fn default_args() {
|
||||
let args = Args::parse_from(["server"]);
|
||||
assert_eq!(args.mode, KeyExchangeMode::X25519);
|
||||
assert_eq!(args.proto, ProtocolMode::Raw);
|
||||
assert_eq!(args.listen.to_string(), "127.0.0.1:4433");
|
||||
}
|
||||
|
||||
|
||||
307
server/src/server/http1.rs
Normal file
307
server/src/server/http1.rs
Normal file
@@ -0,0 +1,307 @@
|
||||
use bytes::Bytes;
|
||||
use common::prelude::*;
|
||||
use http_body_util::Full;
|
||||
use hyper::{
|
||||
Method, Request, Response, StatusCode,
|
||||
header::{ALLOW, CONNECTION, CONTENT_LENGTH, CONTENT_TYPE, HeaderValue},
|
||||
server::conn::http1::Builder,
|
||||
service::service_fn,
|
||||
};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use rustls::{ServerConfig, server::Acceptor};
|
||||
use std::{convert::Infallible, net::SocketAddr, sync::Arc};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::LazyConfigAcceptor;
|
||||
use tracing::{info, warn};
|
||||
|
||||
type RespBody = Full<Bytes>;
|
||||
|
||||
pub async fn handle_http1_connection(
|
||||
stream: TcpStream,
|
||||
peer: SocketAddr,
|
||||
tls_config: Arc<ServerConfig>,
|
||||
) {
|
||||
let acceptor = LazyConfigAcceptor::new(Acceptor::default(), stream);
|
||||
let start_handshake = match acceptor.await {
|
||||
Ok(sh) => sh,
|
||||
Err(e) => {
|
||||
return warn!(peer = %peer, error = %e, "TLS accept error");
|
||||
}
|
||||
};
|
||||
|
||||
let tls_stream = match start_handshake.into_stream(tls_config).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return warn!(peer = %peer, error = %e, "TLS handshake error");
|
||||
}
|
||||
};
|
||||
|
||||
let (_, conn) = tls_stream.get_ref();
|
||||
info!(
|
||||
cipher = ?conn.negotiated_cipher_suite(),
|
||||
version = ?conn.protocol_version(),
|
||||
"connection established"
|
||||
);
|
||||
|
||||
let service = service_fn(move |req| async move { Ok::<_, Infallible>(handle_request(&req)) });
|
||||
|
||||
let io = TokioIo::new(tls_stream);
|
||||
|
||||
if let Err(e) = Builder::new()
|
||||
.keep_alive(false)
|
||||
.serve_connection(io, service)
|
||||
.await
|
||||
{
|
||||
warn!(peer = %peer, error = %e, "http1 serve error");
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_request<B>(req: &Request<B>) -> Response<RespBody> {
|
||||
if req.method() != Method::GET {
|
||||
let mut response = text_response(StatusCode::METHOD_NOT_ALLOWED, "method not allowed");
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(ALLOW, HeaderValue::from_static("GET"));
|
||||
return response;
|
||||
}
|
||||
|
||||
let n = match parse_bytes_path(req.uri().path()) {
|
||||
Ok(n) => n,
|
||||
Err(status) => {
|
||||
let msg = match status {
|
||||
StatusCode::NOT_FOUND => "not found",
|
||||
StatusCode::PAYLOAD_TOO_LARGE => "payload too large",
|
||||
_ => "bad request",
|
||||
};
|
||||
return text_response(status, msg);
|
||||
}
|
||||
};
|
||||
|
||||
let payload = generate_payload(n);
|
||||
let mut response = Response::new(Full::new(Bytes::from(payload)));
|
||||
*response.status_mut() = StatusCode::OK;
|
||||
|
||||
let headers = response.headers_mut();
|
||||
headers.insert(
|
||||
CONTENT_TYPE,
|
||||
HeaderValue::from_static("application/octet-stream"),
|
||||
);
|
||||
headers.insert(CONNECTION, HeaderValue::from_static("close"));
|
||||
|
||||
#[allow(clippy::option_if_let_else)]
|
||||
match HeaderValue::from_str(&n.to_string()) {
|
||||
Ok(v) => {
|
||||
headers.insert(CONTENT_LENGTH, v);
|
||||
response
|
||||
}
|
||||
Err(_) => text_response(StatusCode::INTERNAL_SERVER_ERROR, "internal server error"),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_bytes_path(path: &str) -> Result<u64, StatusCode> {
|
||||
let Some(rest) = path.strip_prefix("/bytes/") else {
|
||||
return Err(StatusCode::NOT_FOUND);
|
||||
};
|
||||
|
||||
if rest.is_empty() || rest.contains('/') {
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
let n = rest.parse::<u64>().map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
|
||||
if n > MAX_PAYLOAD_SIZE {
|
||||
return Err(StatusCode::PAYLOAD_TOO_LARGE);
|
||||
}
|
||||
|
||||
Ok(n)
|
||||
}
|
||||
|
||||
fn text_response(status: StatusCode, msg: &'static str) -> Response<RespBody> {
|
||||
let mut response = Response::new(Full::new(Bytes::from_static(msg.as_bytes())));
|
||||
*response.status_mut() = status;
|
||||
response.headers_mut().insert(
|
||||
CONTENT_TYPE,
|
||||
HeaderValue::from_static("text/plain; charset=utf-8"),
|
||||
);
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(CONNECTION, HeaderValue::from_static("close"));
|
||||
response
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use claims::{assert_err, assert_none, assert_ok, assert_some};
|
||||
use http_body_util::BodyExt;
|
||||
|
||||
fn make_get_request(uri: &str) -> Request<()> {
|
||||
assert_ok!(Request::builder().method(Method::GET).uri(uri).body(()))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_bytes_path_accepts_valid_numeric_size() {
|
||||
let min_n = assert_ok!(parse_bytes_path("/bytes/0"));
|
||||
assert_eq!(min_n, 0);
|
||||
|
||||
let n = assert_ok!(parse_bytes_path("/bytes/1024"));
|
||||
assert_eq!(n, 1024);
|
||||
|
||||
let max_n = assert_ok!(parse_bytes_path(&format!("/bytes/{MAX_PAYLOAD_SIZE}")));
|
||||
assert_eq!(max_n, MAX_PAYLOAD_SIZE);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_bytes_path_rejects_non_bytes_prefix() {
|
||||
let status = assert_err!(parse_bytes_path("/foo/1024"));
|
||||
assert_eq!(status, StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_bytes_path_rejects_empty_size_segment() {
|
||||
let status = assert_err!(parse_bytes_path("/bytes/"));
|
||||
assert_eq!(status, StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_bytes_path_rejects_non_numeric_size() {
|
||||
let status = assert_err!(parse_bytes_path("/bytes/foo"));
|
||||
assert_eq!(status, StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_bytes_path_rejects_nested_segments() {
|
||||
let status = assert_err!(parse_bytes_path("/bytes/16/extra"));
|
||||
assert_eq!(status, StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_bytes_path_rejects_payload_above_max() {
|
||||
let status = assert_err!(parse_bytes_path(&format!(
|
||||
"/bytes/{}",
|
||||
MAX_PAYLOAD_SIZE + 1
|
||||
)));
|
||||
assert_eq!(status, StatusCode::PAYLOAD_TOO_LARGE);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handle_request_get_bytes_returns_200_with_expected_headers() {
|
||||
let req = make_get_request("/bytes/16");
|
||||
|
||||
let resp = handle_request(&req);
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
let content_type = assert_some!(resp.headers().get("content-type"));
|
||||
assert_eq!(content_type, "application/octet-stream");
|
||||
|
||||
let content_length = assert_some!(resp.headers().get("content-length"));
|
||||
assert_eq!(content_length, "16");
|
||||
|
||||
let connection = assert_some!(resp.headers().get("connection"));
|
||||
assert_eq!(connection, "close");
|
||||
|
||||
assert_none!(resp.headers().get("allow"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_request_get_bytes_returns_body_with_requested_length() {
|
||||
let req = make_get_request("/bytes/32");
|
||||
|
||||
let resp = handle_request(&req);
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
let content_length = assert_some!(resp.headers().get("content-length"));
|
||||
assert_eq!(content_length, "32");
|
||||
|
||||
let body = assert_ok!(resp.into_body().collect().await).to_bytes();
|
||||
assert_eq!(body.len(), 32);
|
||||
|
||||
assert_eq!(body[0], 0x00);
|
||||
assert_eq!(body[1], 0x01);
|
||||
assert_eq!(body[31], 0x1F);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handle_request_get_unknown_path_returns_404() {
|
||||
let req = make_get_request("/unknown");
|
||||
|
||||
let resp = handle_request(&req);
|
||||
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
|
||||
|
||||
let content_type = assert_some!(resp.headers().get("content-type"));
|
||||
assert_eq!(content_type, "text/plain; charset=utf-8");
|
||||
|
||||
let connection = assert_some!(resp.headers().get("connection"));
|
||||
assert_eq!(connection, "close");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handle_request_post_bytes_returns_405_and_allow_get() {
|
||||
let req = Request::builder()
|
||||
.method(Method::POST)
|
||||
.uri("/bytes/32")
|
||||
.body(())
|
||||
.expect("post request");
|
||||
|
||||
let resp = handle_request(&req);
|
||||
assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED);
|
||||
|
||||
let allow = assert_some!(resp.headers().get("allow"));
|
||||
assert_eq!(allow, "GET");
|
||||
|
||||
let connection = assert_some!(resp.headers().get("connection"));
|
||||
assert_eq!(connection, "close");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handle_request_get_bytes_without_size_segment_returns_404() {
|
||||
let req = make_get_request("/bytes");
|
||||
|
||||
let resp = handle_request(&req);
|
||||
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
|
||||
|
||||
assert_none!(resp.headers().get("content-length"));
|
||||
|
||||
let connection = assert_some!(resp.headers().get("connection"));
|
||||
assert_eq!(connection, "close");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handle_request_get_bytes_with_non_numeric_size_returns_400() {
|
||||
let req = make_get_request("/bytes/foo");
|
||||
|
||||
let resp = handle_request(&req);
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
assert_none!(resp.headers().get("content-length"));
|
||||
|
||||
let connection = assert_some!(resp.headers().get("connection"));
|
||||
assert_eq!(connection, "close");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handle_request_get_bytes_with_nested_path_returns_400() {
|
||||
let req = make_get_request("/bytes/16/extra");
|
||||
|
||||
let resp = handle_request(&req);
|
||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
assert_none!(resp.headers().get("content-length"));
|
||||
|
||||
let connection = assert_some!(resp.headers().get("connection"));
|
||||
assert_eq!(connection, "close");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handle_request_get_bytes_exceeding_max_payload_returns_413() {
|
||||
let req = make_get_request(&format!("/bytes/{}", MAX_PAYLOAD_SIZE + 1));
|
||||
|
||||
let resp = handle_request(&req);
|
||||
assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE);
|
||||
|
||||
assert_none!(resp.headers().get("content-length"));
|
||||
|
||||
let connection = assert_some!(resp.headers().get("connection"));
|
||||
assert_eq!(connection, "close");
|
||||
}
|
||||
}
|
||||
45
server/src/server/mod.rs
Normal file
45
server/src/server/mod.rs
Normal file
@@ -0,0 +1,45 @@
|
||||
mod http1;
|
||||
mod raw;
|
||||
|
||||
use crate::{
|
||||
Args,
|
||||
error::{Error as ServerError, Result as ServerResult},
|
||||
server::{http1::handle_http1_connection, raw::handle_raw_connection},
|
||||
};
|
||||
use common::prelude::*;
|
||||
use rustls::ServerConfig;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tracing::{error, info};
|
||||
|
||||
pub async fn run_server(args: &Args, tls_config: Arc<ServerConfig>) -> ServerResult<()> {
|
||||
let listener = TcpListener::bind(args.listen)
|
||||
.await
|
||||
.map_err(|e| ServerError::network(format!("failed to bind to {}: {e}", args.listen)))?;
|
||||
|
||||
info!(
|
||||
listen = %args.listen,
|
||||
mode = %args.mode,
|
||||
proto = %args.proto,
|
||||
"listening"
|
||||
);
|
||||
|
||||
loop {
|
||||
let (stream, peer) = match listener.accept().await {
|
||||
Ok(conn) => conn,
|
||||
Err(e) => {
|
||||
error!(error = %e, "accept error");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let config = tls_config.clone();
|
||||
let proto = args.proto;
|
||||
tokio::spawn(async move {
|
||||
match proto {
|
||||
ProtocolMode::Raw => handle_raw_connection(stream, peer, config).await,
|
||||
ProtocolMode::Http1 => handle_http1_connection(stream, peer, config).await,
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
59
server/src/server/raw.rs
Normal file
59
server/src/server/raw.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
use common::prelude::*;
|
||||
use rustls::{ServerConfig, server::Acceptor};
|
||||
use std::{io::ErrorKind, net::SocketAddr, sync::Arc};
|
||||
use tokio::{io::AsyncWriteExt, net::TcpStream};
|
||||
use tokio_rustls::LazyConfigAcceptor;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
pub async fn handle_raw_connection(
|
||||
stream: TcpStream,
|
||||
peer: SocketAddr,
|
||||
tls_config: Arc<ServerConfig>,
|
||||
) {
|
||||
let acceptor = LazyConfigAcceptor::new(Acceptor::default(), stream);
|
||||
let start_handshake = match acceptor.await {
|
||||
Ok(sh) => sh,
|
||||
Err(e) => {
|
||||
return warn!(peer = %peer, error = %e, "TLS accept error");
|
||||
}
|
||||
};
|
||||
|
||||
let mut tls_stream = match start_handshake.into_stream(tls_config).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return warn!(peer = %peer, error = %e, "TLS handshake error");
|
||||
}
|
||||
};
|
||||
|
||||
let (_, conn) = tls_stream.get_ref();
|
||||
info!(
|
||||
cipher = ?conn.negotiated_cipher_suite(),
|
||||
version = ?conn.protocol_version(),
|
||||
"connection established"
|
||||
);
|
||||
|
||||
loop {
|
||||
let payload_size = match read_request(&mut tls_stream).await {
|
||||
Ok(size) => size,
|
||||
Err(e) if e.kind() == ErrorKind::UnexpectedEof => {
|
||||
debug!(peer = %peer, "client disconnected");
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(peer = %peer, error = %e, "connection error");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = write_payload(&mut tls_stream, payload_size).await {
|
||||
warn!(peer = %peer, error = %e, "write error");
|
||||
break;
|
||||
}
|
||||
|
||||
// Flush to ensure data is sent
|
||||
if let Err(e) = tls_stream.flush().await {
|
||||
warn!(peer = %peer, error = %e, "flush error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
42
server/src/tls.rs
Normal file
42
server/src/tls.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use crate::error;
|
||||
use common::{cert::ServerCertificate, prelude::*};
|
||||
use rustls::{
|
||||
ServerConfig,
|
||||
crypto::aws_lc_rs::{
|
||||
self,
|
||||
kx_group::{X25519, X25519MLKEM768},
|
||||
},
|
||||
pki_types::{CertificateDer, PrivateKeyDer},
|
||||
version::TLS13,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Build TLS server config for the given key exchange mode.
|
||||
pub fn build_tls_config(
|
||||
mode: KeyExchangeMode,
|
||||
server_cert: &ServerCertificate,
|
||||
) -> error::Result<Arc<ServerConfig>> {
|
||||
let mut provider = aws_lc_rs::default_provider();
|
||||
provider.kx_groups = match mode {
|
||||
KeyExchangeMode::X25519 => vec![X25519],
|
||||
KeyExchangeMode::X25519Mlkem768 => vec![X25519MLKEM768],
|
||||
};
|
||||
|
||||
let certs = server_cert
|
||||
.cert_chain_der
|
||||
.iter()
|
||||
.map(|der| CertificateDer::from(der.clone()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let key = PrivateKeyDer::try_from(server_cert.private_key_der.clone())
|
||||
.map_err(|e| error::Error::invalid_cert(format!("invalid private_key: {e}")))?;
|
||||
|
||||
let config = ServerConfig::builder_with_provider(Arc::new(provider))
|
||||
.with_protocol_versions(&[&TLS13])
|
||||
.map_err(common::Error::Tls)?
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)
|
||||
.map_err(common::Error::Tls)?;
|
||||
|
||||
Ok(Arc::new(config))
|
||||
}
|
||||
Reference in New Issue
Block a user