first commit

This commit is contained in:
Rik Heijmann 2025-01-30 22:43:30 +01:00
commit e20f21bc8b
33 changed files with 1667 additions and 0 deletions

35
.dockerignore Normal file
View File

@ -0,0 +1,35 @@
# Include any files or directories that you don't want to be copied to your
# container here (e.g., local build artifacts, temporary files, etc.).
#
# For more help, visit the .dockerignore file reference guide at
# https://docs.docker.com/engine/reference/builder/#dockerignore-file
**/.DS_Store
**/.classpath
**/.dockerignore
**/.env
**/.git
**/.gitignore
**/.project
**/.settings
**/.toolstarget
**/.vs
**/.vscode
**/*.*proj.user
**/*.dbmdl
**/*.jfm
**/charts
**/docker-compose*
**/compose*
**/Dockerfile*
**/node_modules
**/npm-debug.log
**/secrets.dev.yaml
**/values.dev.yaml
/bin
/target
LICENSE
README.md

68
.env.example Normal file
View File

@ -0,0 +1,68 @@
# ==============================
# 📌 DATABASE CONFIGURATION
# ==============================
# PostgreSQL connection URL (format: postgres://user:password@host/database)
DATABASE_URL="postgres://postgres:1234@localhost/database_name"
# Maximum number of connections in the database pool
DATABASE_MAX_CONNECTIONS=20
# Minimum number of connections in the database pool
DATABASE_MIN_CONNECTIONS=5
# ==============================
# 🌍 SERVER CONFIGURATION
# ==============================
# IP address the server will bind to (0.0.0.0 allows all network interfaces)
SERVER_IP="0.0.0.0"
# Port the server will listen on
SERVER_PORT="3000"
# Enable tracing for debugging/logging (true/false)
SERVER_TRACE_ENABLED=true
# ==============================
# 🔒 HTTPS CONFIGURATION
# ==============================
# Enable HTTPS (true/false)
SERVER_HTTPS_ENABLED=false
# Enable HTTP/2 when using HTTPS (true/false)
SERVER_HTTPS_HTTP2_ENABLED=true
# Path to the SSL certificate file (only used if SERVER_HTTPS_ENABLED=true)
SERVER_HTTPS_CERT_FILE_PATH=cert.pem
# Path to the SSL private key file (only used if SERVER_HTTPS_ENABLED=true)
SERVER_HTTPS_KEY_FILE_PATH=key.pem
# ==============================
# 🚦 RATE LIMIT CONFIGURATION
# ==============================
# Maximum number of requests allowed per period
SERVER_RATE_LIMIT=5
# Time period (in seconds) for rate limiting
SERVER_RATE_LIMIT_PERIOD=1
# ==============================
# 📦 COMPRESSION CONFIGURATION
# ==============================
# Enable Brotli compression (true/false)
SERVER_COMPRESSION_ENABLED=true
# Compression level (valid range: 0-11, where 11 is the highest compression)
SERVER_COMPRESSION_LEVEL=6
# ==============================
# 🔑 AUTHENTICATION CONFIGURATION
# ==============================
# Argon2 salt for password hashing (must be kept secret!)
AUTHENTICATION_ARGON2_SALT="dMjQgtSmoQIH3Imi"

5
.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
.env
/target
cert.pem
key.pem
Cargo.lock

45
Cargo.toml Normal file
View File

@ -0,0 +1,45 @@
[package]
name = "rustapi"
version = "0.1.0"
edition = "2021"
[dependencies]
# Web framework and server
axum = { version = "0.8.1", features = ["json"] }
# hyper = { version = "1.5.2", features = ["full"] }
# Database interaction
sqlx = { version = "0.8.3", features = ["runtime-tokio-rustls", "postgres", "migrate"] }
# Serialization and deserialization
serde = { version = "1.0.217", features = ["derive"] }
serde_json = "1.0.137"
# Authentication and security
jsonwebtoken = "9.3.0"
argon2 = "0.5.3"
# Asynchronous runtime and traits
tokio = { version = "1.43.0", features = ["rt-multi-thread", "process"] }
# Configuration and environment
dotenvy = "0.15.7"
# Middleware and server utilities
tower = { version = "0.5.2", features = ["limit"] }
tower-http = { version = "0.6.2", features = ["trace", "cors", "compression-br"] }
# Logging and monitoring
tracing = "0.1.41"
tracing-subscriber = "0.3.19"
# System information
sysinfo = "0.33.1"
# Date and time handling
chrono = "0.4.39"
# SSL / TLS
rustls = "0.23.21"
tokio-rustls = "0.26.1"
rustls-pemfile = "2.2.0"

73
Dockerfile Normal file
View File

@ -0,0 +1,73 @@
# syntax=docker/dockerfile:1
# Comments are provided throughout this file to help you get started.
# If you need more help, visit the Dockerfile reference guide at
# https://docs.docker.com/engine/reference/builder/
################################################################################
# Create a stage for building the application.
ARG RUST_VERSION=1.78.0
ARG APP_NAME=backend
FROM rust:${RUST_VERSION}-slim-bullseye AS build
ARG APP_NAME
WORKDIR /app
# Build the application.
# Leverage a cache mount to /usr/local/cargo/registry/
# for downloaded dependencies and a cache mount to /app/target/ for
# compiled dependencies which will speed up subsequent builds.
# Leverage a bind mount to the src directory to avoid having to copy the
# source code into the container. Once built, copy the executable to an
# output directory before the cache mounted /app/target is unmounted.
RUN --mount=type=bind,source=src,target=src \
# --mount=type=bind,source=configuration.yaml,target=configuration.yaml \
--mount=type=bind,source=Cargo.toml,target=Cargo.toml \
--mount=type=bind,source=Cargo.lock,target=Cargo.lock \
--mount=type=cache,target=/app/target/ \
--mount=type=cache,target=/usr/local/cargo/registry/ \
<<EOF
set -e
cargo build --locked --release
cp ./target/release/$APP_NAME /bin/server
EOF
# COPY /src/web /bin/web
################################################################################
# Create a new stage for running the application that contains the minimal
# runtime dependencies for the application. This often uses a different base
# image from the build stage where the necessary files are copied from the build
# stage.
#
# The example below uses the debian bullseye image as the foundation for running the app.
# By specifying the "bullseye-slim" tag, it will also use whatever happens to be the
# most recent version of that tag when you build your Dockerfile. If
# reproducability is important, consider using a digest
# (e.g., debian@sha256:ac707220fbd7b67fc19b112cee8170b41a9e97f703f588b2cdbbcdcecdd8af57).
FROM debian:bullseye-slim AS final
# Create a non-privileged user that the app will run under.
# See https://docs.docker.com/develop/develop-images/dockerfile_best-practices/#user
ARG UID=10001
RUN adduser \
--disabled-password \
--gecos "" \
--home "/nonexistent" \
--shell "/sbin/nologin" \
--no-create-home \
--uid "${UID}" \
appuser
USER appuser
# Copy the executable from the "build" stage.
COPY --from=build /bin/server /bin/
# Expose the port that the application listens on.
EXPOSE 80
# What the container should run when it is started.
CMD ["/bin/server"]

