refactor(server): separate into smaller modules

This commit is contained in:
2026-02-25 18:39:33 +02:00
parent 99f2e0bb72
commit 8f0d2a6efc
4 changed files with 129 additions and 123 deletions

View File

@@ -1 +0,0 @@
pub mod error;

View File

@@ -5,31 +5,15 @@
//! - Responds with exactly N bytes (deterministic pattern)
mod error;
mod server;
mod tls;
use crate::{server::run_server, tls::build_tls_config};
use base64::prelude::*;
use clap::Parser;
use common::{
KeyExchangeMode,
cert::{CaCertificate, ServerCertificate},
protocol::{read_request, write_payload},
};
use rustls::{
ServerConfig,
crypto::aws_lc_rs::{
self,
kx_group::{X25519, X25519MLKEM768},
},
pki_types::{CertificateDer, PrivateKeyDer},
server::Acceptor,
version::TLS13,
};
use std::{env, io::ErrorKind, net::SocketAddr, sync::Arc};
use tokio::{
io::AsyncWriteExt,
net::{TcpListener, TcpStream},
};
use tokio_rustls::LazyConfigAcceptor;
use tracing::{debug, error, info, warn};
use common::{KeyExchangeMode, cert::CaCertificate};
use std::{env, net::SocketAddr};
use tracing::{error, info};
use tracing_subscriber::EnvFilter;
/// TLS benchmark server.
@@ -45,106 +29,6 @@ struct Args {
listen: SocketAddr,
}
/// Build TLS server config for the given key exchange mode.
fn build_tls_config(
mode: KeyExchangeMode,
server_cert: &ServerCertificate,
) -> error::Result<Arc<ServerConfig>> {
let mut provider = aws_lc_rs::default_provider();
provider.kx_groups = match mode {
KeyExchangeMode::X25519 => vec![X25519],
KeyExchangeMode::X25519Mlkem768 => vec![X25519MLKEM768],
};
let certs = server_cert
.cert_chain_der
.iter()
.map(|der| CertificateDer::from(der.clone()))
.collect::<Vec<_>>();
let key = PrivateKeyDer::try_from(server_cert.private_key_der.clone())
.map_err(|e| error::Error::invalid_cert(format!("invalid private_key: {e}")))?;
let config = ServerConfig::builder_with_provider(Arc::new(provider))
.with_protocol_versions(&[&TLS13])
.map_err(common::Error::Tls)?
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(common::Error::Tls)?;
Ok(Arc::new(config))
}
async fn handle_connection(stream: TcpStream, peer: SocketAddr, tls_config: Arc<ServerConfig>) {
let acceptor = LazyConfigAcceptor::new(Acceptor::default(), stream);
let start_handshake = match acceptor.await {
Ok(sh) => sh,
Err(e) => {
return warn!(peer = %peer, error = %e, "TLS accept error");
}
};
let mut tls_stream = match start_handshake.into_stream(tls_config).await {
Ok(s) => s,
Err(e) => {
return warn!(peer = %peer, error = %e, "TLS handshake error");
}
};
let (_, conn) = tls_stream.get_ref();
info!(
cipher = ?conn.negotiated_cipher_suite(),
version = ?conn.protocol_version(),
"connection established"
);
loop {
let payload_size = match read_request(&mut tls_stream).await {
Ok(size) => size,
Err(e) if e.kind() == ErrorKind::UnexpectedEof => {
debug!(peer = %peer, "client disconnected");
break;
}
Err(e) => {
warn!(peer = %peer, error = %e, "connection error");
break;
}
};
if let Err(e) = write_payload(&mut tls_stream, payload_size).await {
warn!(peer = %peer, error = %e, "write error");
break;
}
// Flush to ensure data is sent
if let Err(e) = tls_stream.flush().await {
warn!(peer = %peer, error = %e, "flush error");
break;
}
}
}
async fn run_server(args: Args, tls_config: Arc<ServerConfig>) -> error::Result<()> {
let listener = TcpListener::bind(args.listen)
.await
.map_err(|e| error::Error::network(format!("failed to bind to {}: {e}", args.listen)))?;
info!(listen = %args.listen, mode = %args.mode, "listening");
loop {
let (stream, peer) = match listener.accept().await {
Ok(conn) => conn,
Err(e) => {
error!(error = %e, "accept error");
continue;
}
};
let config = tls_config.clone();
tokio::spawn(handle_connection(stream, peer, config));
}
}
#[tokio::main]
async fn main() -> miette::Result<()> {
tracing_subscriber::fmt()
@@ -189,6 +73,7 @@ mod tests {
use super::*;
use claims::assert_ok;
use common::cert::CaCertificate;
use std::sync::Arc;
#[test]
fn default_args() {

80
server/src/server.rs Normal file
View File

@@ -0,0 +1,80 @@
use crate::{Args, error};
use common::protocol::{read_request, write_payload};
use rustls::{ServerConfig, server::Acceptor};
use std::{io::ErrorKind, net::SocketAddr, sync::Arc};
use tokio::{
io::AsyncWriteExt,
net::{TcpListener, TcpStream},
};
use tokio_rustls::LazyConfigAcceptor;
use tracing::{debug, info, warn};
pub async fn handle_connection(stream: TcpStream, peer: SocketAddr, tls_config: Arc<ServerConfig>) {
let acceptor = LazyConfigAcceptor::new(Acceptor::default(), stream);
let start_handshake = match acceptor.await {
Ok(sh) => sh,
Err(e) => {
return warn!(peer = %peer, error = %e, "TLS accept error");
}
};
let mut tls_stream = match start_handshake.into_stream(tls_config).await {
Ok(s) => s,
Err(e) => {
return warn!(peer = %peer, error = %e, "TLS handshake error");
}
};
let (_, conn) = tls_stream.get_ref();
info!(
cipher = ?conn.negotiated_cipher_suite(),
version = ?conn.protocol_version(),
"connection established"
);
loop {
let payload_size = match read_request(&mut tls_stream).await {
Ok(size) => size,
Err(e) if e.kind() == ErrorKind::UnexpectedEof => {
debug!(peer = %peer, "client disconnected");
break;
}
Err(e) => {
warn!(peer = %peer, error = %e, "connection error");
break;
}
};
if let Err(e) = write_payload(&mut tls_stream, payload_size).await {
warn!(peer = %peer, error = %e, "write error");
break;
}
// Flush to ensure data is sent
if let Err(e) = tls_stream.flush().await {
warn!(peer = %peer, error = %e, "flush error");
break;
}
}
}
pub async fn run_server(args: Args, tls_config: Arc<ServerConfig>) -> error::Result<()> {
let listener = TcpListener::bind(args.listen)
.await
.map_err(|e| error::Error::network(format!("failed to bind to {}: {e}", args.listen)))?;
info!(listen = %args.listen, mode = %args.mode, "listening");
loop {
let (stream, peer) = match listener.accept().await {
Ok(conn) => conn,
Err(e) => {
error!(error = %e, "accept error");
continue;
}
};
let config = tls_config.clone();
tokio::spawn(handle_connection(stream, peer, config));
}
}

42
server/src/tls.rs Normal file
View File

@@ -0,0 +1,42 @@
use crate::error;
use common::{KeyExchangeMode, cert::ServerCertificate};
use rustls::{
ServerConfig,
crypto::aws_lc_rs::{
self,
kx_group::{X25519, X25519MLKEM768},
},
pki_types::{CertificateDer, PrivateKeyDer},
version::TLS13,
};
use std::sync::Arc;
/// Build TLS server config for the given key exchange mode.
pub fn build_tls_config(
mode: KeyExchangeMode,
server_cert: &ServerCertificate,
) -> error::Result<Arc<ServerConfig>> {
let mut provider = aws_lc_rs::default_provider();
provider.kx_groups = match mode {
KeyExchangeMode::X25519 => vec![X25519],
KeyExchangeMode::X25519Mlkem768 => vec![X25519MLKEM768],
};
let certs = server_cert
.cert_chain_der
.iter()
.map(|der| CertificateDer::from(der.clone()))
.collect::<Vec<_>>();
let key = PrivateKeyDer::try_from(server_cert.private_key_der.clone())
.map_err(|e| error::Error::invalid_cert(format!("invalid private_key: {e}")))?;
let config = ServerConfig::builder_with_provider(Arc::new(provider))
.with_protocol_versions(&[&TLS13])
.map_err(common::Error::Tls)?
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(common::Error::Tls)?;
Ok(Arc::new(config))
}