mirror of
https://github.com/kristoferssolo/tls-pq-bench.git
synced 2026-03-22 00:36:21 +00:00
feat(server): add structured logging with tracing
This commit is contained in:
@@ -21,12 +21,14 @@ use rustls::{
|
||||
server::Acceptor,
|
||||
version::TLS13,
|
||||
};
|
||||
use std::{fmt::Write, io::ErrorKind, net::SocketAddr, sync::Arc};
|
||||
use std::{env, fmt::Write, io::ErrorKind, net::SocketAddr, sync::Arc};
|
||||
use tokio::{
|
||||
io::AsyncWriteExt,
|
||||
net::{TcpListener, TcpStream},
|
||||
};
|
||||
use tokio_rustls::LazyConfigAcceptor;
|
||||
use tracing::{debug, error, info, warn};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
/// TLS benchmark server.
|
||||
#[derive(Debug, Parser)]
|
||||
@@ -46,21 +48,18 @@ fn build_tls_config(
|
||||
mode: KeyExchangeMode,
|
||||
server_cert: &ServerCertificate,
|
||||
) -> miette::Result<Arc<ServerConfig>> {
|
||||
// Select crypto provider with appropriate key exchange groups
|
||||
let mut provider = aws_lc_rs::default_provider();
|
||||
provider.kx_groups = match mode {
|
||||
KeyExchangeMode::X25519 => vec![X25519],
|
||||
KeyExchangeMode::X25519Mlkem768 => vec![X25519MLKEM768],
|
||||
};
|
||||
|
||||
// Convert certificate chain
|
||||
let certs: Vec<CertificateDer<'static>> = server_cert
|
||||
let certs = server_cert
|
||||
.cert_chain_der
|
||||
.iter()
|
||||
.map(|der| CertificateDer::from(der.clone()))
|
||||
.collect();
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Convert private key
|
||||
let key = PrivateKeyDer::try_from(server_cert.private_key_der.clone())
|
||||
.map_err(|e| miette!("invalid private key: {e}"))?;
|
||||
|
||||
@@ -75,12 +74,11 @@ fn build_tls_config(
|
||||
}
|
||||
|
||||
async fn handle_connection(stream: TcpStream, peer: SocketAddr, tls_config: Arc<ServerConfig>) {
|
||||
// Perform TLS handshake
|
||||
let acceptor = LazyConfigAcceptor::new(Acceptor::default(), stream);
|
||||
let start_handshake = match acceptor.await {
|
||||
Ok(sh) => sh,
|
||||
Err(e) => {
|
||||
eprintln!("[{peer}] TLS accept error: {e}");
|
||||
warn!(peer = %peer, error = %e, "TLS accept error");
|
||||
return;
|
||||
}
|
||||
};
|
||||
@@ -88,32 +86,32 @@ async fn handle_connection(stream: TcpStream, peer: SocketAddr, tls_config: Arc<
|
||||
let mut tls_stream = match start_handshake.into_stream(tls_config).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
eprintln!("[{peer}] TLS handshake error: {e}");
|
||||
warn!(peer = %peer, error = %e, "TLS handshake error");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Handle protocol
|
||||
loop {
|
||||
let payload_size = match read_request(&mut tls_stream).await {
|
||||
Ok(size) => size,
|
||||
Err(e) if e.kind() == ErrorKind::UnexpectedEof => {
|
||||
debug!(peer = %peer, "client disconnected");
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("[{peer}] read error: {e}");
|
||||
warn!(peer = %peer, error = %e, "connection error");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = write_payload(&mut tls_stream, payload_size).await {
|
||||
eprintln!("[{peer}] write error: {e}");
|
||||
warn!(peer = %peer, error = %e, "write error");
|
||||
break;
|
||||
}
|
||||
|
||||
// Flush to ensure data is sent
|
||||
if let Err(e) = tls_stream.flush().await {
|
||||
eprintln!("[{peer}] flush error: {e}");
|
||||
warn!(peer = %peer, error = %e, "flush error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -124,16 +122,13 @@ async fn run_server(args: Args, tls_config: Arc<ServerConfig>) -> miette::Result
|
||||
.await
|
||||
.map_err(|e| miette!("failed to bind to {}: {e}", args.listen))?;
|
||||
|
||||
eprintln!(
|
||||
"Listening on {} (TLS 1.3, mode: {})",
|
||||
args.listen, args.mode
|
||||
);
|
||||
info!(listen = %args.listen, mode = %args.mode, "listening");
|
||||
|
||||
loop {
|
||||
let (stream, peer) = match listener.accept().await {
|
||||
Ok(conn) => conn,
|
||||
Err(e) => {
|
||||
eprintln!("accept error: {e}");
|
||||
error!(error = %e, "accept error");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
@@ -145,27 +140,38 @@ async fn run_server(args: Args, tls_config: Arc<ServerConfig>) -> miette::Result
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> miette::Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.with_target(false)
|
||||
.init();
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
eprintln!("server configuration:");
|
||||
eprintln!(" mode: {}", args.mode);
|
||||
eprintln!(" listen: {}", args.listen);
|
||||
eprintln!();
|
||||
info!(
|
||||
rust_version = env!("RUSTC_VERSION"),
|
||||
os = env::consts::OS,
|
||||
arch = env::consts::ARCH,
|
||||
command = env::args().collect::<Vec<_>>().join(" "),
|
||||
listen = %args.listen,
|
||||
mode = %args.mode,
|
||||
"server started"
|
||||
);
|
||||
|
||||
// Generate certificates
|
||||
eprintln!("Generating self-signed certificates...");
|
||||
info!("Generating self-signed certificates...");
|
||||
let ca = CaCertificate::generate().map_err(|e| miette!("failed to generate CA: {e}"))?;
|
||||
let server_cert = ca
|
||||
.sign_server_cert("localhost")
|
||||
.map_err(|e| miette!("failed to generate server cert: {e}"))?;
|
||||
|
||||
// Build TLS config
|
||||
let tls_config = build_tls_config(args.mode, &server_cert)?;
|
||||
|
||||
// Print CA certificate for client configuration
|
||||
eprintln!("CA certificate (base64 DER):");
|
||||
eprintln!("{}", base64_encode(&ca.cert_der));
|
||||
eprintln!();
|
||||
info!(
|
||||
ca_cert_base64 = base64_encode(&ca.cert_der)
|
||||
.lines()
|
||||
.take(3)
|
||||
.collect::<String>(),
|
||||
"CA cert (truncated)"
|
||||
);
|
||||
|
||||
run_server(args, tls_config).await
|
||||
}
|
||||
@@ -176,7 +182,7 @@ fn base64_encode(data: &[u8]) -> String {
|
||||
|
||||
let mut result = String::new();
|
||||
for chunk in data.chunks(3) {
|
||||
let mut n = 0u32;
|
||||
let mut n = 0;
|
||||
for (i, &byte) in chunk.iter().enumerate() {
|
||||
n |= u32::from(byte) << (16 - 8 * i);
|
||||
}
|
||||
@@ -191,7 +197,6 @@ fn base64_encode(data: &[u8]) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
// Wrap at 76 characters
|
||||
let mut wrapped = String::new();
|
||||
for (i, c) in result.chars().enumerate() {
|
||||
if i > 0 && i % 76 == 0 {
|
||||
|
||||
Reference in New Issue
Block a user