refactor(server,common): introduce custom error types with thiserror and miette

This commit is contained in:
2026-02-11 16:29:59 +02:00
parent cda6024062
commit 818cfd5598
10 changed files with 134 additions and 61 deletions

View File

@@ -10,10 +10,11 @@ clap.workspace = true
common.workspace = true
miette.workspace = true
rustls.workspace = true
thiserror.workspace = true
tokio-rustls.workspace = true
tokio.workspace = true
tracing.workspace = true
tracing-subscriber.workspace = true
tracing.workspace = true
uuid.workspace = true
[lints]

38
server/src/error.rs Normal file
View File

@@ -0,0 +1,38 @@
use miette::Diagnostic;
use thiserror::Error;
/// Result type using the servers's custom error type.
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Error, Diagnostic)]
pub enum Error {
#[error(transparent)]
#[diagnostic(code(runner::common_error))]
Common(#[from] common::Error),
#[error(transparent)]
#[diagnostic(code(server::cert_validation_error))]
CertValidation(#[from] rustls::pki_types::InvalidDnsNameError),
/// Network connection failure.
#[error("Network error: {0}")]
#[diagnostic(code(server::network_error))]
Network(String),
#[error("Invalid certificate: {0}")]
#[diagnostic(code(server::invalid_cert))]
InvalidCert(String),
}
impl Error {
/// Create a network error.
#[inline]
pub fn network(error: impl Into<String>) -> Self {
Self::Network(error.into())
}
#[inline]
pub fn invalid_cert(error: impl Into<String>) -> Self {
Self::InvalidCert(error.into())
}
}

1
server/src/lib.rs Normal file
View File

@@ -0,0 +1 @@
pub mod error;

View File

@@ -4,6 +4,8 @@
//! - Reads 8-byte little-endian u64 (requested payload size N)
//! - Responds with exactly N bytes (deterministic pattern)
mod error;
use base64::prelude::*;
use clap::Parser;
use common::{
@@ -11,7 +13,6 @@ use common::{
cert::{CaCertificate, ServerCertificate},
protocol::{read_request, write_payload},
};
use miette::miette;
use rustls::{
ServerConfig,
crypto::aws_lc_rs::{
@@ -48,7 +49,7 @@ struct Args {
fn build_tls_config(
mode: KeyExchangeMode,
server_cert: &ServerCertificate,
) -> miette::Result<Arc<ServerConfig>> {
) -> error::Result<Arc<ServerConfig>> {
let mut provider = aws_lc_rs::default_provider();
provider.kx_groups = match mode {
KeyExchangeMode::X25519 => vec![X25519],
@@ -62,14 +63,14 @@ fn build_tls_config(
.collect::<Vec<_>>();
let key = PrivateKeyDer::try_from(server_cert.private_key_der.clone())
.map_err(|e| miette!("invalid private key: {e}"))?;
.map_err(|e| error::Error::invalid_cert(format!("invalid private_key: {e}")))?;
let config = ServerConfig::builder_with_provider(Arc::new(provider))
.with_protocol_versions(&[&TLS13])
.map_err(|e| miette!("failed to set TLS versions: {e}"))?
.map_err(common::Error::Tls)?
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| miette!("failed to configure certificate: {e}"))?;
.map_err(common::Error::Tls)?;
Ok(Arc::new(config))
}
@@ -123,10 +124,10 @@ async fn handle_connection(stream: TcpStream, peer: SocketAddr, tls_config: Arc<
}
}
async fn run_server(args: Args, tls_config: Arc<ServerConfig>) -> miette::Result<()> {
async fn run_server(args: Args, tls_config: Arc<ServerConfig>) -> error::Result<()> {
let listener = TcpListener::bind(args.listen)
.await
.map_err(|e| miette!("failed to bind to {}: {e}", args.listen))?;
.map_err(|e| error::Error::network(format!("failed to bind to {}: {e}", args.listen)))?;
info!(listen = %args.listen, mode = %args.mode, "listening");
@@ -164,10 +165,10 @@ async fn main() -> miette::Result<()> {
);
info!("Generating self-signed certificates...");
let ca = CaCertificate::generate().map_err(|e| miette!("failed to generate CA: {e}"))?;
let ca = CaCertificate::generate().map_err(common::Error::RCGen)?;
let server_cert = ca
.sign_server_cert("localhost")
.map_err(|e| miette!("failed to generate server cert: {e}"))?;
.map_err(common::Error::RCGen)?;
let tls_config = build_tls_config(args.mode, &server_cert)?;
@@ -180,5 +181,5 @@ async fn main() -> miette::Result<()> {
"CA cert (truncated)"
);
run_server(args, tls_config).await
Ok(run_server(args, tls_config).await?)
}