feat: add subscriber name validation

This commit is contained in:
Kristofers Solo 2024-03-25 17:09:51 +02:00
parent af876c680b
commit 7657373f3b
6 changed files with 150 additions and 8 deletions

11
Cargo.lock generated
View File

@ -237,6 +237,15 @@ dependencies = [
"windows-targets 0.52.4", "windows-targets 0.52.4",
] ]
[[package]]
name = "claims"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6995bbe186456c36307f8ea36be3eefe42f49d106896414e18efc4fb2f846b5"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "config" name = "config"
version = "0.14.0" version = "0.14.0"
@ -2710,6 +2719,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"axum", "axum",
"chrono", "chrono",
"claims",
"config", "config",
"once_cell", "once_cell",
"reqwest", "reqwest",
@ -2723,6 +2733,7 @@ dependencies = [
"tracing-bunyan-formatter", "tracing-bunyan-formatter",
"tracing-log 0.2.0", "tracing-log 0.2.0",
"tracing-subscriber", "tracing-subscriber",
"unicode-segmentation",
"uuid", "uuid",
] ]

View File

@ -36,6 +36,8 @@ tracing-bunyan-formatter = "0.3"
tracing-log = "0.2" tracing-log = "0.2"
secrecy = { version = "0.8", features = ["serde"] } secrecy = { version = "0.8", features = ["serde"] }
serde-aux = "4" serde-aux = "4"
unicode-segmentation = "1"
claims = "0.7"
[dev-dependencies] [dev-dependencies]
reqwest = "0.12" reqwest = "0.12"

74
src/domain.rs Normal file
View File

@ -0,0 +1,74 @@
use unicode_segmentation::UnicodeSegmentation;
#[derive(Debug)]
pub struct NewSubscriber {
pub email: String,
pub name: SubscriberName,
}
#[derive(Debug)]
pub struct SubscriberName(String);
impl SubscriberName {
pub fn parse(s: String) -> Result<Self, String> {
let is_empty_or_whitespace = s.trim().is_empty();
let is_too_long = s.graphemes(true).count() > 256;
let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}'];
let contains_forbidden_characters = s.chars().any(|c| forbidden_characters.contains(&c));
if is_empty_or_whitespace || is_too_long || contains_forbidden_characters {
return Err(format!("{} is not a valid subscriber name.", s));
}
Ok(Self(s))
}
}
impl AsRef<str> for SubscriberName {
fn as_ref(&self) -> &str {
&self.0
}
}
#[cfg(test)]
mod tests {
use claims::{assert_err, assert_ok};
use super::*;
#[test]
fn a_256_grapheme_long_name_is_valid() {
let name = "ē".repeat(256);
assert_ok!(SubscriberName::parse(name));
}
#[test]
fn a_name_longer_than_256_graphemes_is_rejected() {
let name = "a".repeat(257);
assert_err!(SubscriberName::parse(name));
}
#[test]
fn whitespace_only_names_are_rejected() {
let name = " ".to_string();
assert_err!(SubscriberName::parse(name));
}
#[test]
fn empty_string_is_rejected() {
let name = "".to_string();
assert_err!(SubscriberName::parse(name));
}
#[test]
fn names_containing_an_invalid_character_are_rejected() {
for name in &['/', '(', ')', '"', '<', '>', '\\', '{', '}'] {
let name = name.to_string();
assert_err!(SubscriberName::parse(name));
}
}
#[test]
fn a_valid_name_is_parsed_successfully() {
let name = "Ursula Le Guin".to_string();
assert_ok!(SubscriberName::parse(name));
}
}

View File

@ -1,3 +1,4 @@
pub mod config; pub mod config;
pub mod domain;
pub mod routes; pub mod routes;
pub mod telemetry; pub mod telemetry;

View File

@ -3,8 +3,11 @@ use chrono::Utc;
use serde::Deserialize; use serde::Deserialize;
use sqlx::PgPool; use sqlx::PgPool;
use tracing::error; use tracing::error;
use unicode_segmentation::UnicodeSegmentation;
use uuid::Uuid; use uuid::Uuid;
use crate::domain::{NewSubscriber, SubscriberName};
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct FormData { pub struct FormData {
name: String, name: String,
@ -23,7 +26,15 @@ pub async fn subscribe(
State(pool): State<PgPool>, State(pool): State<PgPool>,
Form(form): Form<FormData>, Form(form): Form<FormData>,
) -> impl IntoResponse { ) -> impl IntoResponse {
match insert_subscriber(&pool, &form).await { if !is_valid_name(&form.name) {
return StatusCode::BAD_REQUEST;
}
let new_subscriber = NewSubscriber {
email: form.email,
name: SubscriberName::parse(form.name).expect("Name validation failed."),
};
match insert_subscriber(&pool, &new_subscriber).await {
Ok(_) => StatusCode::OK, Ok(_) => StatusCode::OK,
Err(_) => StatusCode::INTERNAL_SERVER_ERROR, Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
} }
@ -31,17 +42,20 @@ pub async fn subscribe(
#[tracing::instrument( #[tracing::instrument(
name = "Saving new subscriber details in the database", name = "Saving new subscriber details in the database",
skip(form, pool) skip(new_subscriber, pool)
)] )]
pub async fn insert_subscriber(pool: &PgPool, form: &FormData) -> Result<(), sqlx::Error> { pub async fn insert_subscriber(
pool: &PgPool,
new_subscriber: &NewSubscriber,
) -> Result<(), sqlx::Error> {
sqlx::query!( sqlx::query!(
r#" r#"
INSERT INTO subscriptions(id, email, name, subscribed_at) INSERT INTO subscriptions(id, email, name, subscribed_at)
VALUES ($1, $2, $3, $4) VALUES ($1, $2, $3, $4)
"#, "#,
Uuid::new_v4(), Uuid::new_v4(),
form.email, new_subscriber.email,
form.name, new_subscriber.name.as_ref(),
Utc::now() Utc::now()
) )
.execute(pool) .execute(pool)
@ -52,3 +66,13 @@ pub async fn insert_subscriber(pool: &PgPool, form: &FormData) -> Result<(), sql
})?; })?;
Ok(()) Ok(())
} }
/// Returns `true` if the input satisfies all validation constraints
/// on subscriber names, `false` otherwise.
pub fn is_valid_name(s: &str) -> bool {
let is_empty_or_whitespace = s.trim().is_empty();
let is_too_long = s.graphemes(true).count() > 256;
let forbidden_characters = ['/', '(', ')', '"', '<', '>', '\\', '{', '}'];
let contains_forbidden_characters = s.chars().any(|c| forbidden_characters.contains(&c));
!(is_empty_or_whitespace || is_too_long || contains_forbidden_characters)
}

