commit e20f21bc8b35cf6c3411565548a95c45ae4807e3 Author: Rik Heijmann Date: Thu Jan 30 22:43:30 2025 +0100 first commit diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..e181198 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,35 @@ + + # Include any files or directories that you don't want to be copied to your + # container here (e.g., local build artifacts, temporary files, etc.). + # + # For more help, visit the .dockerignore file reference guide at + # https://docs.docker.com/engine/reference/builder/#dockerignore-file + + **/.DS_Store + **/.classpath + **/.dockerignore + **/.env + **/.git + **/.gitignore + **/.project + **/.settings + **/.toolstarget + **/.vs + **/.vscode + **/*.*proj.user + **/*.dbmdl + **/*.jfm + **/charts + **/docker-compose* + **/compose* + **/Dockerfile* + **/node_modules + **/npm-debug.log + **/secrets.dev.yaml + **/values.dev.yaml + /bin + /target + LICENSE + README.md + + \ No newline at end of file diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..e315459 --- /dev/null +++ b/.env.example @@ -0,0 +1,68 @@ +# ============================== +# 📌 DATABASE CONFIGURATION +# ============================== + +# PostgreSQL connection URL (format: postgres://user:password@host/database) +DATABASE_URL="postgres://postgres:1234@localhost/database_name" + +# Maximum number of connections in the database pool +DATABASE_MAX_CONNECTIONS=20 + +# Minimum number of connections in the database pool +DATABASE_MIN_CONNECTIONS=5 + +# ============================== +# 🌍 SERVER CONFIGURATION +# ============================== + +# IP address the server will bind to (0.0.0.0 allows all network interfaces) +SERVER_IP="0.0.0.0" + +# Port the server will listen on +SERVER_PORT="3000" + +# Enable tracing for debugging/logging (true/false) +SERVER_TRACE_ENABLED=true + +# ============================== +# 🔒 HTTPS CONFIGURATION +# ============================== + +# Enable HTTPS (true/false) +SERVER_HTTPS_ENABLED=false + +# Enable HTTP/2 when using HTTPS (true/false) +SERVER_HTTPS_HTTP2_ENABLED=true + +# Path to the SSL certificate file (only used if SERVER_HTTPS_ENABLED=true) +SERVER_HTTPS_CERT_FILE_PATH=cert.pem + +# Path to the SSL private key file (only used if SERVER_HTTPS_ENABLED=true) +SERVER_HTTPS_KEY_FILE_PATH=key.pem + +# ============================== +# 🚦 RATE LIMIT CONFIGURATION +# ============================== + +# Maximum number of requests allowed per period +SERVER_RATE_LIMIT=5 + +# Time period (in seconds) for rate limiting +SERVER_RATE_LIMIT_PERIOD=1 + +# ============================== +# 📦 COMPRESSION CONFIGURATION +# ============================== + +# Enable Brotli compression (true/false) +SERVER_COMPRESSION_ENABLED=true + +# Compression level (valid range: 0-11, where 11 is the highest compression) +SERVER_COMPRESSION_LEVEL=6 + +# ============================== +# 🔑 AUTHENTICATION CONFIGURATION +# ============================== + +# Argon2 salt for password hashing (must be kept secret!) +AUTHENTICATION_ARGON2_SALT="dMjQgtSmoQIH3Imi" \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..800b6dc --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.env +/target +cert.pem +key.pem +Cargo.lock \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..86bc5fd --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,45 @@ +[package] +name = "rustapi" +version = "0.1.0" +edition = "2021" + +[dependencies] +# Web framework and server +axum = { version = "0.8.1", features = ["json"] } +# hyper = { version = "1.5.2", features = ["full"] } + +# Database interaction +sqlx = { version = "0.8.3", features = ["runtime-tokio-rustls", "postgres", "migrate"] } + +# Serialization and deserialization +serde = { version = "1.0.217", features = ["derive"] } +serde_json = "1.0.137" + +# Authentication and security +jsonwebtoken = "9.3.0" +argon2 = "0.5.3" + +# Asynchronous runtime and traits +tokio = { version = "1.43.0", features = ["rt-multi-thread", "process"] } + +# Configuration and environment +dotenvy = "0.15.7" + +# Middleware and server utilities +tower = { version = "0.5.2", features = ["limit"] } +tower-http = { version = "0.6.2", features = ["trace", "cors", "compression-br"] } + +# Logging and monitoring +tracing = "0.1.41" +tracing-subscriber = "0.3.19" + +# System information +sysinfo = "0.33.1" + +# Date and time handling +chrono = "0.4.39" + +# SSL / TLS +rustls = "0.23.21" +tokio-rustls = "0.26.1" +rustls-pemfile = "2.2.0" diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..d7266d5 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,73 @@ + + # syntax=docker/dockerfile:1 + + # Comments are provided throughout this file to help you get started. + # If you need more help, visit the Dockerfile reference guide at + # https://docs.docker.com/engine/reference/builder/ + + ################################################################################ + # Create a stage for building the application. + + ARG RUST_VERSION=1.78.0 + ARG APP_NAME=backend + FROM rust:${RUST_VERSION}-slim-bullseye AS build + ARG APP_NAME + WORKDIR /app + + # Build the application. + # Leverage a cache mount to /usr/local/cargo/registry/ + # for downloaded dependencies and a cache mount to /app/target/ for + # compiled dependencies which will speed up subsequent builds. + # Leverage a bind mount to the src directory to avoid having to copy the + # source code into the container. Once built, copy the executable to an + # output directory before the cache mounted /app/target is unmounted. + RUN --mount=type=bind,source=src,target=src \ + # --mount=type=bind,source=configuration.yaml,target=configuration.yaml \ + --mount=type=bind,source=Cargo.toml,target=Cargo.toml \ + --mount=type=bind,source=Cargo.lock,target=Cargo.lock \ + --mount=type=cache,target=/app/target/ \ + --mount=type=cache,target=/usr/local/cargo/registry/ \ + <, // Injected user + Json(todo): Json + ) -> impl IntoResponse { + if todo.user_id != user.id { + return Err((StatusCode::FORBIDDEN, Json(json!({ + "error": "Cannot create todos for others" + })))); + } + ``` +- **Modern protocols ** - HTTP/2 with secure TLS defaults +- **Observability** - Integrated tracing +- **Optimized for performance** - Brotli compression +- **Easy configuration** - `.env` and environment variables +- **Documented codebase** - Extensive inline comments for easy modification and readability +- **Latest dependencies** - Regularly updated Rust ecosystem crates + +## 📦 Installation & Usage +```bash +# Clone and setup +git clone https://github.com/Riktastic/rustapi.git +cd rustapi && cp .env.example .env + +# Database setup +sqlx database create && sqlx migrate run + +# Start server +cargo run --release +``` + +### 🔐 Default Accounts + +**Warning:** These accounts should only be used for initial testing. Always change or disable them in production environments. + +| Email | Password | Role | +|---------------------|----------|----------------| +| `user@test.com` | `test` | User | +| `admin@test.com` | `test` | Administrator | + +⚠️ **Security Recommendations:** +1. Rotate passwords immediately after initial setup +2. Disable default accounts before deploying to production +3. Implement proper user management endpoints + + +## ⚙️ Configuration +```env +# ============================== +# 📌 DATABASE CONFIGURATION +# ============================== + +# PostgreSQL connection URL (format: postgres://user:password@host/database) +DATABASE_URL="postgres://postgres:1234@localhost/database_name" + +# Maximum number of connections in the database pool +DATABASE_MAX_CONNECTIONS=20 + +# Minimum number of connections in the database pool +DATABASE_MIN_CONNECTIONS=5 + +# ============================== +# 🌍 SERVER CONFIGURATION +# ============================== + +# IP address the server will bind to (0.0.0.0 allows all network interfaces) +SERVER_IP="0.0.0.0" + +# Port the server will listen on +SERVER_PORT="3000" + +# Enable tracing for debugging/logging (true/false) +SERVER_TRACE_ENABLED=true + +# ============================== +# 🔒 HTTPS CONFIGURATION +# ============================== + +# Enable HTTPS (true/false) +SERVER_HTTPS_ENABLED=false + +# Enable HTTP/2 when using HTTPS (true/false) +SERVER_HTTPS_HTTP2_ENABLED=true + +# Path to the SSL certificate file (only used if SERVER_HTTPS_ENABLED=true) +SERVER_HTTPS_CERT_FILE_PATH=cert.pem + +# Path to the SSL private key file (only used if SERVER_HTTPS_ENABLED=true) +SERVER_HTTPS_KEY_FILE_PATH=key.pem + +# ============================== +# 🚦 RATE LIMIT CONFIGURATION +# ============================== + +# Maximum number of requests allowed per period +SERVER_RATE_LIMIT=5 + +# Time period (in seconds) for rate limiting +SERVER_RATE_LIMIT_PERIOD=1 + +# ============================== +# 📦 COMPRESSION CONFIGURATION +# ============================== + +# Enable Brotli compression (true/false) +SERVER_COMPRESSION_ENABLED=true + +# Compression level (valid range: 0-11, where 11 is the highest compression) +SERVER_COMPRESSION_LEVEL=6 + +# ============================== +# 🔑 AUTHENTICATION CONFIGURATION +# ============================== + +# Argon2 salt for password hashing (must be kept secret!) +AUTHENTICATION_ARGON2_SALT="dMjQgtSmoQIH3Imi" +``` + +## 📂 Project Structure +``` +rustapi/ +├── migrations/ # SQL schema versions +├── src/ +│ ├── core/ # Config, TLS, server setup +│ ├── database/ # Query handling +│ ├── middlewares/ # Auth system +│ ├── models/ # Data structures +│ └── routes/ # API endpoints +└── Dockerfile # Containerization +``` + +## 🛠️ Technology Stack +| Category | Key Technologies | +|-----------------------|---------------------------------| +| Web Framework | Axum 0.8 + Tower | +| Database | PostgreSQL + SQLx 0.8 | +| Security | JWT + Argon2 + Rustls | +| Monitoring | Tracing + Sysinfo | diff --git a/compose.yaml b/compose.yaml new file mode 100644 index 0000000..a46bd80 --- /dev/null +++ b/compose.yaml @@ -0,0 +1,32 @@ + +services: + server: + build: + context: . + target: final + ports: + - 80:80 + depends_on: + - db_image + networks: + - common-net + + db_image: + image: postgres:latest + environment: + POSTGRES_PORT: 3306 + POSTGRES_DATABASE: database_name + POSTGRES_USER: user + POSTGRES_PASSWORD: database_password + POSTGRES_ROOT_PASSWORD: strong_database_password + expose: + - 3306 + ports: + - "3307:3306" + networks: + - common-net + +networks: + common-net: {} + + \ No newline at end of file diff --git a/generate_ssl_key.bat b/generate_ssl_key.bat new file mode 100644 index 0000000..a40bb7e --- /dev/null +++ b/generate_ssl_key.bat @@ -0,0 +1 @@ +openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes -subj "/CN=localhost" -addext "subjectAltName = DNS:localhost, IP:127.0.0.1" \ No newline at end of file diff --git a/migrations/20250128160043_create_roles_table.sql b/migrations/20250128160043_create_roles_table.sql new file mode 100644 index 0000000..c1546e5 --- /dev/null +++ b/migrations/20250128160043_create_roles_table.sql @@ -0,0 +1,19 @@ +-- Create the roles table +CREATE TABLE IF NOT EXISTS roles ( + id SERIAL PRIMARY KEY, + level INT NOT NULL, + role VARCHAR(255) NOT NULL, + name VARCHAR(255) NOT NULL, + description VARCHAR(255), + CONSTRAINT unique_role UNIQUE (role) -- Add a unique constraint to the 'role' column +); + +-- Insert a role into the roles table (this assumes role_id=1 is for 'user') +INSERT INTO roles (level, role, name, description) +VALUES (1, 'user', 'User', 'A regular user with basic access.') +ON CONFLICT (role) DO NOTHING; -- Prevent duplicate insertions if role already exists + +-- Insert a role into the roles table (this assumes role_id=2 is for 'admin') +INSERT INTO roles (level, role, name, description) +VALUES (2, 'admin', 'Administrator', 'An administrator.') +ON CONFLICT (role) DO NOTHING; -- Prevent duplicate insertions if role already exists diff --git a/migrations/20250128160119_create_users_table.sql b/migrations/20250128160119_create_users_table.sql new file mode 100644 index 0000000..ffc6f12 --- /dev/null +++ b/migrations/20250128160119_create_users_table.sql @@ -0,0 +1,21 @@ +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + username VARCHAR(255) NOT NULL UNIQUE, + email VARCHAR(255) NOT NULL UNIQUE, + password_hash VARCHAR(255) NOT NULL, + totp_secret VARCHAR(255), + role_id INT NOT NULL DEFAULT 1 REFERENCES roles(id), -- Default role_id is set to 1 + CONSTRAINT unique_username UNIQUE (username) -- Ensure that username is unique +); + +-- Insert the example 'user' into the users table with a conflict check for username +INSERT INTO users (username, email, password_hash, role_id) +VALUES + ('user', 'user@test.com', '$argon2i$v=19$m=16,t=2,p=1$ZE1qUWd0U21vUUlIM0ltaQ$dowBmjU4oHtoPd355dXypQ', 1) +ON CONFLICT (username) DO NOTHING; -- Prevent duplicate insertions if username already exists + +-- Insert the example 'admin' into the users table with a conflict check for username +INSERT INTO users (username, email, password_hash, role_id) +VALUES + ('admin', 'admin@test.com', '$argon2i$v=19$m=16,t=2,p=1$ZE1qUWd0U21vUUlIM0ltaQ$dowBmjU4oHtoPd355dXypQ', 2) +ON CONFLICT (username) DO NOTHING; -- Prevent duplicate insertions if username already exists \ No newline at end of file diff --git a/migrations/20250128160204_create_todos_table.sql b/migrations/20250128160204_create_todos_table.sql new file mode 100644 index 0000000..4d9d06e --- /dev/null +++ b/migrations/20250128160204_create_todos_table.sql @@ -0,0 +1,6 @@ +CREATE TABLE todos ( + id SERIAL PRIMARY KEY, -- Auto-incrementing primary key + task TEXT NOT NULL, -- Task description, cannot be null + description TEXT, -- Optional detailed description + user_id INT NOT NULL REFERENCES users(id) -- Foreign key to link to users table +); \ No newline at end of file diff --git a/src/core/config.rs b/src/core/config.rs new file mode 100644 index 0000000..c7b4a76 --- /dev/null +++ b/src/core/config.rs @@ -0,0 +1,55 @@ +// Import the standard library's environment module +use std::env; + +/// Retrieves the value of an environment variable as a `String`. +/// +/// # Arguments +/// * `key` - The name of the environment variable to retrieve. +/// +/// # Returns +/// * The value of the environment variable if it exists. +/// * Panics if the environment variable is missing. +pub fn get_env(key: &str) -> String { + env::var(key).unwrap_or_else(|_| panic!("Missing required environment variable: {}", key)) +} + +/// Retrieves the value of an environment variable as a `String`, with a default value if not found. +/// +/// # Arguments +/// * `key` - The name of the environment variable to retrieve. +/// * `default` - The value to return if the environment variable is not found. +/// +/// # Returns +/// * The value of the environment variable if it exists. +/// * The `default` value if the environment variable is missing. +pub fn get_env_with_default(key: &str, default: &str) -> String { + env::var(key).unwrap_or_else(|_| default.to_string()) +} + +/// Retrieves the value of an environment variable as a `bool`, with a default value if not found. +/// +/// The environment variable is considered `true` if its value is "true" (case-insensitive), otherwise `false`. +/// +/// # Arguments +/// * `key` - The name of the environment variable to retrieve. +/// * `default` - The value to return if the environment variable is not found. +/// +/// # Returns +/// * `true` if the environment variable is "true" (case-insensitive). +/// * `false` otherwise, or if the variable is missing, the `default` value is returned. +pub fn get_env_bool(key: &str, default: bool) -> bool { + env::var(key).map(|v| v.to_lowercase() == "true").unwrap_or(default) +} + +/// Retrieves the value of an environment variable as a `u16`, with a default value if not found. +/// +/// # Arguments +/// * `key` - The name of the environment variable to retrieve. +/// * `default` - The value to return if the environment variable is not found or cannot be parsed. +/// +/// # Returns +/// * The parsed `u16` value of the environment variable if it exists and is valid. +/// * The `default` value if the variable is missing or invalid. +pub fn get_env_u16(key: &str, default: u16) -> u16 { + env::var(key).unwrap_or_else(|_| default.to_string()).parse().unwrap_or(default) +} diff --git a/src/core/mod.rs b/src/core/mod.rs new file mode 100644 index 0000000..b8dcadd --- /dev/null +++ b/src/core/mod.rs @@ -0,0 +1,4 @@ +pub mod config; +pub mod server; +pub mod tls; + diff --git a/src/core/server.rs b/src/core/server.rs new file mode 100644 index 0000000..3c9da33 --- /dev/null +++ b/src/core/server.rs @@ -0,0 +1,38 @@ +// Axum for web server and routing +use axum::Router; + +// Middleware layers from tower_http +use tower_http::compression::{CompressionLayer, CompressionLevel}; // For HTTP response compression +use tower_http::trace::TraceLayer; // For HTTP request/response tracing + +// Local crate imports for database connection and configuration +use crate::database::connect::connect_to_database; // Function to connect to the database +use crate::config; // Environment configuration helper + +/// Function to create and configure the Axum server. +pub async fn create_server() -> Router { + // Establish a connection to the database + let db = connect_to_database().await.expect("Failed to connect to database."); + + // Initialize the routes for the server + let mut app = crate::routes::create_routes(db); + + // Enable tracing middleware if configured + if config::get_env_bool("SERVER_TRACE_ENABLED", true) { + app = app.layer(TraceLayer::new_for_http()); + println!("✔️ Trace hads been enabled."); + } + + // Enable compression middleware if configured + if config::get_env_bool("SERVER_COMPRESSION_ENABLED", true) { + // Parse compression level from environment or default to level 6 + let level = config::get_env("SERVER_COMPRESSION_LEVEL").parse().unwrap_or(6); + // Apply compression layer with Brotli (br) enabled and the specified compression level + app = app.layer(CompressionLayer::new().br(true).quality(CompressionLevel::Precise(level))); + println!("✔️ Brotli compression enabled with compression quality level {}.", level); + + } + + // Return the fully configured application + app +} diff --git a/src/core/tls.rs b/src/core/tls.rs new file mode 100644 index 0000000..a52a6d7 --- /dev/null +++ b/src/core/tls.rs @@ -0,0 +1,145 @@ +// Standard library imports +use std::{ + future::Future, + net::SocketAddr, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + io::BufReader, + fs::File, + iter, +}; + +// External crate imports +use axum::serve::Listener; +use rustls::{self, server::ServerConfig, pki_types::{PrivateKeyDer, CertificateDer}}; +use rustls_pemfile::{Item, read_one, certs}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tracing; + +// Local crate imports +use crate::config; // Import env config helper + +// Function to load TLS configuration from files +pub fn load_tls_config() -> ServerConfig { + // Get certificate and key file paths from the environment + let cert_path = config::get_env("SERVER_HTTPS_CERT_FILE_PATH"); + let key_path = config::get_env("SERVER_HTTPS_KEY_FILE_PATH"); + + // Open the certificate and key files + let cert_file = File::open(cert_path).expect("❌ Failed to open certificate file."); + let key_file = File::open(key_path).expect("❌ Failed to open private key file."); + + // Read the certificate chain and private key from the files + let mut cert_reader = BufReader::new(cert_file); + let mut key_reader = BufReader::new(key_file); + + // Read and parse the certificate chain + let cert_chain: Vec = certs(&mut cert_reader) + .map(|cert| cert.expect("❌ Failed to read certificate.")) + .map(CertificateDer::from) + .collect(); + + // Ensure certificates are found + if cert_chain.is_empty() { + panic!("❌ No valid certificates found."); + } + + // Read the private key from the file + let key = iter::from_fn(|| read_one(&mut key_reader).transpose()) + .find_map(|item| match item.unwrap() { + Item::Pkcs1Key(key) => Some(PrivateKeyDer::from(key)), + Item::Pkcs8Key(key) => Some(PrivateKeyDer::from(key)), + Item::Sec1Key(key) => Some(PrivateKeyDer::from(key)), + _ => None, + }) + .expect("Failed to read a valid private key."); + + // Build and return the TLS server configuration + ServerConfig::builder() + .with_no_client_auth() // No client authentication + .with_single_cert(cert_chain, key) // Use the provided cert and key + .expect("Failed to create TLS configuration") +} + +// Custom listener that implements axum::serve::Listener +#[derive(Clone)] +pub struct TlsListener { + pub inner: Arc, // Inner TCP listener + pub acceptor: tokio_rustls::TlsAcceptor, // TLS acceptor for handling TLS handshakes +} + +impl Listener for TlsListener { + type Io = TlsStreamWrapper; // Type of I/O stream + type Addr = SocketAddr; // Type of address (Socket address) + + // Method to accept incoming connections and establish a TLS handshake + fn accept(&mut self) -> impl Future + Send { + let acceptor = self.acceptor.clone(); // Clone the acceptor for async use + + async move { + loop { + // Accept a TCP connection + let (stream, addr) = match self.inner.accept().await { + Ok((stream, addr)) => (stream, addr), + Err(e) => { + tracing::error!("❌ Error accepting TCP connection: {}", e); + continue; // Retry on error + } + }; + + // Perform TLS handshake + match acceptor.accept(stream).await { + Ok(tls_stream) => { + tracing::info!("Successful TLS handshake with {}", addr); + return (TlsStreamWrapper(tls_stream), addr); // Return TLS stream and address + }, + Err(e) => { + tracing::warn!("TLS handshake failed: {} (Client may not trust certificate)", e); + continue; // Retry on error + } + } + } + } + } + + // Method to retrieve the local address of the listener + fn local_addr(&self) -> std::io::Result { + self.inner.local_addr() + } +} + +// Wrapper for a TLS stream, implementing AsyncRead and AsyncWrite +#[derive(Debug)] +pub struct TlsStreamWrapper(tokio_rustls::server::TlsStream); + +impl AsyncRead for TlsStreamWrapper { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) // Delegate read operation to the underlying TLS stream + } +} + +impl AsyncWrite for TlsStreamWrapper { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) // Delegate write operation to the underlying TLS stream + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) // Flush operation for the TLS stream + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) // Shutdown operation for the TLS stream + } +} + +// Allow the TLS stream wrapper to be used in non-blocking contexts (needed for async operations) +impl Unpin for TlsStreamWrapper {} diff --git a/src/database/connect.rs b/src/database/connect.rs new file mode 100644 index 0000000..484d333 --- /dev/null +++ b/src/database/connect.rs @@ -0,0 +1,51 @@ +use dotenvy::dotenv; +use sqlx::{PgPool, migrate::Migrator, postgres::PgPoolOptions}; +use std::fs; +use std::env; +use std::path::Path; + +/// Connects to the database using the DATABASE_URL environment variable. +pub async fn connect_to_database() -> Result { + dotenv().ok(); + let database_url = &env::var("DATABASE_URL").expect("❌ 'DATABASE_URL' environment variable not fount."); + + // Read max and min connection values from environment variables, with defaults + let max_connections: u32 = env::var("DATABASE_MAX_CONNECTIONS") + .unwrap_or_else(|_| "10".to_string()) // Default to 10 + .parse() + .expect("❌ Invalid 'DATABASE_MAX_CONNECTIONS' value; must be a number."); + + let min_connections: u32 = env::var("DATABASE_MIN_CONNECTIONS") + .unwrap_or_else(|_| "2".to_string()) // Default to 2 + .parse() + .expect("❌ Invalid 'DATABASE_MIN_CONNECTIONS' value; must be a number."); + + // Create and configure the connection pool + let pool = PgPoolOptions::new() + .max_connections(max_connections) + .min_connections(min_connections) + .connect(&database_url) + .await?; + + Ok(pool) +} + +/// Run database migrations +pub async fn run_database_migrations(pool: &PgPool) -> Result<(), sqlx::Error> { + // Define the path to the migrations folder + let migrations_path = Path::new("./migrations"); + + // Check if the migrations folder exists, and if not, create it + if !migrations_path.exists() { + fs::create_dir_all(migrations_path).expect("❌ Failed to create migrations directory. Make sure you have the necessary permissions."); + println!("✔️ Created migrations directory: {:?}", migrations_path); + } + + // Create a migrator instance that looks for migrations in the `./migrations` folder + let migrator = Migrator::new(migrations_path).await?; + + // Run all pending migrations + migrator.run(pool).await?; + + Ok(()) +} \ No newline at end of file diff --git a/src/database/get_users.rs b/src/database/get_users.rs new file mode 100644 index 0000000..7493cdd --- /dev/null +++ b/src/database/get_users.rs @@ -0,0 +1,25 @@ +use sqlx::postgres::PgPool; +use crate::models::user::*; // Import the User struct + +// Get all users +pub async fn get_user_by_email(pool: &PgPool, email: String) -> Result { + // Use a string literal directly in the macro + let user = sqlx::query_as!( + User, // Struct type to map the query result + r#" + SELECT id, username, email, password_hash, totp_secret, role_id + FROM users + WHERE email = $1 + "#, + email // Bind the `email` parameter + ) + .fetch_optional(pool) + .await + .map_err(|e| format!("Database error: {}", e))?; // Handle database errors + + // Handle optional result + match user { + Some(user) => Ok(user), + None => Err(format!("User with email '{}' not found.", email)), + } +} \ No newline at end of file diff --git a/src/database/mod.rs b/src/database/mod.rs new file mode 100644 index 0000000..4066dcc --- /dev/null +++ b/src/database/mod.rs @@ -0,0 +1,3 @@ +// Module declarations +pub mod connect; +pub mod get_users; diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs new file mode 100644 index 0000000..67bad8d --- /dev/null +++ b/src/handlers/mod.rs @@ -0,0 +1,6 @@ +// Module declarations +// pub mod auth; + + +// Re-exporting modules +// pub use auth::*; diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..36dde6a --- /dev/null +++ b/src/main.rs @@ -0,0 +1,91 @@ +#[allow(dead_code)] + +// Core modules for the configuration, TLS setup, and server creation +mod core; +use core::{config, tls, server}; +use core::tls::TlsListener; + +// Other modules for database, routes, models, and middlewares +mod database; +mod routes; +mod models; +mod middlewares; + +use std::net::{IpAddr, SocketAddr}; +use std::sync::Arc; +use tokio::net::TcpListener; +use tokio_rustls::TlsAcceptor; +use axum::serve; + +#[tokio::main] +async fn main() { + dotenvy::dotenv().ok(); // Load environment variables from a .env file + tracing_subscriber::fmt::init(); // Initialize the logging system + + // Print a cool startup message with ASCII art and emojis + println!("{}", r#" + ##### ## ## + ## ## ## + ## ## ## ## ##### ####### ##### #### #### + ##### ## ## ## ## ## ## ## ## ## + #### ## ## ## ## ## ## ## ## ## +## ## ## ### ## ## ## ### ## ## ## +## ## ### ## ##### ### ### ## ##### ###### + ## + + Rustapi - An example API built with Rust, Axum, SQLx, and PostgreSQL + GitHub: https://github.com/Riktastic/rustapi + "#); + + println!("🚀 Starting Rustapi..."); + + // Retrieve server IP and port from the environment, default to 0.0.0.0:3000 + let ip: IpAddr = config::get_env_with_default("SERVER_IP", "0.0.0.0") + .parse() + .expect("❌ Invalid IP address format. Please provide a valid IPv4 address. For example 0.0.0.0 or 127.0.0.1."); + let port: u16 = config::get_env_u16("SERVER_PORT", 3000); + let socket_addr = SocketAddr::new(ip, port); + + // Create the Axum app instance using the server configuration + let app = server::create_server().await; + + // Check if HTTPS is enabled in the environment configuration + if config::get_env_bool("SERVER_HTTPS_ENABLED", false) { + // If HTTPS is enabled, start the server with secure HTTPS. + + // Bind TCP listener for incoming connections + let tcp_listener = TcpListener::bind(socket_addr) + .await + .expect("❌ Failed to bind to socket. Port might allready be in use."); // Explicit error handling + + // Load the TLS configuration for secure HTTPS connections + let tls_config = tls::load_tls_config(); + let acceptor = TlsAcceptor::from(Arc::new(tls_config)); // Create a TLS acceptor + let listener = TlsListener { + inner: Arc::new(tcp_listener), // Wrap TCP listener in TlsListener + acceptor: acceptor, + }; + + println!("🔒 Server started with HTTPS at: https://{}:{}", ip, port); + + // Serve the app using the TLS listener (HTTPS) + serve(listener, app.into_make_service()) + .await + .expect("❌ Server failed to start with HTTPS. Did you provide valid certificate and key files?"); + + } else { + // If HTTPS is not enabled, start the server with non-secure HTTP. + + // Bind TCP listener for non-secure HTTP connections + let listener = TcpListener::bind(socket_addr) + .await + .expect("❌ Failed to bind to socket. Port might allready be in use."); // Explicit error handling + + println!("🔓 Server started with HTTP at: http://{}:{}", ip, port); + + // Serve the app using the non-secure TCP listener (HTTP) + serve(listener, app.into_make_service()) + .await + .expect("❌ Server failed to start without HTTPS."); + } +} diff --git a/src/middlewares/auth.rs b/src/middlewares/auth.rs new file mode 100644 index 0000000..c28e868 --- /dev/null +++ b/src/middlewares/auth.rs @@ -0,0 +1,211 @@ +// Standard library imports for working with HTTP, environment variables, and other necessary utilities +use axum::{ + body::Body, + response::IntoResponse, + extract::{Request, Json}, // Extractor for request and JSON body + http::{self, Response, StatusCode}, // HTTP response and status codes + middleware::Next, // For adding middleware layers to the request handling pipeline +}; + +// Importing `State` for sharing application state (such as a database connection) across request handlers +use axum::extract::State; + +// Importing necessary libraries for password hashing, JWT handling, and date/time management +use std::env; // For accessing environment variables +use argon2::{ + password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, // For password hashing and verification + Argon2, +}; + +use chrono::{Duration, Utc}; // For working with time (JWT expiration, etc.) +use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, TokenData, Validation}; // For encoding and decoding JWT tokens +use serde::{Deserialize, Serialize}; // For serializing and deserializing JSON data +use serde_json::json; // For constructing JSON data +use sqlx::PgPool; // For interacting with PostgreSQL databases asynchronously + +// Importing custom database query functions +use crate::database::get_users::get_user_by_email; + +// Define the structure for JWT claims to be included in the token payload +#[derive(Serialize, Deserialize)] +pub struct Claims { + pub exp: usize, // Expiration timestamp (in seconds) + pub iat: usize, // Issued-at timestamp (in seconds) + pub email: String, // User's email +} + +// Custom error type for handling authentication errors +pub struct AuthError { + message: String, + status_code: StatusCode, // HTTP status code to be returned with the error +} + +// Function to verify a password against a stored hash using the Argon2 algorithm +pub fn verify_password(password: &str, hash: &str) -> Result { + let parsed_hash = PasswordHash::new(hash)?; // Parse the hash + // Verify the password using Argon2 + Ok(Argon2::default().verify_password(password.as_bytes(), &parsed_hash).is_ok()) +} + +// Function to hash a password using Argon2 and a salt retrieved from the environment variables +pub fn hash_password(password: &str) -> Result { + // Get the salt from environment variables (must be set) + let salt = env::var("AUTHENTICATION_ARGON2_SALT").expect("AUTHENTICATION_ARGON2_SALT must be set"); + let salt = SaltString::from_b64(&salt).unwrap(); // Convert base64 string to SaltString + let argon2 = Argon2::default(); // Create an Argon2 instance + // Hash the password with the salt + let password_hash = argon2.hash_password(password.as_bytes(), &salt)?.to_string(); + Ok(password_hash) +} + +// Implement the IntoResponse trait for AuthError to allow it to be returned as a response from the handler +impl IntoResponse for AuthError { + fn into_response(self) -> Response { + let body = Json(json!( { "error": self.message } )); // Create a JSON response body with the error message + + // Return a response with the appropriate status code and error message + (self.status_code, body).into_response() + } +} + +// Function to encode a JWT token for the given email address +pub fn encode_jwt(email: String) -> Result { + let jwt_token: String = "randomstring".to_string(); // Secret key for JWT (should be more secure in production) + + let now = Utc::now(); // Get current time + let expire = Duration::hours(24); // Set token expiration to 24 hours + let exp: usize = (now + expire).timestamp() as usize; // Expiration timestamp + let iat: usize = now.timestamp() as usize; // Issued-at timestamp + + let claim = Claims { iat, exp, email }; // Create JWT claims with timestamps and user email + let secret = jwt_token.clone(); // Secret key to sign the token + + // Encode the claims into a JWT token + encode( + &Header::default(), + &claim, + &EncodingKey::from_secret(secret.as_ref()), + ) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) // Return error if encoding fails +} + +// Function to decode a JWT token and extract the claims +pub fn decode_jwt(jwt: String) -> Result, StatusCode> { + let secret = "randomstring".to_string(); // Secret key to verify the JWT (should be more secure in production) + + // Decode the JWT token using the secret key and extract the claims + decode( + &jwt, + &DecodingKey::from_secret(secret.as_ref()), + &Validation::default(), + ) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) // Return error if decoding fails +} + +// Middleware for role-based access control (RBAC) +// Ensures that only users with specific roles are authorized to access certain resources +pub async fn authorize( + mut req: Request, + next: Next, + allowed_roles: Vec, // Accept a vector of allowed roles +) -> Result, AuthError> { + // Retrieve the database pool from request extensions (shared application state) + let pool = req.extensions().get::().expect("Database pool not found in request extensions"); + + // Retrieve the Authorization header from the request + let auth_header = req.headers().get(http::header::AUTHORIZATION); + + // Ensure the header exists and is correctly formatted + let auth_header = match auth_header { + Some(header) => header.to_str().map_err(|_| AuthError { + message: "Invalid header format".to_string(), + status_code: StatusCode::FORBIDDEN, + })?, + None => return Err(AuthError { + message: "Authorization header missing.".to_string(), + status_code: StatusCode::FORBIDDEN, + }), + }; + + // Extract the token from the Authorization header (Bearer token format) + let mut header = auth_header.split_whitespace(); + let (_, token) = (header.next(), header.next()); + + // Decode the JWT token + let token_data = match decode_jwt(token.unwrap().to_string()) { + Ok(data) => data, + Err(_) => return Err(AuthError { + message: "Unable to decode token.".to_string(), + status_code: StatusCode::UNAUTHORIZED, + }), + }; + + // Fetch the user from the database using the email from the decoded token + let current_user = match get_user_by_email(&pool, token_data.claims.email).await { + Ok(user) => user, + Err(_) => return Err(AuthError { + message: "Unauthorized user.".to_string(), + status_code: StatusCode::UNAUTHORIZED, + }), + }; + + // Check if the user's role is in the list of allowed roles + if !allowed_roles.contains(¤t_user.role_id) { + return Err(AuthError { + message: "Forbidden: insufficient role.".to_string(), + status_code: StatusCode::FORBIDDEN, + }); + } + + // Insert the current user into the request extensions for use in subsequent handlers + req.extensions_mut().insert(current_user); + + // Proceed to the next middleware or handler + Ok(next.run(req).await) +} + +// Structure to hold the data from the sign-in request +#[derive(Deserialize)] +pub struct SignInData { + pub email: String, + pub password: String, +} + +// Handler for user sign-in (authentication) +pub async fn sign_in( + State(pool): State, // Database connection pool injected as state + Json(user_data): Json, // Deserialize the JSON body into SignInData +) -> Result, (StatusCode, Json)> { + + // 1. Retrieve user from the database using the provided email + let user = match get_user_by_email(&pool, user_data.email).await { + Ok(user) => user, + Err(_) => return Err(( + StatusCode::UNAUTHORIZED, + Json(json!({ "error": "Incorrect credentials." })) + )), + }; + + // 2. Verify the password using the stored hash + if !verify_password(&user_data.password, &user.password_hash) + .map_err(|_| ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": "Internal server error." })) + ))? + { + return Err(( + StatusCode::UNAUTHORIZED, + Json(json!({ "error": "Incorrect credentials." })) + )); + } + + // 3. Generate a JWT token for the authenticated user + let token = encode_jwt(user.email) + .map_err(|_| ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": "Internal server error." })) + ))?; + + // 4. Return the JWT token to the client + Ok(Json(json!({ "token": token }))) +} diff --git a/src/middlewares/mod.rs b/src/middlewares/mod.rs new file mode 100644 index 0000000..ff8efe8 --- /dev/null +++ b/src/middlewares/mod.rs @@ -0,0 +1,2 @@ +// Module declarations +pub mod auth; \ No newline at end of file diff --git a/src/models/mod.rs b/src/models/mod.rs new file mode 100644 index 0000000..04b23c0 --- /dev/null +++ b/src/models/mod.rs @@ -0,0 +1,5 @@ +/// Module for to-do related models. +pub mod todo; +/// Module for user related models. +pub mod user; + diff --git a/src/models/role.rs b/src/models/role.rs new file mode 100644 index 0000000..c206f8e --- /dev/null +++ b/src/models/role.rs @@ -0,0 +1,18 @@ +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; + +/// Represents a user role in the system. +#[derive(Deserialize, Debug, Serialize, FromRow, Clone)] +#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL +pub struct Role { + /// ID of the role. + pub id: i32, + /// Level of the role. + pub level: i32, + /// System name of the role. + pub role: String, + /// The name of the role. + pub name: String, + /// Description of the role + pub Description: String, +} \ No newline at end of file diff --git a/src/models/todo.rs b/src/models/todo.rs new file mode 100644 index 0000000..60bbabc --- /dev/null +++ b/src/models/todo.rs @@ -0,0 +1,16 @@ +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; + +/// Represents a to-do item. +#[derive(Deserialize, Debug, Serialize, FromRow)] +#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL +pub struct Todo { + /// The unique identifier for the to-do item. + pub id: i32, + /// The task description. + pub task: String, + /// An optional detailed description of the task. + pub description: Option, + /// The unique identifier of the user who created the to-do item. + pub user_id: i32, +} \ No newline at end of file diff --git a/src/models/user.rs b/src/models/user.rs new file mode 100644 index 0000000..2721499 --- /dev/null +++ b/src/models/user.rs @@ -0,0 +1,20 @@ +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; + +/// Represents a user in the system. +#[derive(Deserialize, Debug, Serialize, FromRow, Clone)] +#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL +pub struct User { + /// The unique identifier for the user. + pub id: i32, + /// The username of the user. + pub username: String, + /// The email of the user. + pub email: String, + /// The hashed password for the user. + pub password_hash: String, + /// The TOTP secret for the user. + pub totp_secret: Option, + /// Current role of the user.. + pub role_id: i32, +} \ No newline at end of file diff --git a/src/routes/get_health.rs b/src/routes/get_health.rs new file mode 100644 index 0000000..a519641 --- /dev/null +++ b/src/routes/get_health.rs @@ -0,0 +1,200 @@ +use axum::{response::IntoResponse, Json, extract::State}; +use serde_json::json; +use sqlx::PgPool; +use sysinfo::{System, RefreshKind, Disks}; +use tokio::{task, join}; +use std::sync::{Arc, Mutex}; + +pub async fn get_health(State(database_connection): State) -> impl IntoResponse { + // Use Arc and Mutex to allow sharing System between tasks + let system = Arc::new(Mutex::new(System::new_with_specifics(RefreshKind::everything()))); + + // Run checks in parallel + let (cpu_result, mem_result, disk_result, process_result, db_result, net_result) = join!( + task::spawn_blocking({ + let system = Arc::clone(&system); + move || { + let mut system = system.lock().unwrap(); // Lock the mutex and get a mutable reference + check_cpu_usage(&mut system) // Pass the mutable reference + } + }), + task::spawn_blocking({ + let system = Arc::clone(&system); + move || { + let mut system = system.lock().unwrap(); // Lock the mutex and get a mutable reference + check_memory(&mut system) // Pass the mutable reference + } + }), + task::spawn_blocking({ + move || { + check_disk_usage() // Does not need a system reference. + } + }), + task::spawn_blocking({ + let system = Arc::clone(&system); + move || { + let mut system = system.lock().unwrap(); // Lock the mutex and get a mutable reference + check_processes(&mut system, &["postgres", "Code"]) // Pass the mutable reference + } + }), + check_database_connection(&database_connection), // Async function + task::spawn_blocking(check_network_connection) // Blocking, okay in spawn_blocking + ); + + let mut status = "healthy"; + let mut details = json!({}); + + // Process CPU result + if let Ok(Ok(cpu_details)) = cpu_result { + details["cpu_usage"] = json!(cpu_details); + if cpu_details["status"] == "low" { + status = "degraded"; + } + } else { + details["cpu_usage"] = json!({ "status": "error", "error": "Failed to retrieve CPU usage" }); + status = "degraded"; + } + + // Process Memory result + if let Ok(Ok(mem_details)) = mem_result { + details["memory"] = json!(mem_details); + if mem_details["status"] == "low" { + status = "degraded"; + } + } else { + details["memory"] = json!({ "status": "error", "error": "Failed to retrieve memory information" }); + status = "degraded"; + } + + // Process Disk result + if let Ok(Ok(disk_details)) = disk_result { + details["disk_usage"] = json!(disk_details); + if disk_details["status"] == "critical" { + status = "degraded"; + } + } else { + details["disk_usage"] = json!({ "status": "error", "error": "Failed to retrieve disk usage" }); + status = "degraded"; + } + + // Process Process result + if let Ok(Ok(process_details)) = process_result { + details["important_processes"] = json!(process_details); + if process_details.iter().any(|p| p["status"] == "not running") { + status = "degraded"; + } + } else { + details["important_processes"] = json!({ "status": "error", "error": "Failed to retrieve process information" }); + status = "degraded"; + } + + // Process Database result + if let Ok(db_status) = db_result { + details["database"] = json!({ "status": if db_status { "ok" } else { "degraded" } }); + if !db_status { + status = "degraded"; + } + } else { + details["database"] = json!({ "status": "error", "error": "Failed to retrieve database status" }); + status = "degraded"; + } + + // Process Network result + if let Ok(Ok(net_status)) = net_result { + details["network"] = json!({ "status": if net_status { "ok" } else { "degraded" } }); + if !net_status { + status = "degraded"; + } + } else { + details["network"] = json!({ "status": "error", "error": "Failed to retrieve network status" }); + status = "degraded"; + } + + Json(json!({ + "status": status, + "details": details, + })) +} + +// Helper functions + +fn check_cpu_usage(system: &mut System) -> Result { + system.refresh_cpu_usage(); + let usage = system.global_cpu_usage(); + let available = 100.0 - usage; + Ok(json!( { + "usage_percentage": format!("{:.2}", usage), + "available_percentage": format!("{:.2}", available), + "status": if available < 10.0 { "low" } else { "normal" }, + })) +} + +fn check_memory(system: &mut System) -> Result { + system.refresh_memory(); + let available = system.available_memory() / 1024 / 1024; // Convert to MB + Ok(json!( { + "available_mb": available, + "status": if available < 512 { "low" } else { "normal" }, + })) +} + +fn check_disk_usage() -> Result { + // Create a new Disks object and refresh the disk information + let mut disks = Disks::new(); + disks.refresh(false); // Refresh disk information without performing a full refresh + + // Iterate through the list of disks and check the usage for each one + let usage: Vec<_> = disks.list().iter().map(|disk| { + let total = disk.total_space() as f64; + let available = disk.available_space() as f64; + let used_percentage = ((total - available) / total) * 100.0; + used_percentage + }).collect(); + + // Get the maximum usage percentage + let max_usage = usage.into_iter() + .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + .unwrap_or(0.0); + + // Return the result as a JSON object + Ok(json!( { + "used_percentage": format!("{:.2}", max_usage), + "status": if max_usage > 90.0 { "critical" } else { "ok" }, + })) +} + +fn check_processes(system: &mut System, processes: &[&str]) -> Result, ()> { + system.refresh_processes(sysinfo::ProcessesToUpdate::All, true); + + let process_statuses: Vec<_> = processes.iter().map(|&name| { + // Adjust process names based on the platform and check if they are running + let adjusted_name = if cfg!(target_os = "windows") { + match name { + "postgres" => "postgres.exe", // Postgres on Windows + "Code" => "Code.exe", // Visual Studio Code on Windows + _ => name, // For other platforms, use the name as is + } + } else { + name // For non-Windows platforms, use the name as is + }; + + // Check if the translated (adjusted) process is running + let is_running = system.processes().iter().any(|(_, proc)| proc.name() == adjusted_name); + + // Return a JSON object for each process with its status + json!({ + "name": name, + "status": if is_running { "running" } else { "not running" } + }) + }).collect(); + + Ok(process_statuses) +} + +async fn check_database_connection(pool: &PgPool) -> Result { + sqlx::query("SELECT 1").fetch_one(pool).await.map(|_| true).or_else(|_| Ok(false)) +} + +fn check_network_connection() -> Result { + Ok(std::net::TcpStream::connect("8.8.8.8:53").is_ok()) +} diff --git a/src/routes/get_todos.rs b/src/routes/get_todos.rs new file mode 100644 index 0000000..abecfb7 --- /dev/null +++ b/src/routes/get_todos.rs @@ -0,0 +1,43 @@ +use axum::extract::{State, Path}; +use axum::Json; +use axum::response::IntoResponse; +use sqlx::postgres::PgPool; +use crate::models::todo::*; + +// Get all todos +pub async fn get_all_todos(State(pool): State,) -> impl IntoResponse { + let todos = sqlx::query_as!(Todo, "SELECT * FROM todos") // Your table name + .fetch_all(&pool) // Borrow the connection pool + .await; + + match todos { + Ok(todos) => Ok(Json(todos)), // Return all todos as JSON + Err(err) => Err(( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + format!("Error fetching todos: {}", err), + )), + } +} + + +// Get a single todo by id +pub async fn get_todos_by_id( + State(pool): State, + Path(id): Path, // Use Path extractor here +) -> impl IntoResponse { + let todo = sqlx::query_as!(Todo, "SELECT * FROM todos WHERE id = $1", id) + .fetch_optional(&pool) // Borrow the connection pool + .await; + + match todo { + Ok(Some(todo)) => Ok(Json(todo)), // Return the todo as JSON if found + Ok(None) => Err(( + axum::http::StatusCode::NOT_FOUND, + format!("Todo with id {} not found", id), + )), + Err(err) => Err(( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + format!("Error fetching todo: {}", err), + )), + } +} \ No newline at end of file diff --git a/src/routes/get_users.rs b/src/routes/get_users.rs new file mode 100644 index 0000000..85f071b --- /dev/null +++ b/src/routes/get_users.rs @@ -0,0 +1,87 @@ +use axum::extract::{State, Path}; +use axum::Json; +use axum::response::IntoResponse; +use sqlx::postgres::PgPool; +use crate::models::user::*; // Import the User struct + +// Get all users +pub async fn get_all_users(State(pool): State,) -> impl IntoResponse { + let users = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users") // Your table name + .fetch_all(&pool) // Borrow the connection pool + .await; + + match users { + Ok(users) => Ok(Json(users)), // Return all users as JSON + Err(err) => Err(( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + format!("Error fetching users: {}", err), + )), + } +} + +// Get a single user by id +pub async fn get_users_by_id( + State(pool): State, + Path(id): Path, // Use Path extractor here + ) -> impl IntoResponse { + + let user = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users WHERE id = $1", id) + .fetch_optional(&pool) // Borrow the connection pool + .await; + + match user { + Ok(Some(user)) => Ok(Json(user)), // Return the user as JSON if found + Ok(None) => Err(( + axum::http::StatusCode::NOT_FOUND, + format!("User with id {} not found", id), + )), + Err(err) => Err(( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + format!("Error fetching user: {}", err), + )), + } +} + +// Get a single user by username +pub async fn get_user_by_username( + State(pool): State, + Path(username): Path, // Use Path extractor here for username +) -> impl IntoResponse { + let user = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users WHERE username = $1", username) + .fetch_optional(&pool) // Borrow the connection pool + .await; + + match user { + Ok(Some(user)) => Ok(Json(user)), // Return the user as JSON if found + Ok(None) => Err(( + axum::http::StatusCode::NOT_FOUND, + format!("User with username {} not found", username), + )), + Err(err) => Err(( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + format!("Error fetching user: {}", err), + )), + } +} + +// Get a single user by email +pub async fn get_user_by_email( + State(pool): State, + Path(email): Path, // Use Path extractor here for email +) -> impl IntoResponse { + let user = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users WHERE email = $1", email) + .fetch_optional(&pool) // Borrow the connection pool + .await; + + match user { + Ok(Some(user)) => Ok(Json(user)), // Return the user as JSON if found + Ok(None) => Err(( + axum::http::StatusCode::NOT_FOUND, + format!("User with email {} not found", email), + )), + Err(err) => Err(( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + format!("Error fetching user: {}", err), + )), + } +} diff --git a/src/routes/mod.rs b/src/routes/mod.rs new file mode 100644 index 0000000..92c4c9a --- /dev/null +++ b/src/routes/mod.rs @@ -0,0 +1,59 @@ +// Module declarations for different route handlers +pub mod get_todos; +pub mod get_users; +pub mod post_todos; +pub mod post_users; +pub mod get_health; +pub mod protected; + +// Re-exporting modules to make their contents available at this level +pub use get_todos::*; +pub use get_users::*; +pub use post_todos::*; +pub use post_users::*; +pub use get_health::*; +pub use protected::*; + +use axum::{ + Router, + routing::{get, post}, +}; + +use sqlx::PgPool; + +use crate::middlewares::auth::{sign_in, authorize}; + +/// Function to create and configure all routes +pub fn create_routes(database_connection: PgPool) -> Router { + // Authentication routes + let auth_routes = Router::new() + .route("/signin", post(sign_in)) + .route("/protected", get(protected).route_layer(axum::middleware::from_fn(|req, next| { + let allowed_roles = vec![1, 2]; + authorize(req, next, allowed_roles) + }))); + + // User-related routes + let user_routes = Router::new() + .route("/all", get(get_all_users)) + .route("/{id}", get(get_users_by_id)) + .route("/", post(post_user)); + + // Todo-related routes + let todo_routes = Router::new() + .route("/all", get(get_all_todos)) + .route("/", post(post_todo).layer(axum::middleware::from_fn(|req, next| { + let allowed_roles = vec![1, 2]; + authorize(req, next, allowed_roles) + }))) + .route("/{id}", get(get_todos_by_id)); + + // Combine all routes and add middleware + Router::new() + .merge(auth_routes) // Add authentication routes + .nest("/users", user_routes) // Add user routes under /users + .nest("/todos", todo_routes) // Add todo routes under /todos + .route("/health", get(get_health)) // Add health check route + .layer(axum::Extension(database_connection.clone())) // Add database connection to all routes + .with_state(database_connection) // Add database connection as state +} diff --git a/src/routes/post_todos.rs b/src/routes/post_todos.rs new file mode 100644 index 0000000..d2e6c37 --- /dev/null +++ b/src/routes/post_todos.rs @@ -0,0 +1,53 @@ +use axum::{extract::{State, Extension}, Json}; +use axum::response::IntoResponse; +use sqlx::postgres::PgPool; +use crate::models::todo::*; +use crate::models::user::*; +use serde::Deserialize; +use axum::http::StatusCode; +use serde_json::json; + +#[derive(Deserialize)] +pub struct TodoBody { + pub task: String, + pub description: Option, + pub user_id: i32, +} + +// Add a new todo +pub async fn post_todo( + State(pool): State, + Extension(user): Extension, // Extract current user from the request extensions + Json(todo): Json +) -> impl IntoResponse { + // Ensure the user_id from the request matches the current user's id + if todo.user_id != user.id { + return Err(( + StatusCode::FORBIDDEN, + Json(json!({ "error": "User is not authorized to create a todo for another user" })) + )); + } + + // Insert the todo into the database + let row = sqlx::query!( + "INSERT INTO todos (task, description, user_id) VALUES ($1, $2, $3) RETURNING id, task, description, user_id", + todo.task, + todo.description, + todo.user_id + ) + .fetch_one(&pool) + .await; + + match row { + Ok(row) => Ok(Json(Todo { + id: row.id, + task: row.task, + description: row.description, + user_id: row.user_id, + })), + Err(err) => Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": format!("Error: {}", err) })) + )), + } +} \ No newline at end of file diff --git a/src/routes/post_users.rs b/src/routes/post_users.rs new file mode 100644 index 0000000..51c4978 --- /dev/null +++ b/src/routes/post_users.rs @@ -0,0 +1,44 @@ +use axum::extract::State; +use axum::Json; +use axum::response::IntoResponse; +use sqlx::postgres::PgPool; +use crate::models::user::*; +use serde::Deserialize; + +#[derive(Deserialize)] +pub struct UserBody { + pub username: String, + pub email: String, + pub password_hash: String, + pub totp_secret: String, + pub role_id: i32, +} + +// Add a new user +pub async fn post_user(State(pool): State, Json(user): Json, ) -> impl IntoResponse { + let row = sqlx::query!( + "INSERT INTO users (username, email, password_hash, totp_secret, role_id) VALUES ($1, $2, $3, $4, $5) RETURNING id, username, email, password_hash, totp_secret, role_id", + user.username, + user.email, + user.password_hash, + user.totp_secret, + user.role_id + ) + .fetch_one(&pool) // Use `&pool` to borrow the connection pool + .await; + + match row { + Ok(row) => Ok(Json(User { + id: row.id, + username: row.username, + email: row.email, + password_hash: row.password_hash, + totp_secret: row.totp_secret, + role_id: row.role_id, + })), + Err(err) => Err(( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + format!("Error: {}", err), + )), + } +} \ No newline at end of file diff --git a/src/routes/protected.rs b/src/routes/protected.rs new file mode 100644 index 0000000..277521e --- /dev/null +++ b/src/routes/protected.rs @@ -0,0 +1,18 @@ +use axum::{Extension, Json, response::IntoResponse}; +use serde::{Serialize, Deserialize}; +use crate::models::user::User; + +#[derive(Serialize, Deserialize)] +struct UserResponse { + id: i32, + username: String, + email: String +} + +pub async fn protected(Extension(user): Extension) -> impl IntoResponse { + Json(UserResponse { + id: user.id, + username: user.username, + email: user.email + }) +} \ No newline at end of file