168
README.md Normal file
View File

@ -0,0 +1,168 @@
```markdown
# 🦀 RustAPI
**An example API built with Rust, Axum, SQLx, and PostgreSQL**
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
## 🚀 Core Features
- **Rust API Template** - Production-ready starter template with modern practices
- **PostgreSQL Integration** - Full database support with SQLx migrations
- **Comprehensive Health Monitoring**
Docker-compatible endpoint with system metrics:
```json
{
"details": {
"cpu_usage": {"available_percentage": "9.85", "status": "low"},
"database": {"status": "ok"},
"disk_usage": {"status": "ok", "used_percentage": "74.00"},
"memory": {"available_mb": 21613, "status": "normal"}
},
"status": "degraded"
}
```
- **JWT Authentication** - Secure token-based auth with Argon2 password hashing
- **Granular Access Control** - Role-based endpoint protection:
```rust
.route("/", post(post_todo).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1, 2];
authorize(req, next, allowed_roles)
})))
```
- **User Context Injection** - Automatic user profile handling in endpoints:
```rust
pub async fn post_todo(
Extension(user): Extension<User>, // Injected user
Json(todo): Json<TodoBody>
) -> impl IntoResponse {
if todo.user_id != user.id {
return Err((StatusCode::FORBIDDEN, Json(json!({
"error": "Cannot create todos for others"
}))));
}
```
- **Modern protocols ** - HTTP/2 with secure TLS defaults
- **Observability** - Integrated tracing
- **Optimized for performance** - Brotli compression
- **Easy configuration** - `.env` and environment variables
- **Documented codebase** - Extensive inline comments for easy modification and readability
- **Latest dependencies** - Regularly updated Rust ecosystem crates
## 📦 Installation & Usage
```bash
# Clone and setup
git clone https://github.com/Riktastic/rustapi.git
cd rustapi && cp .env.example .env
# Database setup
sqlx database create && sqlx migrate run
# Start server
cargo run --release
```
### 🔐 Default Accounts
**Warning:** These accounts should only be used for initial testing. Always change or disable them in production environments.
| Email | Password | Role |
|---------------------|----------|----------------|
| `user@test.com` | `test` | User |
| `admin@test.com` | `test` | Administrator |
⚠️ **Security Recommendations:**
1. Rotate passwords immediately after initial setup
2. Disable default accounts before deploying to production
3. Implement proper user management endpoints
## ⚙️ Configuration
```env
# ==============================
# 📌 DATABASE CONFIGURATION
# ==============================
# PostgreSQL connection URL (format: postgres://user:password@host/database)
DATABASE_URL="postgres://postgres:1234@localhost/database_name"
# Maximum number of connections in the database pool
DATABASE_MAX_CONNECTIONS=20
# Minimum number of connections in the database pool
DATABASE_MIN_CONNECTIONS=5
# ==============================
# 🌍 SERVER CONFIGURATION
# ==============================
# IP address the server will bind to (0.0.0.0 allows all network interfaces)
SERVER_IP="0.0.0.0"
# Port the server will listen on
SERVER_PORT="3000"
# Enable tracing for debugging/logging (true/false)
SERVER_TRACE_ENABLED=true
# ==============================
# 🔒 HTTPS CONFIGURATION
# ==============================
# Enable HTTPS (true/false)
SERVER_HTTPS_ENABLED=false
# Enable HTTP/2 when using HTTPS (true/false)
SERVER_HTTPS_HTTP2_ENABLED=true
# Path to the SSL certificate file (only used if SERVER_HTTPS_ENABLED=true)
SERVER_HTTPS_CERT_FILE_PATH=cert.pem
# Path to the SSL private key file (only used if SERVER_HTTPS_ENABLED=true)
SERVER_HTTPS_KEY_FILE_PATH=key.pem
# ==============================
# 🚦 RATE LIMIT CONFIGURATION
# ==============================
# Maximum number of requests allowed per period
SERVER_RATE_LIMIT=5
# Time period (in seconds) for rate limiting
SERVER_RATE_LIMIT_PERIOD=1
# ==============================
# 📦 COMPRESSION CONFIGURATION
# ==============================
# Enable Brotli compression (true/false)
SERVER_COMPRESSION_ENABLED=true
# Compression level (valid range: 0-11, where 11 is the highest compression)
SERVER_COMPRESSION_LEVEL=6
# ==============================
# 🔑 AUTHENTICATION CONFIGURATION
# ==============================
# Argon2 salt for password hashing (must be kept secret!)
AUTHENTICATION_ARGON2_SALT="dMjQgtSmoQIH3Imi"
```
## 📂 Project Structure
```
rustapi/
├── migrations/ # SQL schema versions
├── src/
│ ├── core/ # Config, TLS, server setup
│ ├── database/ # Query handling
│ ├── middlewares/ # Auth system
│ ├── models/ # Data structures
│ └── routes/ # API endpoints
└── Dockerfile # Containerization
```
## 🛠️ Technology Stack
| Category | Key Technologies |
|-----------------------|---------------------------------|
| Web Framework | Axum 0.8 + Tower |
| Database | PostgreSQL + SQLx 0.8 |
| Security | JWT + Argon2 + Rustls |
| Monitoring | Tracing + Sysinfo |

32
compose.yaml Normal file
View File

@ -0,0 +1,32 @@
services:
server:
build:
context: .
target: final
ports:
- 80:80
depends_on:
- db_image
networks:
- common-net
db_image:
image: postgres:latest
environment:
POSTGRES_PORT: 3306
POSTGRES_DATABASE: database_name
POSTGRES_USER: user
POSTGRES_PASSWORD: database_password
POSTGRES_ROOT_PASSWORD: strong_database_password
expose:
- 3306
ports:
- "3307:3306"
networks:
- common-net
networks:
common-net: {}

1
generate_ssl_key.bat Normal file
View File

@ -0,0 +1 @@
openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes -subj "/CN=localhost" -addext "subjectAltName = DNS:localhost, IP:127.0.0.1"

View File

@ -0,0 +1,19 @@
-- Create the roles table
CREATE TABLE IF NOT EXISTS roles (
id SERIAL PRIMARY KEY,
level INT NOT NULL,
role VARCHAR(255) NOT NULL,
name VARCHAR(255) NOT NULL,
description VARCHAR(255),
CONSTRAINT unique_role UNIQUE (role) -- Add a unique constraint to the 'role' column
);
-- Insert a role into the roles table (this assumes role_id=1 is for 'user')
INSERT INTO roles (level, role, name, description)
VALUES (1, 'user', 'User', 'A regular user with basic access.')
ON CONFLICT (role) DO NOTHING; -- Prevent duplicate insertions if role already exists
-- Insert a role into the roles table (this assumes role_id=2 is for 'admin')
INSERT INTO roles (level, role, name, description)
VALUES (2, 'admin', 'Administrator', 'An administrator.')
ON CONFLICT (role) DO NOTHING; -- Prevent duplicate insertions if role already exists

View File

@ -0,0 +1,21 @@
CREATE TABLE users (
id SERIAL PRIMARY KEY,
username VARCHAR(255) NOT NULL UNIQUE,
email VARCHAR(255) NOT NULL UNIQUE,
password_hash VARCHAR(255) NOT NULL,
totp_secret VARCHAR(255),
role_id INT NOT NULL DEFAULT 1 REFERENCES roles(id), -- Default role_id is set to 1
CONSTRAINT unique_username UNIQUE (username) -- Ensure that username is unique
);
-- Insert the example 'user' into the users table with a conflict check for username
INSERT INTO users (username, email, password_hash, role_id)
VALUES
('user', 'user@test.com', '$argon2i$v=19$m=16,t=2,p=1$ZE1qUWd0U21vUUlIM0ltaQ$dowBmjU4oHtoPd355dXypQ', 1)
ON CONFLICT (username) DO NOTHING; -- Prevent duplicate insertions if username already exists
-- Insert the example 'admin' into the users table with a conflict check for username
INSERT INTO users (username, email, password_hash, role_id)
VALUES
('admin', 'admin@test.com', '$argon2i$v=19$m=16,t=2,p=1$ZE1qUWd0U21vUUlIM0ltaQ$dowBmjU4oHtoPd355dXypQ', 2)
ON CONFLICT (username) DO NOTHING; -- Prevent duplicate insertions if username already exists

View File

@ -0,0 +1,6 @@
CREATE TABLE todos (
id SERIAL PRIMARY KEY, -- Auto-incrementing primary key
task TEXT NOT NULL, -- Task description, cannot be null
description TEXT, -- Optional detailed description
user_id INT NOT NULL REFERENCES users(id) -- Foreign key to link to users table
);

55
src/core/config.rs Normal file
View File

@ -0,0 +1,55 @@
// Import the standard library's environment module
use std::env;
/// Retrieves the value of an environment variable as a `String`.
///
/// # Arguments
/// * `key` - The name of the environment variable to retrieve.
///
/// # Returns
/// * The value of the environment variable if it exists.
/// * Panics if the environment variable is missing.
pub fn get_env(key: &str) -> String {
env::var(key).unwrap_or_else(|_| panic!("Missing required environment variable: {}", key))
}
/// Retrieves the value of an environment variable as a `String`, with a default value if not found.
///
/// # Arguments
/// * `key` - The name of the environment variable to retrieve.
/// * `default` - The value to return if the environment variable is not found.
///
/// # Returns
/// * The value of the environment variable if it exists.
/// * The `default` value if the environment variable is missing.
pub fn get_env_with_default(key: &str, default: &str) -> String {
env::var(key).unwrap_or_else(|_| default.to_string())
}
/// Retrieves the value of an environment variable as a `bool`, with a default value if not found.
///
/// The environment variable is considered `true` if its value is "true" (case-insensitive), otherwise `false`.
///
/// # Arguments
/// * `key` - The name of the environment variable to retrieve.
/// * `default` - The value to return if the environment variable is not found.
///
/// # Returns
/// * `true` if the environment variable is "true" (case-insensitive).
/// * `false` otherwise, or if the variable is missing, the `default` value is returned.
pub fn get_env_bool(key: &str, default: bool) -> bool {
env::var(key).map(|v| v.to_lowercase() == "true").unwrap_or(default)
}
/// Retrieves the value of an environment variable as a `u16`, with a default value if not found.
///
/// # Arguments
/// * `key` - The name of the environment variable to retrieve.
/// * `default` - The value to return if the environment variable is not found or cannot be parsed.
///
/// # Returns
/// * The parsed `u16` value of the environment variable if it exists and is valid.
/// * The `default` value if the variable is missing or invalid.
pub fn get_env_u16(key: &str, default: u16) -> u16 {
env::var(key).unwrap_or_else(|_| default.to_string()).parse().unwrap_or(default)
}

4
src/core/mod.rs Normal file
View File

@ -0,0 +1,4 @@
pub mod config;
pub mod server;
pub mod tls;

38
src/core/server.rs Normal file
View File

@ -0,0 +1,38 @@
// Axum for web server and routing
use axum::Router;
// Middleware layers from tower_http
use tower_http::compression::{CompressionLayer, CompressionLevel}; // For HTTP response compression
use tower_http::trace::TraceLayer; // For HTTP request/response tracing
// Local crate imports for database connection and configuration
use crate::database::connect::connect_to_database; // Function to connect to the database
use crate::config; // Environment configuration helper
/// Function to create and configure the Axum server.
pub async fn create_server() -> Router {
// Establish a connection to the database
let db = connect_to_database().await.expect("Failed to connect to database.");
// Initialize the routes for the server
let mut app = crate::routes::create_routes(db);
// Enable tracing middleware if configured
if config::get_env_bool("SERVER_TRACE_ENABLED", true) {
app = app.layer(TraceLayer::new_for_http());
println!("✔️ Trace hads been enabled.");
}
// Enable compression middleware if configured
if config::get_env_bool("SERVER_COMPRESSION_ENABLED", true) {
// Parse compression level from environment or default to level 6
let level = config::get_env("SERVER_COMPRESSION_LEVEL").parse().unwrap_or(6);
// Apply compression layer with Brotli (br) enabled and the specified compression level
app = app.layer(CompressionLayer::new().br(true).quality(CompressionLevel::Precise(level)));
println!("✔️ Brotli compression enabled with compression quality level {}.", level);
}
// Return the fully configured application
app
}

145
src/core/tls.rs Normal file
View File

@ -0,0 +1,145 @@
// Standard library imports
use std::{
future::Future,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
io::BufReader,
fs::File,
iter,
};
// External crate imports
use axum::serve::Listener;
use rustls::{self, server::ServerConfig, pki_types::{PrivateKeyDer, CertificateDer}};
use rustls_pemfile::{Item, read_one, certs};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing;
// Local crate imports
use crate::config; // Import env config helper
// Function to load TLS configuration from files
pub fn load_tls_config() -> ServerConfig {
// Get certificate and key file paths from the environment
let cert_path = config::get_env("SERVER_HTTPS_CERT_FILE_PATH");
let key_path = config::get_env("SERVER_HTTPS_KEY_FILE_PATH");
// Open the certificate and key files
let cert_file = File::open(cert_path).expect("❌ Failed to open certificate file.");
let key_file = File::open(key_path).expect("❌ Failed to open private key file.");
// Read the certificate chain and private key from the files
let mut cert_reader = BufReader::new(cert_file);
let mut key_reader = BufReader::new(key_file);
// Read and parse the certificate chain
let cert_chain: Vec<CertificateDer> = certs(&mut cert_reader)
.map(|cert| cert.expect("❌ Failed to read certificate."))
.map(CertificateDer::from)
.collect();
// Ensure certificates are found
if cert_chain.is_empty() {
panic!("❌ No valid certificates found.");
}
// Read the private key from the file
let key = iter::from_fn(|| read_one(&mut key_reader).transpose())
.find_map(|item| match item.unwrap() {
Item::Pkcs1Key(key) => Some(PrivateKeyDer::from(key)),
Item::Pkcs8Key(key) => Some(PrivateKeyDer::from(key)),
Item::Sec1Key(key) => Some(PrivateKeyDer::from(key)),
_ => None,
})
.expect("Failed to read a valid private key.");
// Build and return the TLS server configuration
ServerConfig::builder()
.with_no_client_auth() // No client authentication
.with_single_cert(cert_chain, key) // Use the provided cert and key
.expect("Failed to create TLS configuration")
}
// Custom listener that implements axum::serve::Listener
#[derive(Clone)]
pub struct TlsListener {
pub inner: Arc<tokio::net::TcpListener>, // Inner TCP listener
pub acceptor: tokio_rustls::TlsAcceptor, // TLS acceptor for handling TLS handshakes
}
impl Listener for TlsListener {
type Io = TlsStreamWrapper; // Type of I/O stream
type Addr = SocketAddr; // Type of address (Socket address)
// Method to accept incoming connections and establish a TLS handshake
fn accept(&mut self) -> impl Future<Output = (Self::Io, Self::Addr)> + Send {
let acceptor = self.acceptor.clone(); // Clone the acceptor for async use
async move {
loop {
// Accept a TCP connection
let (stream, addr) = match self.inner.accept().await {
Ok((stream, addr)) => (stream, addr),
Err(e) => {
tracing::error!("❌ Error accepting TCP connection: {}", e);
continue; // Retry on error
}
};
// Perform TLS handshake
match acceptor.accept(stream).await {
Ok(tls_stream) => {
tracing::info!("Successful TLS handshake with {}", addr);
return (TlsStreamWrapper(tls_stream), addr); // Return TLS stream and address
},
Err(e) => {
tracing::warn!("TLS handshake failed: {} (Client may not trust certificate)", e);
continue; // Retry on error
}
}
}
}
}
// Method to retrieve the local address of the listener
fn local_addr(&self) -> std::io::Result<Self::Addr> {
self.inner.local_addr()
}
}
// Wrapper for a TLS stream, implementing AsyncRead and AsyncWrite
#[derive(Debug)]
pub struct TlsStreamWrapper(tokio_rustls::server::TlsStream<tokio::net::TcpStream>);
impl AsyncRead for TlsStreamWrapper {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf) // Delegate read operation to the underlying TLS stream
}
}
impl AsyncWrite for TlsStreamWrapper {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf) // Delegate write operation to the underlying TLS stream
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx) // Flush operation for the TLS stream
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx) // Shutdown operation for the TLS stream
}
}
// Allow the TLS stream wrapper to be used in non-blocking contexts (needed for async operations)
impl Unpin for TlsStreamWrapper {}

51
src/database/connect.rs Normal file
View File

@ -0,0 +1,51 @@
use dotenvy::dotenv;
use sqlx::{PgPool, migrate::Migrator, postgres::PgPoolOptions};
use std::fs;
use std::env;
use std::path::Path;
/// Connects to the database using the DATABASE_URL environment variable.
pub async fn connect_to_database() -> Result<PgPool, sqlx::Error> {
dotenv().ok();
let database_url = &env::var("DATABASE_URL").expect("❌ 'DATABASE_URL' environment variable not fount.");
// Read max and min connection values from environment variables, with defaults
let max_connections: u32 = env::var("DATABASE_MAX_CONNECTIONS")
.unwrap_or_else(|_| "10".to_string()) // Default to 10
.parse()
.expect("❌ Invalid 'DATABASE_MAX_CONNECTIONS' value; must be a number.");
let min_connections: u32 = env::var("DATABASE_MIN_CONNECTIONS")
.unwrap_or_else(|_| "2".to_string()) // Default to 2
.parse()
.expect("❌ Invalid 'DATABASE_MIN_CONNECTIONS' value; must be a number.");
// Create and configure the connection pool
let pool = PgPoolOptions::new()
.max_connections(max_connections)
.min_connections(min_connections)
.connect(&database_url)
.await?;
Ok(pool)
}
/// Run database migrations
pub async fn run_database_migrations(pool: &PgPool) -> Result<(), sqlx::Error> {
// Define the path to the migrations folder
let migrations_path = Path::new("./migrations");
// Check if the migrations folder exists, and if not, create it
if !migrations_path.exists() {
fs::create_dir_all(migrations_path).expect("❌ Failed to create migrations directory. Make sure you have the necessary permissions.");
println!("✔️ Created migrations directory: {:?}", migrations_path);
}
// Create a migrator instance that looks for migrations in the `./migrations` folder
let migrator = Migrator::new(migrations_path).await?;
// Run all pending migrations
migrator.run(pool).await?;
Ok(())
}

25
src/database/get_users.rs Normal file
View File

@ -0,0 +1,25 @@
use sqlx::postgres::PgPool;
use crate::models::user::*; // Import the User struct
// Get all users
pub async fn get_user_by_email(pool: &PgPool, email: String) -> Result<User, String> {
// Use a string literal directly in the macro
let user = sqlx::query_as!(
User, // Struct type to map the query result
r#"
SELECT id, username, email, password_hash, totp_secret, role_id
FROM users
WHERE email = $1
"#,
email // Bind the `email` parameter
)
.fetch_optional(pool)
.await
.map_err(|e| format!("Database error: {}", e))?; // Handle database errors
// Handle optional result
match user {
Some(user) => Ok(user),
None => Err(format!("User with email '{}' not found.", email)),
}
}

3
src/database/mod.rs Normal file
View File

@ -0,0 +1,3 @@
// Module declarations
pub mod connect;
pub mod get_users;

6
src/handlers/mod.rs Normal file
View File

@ -0,0 +1,6 @@
// Module declarations
// pub mod auth;
// Re-exporting modules
// pub use auth::*;

91
src/main.rs Normal file
View File

@ -0,0 +1,91 @@
#[allow(dead_code)]
// Core modules for the configuration, TLS setup, and server creation
mod core;
use core::{config, tls, server};
use core::tls::TlsListener;
// Other modules for database, routes, models, and middlewares
mod database;
mod routes;
mod models;
mod middlewares;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor;
use axum::serve;
#[tokio::main]
async fn main() {
dotenvy::dotenv().ok(); // Load environment variables from a .env file
tracing_subscriber::fmt::init(); // Initialize the logging system
// Print a cool startup message with ASCII art and emojis
println!("{}", r#"
##### ## ##
## ## ##
## ## ## ## ##### ####### ##### #### ####
##### ## ## ## ## ## ## ## ## ##
#### ## ## ## ## ## ## ## ## ##
## ## ## ### ## ## ## ### ## ## ##
## ## ### ## ##### ### ### ## ##### ######
##
Rustapi - An example API built with Rust, Axum, SQLx, and PostgreSQL
GitHub: https://github.com/Riktastic/rustapi
"#);
println!("🚀 Starting Rustapi...");
// Retrieve server IP and port from the environment, default to 0.0.0.0:3000
let ip: IpAddr = config::get_env_with_default("SERVER_IP", "0.0.0.0")
.parse()
.expect("❌ Invalid IP address format. Please provide a valid IPv4 address. For example 0.0.0.0 or 127.0.0.1.");
let port: u16 = config::get_env_u16("SERVER_PORT", 3000);
let socket_addr = SocketAddr::new(ip, port);
// Create the Axum app instance using the server configuration
let app = server::create_server().await;
// Check if HTTPS is enabled in the environment configuration
if config::get_env_bool("SERVER_HTTPS_ENABLED", false) {
// If HTTPS is enabled, start the server with secure HTTPS.
// Bind TCP listener for incoming connections
let tcp_listener = TcpListener::bind(socket_addr)
.await
.expect("❌ Failed to bind to socket. Port might allready be in use."); // Explicit error handling
// Load the TLS configuration for secure HTTPS connections
let tls_config = tls::load_tls_config();
let acceptor = TlsAcceptor::from(Arc::new(tls_config)); // Create a TLS acceptor
let listener = TlsListener {
inner: Arc::new(tcp_listener), // Wrap TCP listener in TlsListener
acceptor: acceptor,
};
println!("🔒 Server started with HTTPS at: https://{}:{}", ip, port);
// Serve the app using the TLS listener (HTTPS)
serve(listener, app.into_make_service())
.await
.expect("❌ Server failed to start with HTTPS. Did you provide valid certificate and key files?");
} else {
// If HTTPS is not enabled, start the server with non-secure HTTP.
// Bind TCP listener for non-secure HTTP connections
let listener = TcpListener::bind(socket_addr)
.await
.expect("❌ Failed to bind to socket. Port might allready be in use."); // Explicit error handling
println!("🔓 Server started with HTTP at: http://{}:{}", ip, port);
// Serve the app using the non-secure TCP listener (HTTP)
serve(listener, app.into_make_service())
.await
.expect("❌ Server failed to start without HTTPS.");
}
}

211
src/middlewares/auth.rs Normal file
View File

@ -0,0 +1,211 @@
// Standard library imports for working with HTTP, environment variables, and other necessary utilities
use axum::{
body::Body,
response::IntoResponse,
extract::{Request, Json}, // Extractor for request and JSON body
http::{self, Response, StatusCode}, // HTTP response and status codes
middleware::Next, // For adding middleware layers to the request handling pipeline
};
// Importing `State` for sharing application state (such as a database connection) across request handlers
use axum::extract::State;
// Importing necessary libraries for password hashing, JWT handling, and date/time management
use std::env; // For accessing environment variables
use argon2::{
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, // For password hashing and verification
Argon2,
};
use chrono::{Duration, Utc}; // For working with time (JWT expiration, etc.)
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, TokenData, Validation}; // For encoding and decoding JWT tokens
use serde::{Deserialize, Serialize}; // For serializing and deserializing JSON data
use serde_json::json; // For constructing JSON data
use sqlx::PgPool; // For interacting with PostgreSQL databases asynchronously
// Importing custom database query functions
use crate::database::get_users::get_user_by_email;
// Define the structure for JWT claims to be included in the token payload
#[derive(Serialize, Deserialize)]
pub struct Claims {
pub exp: usize, // Expiration timestamp (in seconds)
pub iat: usize, // Issued-at timestamp (in seconds)
pub email: String, // User's email
}
// Custom error type for handling authentication errors
pub struct AuthError {
message: String,
status_code: StatusCode, // HTTP status code to be returned with the error
}
// Function to verify a password against a stored hash using the Argon2 algorithm
pub fn verify_password(password: &str, hash: &str) -> Result<bool, argon2::password_hash::Error> {
let parsed_hash = PasswordHash::new(hash)?; // Parse the hash
// Verify the password using Argon2
Ok(Argon2::default().verify_password(password.as_bytes(), &parsed_hash).is_ok())
}
// Function to hash a password using Argon2 and a salt retrieved from the environment variables
pub fn hash_password(password: &str) -> Result<String, argon2::password_hash::Error> {
// Get the salt from environment variables (must be set)
let salt = env::var("AUTHENTICATION_ARGON2_SALT").expect("AUTHENTICATION_ARGON2_SALT must be set");
let salt = SaltString::from_b64(&salt).unwrap(); // Convert base64 string to SaltString
let argon2 = Argon2::default(); // Create an Argon2 instance
// Hash the password with the salt
let password_hash = argon2.hash_password(password.as_bytes(), &salt)?.to_string();
Ok(password_hash)
}
// Implement the IntoResponse trait for AuthError to allow it to be returned as a response from the handler
impl IntoResponse for AuthError {
fn into_response(self) -> Response<Body> {
let body = Json(json!( { "error": self.message } )); // Create a JSON response body with the error message
// Return a response with the appropriate status code and error message
(self.status_code, body).into_response()
}
}
// Function to encode a JWT token for the given email address
pub fn encode_jwt(email: String) -> Result<String, StatusCode> {
let jwt_token: String = "randomstring".to_string(); // Secret key for JWT (should be more secure in production)
let now = Utc::now(); // Get current time
let expire = Duration::hours(24); // Set token expiration to 24 hours
let exp: usize = (now + expire).timestamp() as usize; // Expiration timestamp
let iat: usize = now.timestamp() as usize; // Issued-at timestamp
let claim = Claims { iat, exp, email }; // Create JWT claims with timestamps and user email
let secret = jwt_token.clone(); // Secret key to sign the token
// Encode the claims into a JWT token
encode(
&Header::default(),
&claim,
&EncodingKey::from_secret(secret.as_ref()),
)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) // Return error if encoding fails
}
// Function to decode a JWT token and extract the claims
pub fn decode_jwt(jwt: String) -> Result<TokenData<Claims>, StatusCode> {
let secret = "randomstring".to_string(); // Secret key to verify the JWT (should be more secure in production)
// Decode the JWT token using the secret key and extract the claims
decode(
&jwt,
&DecodingKey::from_secret(secret.as_ref()),
&Validation::default(),
)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) // Return error if decoding fails
}
// Middleware for role-based access control (RBAC)
// Ensures that only users with specific roles are authorized to access certain resources
pub async fn authorize(
mut req: Request<Body>,
next: Next,
allowed_roles: Vec<i32>, // Accept a vector of allowed roles
) -> Result<Response<Body>, AuthError> {
// Retrieve the database pool from request extensions (shared application state)
let pool = req.extensions().get::<PgPool>().expect("Database pool not found in request extensions");
// Retrieve the Authorization header from the request
let auth_header = req.headers().get(http::header::AUTHORIZATION);
// Ensure the header exists and is correctly formatted
let auth_header = match auth_header {
Some(header) => header.to_str().map_err(|_| AuthError {
message: "Invalid header format".to_string(),
status_code: StatusCode::FORBIDDEN,
})?,
None => return Err(AuthError {
message: "Authorization header missing.".to_string(),
status_code: StatusCode::FORBIDDEN,
}),
};
// Extract the token from the Authorization header (Bearer token format)
let mut header = auth_header.split_whitespace();
let (_, token) = (header.next(), header.next());
// Decode the JWT token
let token_data = match decode_jwt(token.unwrap().to_string()) {
Ok(data) => data,
Err(_) => return Err(AuthError {
message: "Unable to decode token.".to_string(),
status_code: StatusCode::UNAUTHORIZED,
}),
};
// Fetch the user from the database using the email from the decoded token
let current_user = match get_user_by_email(&pool, token_data.claims.email).await {
Ok(user) => user,
Err(_) => return Err(AuthError {
message: "Unauthorized user.".to_string(),
status_code: StatusCode::UNAUTHORIZED,
}),
};
// Check if the user's role is in the list of allowed roles
if !allowed_roles.contains(&current_user.role_id) {
return Err(AuthError {
message: "Forbidden: insufficient role.".to_string(),
status_code: StatusCode::FORBIDDEN,
});
}
// Insert the current user into the request extensions for use in subsequent handlers
req.extensions_mut().insert(current_user);
// Proceed to the next middleware or handler
Ok(next.run(req).await)
}
// Structure to hold the data from the sign-in request
#[derive(Deserialize)]
pub struct SignInData {
pub email: String,
pub password: String,
}
// Handler for user sign-in (authentication)
pub async fn sign_in(
State(pool): State<PgPool>, // Database connection pool injected as state
Json(user_data): Json<SignInData>, // Deserialize the JSON body into SignInData
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
// 1. Retrieve user from the database using the provided email
let user = match get_user_by_email(&pool, user_data.email).await {
Ok(user) => user,
Err(_) => return Err((
StatusCode::UNAUTHORIZED,
Json(json!({ "error": "Incorrect credentials." }))
)),
};
// 2. Verify the password using the stored hash
if !verify_password(&user_data.password, &user.password_hash)
.map_err(|_| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error." }))
))?
{
return Err((
StatusCode::UNAUTHORIZED,
Json(json!({ "error": "Incorrect credentials." }))
));
}
// 3. Generate a JWT token for the authenticated user
let token = encode_jwt(user.email)
.map_err(|_| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error." }))
))?;
// 4. Return the JWT token to the client
Ok(Json(json!({ "token": token })))
}

2
src/middlewares/mod.rs Normal file
View File

@ -0,0 +1,2 @@
// Module declarations
pub mod auth;

5
src/models/mod.rs Normal file
View File

@ -0,0 +1,5 @@
/// Module for to-do related models.
pub mod todo;
/// Module for user related models.
pub mod user;

18
src/models/role.rs Normal file
View File

@ -0,0 +1,18 @@
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
/// Represents a user role in the system.
#[derive(Deserialize, Debug, Serialize, FromRow, Clone)]
#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL
pub struct Role {
/// ID of the role.
pub id: i32,
/// Level of the role.
pub level: i32,
/// System name of the role.
pub role: String,
/// The name of the role.
pub name: String,
/// Description of the role
pub Description: String,
}

16
src/models/todo.rs Normal file
View File

@ -0,0 +1,16 @@
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
/// Represents a to-do item.
#[derive(Deserialize, Debug, Serialize, FromRow)]
#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL
pub struct Todo {
/// The unique identifier for the to-do item.
pub id: i32,
/// The task description.
pub task: String,
/// An optional detailed description of the task.
pub description: Option<String>,
/// The unique identifier of the user who created the to-do item.
pub user_id: i32,
}

20
src/models/user.rs Normal file
View File

@ -0,0 +1,20 @@
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
/// Represents a user in the system.
#[derive(Deserialize, Debug, Serialize, FromRow, Clone)]
#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL
pub struct User {
/// The unique identifier for the user.
pub id: i32,
/// The username of the user.
pub username: String,
/// The email of the user.
pub email: String,
/// The hashed password for the user.
pub password_hash: String,
/// The TOTP secret for the user.
pub totp_secret: Option<String>,
/// Current role of the user..
pub role_id: i32,
}

200
src/routes/get_health.rs Normal file
View File

@ -0,0 +1,200 @@
use axum::{response::IntoResponse, Json, extract::State};
use serde_json::json;
use sqlx::PgPool;
use sysinfo::{System, RefreshKind, Disks};
use tokio::{task, join};
use std::sync::{Arc, Mutex};
pub async fn get_health(State(database_connection): State<PgPool>) -> impl IntoResponse {
// Use Arc and Mutex to allow sharing System between tasks
let system = Arc::new(Mutex::new(System::new_with_specifics(RefreshKind::everything())));
// Run checks in parallel
let (cpu_result, mem_result, disk_result, process_result, db_result, net_result) = join!(
task::spawn_blocking({
let system = Arc::clone(&system);
move || {
let mut system = system.lock().unwrap(); // Lock the mutex and get a mutable reference
check_cpu_usage(&mut system) // Pass the mutable reference
}
}),
task::spawn_blocking({
let system = Arc::clone(&system);
move || {
let mut system = system.lock().unwrap(); // Lock the mutex and get a mutable reference
check_memory(&mut system) // Pass the mutable reference
}
}),
task::spawn_blocking({
move || {
check_disk_usage() // Does not need a system reference.
}
}),
task::spawn_blocking({
let system = Arc::clone(&system);
move || {
let mut system = system.lock().unwrap(); // Lock the mutex and get a mutable reference
check_processes(&mut system, &["postgres", "Code"]) // Pass the mutable reference
}
}),
check_database_connection(&database_connection), // Async function
task::spawn_blocking(check_network_connection) // Blocking, okay in spawn_blocking
);
let mut status = "healthy";
let mut details = json!({});
// Process CPU result
if let Ok(Ok(cpu_details)) = cpu_result {
details["cpu_usage"] = json!(cpu_details);
if cpu_details["status"] == "low" {
status = "degraded";
}
} else {
details["cpu_usage"] = json!({ "status": "error", "error": "Failed to retrieve CPU usage" });
status = "degraded";
}
// Process Memory result
if let Ok(Ok(mem_details)) = mem_result {
details["memory"] = json!(mem_details);
if mem_details["status"] == "low" {
status = "degraded";
}
} else {
details["memory"] = json!({ "status": "error", "error": "Failed to retrieve memory information" });
status = "degraded";
}
// Process Disk result
if let Ok(Ok(disk_details)) = disk_result {
details["disk_usage"] = json!(disk_details);
if disk_details["status"] == "critical" {
status = "degraded";
}
} else {
details["disk_usage"] = json!({ "status": "error", "error": "Failed to retrieve disk usage" });
status = "degraded";
}
// Process Process result
if let Ok(Ok(process_details)) = process_result {
details["important_processes"] = json!(process_details);
if process_details.iter().any(|p| p["status"] == "not running") {
status = "degraded";
}
} else {
details["important_processes"] = json!({ "status": "error", "error": "Failed to retrieve process information" });
status = "degraded";
}
// Process Database result
if let Ok(db_status) = db_result {
details["database"] = json!({ "status": if db_status { "ok" } else { "degraded" } });
if !db_status {
status = "degraded";
}
} else {
details["database"] = json!({ "status": "error", "error": "Failed to retrieve database status" });
status = "degraded";
}
// Process Network result
if let Ok(Ok(net_status)) = net_result {
details["network"] = json!({ "status": if net_status { "ok" } else { "degraded" } });
if !net_status {
status = "degraded";
}
} else {
details["network"] = json!({ "status": "error", "error": "Failed to retrieve network status" });
status = "degraded";
}
Json(json!({
"status": status,
"details": details,
}))
}
// Helper functions
fn check_cpu_usage(system: &mut System) -> Result<serde_json::Value, ()> {
system.refresh_cpu_usage();
let usage = system.global_cpu_usage();
let available = 100.0 - usage;
Ok(json!( {
"usage_percentage": format!("{:.2}", usage),
"available_percentage": format!("{:.2}", available),
"status": if available < 10.0 { "low" } else { "normal" },
}))
}
fn check_memory(system: &mut System) -> Result<serde_json::Value, ()> {
system.refresh_memory();
let available = system.available_memory() / 1024 / 1024; // Convert to MB
Ok(json!( {
"available_mb": available,
"status": if available < 512 { "low" } else { "normal" },
}))
}
fn check_disk_usage() -> Result<serde_json::Value, ()> {
// Create a new Disks object and refresh the disk information
let mut disks = Disks::new();
disks.refresh(false); // Refresh disk information without performing a full refresh
// Iterate through the list of disks and check the usage for each one
let usage: Vec<_> = disks.list().iter().map(|disk| {
let total = disk.total_space() as f64;
let available = disk.available_space() as f64;
let used_percentage = ((total - available) / total) * 100.0;
used_percentage
}).collect();
// Get the maximum usage percentage
let max_usage = usage.into_iter()
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(0.0);
// Return the result as a JSON object
Ok(json!( {
"used_percentage": format!("{:.2}", max_usage),
"status": if max_usage > 90.0 { "critical" } else { "ok" },
}))
}
fn check_processes(system: &mut System, processes: &[&str]) -> Result<Vec<serde_json::Value>, ()> {
system.refresh_processes(sysinfo::ProcessesToUpdate::All, true);
let process_statuses: Vec<_> = processes.iter().map(|&name| {
// Adjust process names based on the platform and check if they are running
let adjusted_name = if cfg!(target_os = "windows") {
match name {
"postgres" => "postgres.exe", // Postgres on Windows
"Code" => "Code.exe", // Visual Studio Code on Windows
_ => name, // For other platforms, use the name as is
}
} else {
name // For non-Windows platforms, use the name as is
};
// Check if the translated (adjusted) process is running
let is_running = system.processes().iter().any(|(_, proc)| proc.name() == adjusted_name);
// Return a JSON object for each process with its status
json!({
"name": name,
"status": if is_running { "running" } else { "not running" }
})
}).collect();
Ok(process_statuses)
}
async fn check_database_connection(pool: &PgPool) -> Result<bool, sqlx::Error> {
sqlx::query("SELECT 1").fetch_one(pool).await.map(|_| true).or_else(|_| Ok(false))
}
fn check_network_connection() -> Result<bool, ()> {
Ok(std::net::TcpStream::connect("8.8.8.8:53").is_ok())
}

43
src/routes/get_todos.rs Normal file
View File

@ -0,0 +1,43 @@
use axum::extract::{State, Path};
use axum::Json;
use axum::response::IntoResponse;
use sqlx::postgres::PgPool;
use crate::models::todo::*;
// Get all todos
pub async fn get_all_todos(State(pool): State<PgPool>,) -> impl IntoResponse {
let todos = sqlx::query_as!(Todo, "SELECT * FROM todos") // Your table name
.fetch_all(&pool) // Borrow the connection pool
.await;
match todos {
Ok(todos) => Ok(Json(todos)), // Return all todos as JSON
Err(err) => Err((
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
format!("Error fetching todos: {}", err),
)),
}
}
// Get a single todo by id
pub async fn get_todos_by_id(
State(pool): State<PgPool>,
Path(id): Path<i32>, // Use Path extractor here
) -> impl IntoResponse {
let todo = sqlx::query_as!(Todo, "SELECT * FROM todos WHERE id = $1", id)
.fetch_optional(&pool) // Borrow the connection pool
.await;
match todo {
Ok(Some(todo)) => Ok(Json(todo)), // Return the todo as JSON if found
Ok(None) => Err((
axum::http::StatusCode::NOT_FOUND,
format!("Todo with id {} not found", id),
)),
Err(err) => Err((
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
format!("Error fetching todo: {}", err),
)),
}
}

87
src/routes/get_users.rs Normal file
View File

@ -0,0 +1,87 @@
use axum::extract::{State, Path};
use axum::Json;
use axum::response::IntoResponse;
use sqlx::postgres::PgPool;
use crate::models::user::*; // Import the User struct
// Get all users
pub async fn get_all_users(State(pool): State<PgPool>,) -> impl IntoResponse {
let users = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users") // Your table name
.fetch_all(&pool) // Borrow the connection pool
.await;
match users {
Ok(users) => Ok(Json(users)), // Return all users as JSON
Err(err) => Err((
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
format!("Error fetching users: {}", err),
)),
}
}
// Get a single user by id
pub async fn get_users_by_id(
State(pool): State<PgPool>,
Path(id): Path<i32>, // Use Path extractor here
) -> impl IntoResponse {
let user = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users WHERE id = $1", id)
.fetch_optional(&pool) // Borrow the connection pool
.await;
match user {
Ok(Some(user)) => Ok(Json(user)), // Return the user as JSON if found
Ok(None) => Err((
axum::http::StatusCode::NOT_FOUND,
format!("User with id {} not found", id),
)),
Err(err) => Err((
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
format!("Error fetching user: {}", err),
)),
}
}
// Get a single user by username
pub async fn get_user_by_username(
State(pool): State<PgPool>,
Path(username): Path<String>, // Use Path extractor here for username
) -> impl IntoResponse {
let user = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users WHERE username = $1", username)
.fetch_optional(&pool) // Borrow the connection pool
.await;
match user {
Ok(Some(user)) => Ok(Json(user)), // Return the user as JSON if found
Ok(None) => Err((
axum::http::StatusCode::NOT_FOUND,
format!("User with username {} not found", username),
)),
Err(err) => Err((
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
format!("Error fetching user: {}", err),
)),
}
}
// Get a single user by email
pub async fn get_user_by_email(
State(pool): State<PgPool>,
Path(email): Path<String>, // Use Path extractor here for email
) -> impl IntoResponse {
let user = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users WHERE email = $1", email)
.fetch_optional(&pool) // Borrow the connection pool
.await;
match user {
Ok(Some(user)) => Ok(Json(user)), // Return the user as JSON if found
Ok(None) => Err((
axum::http::StatusCode::NOT_FOUND,
format!("User with email {} not found", email),
)),
Err(err) => Err((
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
format!("Error fetching user: {}", err),
)),
}
}

59
src/routes/mod.rs Normal file
View File

@ -0,0 +1,59 @@
// Module declarations for different route handlers
pub mod get_todos;
pub mod get_users;
pub mod post_todos;
pub mod post_users;
pub mod get_health;
pub mod protected;
// Re-exporting modules to make their contents available at this level
pub use get_todos::*;
pub use get_users::*;
pub use post_todos::*;
pub use post_users::*;
pub use get_health::*;
pub use protected::*;
use axum::{
Router,
routing::{get, post},
};
use sqlx::PgPool;
use crate::middlewares::auth::{sign_in, authorize};
/// Function to create and configure all routes
pub fn create_routes(database_connection: PgPool) -> Router {
// Authentication routes
let auth_routes = Router::new()
.route("/signin", post(sign_in))
.route("/protected", get(protected).route_layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1, 2];
authorize(req, next, allowed_roles)
})));
// User-related routes
let user_routes = Router::new()
.route("/all", get(get_all_users))
.route("/{id}", get(get_users_by_id))
.route("/", post(post_user));
// Todo-related routes
let todo_routes = Router::new()
.route("/all", get(get_all_todos))
.route("/", post(post_todo).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1, 2];
authorize(req, next, allowed_roles)
})))
.route("/{id}", get(get_todos_by_id));
// Combine all routes and add middleware
Router::new()
.merge(auth_routes) // Add authentication routes
.nest("/users", user_routes) // Add user routes under /users
.nest("/todos", todo_routes) // Add todo routes under /todos
.route("/health", get(get_health)) // Add health check route
.layer(axum::Extension(database_connection.clone())) // Add database connection to all routes
.with_state(database_connection) // Add database connection as state
}

53
src/routes/post_todos.rs Normal file
View File

@ -0,0 +1,53 @@
use axum::{extract::{State, Extension}, Json};
use axum::response::IntoResponse;
use sqlx::postgres::PgPool;
use crate::models::todo::*;
use crate::models::user::*;
use serde::Deserialize;
use axum::http::StatusCode;
use serde_json::json;
#[derive(Deserialize)]
pub struct TodoBody {
pub task: String,
pub description: Option<String>,
pub user_id: i32,
}
// Add a new todo
pub async fn post_todo(
State(pool): State<PgPool>,
Extension(user): Extension<User>, // Extract current user from the request extensions
Json(todo): Json<TodoBody>
) -> impl IntoResponse {
// Ensure the user_id from the request matches the current user's id
if todo.user_id != user.id {
return Err((
StatusCode::FORBIDDEN,
Json(json!({ "error": "User is not authorized to create a todo for another user" }))
));
}
// Insert the todo into the database
let row = sqlx::query!(
"INSERT INTO todos (task, description, user_id) VALUES ($1, $2, $3) RETURNING id, task, description, user_id",
todo.task,
todo.description,
todo.user_id
)
.fetch_one(&pool)
.await;
match row {
Ok(row) => Ok(Json(Todo {
id: row.id,
task: row.task,
description: row.description,
user_id: row.user_id,
})),
Err(err) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Error: {}", err) }))
)),
}
}

44
src/routes/post_users.rs Normal file
View File

@ -0,0 +1,44 @@
use axum::extract::State;
use axum::Json;
use axum::response::IntoResponse;
use sqlx::postgres::PgPool;
use crate::models::user::*;
use serde::Deserialize;
#[derive(Deserialize)]
pub struct UserBody {
pub username: String,
pub email: String,
pub password_hash: String,
pub totp_secret: String,
pub role_id: i32,
}
// Add a new user
pub async fn post_user(State(pool): State<PgPool>, Json(user): Json<UserBody>, ) -> impl IntoResponse {
let row = sqlx::query!(
"INSERT INTO users (username, email, password_hash, totp_secret, role_id) VALUES ($1, $2, $3, $4, $5) RETURNING id, username, email, password_hash, totp_secret, role_id",
user.username,
user.email,
user.password_hash,
user.totp_secret,
user.role_id
)
.fetch_one(&pool) // Use `&pool` to borrow the connection pool
.await;
match row {
Ok(row) => Ok(Json(User {
id: row.id,
username: row.username,
email: row.email,
password_hash: row.password_hash,
totp_secret: row.totp_secret,
role_id: row.role_id,
})),
Err(err) => Err((
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
format!("Error: {}", err),
)),
}
}

18
src/routes/protected.rs Normal file
View File

@ -0,0 +1,18 @@
use axum::{Extension, Json, response::IntoResponse};
use serde::{Serialize, Deserialize};
use crate::models::user::User;
#[derive(Serialize, Deserialize)]
struct UserResponse {
id: i32,
username: String,
email: String
}
pub async fn protected(Extension(user): Extension<User>) -> impl IntoResponse {
Json(UserResponse {
id: user.id,
username: user.username,
email: user.email
})
}