From 8f0d2a6efcbed309a8989d96021f5c4115b06dea Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Wed, 25 Feb 2026 18:39:33 +0200 Subject: [PATCH] refactor(server): separate into smaller modules --- server/src/lib.rs | 1 - server/src/main.rs | 129 +++---------------------------------------- server/src/server.rs | 80 +++++++++++++++++++++++++++ server/src/tls.rs | 42 ++++++++++++++ 4 files changed, 129 insertions(+), 123 deletions(-) delete mode 100644 server/src/lib.rs create mode 100644 server/src/server.rs create mode 100644 server/src/tls.rs diff --git a/server/src/lib.rs b/server/src/lib.rs deleted file mode 100644 index a91e735..0000000 --- a/server/src/lib.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod error; diff --git a/server/src/main.rs b/server/src/main.rs index e4c18b1..806cf13 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -5,31 +5,15 @@ //! - Responds with exactly N bytes (deterministic pattern) mod error; +mod server; +mod tls; +use crate::{server::run_server, tls::build_tls_config}; use base64::prelude::*; use clap::Parser; -use common::{ - KeyExchangeMode, - cert::{CaCertificate, ServerCertificate}, - protocol::{read_request, write_payload}, -}; -use rustls::{ - ServerConfig, - crypto::aws_lc_rs::{ - self, - kx_group::{X25519, X25519MLKEM768}, - }, - pki_types::{CertificateDer, PrivateKeyDer}, - server::Acceptor, - version::TLS13, -}; -use std::{env, io::ErrorKind, net::SocketAddr, sync::Arc}; -use tokio::{ - io::AsyncWriteExt, - net::{TcpListener, TcpStream}, -}; -use tokio_rustls::LazyConfigAcceptor; -use tracing::{debug, error, info, warn}; +use common::{KeyExchangeMode, cert::CaCertificate}; +use std::{env, net::SocketAddr}; +use tracing::{error, info}; use tracing_subscriber::EnvFilter; /// TLS benchmark server. @@ -45,106 +29,6 @@ struct Args { listen: SocketAddr, } -/// Build TLS server config for the given key exchange mode. -fn build_tls_config( - mode: KeyExchangeMode, - server_cert: &ServerCertificate, -) -> error::Result> { - let mut provider = aws_lc_rs::default_provider(); - provider.kx_groups = match mode { - KeyExchangeMode::X25519 => vec![X25519], - KeyExchangeMode::X25519Mlkem768 => vec![X25519MLKEM768], - }; - - let certs = server_cert - .cert_chain_der - .iter() - .map(|der| CertificateDer::from(der.clone())) - .collect::>(); - - let key = PrivateKeyDer::try_from(server_cert.private_key_der.clone()) - .map_err(|e| error::Error::invalid_cert(format!("invalid private_key: {e}")))?; - - let config = ServerConfig::builder_with_provider(Arc::new(provider)) - .with_protocol_versions(&[&TLS13]) - .map_err(common::Error::Tls)? - .with_no_client_auth() - .with_single_cert(certs, key) - .map_err(common::Error::Tls)?; - - Ok(Arc::new(config)) -} - -async fn handle_connection(stream: TcpStream, peer: SocketAddr, tls_config: Arc) { - let acceptor = LazyConfigAcceptor::new(Acceptor::default(), stream); - let start_handshake = match acceptor.await { - Ok(sh) => sh, - Err(e) => { - return warn!(peer = %peer, error = %e, "TLS accept error"); - } - }; - - let mut tls_stream = match start_handshake.into_stream(tls_config).await { - Ok(s) => s, - Err(e) => { - return warn!(peer = %peer, error = %e, "TLS handshake error"); - } - }; - - let (_, conn) = tls_stream.get_ref(); - info!( - cipher = ?conn.negotiated_cipher_suite(), - version = ?conn.protocol_version(), - "connection established" - ); - - loop { - let payload_size = match read_request(&mut tls_stream).await { - Ok(size) => size, - Err(e) if e.kind() == ErrorKind::UnexpectedEof => { - debug!(peer = %peer, "client disconnected"); - break; - } - Err(e) => { - warn!(peer = %peer, error = %e, "connection error"); - break; - } - }; - - if let Err(e) = write_payload(&mut tls_stream, payload_size).await { - warn!(peer = %peer, error = %e, "write error"); - break; - } - - // Flush to ensure data is sent - if let Err(e) = tls_stream.flush().await { - warn!(peer = %peer, error = %e, "flush error"); - break; - } - } -} - -async fn run_server(args: Args, tls_config: Arc) -> error::Result<()> { - let listener = TcpListener::bind(args.listen) - .await - .map_err(|e| error::Error::network(format!("failed to bind to {}: {e}", args.listen)))?; - - info!(listen = %args.listen, mode = %args.mode, "listening"); - - loop { - let (stream, peer) = match listener.accept().await { - Ok(conn) => conn, - Err(e) => { - error!(error = %e, "accept error"); - continue; - } - }; - - let config = tls_config.clone(); - tokio::spawn(handle_connection(stream, peer, config)); - } -} - #[tokio::main] async fn main() -> miette::Result<()> { tracing_subscriber::fmt() @@ -189,6 +73,7 @@ mod tests { use super::*; use claims::assert_ok; use common::cert::CaCertificate; + use std::sync::Arc; #[test] fn default_args() { diff --git a/server/src/server.rs b/server/src/server.rs new file mode 100644 index 0000000..8d3631b --- /dev/null +++ b/server/src/server.rs @@ -0,0 +1,80 @@ +use crate::{Args, error}; +use common::protocol::{read_request, write_payload}; +use rustls::{ServerConfig, server::Acceptor}; +use std::{io::ErrorKind, net::SocketAddr, sync::Arc}; +use tokio::{ + io::AsyncWriteExt, + net::{TcpListener, TcpStream}, +}; +use tokio_rustls::LazyConfigAcceptor; +use tracing::{debug, info, warn}; + +pub async fn handle_connection(stream: TcpStream, peer: SocketAddr, tls_config: Arc) { + let acceptor = LazyConfigAcceptor::new(Acceptor::default(), stream); + let start_handshake = match acceptor.await { + Ok(sh) => sh, + Err(e) => { + return warn!(peer = %peer, error = %e, "TLS accept error"); + } + }; + + let mut tls_stream = match start_handshake.into_stream(tls_config).await { + Ok(s) => s, + Err(e) => { + return warn!(peer = %peer, error = %e, "TLS handshake error"); + } + }; + + let (_, conn) = tls_stream.get_ref(); + info!( + cipher = ?conn.negotiated_cipher_suite(), + version = ?conn.protocol_version(), + "connection established" + ); + + loop { + let payload_size = match read_request(&mut tls_stream).await { + Ok(size) => size, + Err(e) if e.kind() == ErrorKind::UnexpectedEof => { + debug!(peer = %peer, "client disconnected"); + break; + } + Err(e) => { + warn!(peer = %peer, error = %e, "connection error"); + break; + } + }; + + if let Err(e) = write_payload(&mut tls_stream, payload_size).await { + warn!(peer = %peer, error = %e, "write error"); + break; + } + + // Flush to ensure data is sent + if let Err(e) = tls_stream.flush().await { + warn!(peer = %peer, error = %e, "flush error"); + break; + } + } +} + +pub async fn run_server(args: Args, tls_config: Arc) -> error::Result<()> { + let listener = TcpListener::bind(args.listen) + .await + .map_err(|e| error::Error::network(format!("failed to bind to {}: {e}", args.listen)))?; + + info!(listen = %args.listen, mode = %args.mode, "listening"); + + loop { + let (stream, peer) = match listener.accept().await { + Ok(conn) => conn, + Err(e) => { + error!(error = %e, "accept error"); + continue; + } + }; + + let config = tls_config.clone(); + tokio::spawn(handle_connection(stream, peer, config)); + } +} diff --git a/server/src/tls.rs b/server/src/tls.rs new file mode 100644 index 0000000..5f552aa --- /dev/null +++ b/server/src/tls.rs @@ -0,0 +1,42 @@ +use crate::error; +use common::{KeyExchangeMode, cert::ServerCertificate}; +use rustls::{ + ServerConfig, + crypto::aws_lc_rs::{ + self, + kx_group::{X25519, X25519MLKEM768}, + }, + pki_types::{CertificateDer, PrivateKeyDer}, + version::TLS13, +}; +use std::sync::Arc; + +/// Build TLS server config for the given key exchange mode. +pub fn build_tls_config( + mode: KeyExchangeMode, + server_cert: &ServerCertificate, +) -> error::Result> { + let mut provider = aws_lc_rs::default_provider(); + provider.kx_groups = match mode { + KeyExchangeMode::X25519 => vec![X25519], + KeyExchangeMode::X25519Mlkem768 => vec![X25519MLKEM768], + }; + + let certs = server_cert + .cert_chain_der + .iter() + .map(|der| CertificateDer::from(der.clone())) + .collect::>(); + + let key = PrivateKeyDer::try_from(server_cert.private_key_der.clone()) + .map_err(|e| error::Error::invalid_cert(format!("invalid private_key: {e}")))?; + + let config = ServerConfig::builder_with_provider(Arc::new(provider)) + .with_protocol_versions(&[&TLS13]) + .map_err(common::Error::Tls)? + .with_no_client_auth() + .with_single_cert(certs, key) + .map_err(common::Error::Tls)?; + + Ok(Arc::new(config)) +}