From 818cfd559849f39fda22541aa768ba1ab78731f7 Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Wed, 11 Feb 2026 16:29:59 +0200 Subject: [PATCH] refactor(server,common): introduce custom error types with thiserror and miette --- Cargo.lock | 5 ++++ common/Cargo.toml | 4 +++ common/src/error.rs | 65 ++++++++++++++++++++++++++++++++++++++++++++ common/src/lib.rs | 2 ++ runner/src/config.rs | 8 +++--- runner/src/error.rs | 48 ++------------------------------ server/Cargo.toml | 3 +- server/src/error.rs | 38 ++++++++++++++++++++++++++ server/src/lib.rs | 1 + server/src/main.rs | 21 +++++++------- 10 files changed, 134 insertions(+), 61 deletions(-) create mode 100644 common/src/error.rs create mode 100644 server/src/error.rs create mode 100644 server/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 26d18bf..90ff8ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -275,11 +275,15 @@ name = "common" version = "0.1.0" dependencies = [ "cargo-husky", + "miette", "rcgen", + "rustls", "serde", "serde_json", "strum", + "thiserror", "tokio", + "toml", ] [[package]] @@ -915,6 +919,7 @@ dependencies = [ "common", "miette", "rustls", + "thiserror", "tokio", "tokio-rustls", "tracing", diff --git a/common/Cargo.toml b/common/Cargo.toml index 87af3e5..c7ac8b0 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -5,11 +5,15 @@ authors.workspace = true edition.workspace = true [dependencies] +miette.workspace = true rcgen.workspace = true +rustls.workspace = true serde.workspace = true serde_json.workspace = true strum.workspace = true +thiserror.workspace = true tokio.workspace = true +toml.workspace = true [dev-dependencies] cargo-husky.workspace = true diff --git a/common/src/error.rs b/common/src/error.rs new file mode 100644 index 0000000..6b30841 --- /dev/null +++ b/common/src/error.rs @@ -0,0 +1,65 @@ +use miette::Diagnostic; +use thiserror::Error; + +/// Result type using the common's custom error type. +pub type Result = std::result::Result; + +#[derive(Debug, Error, Diagnostic)] +pub enum Error { + /// File or network I/O error. + #[error(transparent)] + #[diagnostic(code(common::io_error))] + Io(#[from] std::io::Error), + + #[error(transparent)] + #[diagnostic(code(common::rustls_error))] + Tls(#[from] rustls::Error), + + /// TOML configuration file parse error. + #[error(transparent)] + #[diagnostic(code(common::toml_error))] + Toml(#[from] toml::de::Error), + + #[error(transparent)] + #[diagnostic(code(common::json_error))] + Json(#[from] serde_json::Error), + + #[error(transparent)] + #[diagnostic(code(common::rcgen_error))] + RCGen(#[from] rcgen::Error), + + /// Configuration validation or missing required fields. + #[error("Config error: {0}")] + #[diagnostic(code(common::config_error))] + Config(String), + + /// Invalid key exchange mode string. + #[error("Invalid mode: {0}")] + #[diagnostic(code(common::invalid_mode))] + InvalidMode(String), + + /// Protocol-level error (malformed requests, unexpected responses). + #[error("Protocol error: {0}")] + #[diagnostic(code(common::protocol_error))] + Protocol(String), +} + +impl Error { + /// Create an invalid mode error. + #[inline] + pub fn invalid_mode(error: impl Into) -> Self { + Self::InvalidMode(error.into()) + } + + /// Create a config error. + #[inline] + pub fn config(error: impl Into) -> Self { + Self::Config(error.into()) + } + + /// Create a protocol error. + #[inline] + pub fn protocol(error: impl Into) -> Self { + Self::Protocol(error.into()) + } +} diff --git a/common/src/lib.rs b/common/src/lib.rs index 50e627f..d5faf15 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -1,8 +1,10 @@ //! Common types and utilities for the TLS benchmark harness. pub mod cert; +pub mod error; pub mod protocol; +pub use error::Error; use serde::{Deserialize, Serialize}; use std::fmt; use strum::{Display, EnumString}; diff --git a/runner/src/config.rs b/runner/src/config.rs index 174f4bd..9c8d0f8 100644 --- a/runner/src/config.rs +++ b/runner/src/config.rs @@ -1,5 +1,5 @@ use crate::{args::Args, error}; -use common::KeyExchangeMode; +use common::{self, KeyExchangeMode}; use serde::Deserialize; use std::{fs::read_to_string, net::SocketAddr, path::PathBuf}; @@ -23,8 +23,8 @@ pub struct Config { /// # Errors /// Returns an error if the file cannot be read or parsed. pub fn load_from_file(path: &PathBuf) -> error::Result { - let content = read_to_string(path).map_err(error::Error::Io)?; - let config = toml::from_str::(&content).map_err(error::Error::Toml)?; + let content = read_to_string(path).map_err(common::Error::Io)?; + let config = toml::from_str::(&content).map_err(common::Error::Toml)?; Ok(config) } @@ -42,7 +42,7 @@ pub fn load_from_cli(args: &Args) -> error::Result { concurrency: args.concurrency, server: args .server - .ok_or_else(|| error::Error::config("--server ir required"))?, + .ok_or_else(|| common::Error::config("--server ir required"))?, }], }) } diff --git a/runner/src/error.rs b/runner/src/error.rs index 486513a..72ce337 100644 --- a/runner/src/error.rs +++ b/runner/src/error.rs @@ -9,64 +9,20 @@ pub type Result = std::result::Result; /// Errors that can occur during benchmark execution. #[derive(Debug, Error, Diagnostic)] pub enum Error { - /// TLS configuration or handshake failure. #[error(transparent)] - #[diagnostic(code(runner::tls_error))] - TlsConfig(#[from] rustls::Error), - - /// File or network I/O error. - #[error(transparent)] - #[diagnostic(code(runner::io_error))] - Io(#[from] std::io::Error), - - /// TOML configuration file parse error. - #[error(transparent)] - #[diagnostic(code(runner::toml_error))] - Toml(#[from] toml::de::Error), - - /// Invalid key exchange mode string. - #[error("Invalid mode: {0}")] - #[diagnostic(code(runner::invalid_mode))] - InvalidMode(String), - - /// Configuration validation or missing required fields. - #[error("Config error: {0}")] - #[diagnostic(code(runner::config_error))] - Config(String), + #[diagnostic(code(runner::common_error))] + Common(#[from] common::Error), /// Network connection failure. #[error("Network error: {0}")] #[diagnostic(code(runner::network_error))] Network(String), - - /// Protocol-level error (malformed requests, unexpected responses). - #[error("Protocol error: {0}")] - #[diagnostic(code(runner::protocol_error))] - Protocol(String), } impl Error { - /// Create an invalid mode error. - #[inline] - pub fn invalid_mode(error: impl Into) -> Self { - Self::InvalidMode(error.into()) - } - - /// Create a config error. - #[inline] - pub fn config(error: impl Into) -> Self { - Self::Config(error.into()) - } - /// Create a network error. #[inline] pub fn network(error: impl Into) -> Self { Self::Network(error.into()) } - - /// Create a protocol error. - #[inline] - pub fn protocol(error: impl Into) -> Self { - Self::Protocol(error.into()) - } } diff --git a/server/Cargo.toml b/server/Cargo.toml index a02630a..fd4e851 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -10,10 +10,11 @@ clap.workspace = true common.workspace = true miette.workspace = true rustls.workspace = true +thiserror.workspace = true tokio-rustls.workspace = true tokio.workspace = true -tracing.workspace = true tracing-subscriber.workspace = true +tracing.workspace = true uuid.workspace = true [lints] diff --git a/server/src/error.rs b/server/src/error.rs new file mode 100644 index 0000000..217940d --- /dev/null +++ b/server/src/error.rs @@ -0,0 +1,38 @@ +use miette::Diagnostic; +use thiserror::Error; + +/// Result type using the servers's custom error type. +pub type Result = std::result::Result; + +#[derive(Debug, Error, Diagnostic)] +pub enum Error { + #[error(transparent)] + #[diagnostic(code(runner::common_error))] + Common(#[from] common::Error), + + #[error(transparent)] + #[diagnostic(code(server::cert_validation_error))] + CertValidation(#[from] rustls::pki_types::InvalidDnsNameError), + + /// Network connection failure. + #[error("Network error: {0}")] + #[diagnostic(code(server::network_error))] + Network(String), + + #[error("Invalid certificate: {0}")] + #[diagnostic(code(server::invalid_cert))] + InvalidCert(String), +} + +impl Error { + /// Create a network error. + #[inline] + pub fn network(error: impl Into) -> Self { + Self::Network(error.into()) + } + + #[inline] + pub fn invalid_cert(error: impl Into) -> Self { + Self::InvalidCert(error.into()) + } +} diff --git a/server/src/lib.rs b/server/src/lib.rs new file mode 100644 index 0000000..a91e735 --- /dev/null +++ b/server/src/lib.rs @@ -0,0 +1 @@ +pub mod error; diff --git a/server/src/main.rs b/server/src/main.rs index c28e854..675b84e 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -4,6 +4,8 @@ //! - Reads 8-byte little-endian u64 (requested payload size N) //! - Responds with exactly N bytes (deterministic pattern) +mod error; + use base64::prelude::*; use clap::Parser; use common::{ @@ -11,7 +13,6 @@ use common::{ cert::{CaCertificate, ServerCertificate}, protocol::{read_request, write_payload}, }; -use miette::miette; use rustls::{ ServerConfig, crypto::aws_lc_rs::{ @@ -48,7 +49,7 @@ struct Args { fn build_tls_config( mode: KeyExchangeMode, server_cert: &ServerCertificate, -) -> miette::Result> { +) -> error::Result> { let mut provider = aws_lc_rs::default_provider(); provider.kx_groups = match mode { KeyExchangeMode::X25519 => vec![X25519], @@ -62,14 +63,14 @@ fn build_tls_config( .collect::>(); let key = PrivateKeyDer::try_from(server_cert.private_key_der.clone()) - .map_err(|e| miette!("invalid private key: {e}"))?; + .map_err(|e| error::Error::invalid_cert(format!("invalid private_key: {e}")))?; let config = ServerConfig::builder_with_provider(Arc::new(provider)) .with_protocol_versions(&[&TLS13]) - .map_err(|e| miette!("failed to set TLS versions: {e}"))? + .map_err(common::Error::Tls)? .with_no_client_auth() .with_single_cert(certs, key) - .map_err(|e| miette!("failed to configure certificate: {e}"))?; + .map_err(common::Error::Tls)?; Ok(Arc::new(config)) } @@ -123,10 +124,10 @@ async fn handle_connection(stream: TcpStream, peer: SocketAddr, tls_config: Arc< } } -async fn run_server(args: Args, tls_config: Arc) -> miette::Result<()> { +async fn run_server(args: Args, tls_config: Arc) -> error::Result<()> { let listener = TcpListener::bind(args.listen) .await - .map_err(|e| miette!("failed to bind to {}: {e}", args.listen))?; + .map_err(|e| error::Error::network(format!("failed to bind to {}: {e}", args.listen)))?; info!(listen = %args.listen, mode = %args.mode, "listening"); @@ -164,10 +165,10 @@ async fn main() -> miette::Result<()> { ); info!("Generating self-signed certificates..."); - let ca = CaCertificate::generate().map_err(|e| miette!("failed to generate CA: {e}"))?; + let ca = CaCertificate::generate().map_err(common::Error::RCGen)?; let server_cert = ca .sign_server_cert("localhost") - .map_err(|e| miette!("failed to generate server cert: {e}"))?; + .map_err(common::Error::RCGen)?; let tls_config = build_tls_config(args.mode, &server_cert)?; @@ -180,5 +181,5 @@ async fn main() -> miette::Result<()> { "CA cert (truncated)" ); - run_server(args, tls_config).await + Ok(run_server(args, tls_config).await?) }