refactor(server,common): introduce custom error types with thiserror and miette

This commit is contained in:
2026-02-11 16:29:59 +02:00
parent cda6024062
commit 818cfd5598
10 changed files with 134 additions and 61 deletions

5
Cargo.lock generated
View File

@@ -275,11 +275,15 @@ name = "common"
version = "0.1.0"
dependencies = [
"cargo-husky",
"miette",
"rcgen",
"rustls",
"serde",
"serde_json",
"strum",
"thiserror",
"tokio",
"toml",
]
[[package]]
@@ -915,6 +919,7 @@ dependencies = [
"common",
"miette",
"rustls",
"thiserror",
"tokio",
"tokio-rustls",
"tracing",

View File

@@ -5,11 +5,15 @@ authors.workspace = true
edition.workspace = true
[dependencies]
miette.workspace = true
rcgen.workspace = true
rustls.workspace = true
serde.workspace = true
serde_json.workspace = true
strum.workspace = true
thiserror.workspace = true
tokio.workspace = true
toml.workspace = true
[dev-dependencies]
cargo-husky.workspace = true

65
common/src/error.rs Normal file
View File

@@ -0,0 +1,65 @@
use miette::Diagnostic;
use thiserror::Error;
/// Result type using the common's custom error type.
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Error, Diagnostic)]
pub enum Error {
/// File or network I/O error.
#[error(transparent)]
#[diagnostic(code(common::io_error))]
Io(#[from] std::io::Error),
#[error(transparent)]
#[diagnostic(code(common::rustls_error))]
Tls(#[from] rustls::Error),
/// TOML configuration file parse error.
#[error(transparent)]
#[diagnostic(code(common::toml_error))]
Toml(#[from] toml::de::Error),
#[error(transparent)]
#[diagnostic(code(common::json_error))]
Json(#[from] serde_json::Error),
#[error(transparent)]
#[diagnostic(code(common::rcgen_error))]
RCGen(#[from] rcgen::Error),
/// Configuration validation or missing required fields.
#[error("Config error: {0}")]
#[diagnostic(code(common::config_error))]
Config(String),
/// Invalid key exchange mode string.
#[error("Invalid mode: {0}")]
#[diagnostic(code(common::invalid_mode))]
InvalidMode(String),
/// Protocol-level error (malformed requests, unexpected responses).
#[error("Protocol error: {0}")]
#[diagnostic(code(common::protocol_error))]
Protocol(String),
}
impl Error {
/// Create an invalid mode error.
#[inline]
pub fn invalid_mode(error: impl Into<String>) -> Self {
Self::InvalidMode(error.into())
}
/// Create a config error.
#[inline]
pub fn config(error: impl Into<String>) -> Self {
Self::Config(error.into())
}
/// Create a protocol error.
#[inline]
pub fn protocol(error: impl Into<String>) -> Self {
Self::Protocol(error.into())
}
}

View File

@@ -1,8 +1,10 @@
//! Common types and utilities for the TLS benchmark harness.
pub mod cert;
pub mod error;
pub mod protocol;
pub use error::Error;
use serde::{Deserialize, Serialize};
use std::fmt;
use strum::{Display, EnumString};

View File

@@ -1,5 +1,5 @@
use crate::{args::Args, error};
use common::KeyExchangeMode;
use common::{self, KeyExchangeMode};
use serde::Deserialize;
use std::{fs::read_to_string, net::SocketAddr, path::PathBuf};
@@ -23,8 +23,8 @@ pub struct Config {
/// # Errors
/// Returns an error if the file cannot be read or parsed.
pub fn load_from_file(path: &PathBuf) -> error::Result<Config> {
let content = read_to_string(path).map_err(error::Error::Io)?;
let config = toml::from_str::<Config>(&content).map_err(error::Error::Toml)?;
let content = read_to_string(path).map_err(common::Error::Io)?;
let config = toml::from_str::<Config>(&content).map_err(common::Error::Toml)?;
Ok(config)
}
@@ -42,7 +42,7 @@ pub fn load_from_cli(args: &Args) -> error::Result<Config> {
concurrency: args.concurrency,
server: args
.server
.ok_or_else(|| error::Error::config("--server ir required"))?,
.ok_or_else(|| common::Error::config("--server ir required"))?,
}],
})
}

View File

@@ -9,64 +9,20 @@ pub type Result<T> = std::result::Result<T, Error>;
/// Errors that can occur during benchmark execution.
#[derive(Debug, Error, Diagnostic)]
pub enum Error {
/// TLS configuration or handshake failure.
#[error(transparent)]
#[diagnostic(code(runner::tls_error))]
TlsConfig(#[from] rustls::Error),
/// File or network I/O error.
#[error(transparent)]
#[diagnostic(code(runner::io_error))]
Io(#[from] std::io::Error),
/// TOML configuration file parse error.
#[error(transparent)]
#[diagnostic(code(runner::toml_error))]
Toml(#[from] toml::de::Error),
/// Invalid key exchange mode string.
#[error("Invalid mode: {0}")]
#[diagnostic(code(runner::invalid_mode))]
InvalidMode(String),
/// Configuration validation or missing required fields.
#[error("Config error: {0}")]
#[diagnostic(code(runner::config_error))]
Config(String),
#[diagnostic(code(runner::common_error))]
Common(#[from] common::Error),
/// Network connection failure.
#[error("Network error: {0}")]
#[diagnostic(code(runner::network_error))]
Network(String),
/// Protocol-level error (malformed requests, unexpected responses).
#[error("Protocol error: {0}")]
#[diagnostic(code(runner::protocol_error))]
Protocol(String),
}
impl Error {
/// Create an invalid mode error.
#[inline]
pub fn invalid_mode(error: impl Into<String>) -> Self {
Self::InvalidMode(error.into())
}
/// Create a config error.
#[inline]
pub fn config(error: impl Into<String>) -> Self {
Self::Config(error.into())
}
/// Create a network error.
#[inline]
pub fn network(error: impl Into<String>) -> Self {
Self::Network(error.into())
}
/// Create a protocol error.
#[inline]
pub fn protocol(error: impl Into<String>) -> Self {
Self::Protocol(error.into())
}
}

View File

@@ -10,10 +10,11 @@ clap.workspace = true
common.workspace = true
miette.workspace = true
rustls.workspace = true
thiserror.workspace = true
tokio-rustls.workspace = true
tokio.workspace = true
tracing.workspace = true
tracing-subscriber.workspace = true
tracing.workspace = true
uuid.workspace = true
[lints]

38
server/src/error.rs Normal file
View File

@@ -0,0 +1,38 @@
use miette::Diagnostic;
use thiserror::Error;
/// Result type using the servers's custom error type.
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Error, Diagnostic)]
pub enum Error {
#[error(transparent)]
#[diagnostic(code(runner::common_error))]
Common(#[from] common::Error),
#[error(transparent)]
#[diagnostic(code(server::cert_validation_error))]
CertValidation(#[from] rustls::pki_types::InvalidDnsNameError),
/// Network connection failure.
#[error("Network error: {0}")]
#[diagnostic(code(server::network_error))]
Network(String),
#[error("Invalid certificate: {0}")]
#[diagnostic(code(server::invalid_cert))]
InvalidCert(String),
}
impl Error {
/// Create a network error.
#[inline]
pub fn network(error: impl Into<String>) -> Self {
Self::Network(error.into())
}
#[inline]
pub fn invalid_cert(error: impl Into<String>) -> Self {
Self::InvalidCert(error.into())
}
}

1
server/src/lib.rs Normal file
View File

@@ -0,0 +1 @@
pub mod error;

View File

@@ -4,6 +4,8 @@
//! - Reads 8-byte little-endian u64 (requested payload size N)
//! - Responds with exactly N bytes (deterministic pattern)
mod error;
use base64::prelude::*;
use clap::Parser;
use common::{
@@ -11,7 +13,6 @@ use common::{
cert::{CaCertificate, ServerCertificate},
protocol::{read_request, write_payload},
};
use miette::miette;
use rustls::{
ServerConfig,
crypto::aws_lc_rs::{
@@ -48,7 +49,7 @@ struct Args {
fn build_tls_config(
mode: KeyExchangeMode,
server_cert: &ServerCertificate,
) -> miette::Result<Arc<ServerConfig>> {
) -> error::Result<Arc<ServerConfig>> {
let mut provider = aws_lc_rs::default_provider();
provider.kx_groups = match mode {
KeyExchangeMode::X25519 => vec![X25519],
@@ -62,14 +63,14 @@ fn build_tls_config(
.collect::<Vec<_>>();
let key = PrivateKeyDer::try_from(server_cert.private_key_der.clone())
.map_err(|e| miette!("invalid private key: {e}"))?;
.map_err(|e| error::Error::invalid_cert(format!("invalid private_key: {e}")))?;
let config = ServerConfig::builder_with_provider(Arc::new(provider))
.with_protocol_versions(&[&TLS13])
.map_err(|e| miette!("failed to set TLS versions: {e}"))?
.map_err(common::Error::Tls)?
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| miette!("failed to configure certificate: {e}"))?;
.map_err(common::Error::Tls)?;
Ok(Arc::new(config))
}
@@ -123,10 +124,10 @@ async fn handle_connection(stream: TcpStream, peer: SocketAddr, tls_config: Arc<
}
}
async fn run_server(args: Args, tls_config: Arc<ServerConfig>) -> miette::Result<()> {
async fn run_server(args: Args, tls_config: Arc<ServerConfig>) -> error::Result<()> {
let listener = TcpListener::bind(args.listen)
.await
.map_err(|e| miette!("failed to bind to {}: {e}", args.listen))?;
.map_err(|e| error::Error::network(format!("failed to bind to {}: {e}", args.listen)))?;
info!(listen = %args.listen, mode = %args.mode, "listening");
@@ -164,10 +165,10 @@ async fn main() -> miette::Result<()> {
);
info!("Generating self-signed certificates...");
let ca = CaCertificate::generate().map_err(|e| miette!("failed to generate CA: {e}"))?;
let ca = CaCertificate::generate().map_err(common::Error::RCGen)?;
let server_cert = ca
.sign_server_cert("localhost")
.map_err(|e| miette!("failed to generate server cert: {e}"))?;
.map_err(common::Error::RCGen)?;
let tls_config = build_tls_config(args.mode, &server_cert)?;
@@ -180,5 +181,5 @@ async fn main() -> miette::Result<()> {
"CA cert (truncated)"
);
run_server(args, tls_config).await
Ok(run_server(args, tls_config).await?)
}