View File

@ -27,13 +27,14 @@ async fn health_check() {
#[tokio::test] #[tokio::test]
async fn subscribe_returns_200_for_valid_form_data() { async fn subscribe_returns_200_for_valid_form_data() {
let app = spawn_app().await; let app = spawn_app().await;
let body = "name=Kristofers%20Solo&email=dev%40kristofers.solo";
let config = get_config().expect("Failed to read configuration."); let config = get_config().expect("Failed to read configuration.");
let mut connection = PgConnection::connect_with(&config.database.with_db()) let mut connection = PgConnection::connect_with(&config.database.with_db())
.await .await
.expect("Failed to connect to Postgres."); .expect("Failed to connect to Postgres.");
let client = Client::new(); let client = Client::new();
let body = "name=Kristofers%20Solo&email=dev%40kristofers.solo";
let response = client let response = client
.post(&format!("{}/subscriptions", &app.address)) .post(&format!("{}/subscriptions", &app.address))
.header("Content-Type", "application/x-www-form-urlencoded") .header("Content-Type", "application/x-www-form-urlencoded")
@ -53,8 +54,8 @@ async fn subscribe_returns_200_for_valid_form_data() {
.await .await
.expect("Failed to fetch saved subscription."); .expect("Failed to fetch saved subscription.");
assert_eq!(saved.email, "dev@kristofers.solo");
assert_eq!(saved.name, "Kristofers Solo"); assert_eq!(saved.name, "Kristofers Solo");
assert_eq!(saved.email, "dev@kristofers.solo");
} }
#[tokio::test] #[tokio::test]
@ -86,8 +87,35 @@ async fn subscribe_returns_400_when_data_is_missing() {
} }
} }
#[tokio::test]
async fn subscribe_returns_400_when_fields_are_present_but_invalid() {
let app = spawn_app().await;
let client = Client::new();
let test_cases = vec![
("name=&email=dev%40kristofers.solo", "empty name"),
("name=kristofers%20solo&email=", "empty email"),
("name=solo&email=definetely-not-an-email", "invalid email"),
];
for (body, description) in test_cases {
let response = client
.post(&format!("{}/subscriptions", &app.address))
.header("Content-Type", "application/x-www-form-urlencoded")
.body(body)
.send()
.await
.expect("Failed to execute request.");
assert_eq!(
400,
response.status().as_u16(),
"The API did not return 400 Bad Request when the payload was {}.",
description
);
}
}
static TRACING: Lazy<()> = Lazy::new(|| { static TRACING: Lazy<()> = Lazy::new(|| {
let default_filter_level = "info"; let default_filter_level = "trace";
let subscriber_name = "test"; let subscriber_name = "test";
if std::env::var("TEST_LOG").is_ok() { if std::env::var("TEST_LOG").is_ok() {
let subscriber = get_subscriber(subscriber_name, default_filter_level, std::io::stdout); let subscriber = get_subscriber(subscriber_name, default_filter_level, std::io::stdout);
@ -105,9 +133,11 @@ async fn spawn_app() -> TestApp {
.expect("Failed to bind random port"); .expect("Failed to bind random port");
let port = listener.local_addr().unwrap().port(); let port = listener.local_addr().unwrap().port();
let address = format!("http://127.0.0.1:{}", port); let address = format!("http://127.0.0.1:{}", port);
let mut config = get_config().expect("Failed to read configuration."); let mut config = get_config().expect("Failed to read configuration.");
config.database.database_name = Uuid::new_v4().to_string(); config.database.database_name = Uuid::new_v4().to_string();
let pool = configure_database(&config.database).await; let pool = configure_database(&config.database).await;
let pool_clone = pool.clone(); let pool_clone = pool.clone();
let _ = tokio::spawn(async move { let _ = tokio::spawn(async move {