diff --git a/Cargo.lock b/Cargo.lock index 99532ed..09acb52 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -237,6 +237,15 @@ dependencies = [ "windows-targets 0.52.4", ] +[[package]] +name = "claims" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6995bbe186456c36307f8ea36be3eefe42f49d106896414e18efc4fb2f846b5" +dependencies = [ + "autocfg", +] + [[package]] name = "config" version = "0.14.0" @@ -2710,6 +2719,7 @@ version = "0.1.0" dependencies = [ "axum", "chrono", + "claims", "config", "once_cell", "reqwest", @@ -2723,6 +2733,7 @@ dependencies = [ "tracing-bunyan-formatter", "tracing-log 0.2.0", "tracing-subscriber", + "unicode-segmentation", "uuid", ] diff --git a/Cargo.toml b/Cargo.toml index 6cf111b..9547c49 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,8 @@ tracing-bunyan-formatter = "0.3" tracing-log = "0.2" secrecy = { version = "0.8", features = ["serde"] } serde-aux = "4" +unicode-segmentation = "1" +claims = "0.7" [dev-dependencies] reqwest = "0.12" diff --git a/src/domain.rs b/src/domain.rs new file mode 100644 index 0000000..221e8c6 --- /dev/null +++ b/src/domain.rs @@ -0,0 +1,74 @@ +use unicode_segmentation::UnicodeSegmentation; + +#[derive(Debug)] +pub struct NewSubscriber { + pub email: String, + pub name: SubscriberName, +} + +#[derive(Debug)] +pub struct SubscriberName(String); + +impl SubscriberName { + pub fn parse(s: String) -> Result { + let is_empty_or_whitespace = s.trim().is_empty(); + let is_too_long = s.graphemes(true).count() > 256; + let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}']; + let contains_forbidden_characters = s.chars().any(|c| forbidden_characters.contains(&c)); + if is_empty_or_whitespace || is_too_long || contains_forbidden_characters { + return Err(format!("{} is not a valid subscriber name.", s)); + } + Ok(Self(s)) + } +} + +impl AsRef for SubscriberName { + fn as_ref(&self) -> &str { + &self.0 + } +} + +#[cfg(test)] +mod tests { + + use claims::{assert_err, assert_ok}; + + use super::*; + #[test] + fn a_256_grapheme_long_name_is_valid() { + let name = "ē".repeat(256); + assert_ok!(SubscriberName::parse(name)); + } + + #[test] + fn a_name_longer_than_256_graphemes_is_rejected() { + let name = "a".repeat(257); + assert_err!(SubscriberName::parse(name)); + } + + #[test] + fn whitespace_only_names_are_rejected() { + let name = " ".to_string(); + assert_err!(SubscriberName::parse(name)); + } + + #[test] + fn empty_string_is_rejected() { + let name = "".to_string(); + assert_err!(SubscriberName::parse(name)); + } + + #[test] + fn names_containing_an_invalid_character_are_rejected() { + for name in &['/', '(', ')', '"', '<', '>', '\\', '{', '}'] { + let name = name.to_string(); + assert_err!(SubscriberName::parse(name)); + } + } + + #[test] + fn a_valid_name_is_parsed_successfully() { + let name = "Ursula Le Guin".to_string(); + assert_ok!(SubscriberName::parse(name)); + } +} diff --git a/src/lib.rs b/src/lib.rs index 0276897..f24afc4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ pub mod config; +pub mod domain; pub mod routes; pub mod telemetry; diff --git a/src/routes/subscibtions.rs b/src/routes/subscibtions.rs index 955aa2c..ad63037 100644 --- a/src/routes/subscibtions.rs +++ b/src/routes/subscibtions.rs @@ -3,8 +3,11 @@ use chrono::Utc; use serde::Deserialize; use sqlx::PgPool; use tracing::error; +use unicode_segmentation::UnicodeSegmentation; use uuid::Uuid; +use crate::domain::{NewSubscriber, SubscriberName}; + #[derive(Deserialize)] pub struct FormData { name: String, @@ -23,7 +26,15 @@ pub async fn subscribe( State(pool): State, Form(form): Form, ) -> impl IntoResponse { - match insert_subscriber(&pool, &form).await { + if !is_valid_name(&form.name) { + return StatusCode::BAD_REQUEST; + } + + let new_subscriber = NewSubscriber { + email: form.email, + name: SubscriberName::parse(form.name).expect("Name validation failed."), + }; + match insert_subscriber(&pool, &new_subscriber).await { Ok(_) => StatusCode::OK, Err(_) => StatusCode::INTERNAL_SERVER_ERROR, } @@ -31,17 +42,20 @@ pub async fn subscribe( #[tracing::instrument( name = "Saving new subscriber details in the database", - skip(form, pool) + skip(new_subscriber, pool) )] -pub async fn insert_subscriber(pool: &PgPool, form: &FormData) -> Result<(), sqlx::Error> { +pub async fn insert_subscriber( + pool: &PgPool, + new_subscriber: &NewSubscriber, +) -> Result<(), sqlx::Error> { sqlx::query!( r#" INSERT INTO subscriptions(id, email, name, subscribed_at) VALUES ($1, $2, $3, $4) "#, Uuid::new_v4(), - form.email, - form.name, + new_subscriber.email, + new_subscriber.name.as_ref(), Utc::now() ) .execute(pool) @@ -52,3 +66,13 @@ pub async fn insert_subscriber(pool: &PgPool, form: &FormData) -> Result<(), sql })?; Ok(()) } + +/// Returns `true` if the input satisfies all validation constraints +/// on subscriber names, `false` otherwise. +pub fn is_valid_name(s: &str) -> bool { + let is_empty_or_whitespace = s.trim().is_empty(); + let is_too_long = s.graphemes(true).count() > 256; + let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}']; + let contains_forbidden_characters = s.chars().any(|c| forbidden_characters.contains(&c)); + !(is_empty_or_whitespace || is_too_long || contains_forbidden_characters) +} diff --git a/tests/health_check.rs b/tests/health_check.rs index 278aa2d..fc110a3 100644 --- a/tests/health_check.rs +++ b/tests/health_check.rs @@ -27,13 +27,14 @@ async fn health_check() { #[tokio::test] async fn subscribe_returns_200_for_valid_form_data() { let app = spawn_app().await; + let body = "name=Kristofers%20Solo&email=dev%40kristofers.solo"; + let config = get_config().expect("Failed to read configuration."); let mut connection = PgConnection::connect_with(&config.database.with_db()) .await .expect("Failed to connect to Postgres."); let client = Client::new(); - let body = "name=Kristofers%20Solo&email=dev%40kristofers.solo"; let response = client .post(&format!("{}/subscriptions", &app.address)) .header("Content-Type", "application/x-www-form-urlencoded") @@ -53,8 +54,8 @@ async fn subscribe_returns_200_for_valid_form_data() { .await .expect("Failed to fetch saved subscription."); - assert_eq!(saved.email, "dev@kristofers.solo"); assert_eq!(saved.name, "Kristofers Solo"); + assert_eq!(saved.email, "dev@kristofers.solo"); } #[tokio::test] @@ -86,8 +87,35 @@ async fn subscribe_returns_400_when_data_is_missing() { } } +#[tokio::test] +async fn subscribe_returns_400_when_fields_are_present_but_invalid() { + let app = spawn_app().await; + let client = Client::new(); + let test_cases = vec![ + ("name=&email=dev%40kristofers.solo", "empty name"), + ("name=kristofers%20solo&email=", "empty email"), + ("name=solo&email=definetely-not-an-email", "invalid email"), + ]; + + for (body, description) in test_cases { + let response = client + .post(&format!("{}/subscriptions", &app.address)) + .header("Content-Type", "application/x-www-form-urlencoded") + .body(body) + .send() + .await + .expect("Failed to execute request."); + assert_eq!( + 400, + response.status().as_u16(), + "The API did not return 400 Bad Request when the payload was {}.", + description + ); + } +} + static TRACING: Lazy<()> = Lazy::new(|| { - let default_filter_level = "info"; + let default_filter_level = "trace"; let subscriber_name = "test"; if std::env::var("TEST_LOG").is_ok() { let subscriber = get_subscriber(subscriber_name, default_filter_level, std::io::stdout); @@ -105,9 +133,11 @@ async fn spawn_app() -> TestApp { .expect("Failed to bind random port"); let port = listener.local_addr().unwrap().port(); let address = format!("http://127.0.0.1:{}", port); + let mut config = get_config().expect("Failed to read configuration."); config.database.database_name = Uuid::new_v4().to_string(); + let pool = configure_database(&config.database).await; let pool_clone = pool.clone(); let _ = tokio::spawn(async move {