From e20f21bc8b35cf6c3411565548a95c45ae4807e3 Mon Sep 17 00:00:00 2001 From: Rik Heijmann Date: Thu, 30 Jan 2025 22:43:30 +0100 Subject: [PATCH] first commit --- .dockerignore | 35 +++ .env.example | 68 ++++++ .gitignore | 5 + Cargo.toml | 45 ++++ Dockerfile | 73 ++++++ README.md | 168 ++++++++++++++ compose.yaml | 32 +++ generate_ssl_key.bat | 1 + .../20250128160043_create_roles_table.sql | 19 ++ .../20250128160119_create_users_table.sql | 21 ++ .../20250128160204_create_todos_table.sql | 6 + src/core/config.rs | 55 +++++ src/core/mod.rs | 4 + src/core/server.rs | 38 ++++ src/core/tls.rs | 145 ++++++++++++ src/database/connect.rs | 51 +++++ src/database/get_users.rs | 25 +++ src/database/mod.rs | 3 + src/handlers/mod.rs | 6 + src/main.rs | 91 ++++++++ src/middlewares/auth.rs | 211 ++++++++++++++++++ src/middlewares/mod.rs | 2 + src/models/mod.rs | 5 + src/models/role.rs | 18 ++ src/models/todo.rs | 16 ++ src/models/user.rs | 20 ++ src/routes/get_health.rs | 200 +++++++++++++++++ src/routes/get_todos.rs | 43 ++++ src/routes/get_users.rs | 87 ++++++++ src/routes/mod.rs | 59 +++++ src/routes/post_todos.rs | 53 +++++ src/routes/post_users.rs | 44 ++++ src/routes/protected.rs | 18 ++ 33 files changed, 1667 insertions(+) create mode 100644 .dockerignore create mode 100644 .env.example create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 Dockerfile create mode 100644 README.md create mode 100644 compose.yaml create mode 100644 generate_ssl_key.bat create mode 100644 migrations/20250128160043_create_roles_table.sql create mode 100644 migrations/20250128160119_create_users_table.sql create mode 100644 migrations/20250128160204_create_todos_table.sql create mode 100644 src/core/config.rs create mode 100644 src/core/mod.rs create mode 100644 src/core/server.rs create mode 100644 src/core/tls.rs create mode 100644 src/database/connect.rs create mode 100644 src/database/get_users.rs create mode 100644 src/database/mod.rs create mode 100644 src/handlers/mod.rs create mode 100644 src/main.rs create mode 100644 src/middlewares/auth.rs create mode 100644 src/middlewares/mod.rs create mode 100644 src/models/mod.rs create mode 100644 src/models/role.rs create mode 100644 src/models/todo.rs create mode 100644 src/models/user.rs create mode 100644 src/routes/get_health.rs create mode 100644 src/routes/get_todos.rs create mode 100644 src/routes/get_users.rs create mode 100644 src/routes/mod.rs create mode 100644 src/routes/post_todos.rs create mode 100644 src/routes/post_users.rs create mode 100644 src/routes/protected.rs diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..e181198 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,35 @@ + + # Include any files or directories that you don't want to be copied to your + # container here (e.g., local build artifacts, temporary files, etc.). + # + # For more help, visit the .dockerignore file reference guide at + # https://docs.docker.com/engine/reference/builder/#dockerignore-file + + **/.DS_Store + **/.classpath + **/.dockerignore + **/.env + **/.git + **/.gitignore + **/.project + **/.settings + **/.toolstarget + **/.vs + **/.vscode + **/*.*proj.user + **/*.dbmdl + **/*.jfm + **/charts + **/docker-compose* + **/compose* + **/Dockerfile* + **/node_modules + **/npm-debug.log + **/secrets.dev.yaml + **/values.dev.yaml + /bin + /target + LICENSE + README.md + + \ No newline at end of file diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..e315459 --- /dev/null +++ b/.env.example @@ -0,0 +1,68 @@ +# ============================== +# 📌 DATABASE CONFIGURATION +# ============================== + +# PostgreSQL connection URL (format: postgres://user:password@host/database) +DATABASE_URL="postgres://postgres:1234@localhost/database_name" + +# Maximum number of connections in the database pool +DATABASE_MAX_CONNECTIONS=20 + +# Minimum number of connections in the database pool +DATABASE_MIN_CONNECTIONS=5 + +# ============================== +# 🌍 SERVER CONFIGURATION +# ============================== + +# IP address the server will bind to (0.0.0.0 allows all network interfaces) +SERVER_IP="0.0.0.0" + +# Port the server will listen on +SERVER_PORT="3000" + +# Enable tracing for debugging/logging (true/false) +SERVER_TRACE_ENABLED=true + +# ============================== +# 🔒 HTTPS CONFIGURATION +# ============================== + +# Enable HTTPS (true/false) +SERVER_HTTPS_ENABLED=false + +# Enable HTTP/2 when using HTTPS (true/false) +SERVER_HTTPS_HTTP2_ENABLED=true + +# Path to the SSL certificate file (only used if SERVER_HTTPS_ENABLED=true) +SERVER_HTTPS_CERT_FILE_PATH=cert.pem + +# Path to the SSL private key file (only used if SERVER_HTTPS_ENABLED=true) +SERVER_HTTPS_KEY_FILE_PATH=key.pem + +# ============================== +# 🚦 RATE LIMIT CONFIGURATION +# ============================== + +# Maximum number of requests allowed per period +SERVER_RATE_LIMIT=5 + +# Time period (in seconds) for rate limiting +SERVER_RATE_LIMIT_PERIOD=1 + +# ============================== +# 📦 COMPRESSION CONFIGURATION +# ============================== + +# Enable Brotli compression (true/false) +SERVER_COMPRESSION_ENABLED=true + +# Compression level (valid range: 0-11, where 11 is the highest compression) +SERVER_COMPRESSION_LEVEL=6 + +# ============================== +# 🔑 AUTHENTICATION CONFIGURATION +# ============================== + +# Argon2 salt for password hashing (must be kept secret!) +AUTHENTICATION_ARGON2_SALT="dMjQgtSmoQIH3Imi" \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..800b6dc --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.env +/target +cert.pem +key.pem +Cargo.lock \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..86bc5fd --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,45 @@ +[package] +name = "rustapi" +version = "0.1.0" +edition = "2021" + +[dependencies] +# Web framework and server +axum = { version = "0.8.1", features = ["json"] } +# hyper = { version = "1.5.2", features = ["full"] } + +# Database interaction +sqlx = { version = "0.8.3", features = ["runtime-tokio-rustls", "postgres", "migrate"] } + +# Serialization and deserialization +serde = { version = "1.0.217", features = ["derive"] } +serde_json = "1.0.137" + +# Authentication and security +jsonwebtoken = "9.3.0" +argon2 = "0.5.3" + +# Asynchronous runtime and traits +tokio = { version = "1.43.0", features = ["rt-multi-thread", "process"] } + +# Configuration and environment +dotenvy = "0.15.7" + +# Middleware and server utilities +tower = { version = "0.5.2", features = ["limit"] } +tower-http = { version = "0.6.2", features = ["trace", "cors", "compression-br"] } + +# Logging and monitoring +tracing = "0.1.41" +tracing-subscriber = "0.3.19" + +# System information +sysinfo = "0.33.1" + +# Date and time handling +chrono = "0.4.39" + +# SSL / TLS +rustls = "0.23.21" +tokio-rustls = "0.26.1" +rustls-pemfile = "2.2.0" diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..d7266d5 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,73 @@ + + # syntax=docker/dockerfile:1 + + # Comments are provided throughout this file to help you get started. + # If you need more help, visit the Dockerfile reference guide at + # https://docs.docker.com/engine/reference/builder/ + + ################################################################################ + # Create a stage for building the application. + + ARG RUST_VERSION=1.78.0 + ARG APP_NAME=backend + FROM rust:${RUST_VERSION}-slim-bullseye AS build + ARG APP_NAME + WORKDIR /app + + # Build the application. + # Leverage a cache mount to /usr/local/cargo/registry/ + # for downloaded dependencies and a cache mount to /app/target/ for + # compiled dependencies which will speed up subsequent builds. + # Leverage a bind mount to the src directory to avoid having to copy the + # source code into the container. Once built, copy the executable to an + # output directory before the cache mounted /app/target is unmounted. + RUN --mount=type=bind,source=src,target=src \ + # --mount=type=bind,source=configuration.yaml,target=configuration.yaml \ + --mount=type=bind,source=Cargo.toml,target=Cargo.toml \ + --mount=type=bind,source=Cargo.lock,target=Cargo.lock \ + --mount=type=cache,target=/app/target/ \ + --mount=type=cache,target=/usr/local/cargo/registry/ \ + <, // Injected user + Json(todo): Json + ) -> impl IntoResponse { + if todo.user_id != user.id { + return Err((StatusCode::FORBIDDEN, Json(json!({ + "error": "Cannot create todos for others" + })))); + } + ``` +- **Modern protocols ** - HTTP/2 with secure TLS defaults +- **Observability** - Integrated tracing +- **Optimized for performance** - Brotli compression +- **Easy configuration** - `.env` and environment variables +- **Documented codebase** - Extensive inline comments for easy modification and readability +- **Latest dependencies** - Regularly updated Rust ecosystem crates + +## 📦 Installation & Usage +```bash +# Clone and setup +git clone https://github.com/Riktastic/rustapi.git +cd rustapi && cp .env.example .env + +# Database setup +sqlx database create && sqlx migrate run + +# Start server +cargo run --release +``` + +### 🔐 Default Accounts + +**Warning:** These accounts should only be used for initial testing. Always change or disable them in production environments. + +| Email | Password | Role | +|---------------------|----------|----------------| +| `user@test.com` | `test` | User | +| `admin@test.com` | `test` | Administrator | + +⚠️ **Security Recommendations:** +1. Rotate passwords immediately after initial setup +2. Disable default accounts before deploying to production +3. Implement proper user management endpoints + + +## ⚙️ Configuration +```env +# ============================== +# 📌 DATABASE CONFIGURATION +# ============================== + +# PostgreSQL connection URL (format: postgres://user:password@host/database) +DATABASE_URL="postgres://postgres:1234@localhost/database_name" + +# Maximum number of connections in the database pool +DATABASE_MAX_CONNECTIONS=20 + +# Minimum number of connections in the database pool +DATABASE_MIN_CONNECTIONS=5 + +# ============================== +# 🌍 SERVER CONFIGURATION +# ============================== + +# IP address the server will bind to (0.0.0.0 allows all network interfaces) +SERVER_IP="0.0.0.0" + +# Port the server will listen on +SERVER_PORT="3000" + +# Enable tracing for debugging/logging (true/false) +SERVER_TRACE_ENABLED=true + +# ============================== +# 🔒 HTTPS CONFIGURATION +# ============================== + +# Enable HTTPS (true/false) +SERVER_HTTPS_ENABLED=false + +# Enable HTTP/2 when using HTTPS (true/false) +SERVER_HTTPS_HTTP2_ENABLED=true + +# Path to the SSL certificate file (only used if SERVER_HTTPS_ENABLED=true) +SERVER_HTTPS_CERT_FILE_PATH=cert.pem + +# Path to the SSL private key file (only used if SERVER_HTTPS_ENABLED=true) +SERVER_HTTPS_KEY_FILE_PATH=key.pem + +# ============================== +# 🚦 RATE LIMIT CONFIGURATION +# ============================== + +# Maximum number of requests allowed per period +SERVER_RATE_LIMIT=5 + +# Time period (in seconds) for rate limiting +SERVER_RATE_LIMIT_PERIOD=1 + +# ============================== +# 📦 COMPRESSION CONFIGURATION +# ============================== + +# Enable Brotli compression (true/false) +SERVER_COMPRESSION_ENABLED=true + +# Compression level (valid range: 0-11, where 11 is the highest compression) +SERVER_COMPRESSION_LEVEL=6 + +# ============================== +# 🔑 AUTHENTICATION CONFIGURATION +# ============================== + +# Argon2 salt for password hashing (must be kept secret!) +AUTHENTICATION_ARGON2_SALT="dMjQgtSmoQIH3Imi" +``` + +## 📂 Project Structure +``` +rustapi/ +├── migrations/ # SQL schema versions +├── src/ +│ ├── core/ # Config, TLS, server setup +│ ├── database/ # Query handling +│ ├── middlewares/ # Auth system +│ ├── models/ # Data structures +│ └── routes/ # API endpoints +└── Dockerfile # Containerization +``` + +## 🛠️ Technology Stack +| Category | Key Technologies | +|-----------------------|---------------------------------| +| Web Framework | Axum 0.8 + Tower | +| Database | PostgreSQL + SQLx 0.8 | +| Security | JWT + Argon2 + Rustls | +| Monitoring | Tracing + Sysinfo | diff --git a/compose.yaml b/compose.yaml new file mode 100644 index 0000000..a46bd80 --- /dev/null +++ b/compose.yaml @@ -0,0 +1,32 @@ + +services: + server: + build: + context: . + target: final + ports: + - 80:80 + depends_on: + - db_image + networks: + - common-net + + db_image: + image: postgres:latest + environment: + POSTGRES_PORT: 3306 + POSTGRES_DATABASE: database_name + POSTGRES_USER: user + POSTGRES_PASSWORD: database_password + POSTGRES_ROOT_PASSWORD: strong_database_password + expose: + - 3306 + ports: + - "3307:3306" + networks: + - common-net + +networks: + common-net: {} + + \ No newline at end of file diff --git a/generate_ssl_key.bat b/generate_ssl_key.bat new file mode 100644 index 0000000..a40bb7e --- /dev/null +++ b/generate_ssl_key.bat @@ -0,0 +1 @@ +openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes -subj "/CN=localhost" -addext "subjectAltName = DNS:localhost, IP:127.0.0.1" \ No newline at end of file diff --git a/migrations/20250128160043_create_roles_table.sql b/migrations/20250128160043_create_roles_table.sql new file mode 100644 index 0000000..c1546e5 --- /dev/null +++ b/migrations/20250128160043_create_roles_table.sql @@ -0,0 +1,19 @@ +-- Create the roles table +CREATE TABLE IF NOT EXISTS roles ( + id SERIAL PRIMARY KEY, + level INT NOT NULL, + role VARCHAR(255) NOT NULL, + name VARCHAR(255) NOT NULL, + description VARCHAR(255), + CONSTRAINT unique_role UNIQUE (role) -- Add a unique constraint to the 'role' column +); + +-- Insert a role into the roles table (this assumes role_id=1 is for 'user') +INSERT INTO roles (level, role, name, description) +VALUES (1, 'user', 'User', 'A regular user with basic access.') +ON CONFLICT (role) DO NOTHING; -- Prevent duplicate insertions if role already exists + +-- Insert a role into the roles table (this assumes role_id=2 is for 'admin') +INSERT INTO roles (level, role, name, description) +VALUES (2, 'admin', 'Administrator', 'An administrator.') +ON CONFLICT (role) DO NOTHING; -- Prevent duplicate insertions if role already exists diff --git a/migrations/20250128160119_create_users_table.sql b/migrations/20250128160119_create_users_table.sql new file mode 100644 index 0000000..ffc6f12 --- /dev/null +++ b/migrations/20250128160119_create_users_table.sql @@ -0,0 +1,21 @@ +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + username VARCHAR(255) NOT NULL UNIQUE, + email VARCHAR(255) NOT NULL UNIQUE, + password_hash VARCHAR(255) NOT NULL, + totp_secret VARCHAR(255), + role_id INT NOT NULL DEFAULT 1 REFERENCES roles(id), -- Default role_id is set to 1 + CONSTRAINT unique_username UNIQUE (username) -- Ensure that username is unique +); + +-- Insert the example 'user' into the users table with a conflict check for username +INSERT INTO users (username, email, password_hash, role_id) +VALUES + ('user', 'user@test.com', '$argon2i$v=19$m=16,t=2,p=1$ZE1qUWd0U21vUUlIM0ltaQ$dowBmjU4oHtoPd355dXypQ', 1) +ON CONFLICT (username) DO NOTHING; -- Prevent duplicate insertions if username already exists + +-- Insert the example 'admin' into the users table with a conflict check for username +INSERT INTO users (username, email, password_hash, role_id) +VALUES + ('admin', 'admin@test.com', '$argon2i$v=19$m=16,t=2,p=1$ZE1qUWd0U21vUUlIM0ltaQ$dowBmjU4oHtoPd355dXypQ', 2) +ON CONFLICT (username) DO NOTHING; -- Prevent duplicate insertions if username already exists \ No newline at end of file diff --git a/migrations/20250128160204_create_todos_table.sql b/migrations/20250128160204_create_todos_table.sql new file mode 100644 index 0000000..4d9d06e --- /dev/null +++ b/migrations/20250128160204_create_todos_table.sql @@ -0,0 +1,6 @@ +CREATE TABLE todos ( + id SERIAL PRIMARY KEY, -- Auto-incrementing primary key + task TEXT NOT NULL, -- Task description, cannot be null + description TEXT, -- Optional detailed description + user_id INT NOT NULL REFERENCES users(id) -- Foreign key to link to users table +); \ No newline at end of file diff --git a/src/core/config.rs b/src/core/config.rs new file mode 100644 index 0000000..c7b4a76 --- /dev/null +++ b/src/core/config.rs @@ -0,0 +1,55 @@ +// Import the standard library's environment module +use std::env; + +/// Retrieves the value of an environment variable as a `String`. +/// +/// # Arguments +/// * `key` - The name of the environment variable to retrieve. +/// +/// # Returns +/// * The value of the environment variable if it exists. +/// * Panics if the environment variable is missing. +pub fn get_env(key: &str) -> String { + env::var(key).unwrap_or_else(|_| panic!("Missing required environment variable: {}", key)) +} + +/// Retrieves the value of an environment variable as a `String`, with a default value if not found. +/// +/// # Arguments +/// * `key` - The name of the environment variable to retrieve. +/// * `default` - The value to return if the environment variable is not found. +/// +/// # Returns +/// * The value of the environment variable if it exists. +/// * The `default` value if the environment variable is missing. +pub fn get_env_with_default(key: &str, default: &str) -> String { + env::var(key).unwrap_or_else(|_| default.to_string()) +} + +/// Retrieves the value of an environment variable as a `bool`, with a default value if not found. +/// +/// The environment variable is considered `true` if its value is "true" (case-insensitive), otherwise `false`. +/// +/// # Arguments +/// * `key` - The name of the environment variable to retrieve. +/// * `default` - The value to return if the environment variable is not found. +/// +/// # Returns +/// * `true` if the environment variable is "true" (case-insensitive). +/// * `false` otherwise, or if the variable is missing, the `default` value is returned. +pub fn get_env_bool(key: &str, default: bool) -> bool { + env::var(key).map(|v| v.to_lowercase() == "true").unwrap_or(default) +} + +/// Retrieves the value of an environment variable as a `u16`, with a default value if not found. +/// +/// # Arguments +/// * `key` - The name of the environment variable to retrieve. +/// * `default` - The value to return if the environment variable is not found or cannot be parsed. +/// +/// # Returns +/// * The parsed `u16` value of the environment variable if it exists and is valid. +/// * The `default` value if the variable is missing or invalid. +pub fn get_env_u16(key: &str, default: u16) -> u16 { + env::var(key).unwrap_or_else(|_| default.to_string()).parse().unwrap_or(default) +} diff --git a/src/core/mod.rs b/src/core/mod.rs new file mode 100644 index 0000000..b8dcadd --- /dev/null +++ b/src/core/mod.rs @@ -0,0 +1,4 @@ +pub mod config; +pub mod server; +pub mod tls; + diff --git a/src/core/server.rs b/src/core/server.rs new file mode 100644 index 0000000..3c9da33 --- /dev/null +++ b/src/core/server.rs @@ -0,0 +1,38 @@ +// Axum for web server and routing +use axum::Router; + +// Middleware layers from tower_http +use tower_http::compression::{CompressionLayer, CompressionLevel}; // For HTTP response compression +use tower_http::trace::TraceLayer; // For HTTP request/response tracing + +// Local crate imports for database connection and configuration +use crate::database::connect::connect_to_database; // Function to connect to the database +use crate::config; // Environment configuration helper + +/// Function to create and configure the Axum server. +pub async fn create_server() -> Router { + // Establish a connection to the database + let db = connect_to_database().await.expect("Failed to connect to database."); + + // Initialize the routes for the server + let mut app = crate::routes::create_routes(db); + + // Enable tracing middleware if configured + if config::get_env_bool("SERVER_TRACE_ENABLED", true) { + app = app.layer(TraceLayer::new_for_http()); + println!("✔️ Trace hads been enabled."); + } + + // Enable compression middleware if configured + if config::get_env_bool("SERVER_COMPRESSION_ENABLED", true) { + // Parse compression level from environment or default to level 6 + let level = config::get_env("SERVER_COMPRESSION_LEVEL").parse().unwrap_or(6); + // Apply compression layer with Brotli (br) enabled and the specified compression level + app = app.layer(CompressionLayer::new().br(true).quality(CompressionLevel::Precise(level))); + println!("✔️ Brotli compression enabled with compression quality level {}.", level); + + } + + // Return the fully configured application + app +} diff --git a/src/core/tls.rs b/src/core/tls.rs new file mode 100644 index 0000000..a52a6d7 --- /dev/null +++ b/src/core/tls.rs @@ -0,0 +1,145 @@ +// Standard library imports +use std::{ + future::Future, + net::SocketAddr, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + io::BufReader, + fs::File, + iter, +}; + +// External crate imports +use axum::serve::Listener; +use rustls::{self, server::ServerConfig, pki_types::{PrivateKeyDer, CertificateDer}}; +use rustls_pemfile::{Item, read_one, certs}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tracing; + +// Local crate imports +use crate::config; // Import env config helper + +// Function to load TLS configuration from files +pub fn load_tls_config() -> ServerConfig { + // Get certificate and key file paths from the environment + let cert_path = config::get_env("SERVER_HTTPS_CERT_FILE_PATH"); + let key_path = config::get_env("SERVER_HTTPS_KEY_FILE_PATH"); + + // Open the certificate and key files + let cert_file = File::open(cert_path).expect("❌ Failed to open certificate file."); + let key_file = File::open(key_path).expect("❌ Failed to open private key file."); + + // Read the certificate chain and private key from the files + let mut cert_reader = BufReader::new(cert_file); + let mut key_reader = BufReader::new(key_file); + + // Read and parse the certificate chain + let cert_chain: Vec = certs(&mut cert_reader) + .map(|cert| cert.expect("❌ Failed to read certificate.")) + .map(CertificateDer::from) + .collect(); + + // Ensure certificates are found + if cert_chain.is_empty() { + panic!("❌ No valid certificates found."); + } + + // Read the private key from the file + let key = iter::from_fn(|| read_one(&mut key_reader).transpose()) + .find_map(|item| match item.unwrap() { + Item::Pkcs1Key(key) => Some(PrivateKeyDer::from(key)), + Item::Pkcs8Key(key) => Some(PrivateKeyDer::from(key)), + Item::Sec1Key(key) => Some(PrivateKeyDer::from(key)), + _ => None, + }) + .expect("Failed to read a valid private key."); + + // Build and return the TLS server configuration + ServerConfig::builder() + .with_no_client_auth() // No client authentication + .with_single_cert(cert_chain, key) // Use the provided cert and key + .expect("Failed to create TLS configuration") +} + +// Custom listener that implements axum::serve::Listener +#[derive(Clone)] +pub struct TlsListener { + pub inner: Arc, // Inner TCP listener + pub acceptor: tokio_rustls::TlsAcceptor, // TLS acceptor for handling TLS handshakes +} + +impl Listener for TlsListener { + type Io = TlsStreamWrapper; // Type of I/O stream + type Addr = SocketAddr; // Type of address (Socket address) + + // Method to accept incoming connections and establish a TLS handshake + fn accept(&mut self) -> impl Future + Send { + let acceptor = self.acceptor.clone(); // Clone the acceptor for async use + + async move { + loop { + // Accept a TCP connection + let (stream, addr) = match self.inner.accept().await { + Ok((stream, addr)) => (stream, addr), + Err(e) => { + tracing::error!("❌ Error accepting TCP connection: {}", e); + continue; // Retry on error + } + }; + + // Perform TLS handshake + match acceptor.accept(stream).await { + Ok(tls_stream) => { + tracing::info!("Successful TLS handshake with {}", addr); + return (TlsStreamWrapper(tls_stream), addr); // Return TLS stream and address + }, + Err(e) => { + tracing::warn!("TLS handshake failed: {} (Client may not trust certificate)", e); + continue; // Retry on error + } + } + } + } + } + + // Method to retrieve the local address of the listener + fn local_addr(&self) -> std::io::Result { + self.inner.local_addr() + } +} + +// Wrapper for a TLS stream, implementing AsyncRead and AsyncWrite +#[derive(Debug)] +pub struct TlsStreamWrapper(tokio_rustls::server::TlsStream); + +impl AsyncRead for TlsStreamWrapper { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) // Delegate read operation to the underlying TLS stream + } +} + +impl AsyncWrite for TlsStreamWrapper { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) // Delegate write operation to the underlying TLS stream + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) // Flush operation for the TLS stream + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) // Shutdown operation for the TLS stream + } +} + +// Allow the TLS stream wrapper to be used in non-blocking contexts (needed for async operations) +impl Unpin for TlsStreamWrapper {} diff --git a/src/database/connect.rs b/src/database/connect.rs new file mode 100644 index 0000000..484d333 --- /dev/null +++ b/src/database/connect.rs @@ -0,0 +1,51 @@ +use dotenvy::dotenv; +use sqlx::{PgPool, migrate::Migrator, postgres::PgPoolOptions}; +use std::fs; +use std::env; +use std::path::Path; + +/// Connects to the database using the DATABASE_URL environment variable. +pub async fn connect_to_database() -> Result { + dotenv().ok(); + let database_url = &env::var("DATABASE_URL").expect("❌ 'DATABASE_URL' environment variable not fount."); + + // Read max and min connection values from environment variables, with defaults + let max_connections: u32 = env::var("DATABASE_MAX_CONNECTIONS") + .unwrap_or_else(|_| "10".to_string()) // Default to 10 + .parse() + .expect("❌ Invalid 'DATABASE_MAX_CONNECTIONS' value; must be a number."); + + let min_connections: u32 = env::var("DATABASE_MIN_CONNECTIONS") + .unwrap_or_else(|_| "2".to_string()) // Default to 2 + .parse() + .expect("❌ Invalid 'DATABASE_MIN_CONNECTIONS' value; must be a number."); + + // Create and configure the connection pool + let pool = PgPoolOptions::new() + .max_connections(max_connections) + .min_connections(min_connections) + .connect(&database_url) + .await?; + + Ok(pool) +} + +/// Run database migrations +pub async fn run_database_migrations(pool: &PgPool) -> Result<(), sqlx::Error> { + // Define the path to the migrations folder + let migrations_path = Path::new("./migrations"); + + // Check if the migrations folder exists, and if not, create it + if !migrations_path.exists() { + fs::create_dir_all(migrations_path).expect("❌ Failed to create migrations directory. Make sure you have the necessary permissions."); + println!("✔️ Created migrations directory: {:?}", migrations_path); + } + + // Create a migrator instance that looks for migrations in the `./migrations` folder + let migrator = Migrator::new(migrations_path).await?; + + // Run all pending migrations + migrator.run(pool).await?; + + Ok(()) +} \ No newline at end of file diff --git a/src/database/get_users.rs b/src/database/get_users.rs new file mode 100644 index 0000000..7493cdd --- /dev/null +++ b/src/database/get_users.rs @@ -0,0 +1,25 @@ +use sqlx::postgres::PgPool; +use crate::models::user::*; // Import the User struct + +// Get all users +pub async fn get_user_by_email(pool: &PgPool, email: String) -> Result { + // Use a string literal directly in the macro + let user = sqlx::query_as!( + User, // Struct type to map the query result + r#" + SELECT id, username, email, password_hash, totp_secret, role_id + FROM users + WHERE email = $1 + "#, + email // Bind the `email` parameter + ) + .fetch_optional(pool) + .await + .map_err(|e| format!("Database error: {}", e))?; // Handle database errors + + // Handle optional result + match user { + Some(user) => Ok(user), + None => Err(format!("User with email '{}' not found.", email)), + } +} \ No newline at end of file diff --git a/src/database/mod.rs b/src/database/mod.rs new file mode 100644 index 0000000..4066dcc --- /dev/null +++ b/src/database/mod.rs @@ -0,0 +1,3 @@ +// Module declarations +pub mod connect; +pub mod get_users; diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs new file mode 100644 index 0000000..67bad8d --- /dev/null +++ b/src/handlers/mod.rs @@ -0,0 +1,6 @@ +// Module declarations +// pub mod auth; + + +// Re-exporting modules +// pub use auth::*; diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..36dde6a --- /dev/null +++ b/src/main.rs @@ -0,0 +1,91 @@ +#[allow(dead_code)] + +// Core modules for the configuration, TLS setup, and server creation +mod core; +use core::{config, tls, server}; +use core::tls::TlsListener; + +// Other modules for database, routes, models, and middlewares +mod database; +mod routes; +mod models; +mod middlewares; + +use std::net::{IpAddr, SocketAddr}; +use std::sync::Arc; +use tokio::net::TcpListener; +use tokio_rustls::TlsAcceptor; +use axum::serve; + +#[tokio::main] +async fn main() { + dotenvy::dotenv().ok(); // Load environment variables from a .env file + tracing_subscriber::fmt::init(); // Initialize the logging system + + // Print a cool startup message with ASCII art and emojis + println!("{}", r#" + ##### ## ## + ## ## ## + ## ## ## ## ##### ####### ##### #### #### + ##### ## ## ## ## ## ## ## ## ## + #### ## ## ## ## ## ## ## ## ## +## ## ## ### ## ## ## ### ## ## ## +## ## ### ## ##### ### ### ## ##### ###### + ## + + Rustapi - An example API built with Rust, Axum, SQLx, and PostgreSQL + GitHub: https://github.com/Riktastic/rustapi + "#); + + println!("🚀 Starting Rustapi..."); + + // Retrieve server IP and port from the environment, default to 0.0.0.0:3000 + let ip: IpAddr = config::get_env_with_default("SERVER_IP", "0.0.0.0") + .parse() + .expect("❌ Invalid IP address format. Please provide a valid IPv4 address. For example 0.0.0.0 or 127.0.0.1."); + let port: u16 = config::get_env_u16("SERVER_PORT", 3000); + let socket_addr = SocketAddr::new(ip, port); + + // Create the Axum app instance using the server configuration + let app = server::create_server().await; + + // Check if HTTPS is enabled in the environment configuration + if config::get_env_bool("SERVER_HTTPS_ENABLED", false) { + // If HTTPS is enabled, start the server with secure HTTPS. + + // Bind TCP listener for incoming connections + let tcp_listener = TcpListener::bind(socket_addr) + .await + .expect("❌ Failed to bind to socket. Port might allready be in use."); // Explicit error handling + + // Load the TLS configuration for secure HTTPS connections + let tls_config = tls::load_tls_config(); + let acceptor = TlsAcceptor::from(Arc::new(tls_config)); // Create a TLS acceptor + let listener = TlsListener { + inner: Arc::new(tcp_listener), // Wrap TCP listener in TlsListener + acceptor: acceptor, + }; + + println!("🔒 Server started with HTTPS at: https://{}:{}", ip, port); + + // Serve the app using the TLS listener (HTTPS) + serve(listener, app.into_make_service()) + .await + .expect("❌ Server failed to start with HTTPS. Did you provide valid certificate and key files?"); + + } else { + // If HTTPS is not enabled, start the server with non-secure HTTP. + + // Bind TCP listener for non-secure HTTP connections + let listener = TcpListener::bind(socket_addr) + .await + .expect("❌ Failed to bind to socket. Port might allready be in use."); // Explicit error handling + + println!("🔓 Server started with HTTP at: http://{}:{}", ip, port); + + // Serve the app using the non-secure TCP listener (HTTP) + serve(listener, app.into_make_service()) + .await + .expect("❌ Server failed to start without HTTPS."); + } +} diff --git a/src/middlewares/auth.rs b/src/middlewares/auth.rs new file mode 100644 index 0000000..c28e868 --- /dev/null +++ b/src/middlewares/auth.rs @@ -0,0 +1,211 @@ +// Standard library imports for working with HTTP, environment variables, and other necessary utilities +use axum::{ + body::Body, + response::IntoResponse, + extract::{Request, Json}, // Extractor for request and JSON body + http::{self, Response, StatusCode}, // HTTP response and status codes + middleware::Next, // For adding middleware layers to the request handling pipeline +}; + +// Importing `State` for sharing application state (such as a database connection) across request handlers +use axum::extract::State; + +// Importing necessary libraries for password hashing, JWT handling, and date/time management +use std::env; // For accessing environment variables +use argon2::{ + password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, // For password hashing and verification + Argon2, +}; + +use chrono::{Duration, Utc}; // For working with time (JWT expiration, etc.) +use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, TokenData, Validation}; // For encoding and decoding JWT tokens +use serde::{Deserialize, Serialize}; // For serializing and deserializing JSON data +use serde_json::json; // For constructing JSON data +use sqlx::PgPool; // For interacting with PostgreSQL databases asynchronously + +// Importing custom database query functions +use crate::database::get_users::get_user_by_email; + +// Define the structure for JWT claims to be included in the token payload +#[derive(Serialize, Deserialize)] +pub struct Claims { + pub exp: usize, // Expiration timestamp (in seconds) + pub iat: usize, // Issued-at timestamp (in seconds) + pub email: String, // User's email +} + +// Custom error type for handling authentication errors +pub struct AuthError { + message: String, + status_code: StatusCode, // HTTP status code to be returned with the error +} + +// Function to verify a password against a stored hash using the Argon2 algorithm +pub fn verify_password(password: &str, hash: &str) -> Result { + let parsed_hash = PasswordHash::new(hash)?; // Parse the hash + // Verify the password using Argon2 + Ok(Argon2::default().verify_password(password.as_bytes(), &parsed_hash).is_ok()) +} + +// Function to hash a password using Argon2 and a salt retrieved from the environment variables +pub fn hash_password(password: &str) -> Result { + // Get the salt from environment variables (must be set) + let salt = env::var("AUTHENTICATION_ARGON2_SALT").expect("AUTHENTICATION_ARGON2_SALT must be set"); + let salt = SaltString::from_b64(&salt).unwrap(); // Convert base64 string to SaltString + let argon2 = Argon2::default(); // Create an Argon2 instance + // Hash the password with the salt + let password_hash = argon2.hash_password(password.as_bytes(), &salt)?.to_string(); + Ok(password_hash) +} + +// Implement the IntoResponse trait for AuthError to allow it to be returned as a response from the handler +impl IntoResponse for AuthError { + fn into_response(self) -> Response { + let body = Json(json!( { "error": self.message } )); // Create a JSON response body with the error message + + // Return a response with the appropriate status code and error message + (self.status_code, body).into_response() + } +} + +// Function to encode a JWT token for the given email address +pub fn encode_jwt(email: String) -> Result { + let jwt_token: String = "randomstring".to_string(); // Secret key for JWT (should be more secure in production) + + let now = Utc::now(); // Get current time + let expire = Duration::hours(24); // Set token expiration to 24 hours + let exp: usize = (now + expire).timestamp() as usize; // Expiration timestamp + let iat: usize = now.timestamp() as usize; // Issued-at timestamp + + let claim = Claims { iat, exp, email }; // Create JWT claims with timestamps and user email + let secret = jwt_token.clone(); // Secret key to sign the token + + // Encode the claims into a JWT token + encode( + &Header::default(), + &claim, + &EncodingKey::from_secret(secret.as_ref()), + ) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) // Return error if encoding fails +} + +// Function to decode a JWT token and extract the claims +pub fn decode_jwt(jwt: String) -> Result, StatusCode> { + let secret = "randomstring".to_string(); // Secret key to verify the JWT (should be more secure in production) + + // Decode the JWT token using the secret key and extract the claims + decode( + &jwt, + &DecodingKey::from_secret(secret.as_ref()), + &Validation::default(), + ) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) // Return error if decoding fails +} + +// Middleware for role-based access control (RBAC) +// Ensures that only users with specific roles are authorized to access certain resources +pub async fn authorize( + mut req: Request, + next: Next, + allowed_roles: Vec, // Accept a vector of allowed roles +) -> Result, AuthError> { + // Retrieve the database pool from request extensions (shared application state) + let pool = req.extensions().get::().expect("Database pool not found in request extensions"); + + // Retrieve the Authorization header from the request + let auth_header = req.headers().get(http::header::AUTHORIZATION); + + // Ensure the header exists and is correctly formatted + let auth_header = match auth_header { + Some(header) => header.to_str().map_err(|_| AuthError { + message: "Invalid header format".to_string(), + status_code: StatusCode::FORBIDDEN, + })?, + None => return Err(AuthError { + message: "Authorization header missing.".to_string(), + status_code: StatusCode::FORBIDDEN, + }), + }; + + // Extract the token from the Authorization header (Bearer token format) + let mut header = auth_header.split_whitespace(); + let (_, token) = (header.next(), header.next()); + + // Decode the JWT token + let token_data = match decode_jwt(token.unwrap().to_string()) { + Ok(data) => data, + Err(_) => return Err(AuthError { + message: "Unable to decode token.".to_string(), + status_code: StatusCode::UNAUTHORIZED, + }), + }; + + // Fetch the user from the database using the email from the decoded token + let current_user = match get_user_by_email(&pool, token_data.claims.email).await { + Ok(user) => user, + Err(_) => return Err(AuthError { + message: "Unauthorized user.".to_string(), + status_code: StatusCode::UNAUTHORIZED, + }), + }; + + // Check if the user's role is in the list of allowed roles + if !allowed_roles.contains(¤t_user.role_id) { + return Err(AuthError { + message: "Forbidden: insufficient role.".to_string(), + status_code: StatusCode::FORBIDDEN, + }); + } + + // Insert the current user into the request extensions for use in subsequent handlers + req.extensions_mut().insert(current_user); + + // Proceed to the next middleware or handler + Ok(next.run(req).await) +} + +// Structure to hold the data from the sign-in request +#[derive(Deserialize)] +pub struct SignInData { + pub email: String, + pub password: String, +} + +// Handler for user sign-in (authentication) +pub async fn sign_in( + State(pool): State, // Database connection pool injected as state + Json(user_data): Json, // Deserialize the JSON body into SignInData +) -> Result, (StatusCode, Json)> { + + // 1. Retrieve user from the database using the provided email + let user = match get_user_by_email(&pool, user_data.email).await { + Ok(user) => user, + Err(_) => return Err(( + StatusCode::UNAUTHORIZED, + Json(json!({ "error": "Incorrect credentials." })) + )), + }; + + // 2. Verify the password using the stored hash + if !verify_password(&user_data.password, &user.password_hash) + .map_err(|_| ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": "Internal server error." })) + ))? + { + return Err(( + StatusCode::UNAUTHORIZED, + Json(json!({ "error": "Incorrect credentials." })) + )); + } + + // 3. Generate a JWT token for the authenticated user + let token = encode_jwt(user.email) + .map_err(|_| ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": "Internal server error." })) + ))?; + + // 4. Return the JWT token to the client + Ok(Json(json!({ "token": token }))) +} diff --git a/src/middlewares/mod.rs b/src/middlewares/mod.rs new file mode 100644 index 0000000..ff8efe8 --- /dev/null +++ b/src/middlewares/mod.rs @@ -0,0 +1,2 @@ +// Module declarations +pub mod auth; \ No newline at end of file diff --git a/src/models/mod.rs b/src/models/mod.rs new file mode 100644 index 0000000..04b23c0 --- /dev/null +++ b/src/models/mod.rs @@ -0,0 +1,5 @@ +/// Module for to-do related models. +pub mod todo; +/// Module for user related models. +pub mod user; + diff --git a/src/models/role.rs b/src/models/role.rs new file mode 100644 index 0000000..c206f8e --- /dev/null +++ b/src/models/role.rs @@ -0,0 +1,18 @@ +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; + +/// Represents a user role in the system. +#[derive(Deserialize, Debug, Serialize, FromRow, Clone)] +#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL +pub struct Role { + /// ID of the role. + pub id: i32, + /// Level of the role. + pub level: i32, + /// System name of the role. + pub role: String, + /// The name of the role. + pub name: String, + /// Description of the role + pub Description: String, +} \ No newline at end of file diff --git a/src/models/todo.rs b/src/models/todo.rs new file mode 100644 index 0000000..60bbabc --- /dev/null +++ b/src/models/todo.rs @@ -0,0 +1,16 @@ +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; + +/// Represents a to-do item. +#[derive(Deserialize, Debug, Serialize, FromRow)] +#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL +pub struct Todo { + /// The unique identifier for the to-do item. + pub id: i32, + /// The task description. + pub task: String, + /// An optional detailed description of the task. + pub description: Option, + /// The unique identifier of the user who created the to-do item. + pub user_id: i32, +} \ No newline at end of file diff --git a/src/models/user.rs b/src/models/user.rs new file mode 100644 index 0000000..2721499 --- /dev/null +++ b/src/models/user.rs @@ -0,0 +1,20 @@ +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; + +/// Represents a user in the system. +#[derive(Deserialize, Debug, Serialize, FromRow, Clone)] +#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL +pub struct User { + /// The unique identifier for the user. + pub id: i32, + /// The username of the user. + pub username: String, + /// The email of the user. + pub email: String, + /// The hashed password for the user. + pub password_hash: String, + /// The TOTP secret for the user. + pub totp_secret: Option, + /// Current role of the user.. + pub role_id: i32, +} \ No newline at end of file diff --git a/src/routes/get_health.rs b/src/routes/get_health.rs new file mode 100644 index 0000000..a519641 --- /dev/null +++ b/src/routes/get_health.rs @@ -0,0 +1,200 @@ +use axum::{response::IntoResponse, Json, extract::State}; +use serde_json::json; +use sqlx::PgPool; +use sysinfo::{System, RefreshKind, Disks}; +use tokio::{task, join}; +use std::sync::{Arc, Mutex}; + +pub async fn get_health(State(database_connection): State) -> impl IntoResponse { + // Use Arc and Mutex to allow sharing System between tasks + let system = Arc::new(Mutex::new(System::new_with_specifics(RefreshKind::everything()))); + + // Run checks in parallel + let (cpu_result, mem_result, disk_result, process_result, db_result, net_result) = join!( + task::spawn_blocking({ + let system = Arc::clone(&system); + move || { + let mut system = system.lock().unwrap(); // Lock the mutex and get a mutable reference + check_cpu_usage(&mut system) // Pass the mutable reference + } + }), + task::spawn_blocking({ + let system = Arc::clone(&system); + move || { + let mut system = system.lock().unwrap(); // Lock the mutex and get a mutable reference + check_memory(&mut system) // Pass the mutable reference + } + }), + task::spawn_blocking({ + move || { + check_disk_usage() // Does not need a system reference. + } + }), + task::spawn_blocking({ + let system = Arc::clone(&system); + move || { + let mut system = system.lock().unwrap(); // Lock the mutex and get a mutable reference + check_processes(&mut system, &["postgres", "Code"]) // Pass the mutable reference + } + }), + check_database_connection(&database_connection), // Async function + task::spawn_blocking(check_network_connection) // Blocking, okay in spawn_blocking + ); + + let mut status = "healthy"; + let mut details = json!({}); + + // Process CPU result + if let Ok(Ok(cpu_details)) = cpu_result { + details["cpu_usage"] = json!(cpu_details); + if cpu_details["status"] == "low" { + status = "degraded"; + } + } else { + details["cpu_usage"] = json!({ "status": "error", "error": "Failed to retrieve CPU usage" }); + status = "degraded"; + } + + // Process Memory result + if let Ok(Ok(mem_details)) = mem_result { + details["memory"] = json!(mem_details); + if mem_details["status"] == "low" { + status = "degraded"; + } + } else { + details["memory"] = json!({ "status": "error", "error": "Failed to retrieve memory information" }); + status = "degraded"; + } + + // Process Disk result + if let Ok(Ok(disk_details)) = disk_result { + details["disk_usage"] = json!(disk_details); + if disk_details["status"] == "critical" { + status = "degraded"; + } + } else { + details["disk_usage"] = json!({ "status": "error", "error": "Failed to retrieve disk usage" }); + status = "degraded"; + } + + // Process Process result + if let Ok(Ok(process_details)) = process_result { + details["important_processes"] = json!(process_details); + if process_details.iter().any(|p| p["status"] == "not running") { + status = "degraded"; + } + } else { + details["important_processes"] = json!({ "status": "error", "error": "Failed to retrieve process information" }); + status = "degraded"; + } + + // Process Database result + if let Ok(db_status) = db_result { + details["database"] = json!({ "status": if db_status { "ok" } else { "degraded" } }); + if !db_status { + status = "degraded"; + } + } else { + details["database"] = json!({ "status": "error", "error": "Failed to retrieve database status" }); + status = "degraded"; + } + + // Process Network result + if let Ok(Ok(net_status)) = net_result { + details["network"] = json!({ "status": if net_status { "ok" } else { "degraded" } }); + if !net_status { + status = "degraded"; + } + } else { + details["network"] = json!({ "status": "error", "error": "Failed to retrieve network status" }); + status = "degraded"; + } + + Json(json!({ + "status": status, + "details": details, + })) +} + +// Helper functions + +fn check_cpu_usage(system: &mut System) -> Result { + system.refresh_cpu_usage(); + let usage = system.global_cpu_usage(); + let available = 100.0 - usage; + Ok(json!( { + "usage_percentage": format!("{:.2}", usage), + "available_percentage": format!("{:.2}", available), + "status": if available < 10.0 { "low" } else { "normal" }, + })) +} + +fn check_memory(system: &mut System) -> Result { + system.refresh_memory(); + let available = system.available_memory() / 1024 / 1024; // Convert to MB + Ok(json!( { + "available_mb": available, + "status": if available < 512 { "low" } else { "normal" }, + })) +} + +fn check_disk_usage() -> Result { + // Create a new Disks object and refresh the disk information + let mut disks = Disks::new(); + disks.refresh(false); // Refresh disk information without performing a full refresh + + // Iterate through the list of disks and check the usage for each one + let usage: Vec<_> = disks.list().iter().map(|disk| { + let total = disk.total_space() as f64; + let available = disk.available_space() as f64; + let used_percentage = ((total - available) / total) * 100.0; + used_percentage + }).collect(); + + // Get the maximum usage percentage + let max_usage = usage.into_iter() + .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + .unwrap_or(0.0); + + // Return the result as a JSON object + Ok(json!( { + "used_percentage": format!("{:.2}", max_usage), + "status": if max_usage > 90.0 { "critical" } else { "ok" }, + })) +} + +fn check_processes(system: &mut System, processes: &[&str]) -> Result, ()> { + system.refresh_processes(sysinfo::ProcessesToUpdate::All, true); + + let process_statuses: Vec<_> = processes.iter().map(|&name| { + // Adjust process names based on the platform and check if they are running + let adjusted_name = if cfg!(target_os = "windows") { + match name { + "postgres" => "postgres.exe", // Postgres on Windows + "Code" => "Code.exe", // Visual Studio Code on Windows + _ => name, // For other platforms, use the name as is + } + } else { + name // For non-Windows platforms, use the name as is + }; + + // Check if the translated (adjusted) process is running + let is_running = system.processes().iter().any(|(_, proc)| proc.name() == adjusted_name); + + // Return a JSON object for each process with its status + json!({ + "name": name, + "status": if is_running { "running" } else { "not running" } + }) + }).collect(); + + Ok(process_statuses) +} + +async fn check_database_connection(pool: &PgPool) -> Result { + sqlx::query("SELECT 1").fetch_one(pool).await.map(|_| true).or_else(|_| Ok(false)) +} + +fn check_network_connection() -> Result { + Ok(std::net::TcpStream::connect("8.8.8.8:53").is_ok()) +} diff --git a/src/routes/get_todos.rs b/src/routes/get_todos.rs new file mode 100644 index 0000000..abecfb7 --- /dev/null +++ b/src/routes/get_todos.rs @@ -0,0 +1,43 @@ +use axum::extract::{State, Path}; +use axum::Json; +use axum::response::IntoResponse; +use sqlx::postgres::PgPool; +use crate::models::todo::*; + +// Get all todos +pub async fn get_all_todos(State(pool): State,) -> impl IntoResponse { + let todos = sqlx::query_as!(Todo, "SELECT * FROM todos") // Your table name + .fetch_all(&pool) // Borrow the connection pool + .await; + + match todos { + Ok(todos) => Ok(Json(todos)), // Return all todos as JSON + Err(err) => Err(( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + format!("Error fetching todos: {}", err), + )), + } +} + + +// Get a single todo by id +pub async fn get_todos_by_id( + State(pool): State, + Path(id): Path, // Use Path extractor here +) -> impl IntoResponse { + let todo = sqlx::query_as!(Todo, "SELECT * FROM todos WHERE id = $1", id) + .fetch_optional(&pool) // Borrow the connection pool + .await; + + match todo { + Ok(Some(todo)) => Ok(Json(todo)), // Return the todo as JSON if found + Ok(None) => Err(( + axum::http::StatusCode::NOT_FOUND, + format!("Todo with id {} not found", id), + )), + Err(err) => Err(( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + format!("Error fetching todo: {}", err), + )), + } +} \ No newline at end of file diff --git a/src/routes/get_users.rs b/src/routes/get_users.rs new file mode 100644 index 0000000..85f071b --- /dev/null +++ b/src/routes/get_users.rs @@ -0,0 +1,87 @@ +use axum::extract::{State, Path}; +use axum::Json; +use axum::response::IntoResponse; +use sqlx::postgres::PgPool; +use crate::models::user::*; // Import the User struct + +// Get all users +pub async fn get_all_users(State(pool): State,) -> impl IntoResponse { + let users = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users") // Your table name + .fetch_all(&pool) // Borrow the connection pool + .await; + + match users { + Ok(users) => Ok(Json(users)), // Return all users as JSON + Err(err) => Err(( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + format!("Error fetching users: {}", err), + )), + } +} + +// Get a single user by id +pub async fn get_users_by_id( + State(pool): State, + Path(id): Path, // Use Path extractor here + ) -> impl IntoResponse { + + let user = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users WHERE id = $1", id) + .fetch_optional(&pool) // Borrow the connection pool + .await; + + match user { + Ok(Some(user)) => Ok(Json(user)), // Return the user as JSON if found + Ok(None) => Err(( + axum::http::StatusCode::NOT_FOUND, + format!("User with id {} not found", id), + )), + Err(err) => Err(( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + format!("Error fetching user: {}", err), + )), + } +} + +// Get a single user by username +pub async fn get_user_by_username( + State(pool): State, + Path(username): Path, // Use Path extractor here for username +) -> impl IntoResponse { + let user = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users WHERE username = $1", username) + .fetch_optional(&pool) // Borrow the connection pool + .await; + + match user { + Ok(Some(user)) => Ok(Json(user)), // Return the user as JSON if found + Ok(None) => Err(( + axum::http::StatusCode::NOT_FOUND, + format!("User with username {} not found", username), + )), + Err(err) => Err(( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + format!("Error fetching user: {}", err), + )), + } +} + +// Get a single user by email +pub async fn get_user_by_email( + State(pool): State, + Path(email): Path, // Use Path extractor here for email +) -> impl IntoResponse { + let user = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users WHERE email = $1", email) + .fetch_optional(&pool) // Borrow the connection pool + .await; + + match user { + Ok(Some(user)) => Ok(Json(user)), // Return the user as JSON if found + Ok(None) => Err(( + axum::http::StatusCode::NOT_FOUND, + format!("User with email {} not found", email), + )), + Err(err) => Err(( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + format!("Error fetching user: {}", err), + )), + } +} diff --git a/src/routes/mod.rs b/src/routes/mod.rs new file mode 100644 index 0000000..92c4c9a --- /dev/null +++ b/src/routes/mod.rs @@ -0,0 +1,59 @@ +// Module declarations for different route handlers +pub mod get_todos; +pub mod get_users; +pub mod post_todos; +pub mod post_users; +pub mod get_health; +pub mod protected; + +// Re-exporting modules to make their contents available at this level +pub use get_todos::*; +pub use get_users::*; +pub use post_todos::*; +pub use post_users::*; +pub use get_health::*; +pub use protected::*; + +use axum::{ + Router, + routing::{get, post}, +}; + +use sqlx::PgPool; + +use crate::middlewares::auth::{sign_in, authorize}; + +/// Function to create and configure all routes +pub fn create_routes(database_connection: PgPool) -> Router { + // Authentication routes + let auth_routes = Router::new() + .route("/signin", post(sign_in)) + .route("/protected", get(protected).route_layer(axum::middleware::from_fn(|req, next| { + let allowed_roles = vec![1, 2]; + authorize(req, next, allowed_roles) + }))); + + // User-related routes + let user_routes = Router::new() + .route("/all", get(get_all_users)) + .route("/{id}", get(get_users_by_id)) + .route("/", post(post_user)); + + // Todo-related routes + let todo_routes = Router::new() + .route("/all", get(get_all_todos)) + .route("/", post(post_todo).layer(axum::middleware::from_fn(|req, next| { + let allowed_roles = vec![1, 2]; + authorize(req, next, allowed_roles) + }))) + .route("/{id}", get(get_todos_by_id)); + + // Combine all routes and add middleware + Router::new() + .merge(auth_routes) // Add authentication routes + .nest("/users", user_routes) // Add user routes under /users + .nest("/todos", todo_routes) // Add todo routes under /todos + .route("/health", get(get_health)) // Add health check route + .layer(axum::Extension(database_connection.clone())) // Add database connection to all routes + .with_state(database_connection) // Add database connection as state +} diff --git a/src/routes/post_todos.rs b/src/routes/post_todos.rs new file mode 100644 index 0000000..d2e6c37 --- /dev/null +++ b/src/routes/post_todos.rs @@ -0,0 +1,53 @@ +use axum::{extract::{State, Extension}, Json}; +use axum::response::IntoResponse; +use sqlx::postgres::PgPool; +use crate::models::todo::*; +use crate::models::user::*; +use serde::Deserialize; +use axum::http::StatusCode; +use serde_json::json; + +#[derive(Deserialize)] +pub struct TodoBody { + pub task: String, + pub description: Option, + pub user_id: i32, +} + +// Add a new todo +pub async fn post_todo( + State(pool): State, + Extension(user): Extension, // Extract current user from the request extensions + Json(todo): Json +) -> impl IntoResponse { + // Ensure the user_id from the request matches the current user's id + if todo.user_id != user.id { + return Err(( + StatusCode::FORBIDDEN, + Json(json!({ "error": "User is not authorized to create a todo for another user" })) + )); + } + + // Insert the todo into the database + let row = sqlx::query!( + "INSERT INTO todos (task, description, user_id) VALUES ($1, $2, $3) RETURNING id, task, description, user_id", + todo.task, + todo.description, + todo.user_id + ) + .fetch_one(&pool) + .await; + + match row { + Ok(row) => Ok(Json(Todo { + id: row.id, + task: row.task, + description: row.description, + user_id: row.user_id, + })), + Err(err) => Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": format!("Error: {}", err) })) + )), + } +} \ No newline at end of file diff --git a/src/routes/post_users.rs b/src/routes/post_users.rs new file mode 100644 index 0000000..51c4978 --- /dev/null +++ b/src/routes/post_users.rs @@ -0,0 +1,44 @@ +use axum::extract::State; +use axum::Json; +use axum::response::IntoResponse; +use sqlx::postgres::PgPool; +use crate::models::user::*; +use serde::Deserialize; + +#[derive(Deserialize)] +pub struct UserBody { + pub username: String, + pub email: String, + pub password_hash: String, + pub totp_secret: String, + pub role_id: i32, +} + +// Add a new user +pub async fn post_user(State(pool): State, Json(user): Json, ) -> impl IntoResponse { + let row = sqlx::query!( + "INSERT INTO users (username, email, password_hash, totp_secret, role_id) VALUES ($1, $2, $3, $4, $5) RETURNING id, username, email, password_hash, totp_secret, role_id", + user.username, + user.email, + user.password_hash, + user.totp_secret, + user.role_id + ) + .fetch_one(&pool) // Use `&pool` to borrow the connection pool + .await; + + match row { + Ok(row) => Ok(Json(User { + id: row.id, + username: row.username, + email: row.email, + password_hash: row.password_hash, + totp_secret: row.totp_secret, + role_id: row.role_id, + })), + Err(err) => Err(( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + format!("Error: {}", err), + )), + } +} \ No newline at end of file diff --git a/src/routes/protected.rs b/src/routes/protected.rs new file mode 100644 index 0000000..277521e --- /dev/null +++ b/src/routes/protected.rs @@ -0,0 +1,18 @@ +use axum::{Extension, Json, response::IntoResponse}; +use serde::{Serialize, Deserialize}; +use crate::models::user::User; + +#[derive(Serialize, Deserialize)] +struct UserResponse { + id: i32, + username: String, + email: String +} + +pub async fn protected(Extension(user): Extension) -> impl IntoResponse { + Json(UserResponse { + id: user.id, + username: user.username, + email: user.email + }) +} \ No newline at end of file