mirror of
https://github.com/kristoferssolo/tls-pq-bench.git
synced 2026-03-21 16:26:22 +00:00
refactor(server,common): introduce custom error types with thiserror and miette
This commit is contained in:
5
Cargo.lock
generated
5
Cargo.lock
generated
@@ -275,11 +275,15 @@ name = "common"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"cargo-husky",
|
||||
"miette",
|
||||
"rcgen",
|
||||
"rustls",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"strum",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"toml",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -915,6 +919,7 @@ dependencies = [
|
||||
"common",
|
||||
"miette",
|
||||
"rustls",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tracing",
|
||||
|
||||
@@ -5,11 +5,15 @@ authors.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
miette.workspace = true
|
||||
rcgen.workspace = true
|
||||
rustls.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
strum.workspace = true
|
||||
thiserror.workspace = true
|
||||
tokio.workspace = true
|
||||
toml.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
cargo-husky.workspace = true
|
||||
|
||||
65
common/src/error.rs
Normal file
65
common/src/error.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
use miette::Diagnostic;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Result type using the common's custom error type.
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Debug, Error, Diagnostic)]
|
||||
pub enum Error {
|
||||
/// File or network I/O error.
|
||||
#[error(transparent)]
|
||||
#[diagnostic(code(common::io_error))]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
#[diagnostic(code(common::rustls_error))]
|
||||
Tls(#[from] rustls::Error),
|
||||
|
||||
/// TOML configuration file parse error.
|
||||
#[error(transparent)]
|
||||
#[diagnostic(code(common::toml_error))]
|
||||
Toml(#[from] toml::de::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
#[diagnostic(code(common::json_error))]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
#[diagnostic(code(common::rcgen_error))]
|
||||
RCGen(#[from] rcgen::Error),
|
||||
|
||||
/// Configuration validation or missing required fields.
|
||||
#[error("Config error: {0}")]
|
||||
#[diagnostic(code(common::config_error))]
|
||||
Config(String),
|
||||
|
||||
/// Invalid key exchange mode string.
|
||||
#[error("Invalid mode: {0}")]
|
||||
#[diagnostic(code(common::invalid_mode))]
|
||||
InvalidMode(String),
|
||||
|
||||
/// Protocol-level error (malformed requests, unexpected responses).
|
||||
#[error("Protocol error: {0}")]
|
||||
#[diagnostic(code(common::protocol_error))]
|
||||
Protocol(String),
|
||||
}
|
||||
|
||||
impl Error {
|
||||
/// Create an invalid mode error.
|
||||
#[inline]
|
||||
pub fn invalid_mode(error: impl Into<String>) -> Self {
|
||||
Self::InvalidMode(error.into())
|
||||
}
|
||||
|
||||
/// Create a config error.
|
||||
#[inline]
|
||||
pub fn config(error: impl Into<String>) -> Self {
|
||||
Self::Config(error.into())
|
||||
}
|
||||
|
||||
/// Create a protocol error.
|
||||
#[inline]
|
||||
pub fn protocol(error: impl Into<String>) -> Self {
|
||||
Self::Protocol(error.into())
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,10 @@
|
||||
//! Common types and utilities for the TLS benchmark harness.
|
||||
|
||||
pub mod cert;
|
||||
pub mod error;
|
||||
pub mod protocol;
|
||||
|
||||
pub use error::Error;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
use strum::{Display, EnumString};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::{args::Args, error};
|
||||
use common::KeyExchangeMode;
|
||||
use common::{self, KeyExchangeMode};
|
||||
use serde::Deserialize;
|
||||
use std::{fs::read_to_string, net::SocketAddr, path::PathBuf};
|
||||
|
||||
@@ -23,8 +23,8 @@ pub struct Config {
|
||||
/// # Errors
|
||||
/// Returns an error if the file cannot be read or parsed.
|
||||
pub fn load_from_file(path: &PathBuf) -> error::Result<Config> {
|
||||
let content = read_to_string(path).map_err(error::Error::Io)?;
|
||||
let config = toml::from_str::<Config>(&content).map_err(error::Error::Toml)?;
|
||||
let content = read_to_string(path).map_err(common::Error::Io)?;
|
||||
let config = toml::from_str::<Config>(&content).map_err(common::Error::Toml)?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ pub fn load_from_cli(args: &Args) -> error::Result<Config> {
|
||||
concurrency: args.concurrency,
|
||||
server: args
|
||||
.server
|
||||
.ok_or_else(|| error::Error::config("--server ir required"))?,
|
||||
.ok_or_else(|| common::Error::config("--server ir required"))?,
|
||||
}],
|
||||
})
|
||||
}
|
||||
|
||||
@@ -9,64 +9,20 @@ pub type Result<T> = std::result::Result<T, Error>;
|
||||
/// Errors that can occur during benchmark execution.
|
||||
#[derive(Debug, Error, Diagnostic)]
|
||||
pub enum Error {
|
||||
/// TLS configuration or handshake failure.
|
||||
#[error(transparent)]
|
||||
#[diagnostic(code(runner::tls_error))]
|
||||
TlsConfig(#[from] rustls::Error),
|
||||
|
||||
/// File or network I/O error.
|
||||
#[error(transparent)]
|
||||
#[diagnostic(code(runner::io_error))]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
/// TOML configuration file parse error.
|
||||
#[error(transparent)]
|
||||
#[diagnostic(code(runner::toml_error))]
|
||||
Toml(#[from] toml::de::Error),
|
||||
|
||||
/// Invalid key exchange mode string.
|
||||
#[error("Invalid mode: {0}")]
|
||||
#[diagnostic(code(runner::invalid_mode))]
|
||||
InvalidMode(String),
|
||||
|
||||
/// Configuration validation or missing required fields.
|
||||
#[error("Config error: {0}")]
|
||||
#[diagnostic(code(runner::config_error))]
|
||||
Config(String),
|
||||
#[diagnostic(code(runner::common_error))]
|
||||
Common(#[from] common::Error),
|
||||
|
||||
/// Network connection failure.
|
||||
#[error("Network error: {0}")]
|
||||
#[diagnostic(code(runner::network_error))]
|
||||
Network(String),
|
||||
|
||||
/// Protocol-level error (malformed requests, unexpected responses).
|
||||
#[error("Protocol error: {0}")]
|
||||
#[diagnostic(code(runner::protocol_error))]
|
||||
Protocol(String),
|
||||
}
|
||||
|
||||
impl Error {
|
||||
/// Create an invalid mode error.
|
||||
#[inline]
|
||||
pub fn invalid_mode(error: impl Into<String>) -> Self {
|
||||
Self::InvalidMode(error.into())
|
||||
}
|
||||
|
||||
/// Create a config error.
|
||||
#[inline]
|
||||
pub fn config(error: impl Into<String>) -> Self {
|
||||
Self::Config(error.into())
|
||||
}
|
||||
|
||||
/// Create a network error.
|
||||
#[inline]
|
||||
pub fn network(error: impl Into<String>) -> Self {
|
||||
Self::Network(error.into())
|
||||
}
|
||||
|
||||
/// Create a protocol error.
|
||||
#[inline]
|
||||
pub fn protocol(error: impl Into<String>) -> Self {
|
||||
Self::Protocol(error.into())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,10 +10,11 @@ clap.workspace = true
|
||||
common.workspace = true
|
||||
miette.workspace = true
|
||||
rustls.workspace = true
|
||||
thiserror.workspace = true
|
||||
tokio-rustls.workspace = true
|
||||
tokio.workspace = true
|
||||
tracing.workspace = true
|
||||
tracing-subscriber.workspace = true
|
||||
tracing.workspace = true
|
||||
uuid.workspace = true
|
||||
|
||||
[lints]
|
||||
|
||||
38
server/src/error.rs
Normal file
38
server/src/error.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
use miette::Diagnostic;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Result type using the servers's custom error type.
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Debug, Error, Diagnostic)]
|
||||
pub enum Error {
|
||||
#[error(transparent)]
|
||||
#[diagnostic(code(runner::common_error))]
|
||||
Common(#[from] common::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
#[diagnostic(code(server::cert_validation_error))]
|
||||
CertValidation(#[from] rustls::pki_types::InvalidDnsNameError),
|
||||
|
||||
/// Network connection failure.
|
||||
#[error("Network error: {0}")]
|
||||
#[diagnostic(code(server::network_error))]
|
||||
Network(String),
|
||||
|
||||
#[error("Invalid certificate: {0}")]
|
||||
#[diagnostic(code(server::invalid_cert))]
|
||||
InvalidCert(String),
|
||||
}
|
||||
|
||||
impl Error {
|
||||
/// Create a network error.
|
||||
#[inline]
|
||||
pub fn network(error: impl Into<String>) -> Self {
|
||||
Self::Network(error.into())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn invalid_cert(error: impl Into<String>) -> Self {
|
||||
Self::InvalidCert(error.into())
|
||||
}
|
||||
}
|
||||
1
server/src/lib.rs
Normal file
1
server/src/lib.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod error;
|
||||
@@ -4,6 +4,8 @@
|
||||
//! - Reads 8-byte little-endian u64 (requested payload size N)
|
||||
//! - Responds with exactly N bytes (deterministic pattern)
|
||||
|
||||
mod error;
|
||||
|
||||
use base64::prelude::*;
|
||||
use clap::Parser;
|
||||
use common::{
|
||||
@@ -11,7 +13,6 @@ use common::{
|
||||
cert::{CaCertificate, ServerCertificate},
|
||||
protocol::{read_request, write_payload},
|
||||
};
|
||||
use miette::miette;
|
||||
use rustls::{
|
||||
ServerConfig,
|
||||
crypto::aws_lc_rs::{
|
||||
@@ -48,7 +49,7 @@ struct Args {
|
||||
fn build_tls_config(
|
||||
mode: KeyExchangeMode,
|
||||
server_cert: &ServerCertificate,
|
||||
) -> miette::Result<Arc<ServerConfig>> {
|
||||
) -> error::Result<Arc<ServerConfig>> {
|
||||
let mut provider = aws_lc_rs::default_provider();
|
||||
provider.kx_groups = match mode {
|
||||
KeyExchangeMode::X25519 => vec![X25519],
|
||||
@@ -62,14 +63,14 @@ fn build_tls_config(
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let key = PrivateKeyDer::try_from(server_cert.private_key_der.clone())
|
||||
.map_err(|e| miette!("invalid private key: {e}"))?;
|
||||
.map_err(|e| error::Error::invalid_cert(format!("invalid private_key: {e}")))?;
|
||||
|
||||
let config = ServerConfig::builder_with_provider(Arc::new(provider))
|
||||
.with_protocol_versions(&[&TLS13])
|
||||
.map_err(|e| miette!("failed to set TLS versions: {e}"))?
|
||||
.map_err(common::Error::Tls)?
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)
|
||||
.map_err(|e| miette!("failed to configure certificate: {e}"))?;
|
||||
.map_err(common::Error::Tls)?;
|
||||
|
||||
Ok(Arc::new(config))
|
||||
}
|
||||
@@ -123,10 +124,10 @@ async fn handle_connection(stream: TcpStream, peer: SocketAddr, tls_config: Arc<
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_server(args: Args, tls_config: Arc<ServerConfig>) -> miette::Result<()> {
|
||||
async fn run_server(args: Args, tls_config: Arc<ServerConfig>) -> error::Result<()> {
|
||||
let listener = TcpListener::bind(args.listen)
|
||||
.await
|
||||
.map_err(|e| miette!("failed to bind to {}: {e}", args.listen))?;
|
||||
.map_err(|e| error::Error::network(format!("failed to bind to {}: {e}", args.listen)))?;
|
||||
|
||||
info!(listen = %args.listen, mode = %args.mode, "listening");
|
||||
|
||||
@@ -164,10 +165,10 @@ async fn main() -> miette::Result<()> {
|
||||
);
|
||||
|
||||
info!("Generating self-signed certificates...");
|
||||
let ca = CaCertificate::generate().map_err(|e| miette!("failed to generate CA: {e}"))?;
|
||||
let ca = CaCertificate::generate().map_err(common::Error::RCGen)?;
|
||||
let server_cert = ca
|
||||
.sign_server_cert("localhost")
|
||||
.map_err(|e| miette!("failed to generate server cert: {e}"))?;
|
||||
.map_err(common::Error::RCGen)?;
|
||||
|
||||
let tls_config = build_tls_config(args.mode, &server_cert)?;
|
||||
|
||||
@@ -180,5 +181,5 @@ async fn main() -> miette::Result<()> {
|
||||
"CA cert (truncated)"
|
||||
);
|
||||
|
||||
run_server(args, tls_config).await
|
||||
Ok(run_server(args, tls_config).await?)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user