Better API security, key rotation, delete endpoints, refactoring to a better to understand format, performance optimization, cahcing, better tracing, better logging.

This commit is contained in:
Rik Heijmann 2025-02-15 12:44:40 +01:00
parent e20f21bc8b
commit 40ab25987c
40 changed files with 2253 additions and 289 deletions

316
Bruno.json Normal file
View File

@ -0,0 +1,316 @@
{
"name": "Axium",
"version": "1",
"items": [
{
"type": "http",
"name": "Health",
"seq": 3,
"request": {
"url": "{{base_url}}/health",
"method": "GET",
"headers": [],
"params": [],
"body": {
"mode": "none",
"formUrlEncoded": [],
"multipartForm": []
},
"script": {},
"vars": {},
"assertions": [],
"tests": "",
"docs": "",
"auth": {
"mode": "none"
}
}
},
{
"type": "http",
"name": "Protected",
"seq": 4,
"request": {
"url": "{{base_url}}/protected",
"method": "GET",
"headers": [
{
"name": "Authorization",
"value": "Bearer {{token}}",
"enabled": true
}
],
"params": [],
"body": {
"mode": "json",
"json": "",
"formUrlEncoded": [],
"multipartForm": []
},
"script": {},
"vars": {},
"assertions": [],
"tests": "",
"docs": "",
"auth": {
"mode": "none"
}
}
},
{
"type": "http",
"name": "Sign-in",
"seq": 1,
"request": {
"url": "{{base_url}}/signin",
"method": "POST",
"headers": [],
"params": [],
"body": {
"mode": "json",
"json": "{\n \"email\":\"user@test.com\",\n \"password\":\"test\"\n}",
"formUrlEncoded": [],
"multipartForm": []
},
"script": {},
"vars": {},
"assertions": [],
"tests": "",
"docs": "",
"auth": {
"mode": "none"
}
}
},
{
"type": "folder",
"name": "To do's",
"root": {
"request": {
"headers": [
{
"name": "Authorization",
"value": "Bearer {{token}}",
"enabled": true,
"uid": "PRiX2eBEKKPlsc1xxRHeN"
}
]
},
"meta": {
"name": "To do's"
}
},
"items": [
{
"type": "http",
"name": "Get all",
"seq": 1,
"request": {
"url": "{{base_url}}/todos/all",
"method": "GET",
"headers": [],
"params": [],
"body": {
"mode": "none",
"formUrlEncoded": [],
"multipartForm": []
},
"script": {},
"vars": {},
"assertions": [],
"tests": "",
"docs": "",
"auth": {
"mode": "none"
}
}
},
{
"type": "http",
"name": "Get by ID",
"seq": 2,
"request": {
"url": "{{base_url}}/todos/1",
"method": "GET",
"headers": [],
"params": [],
"body": {
"mode": "none",
"formUrlEncoded": [],
"multipartForm": []
},
"script": {},
"vars": {},
"assertions": [],
"tests": "",
"docs": "",
"auth": {
"mode": "none"
}
}
},
{
"type": "http",
"name": "Post new",
"seq": 3,
"request": {
"url": "{{base_url}}/todos/",
"method": "POST",
"headers": [],
"params": [],
"body": {
"mode": "json",
"json": "{\n \"task\": \"Finish Rust project.\",\n \"description\": \"Complete the API endpoints for the todo app.\"\n}",
"formUrlEncoded": [],
"multipartForm": []
},
"script": {},
"vars": {},
"assertions": [],
"tests": "",
"docs": "",
"auth": {
"mode": "none"
}
}
}
]
},
{
"type": "folder",
"name": "Users",
"root": {
"request": {
"headers": [
{
"name": "Authorization",
"value": "Bearer {{token}}",
"enabled": true,
"uid": "Dv1ZS2orRQaKpVNKRBmLf"
}
]
},
"meta": {
"name": "Users"
}
},
"items": [
{
"type": "http",
"name": "Get all",
"seq": 1,
"request": {
"url": "{{base_url}}/users/all",
"method": "GET",
"headers": [
{
"name": "",
"value": "",
"enabled": true
}
],
"params": [],
"body": {
"mode": "none",
"formUrlEncoded": [],
"multipartForm": []
},
"script": {},
"vars": {},
"assertions": [],
"tests": "",
"docs": "",
"auth": {
"mode": "none"
}
}
},
{
"type": "http",
"name": "Post new",
"seq": 3,
"request": {
"url": "{{base_url}}/users/",
"method": "POST",
"headers": [],
"params": [],
"body": {
"mode": "json",
"json": "{\n \"username\": \"MyNewUser\",\n \"email\": \"MyNewUser@test.com\",\n \"password\": \"MyNewUser\",\n \"totp\": \"true\"\n}",
"formUrlEncoded": [],
"multipartForm": []
},
"script": {},
"vars": {},
"assertions": [],
"tests": "",
"docs": "",
"auth": {
"mode": "none"
}
}
},
{
"type": "http",
"name": "Get by ID",
"seq": 2,
"request": {
"url": "{{base_url}}/users/1",
"method": "GET",
"headers": [],
"params": [],
"body": {
"mode": "none",
"formUrlEncoded": [],
"multipartForm": []
},
"script": {},
"vars": {},
"assertions": [],
"tests": "",
"docs": "",
"auth": {
"mode": "none"
}
}
}
]
}
],
"activeEnvironmentUid": "6LVIlBNVHmWamnS5xdrf0",
"environments": [
{
"variables": [
{
"name": "base_url",
"value": "http://127.0.0.1:3000",
"enabled": true,
"secret": false,
"type": "text"
},
{
"name": "token",
"value": "",
"enabled": true,
"secret": true,
"type": "text"
}
],
"name": "Default"
}
],
"root": {
"request": {
"vars": {}
}
},
"brunoConfig": {
"version": "1",
"name": "Axium",
"type": "collection",
"ignore": [
"node_modules",
".git"
]
}
}

View File

@ -1,5 +1,5 @@
[package] [package]
name = "rustapi" name = "Axium"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
@ -9,7 +9,10 @@ axum = { version = "0.8.1", features = ["json"] }
# hyper = { version = "1.5.2", features = ["full"] } # hyper = { version = "1.5.2", features = ["full"] }
# Database interaction # Database interaction
sqlx = { version = "0.8.3", features = ["runtime-tokio-rustls", "postgres", "migrate"] } sqlx = { version = "0.8.3", features = ["runtime-tokio-rustls", "postgres", "migrate", "uuid", "chrono"] }
uuid = { version = "1.12.1", features = ["serde"] }
rand = "0.8.5"
rand_core = "0.6.4" # 2024-2-3: SQLx 0.8.3 does not support 0.9.
# Serialization and deserialization # Serialization and deserialization
serde = { version = "1.0.217", features = ["derive"] } serde = { version = "1.0.217", features = ["derive"] }
@ -18,6 +21,9 @@ serde_json = "1.0.137"
# Authentication and security # Authentication and security
jsonwebtoken = "9.3.0" jsonwebtoken = "9.3.0"
argon2 = "0.5.3" argon2 = "0.5.3"
totp-rs = { version = "5.6.0", features = ["gen_secret"] }
base64 = "0.22.1"
bcrypt = "0.17.0"
# Asynchronous runtime and traits # Asynchronous runtime and traits
tokio = { version = "1.43.0", features = ["rt-multi-thread", "process"] } tokio = { version = "1.43.0", features = ["rt-multi-thread", "process"] }
@ -37,9 +43,17 @@ tracing-subscriber = "0.3.19"
sysinfo = "0.33.1" sysinfo = "0.33.1"
# Date and time handling # Date and time handling
chrono = "0.4.39" chrono = { version = "0.4.39", features = ["serde"] }
# SSL / TLS # SSL / TLS
rustls = "0.23.21" rustls = "0.23.21"
tokio-rustls = "0.26.1" tokio-rustls = "0.26.1"
rustls-pemfile = "2.2.0" rustls-pemfile = "2.2.0"
# Input validation
validator = { version = "0.20.0", features = ["derive"] }
regex = "1.11.1"
# Documentation
utoipa = { version = "5.3.1", features = ["axum_extras", "chrono", "uuid"] }
utoipa-swagger-ui = { version = "9.0.0", features = ["axum"] }

118
README.md
View File

@ -1,12 +1,15 @@
```markdown # 🦀 Axium
# 🦀 RustAPI
**An example API built with Rust, Axum, SQLx, and PostgreSQL** **An example API built with Rust, Axum, SQLx, and PostgreSQL**
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
## 🚀 Core Features ## 🚀 Core Features
- **Rust API Template** - Production-ready starter template with modern practices - **Rust API template** - Production-ready starter template with modern practices,
- **PostgreSQL Integration** - Full database support with SQLx migrations - **PostgreSQL integration** - Full database support with SQLx migrations,
- **Comprehensive Health Monitoring** - **Easy to secure** - HTTP/2 with secure TLS defaults (AWS-LC, FIPS 140-3),
- **Easy to configure** - `.env` and environment variables,
- **JWT authentication** - Secure token-based auth with Argon2 password hashing,
- **Optimized for performance** - Brotli compression,
- **Comprehensive health monitoring**
Docker-compatible endpoint with system metrics: Docker-compatible endpoint with system metrics:
```json ```json
{ {
@ -19,15 +22,14 @@
"status": "degraded" "status": "degraded"
} }
``` ```
- **JWT Authentication** - Secure token-based auth with Argon2 password hashing - **Granular access control** - Role-based endpoint protection:
- **Granular Access Control** - Role-based endpoint protection:
```rust ```rust
.route("/", post(post_todo).layer(axum::middleware::from_fn(|req, next| { .route("/", post(post_todo).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1, 2]; let allowed_roles = vec![1, 2];
authorize(req, next, allowed_roles) authorize(req, next, allowed_roles)
}))) })))
``` ```
- **User Context Injection** - Automatic user profile handling in endpoints: - **User context injection** - Automatic user profile handling in endpoints:
```rust ```rust
pub async fn post_todo( pub async fn post_todo(
Extension(user): Extension<User>, // Injected user Extension(user): Extension<User>, // Injected user
@ -39,18 +41,73 @@
})))); }))));
} }
``` ```
- **Modern protocols ** - HTTP/2 with secure TLS defaults - **Observability** - Integrated tracing,
- **Observability** - Integrated tracing - **Documented codebase** - Extensive inline comments for easy modification and readability,
- **Optimized for performance** - Brotli compression - **Latest dependencies** - Regularly updated Rust ecosystem crates,
- **Easy configuration** - `.env` and environment variables
- **Documented codebase** - Extensive inline comments for easy modification and readability ## 🛠️ Technology stack
- **Latest dependencies** - Regularly updated Rust ecosystem crates | Category | Key Technologies |
|-----------------------|---------------------------------|
| Web Framework | Axum 0.8 + Tower |
| Database | PostgreSQL + SQLx 0.8 |
| Security | JWT + Argon2 + Rustls |
| Monitoring | Tracing + Sysinfo |
## 📂 Project structure
```
Axium/
├── migrations/ # SQL schema migrations. Creates the required tables and inserts demo data.
├── src/
│ ├── core/ # Core modules: for reading configuration files, starting the server and configuring HTTPS/
│ ├── database/ # Database connectivity, getters and setters for the database.
│ ├── middlewares/ # Currently just the authentication system.
│ ├── models/ # Data structures
│ └── routes/ # API endpoints
│ └── mod.rs # API endpoint router.
│ └── .env # Configuration file.
└── Dockerfile # Builds a docker container for the application.
└── compose.yaml # Docker-compose.yaml. Runs container for the application (also includes a PostgreSQL-container).
```
## 🌐 Default API endpoints
| Method | Endpoint | Auth Required | Allowed Roles | Description |
|--------|------------------------|---------------|---------------|--------------------------------------|
| POST | `/signin` | No | | Authenticate user and get JWT token |
| GET | `/protected` | Yes | 1, 2 | Test endpoint for authenticated users |
| GET | `/health` | No | | System health check with metrics |
| | | | | |
| **User routes** | | | | |
| GET | `/users/all` | No* | | Get all users |
| GET | `/users/{id}` | No* | | Get user by ID |
| POST | `/users/` | No* | | Create new user |
| | | | | |
| **Todo routes** | | | | |
| GET | `/todos/all` | No* | | Get all todos |
| POST | `/todos/` | Yes | 1, 2 | Create new todo |
| GET | `/todos/{id}` | No* | | Get todo by ID |
**Key:**
🔒 = Requires JWT in `Authorization: Bearer <token>` header
\* Currently unprotected - recommend adding authentication for production
**Roles:** 1 = User, 2 = Administrator
**Security notes:**
- All POST endpoints expect JSON payloads
- User creation endpoint should be protected in production
- Consider adding rate limiting to authentication endpoints
**Notes:**
- 🔒 = Requires JWT in `Authorization: Bearer <token>` header
- Roles: `1` = Regular User, `2` = Administrator
- *Marked endpoints currently unprotected - recommend adding middleware for production use
- All POST endpoints expect JSON payloads
## 📦 Installation & Usage ## 📦 Installation & Usage
```bash ```bash
# Clone and setup # Clone and setup
git clone https://github.com/Riktastic/rustapi.git git clone https://github.com/Riktastic/Axium.git
cd rustapi && cp .env.example .env cd Axium && cp .env.example .env
# Database setup # Database setup
sqlx database create && sqlx migrate run sqlx database create && sqlx migrate run
@ -59,7 +116,7 @@ sqlx database create && sqlx migrate run
cargo run --release cargo run --release
``` ```
### 🔐 Default Accounts ### 🔐 Default accounts
**Warning:** These accounts should only be used for initial testing. Always change or disable them in production environments. **Warning:** These accounts should only be used for initial testing. Always change or disable them in production environments.
@ -68,13 +125,15 @@ cargo run --release
| `user@test.com` | `test` | User | | `user@test.com` | `test` | User |
| `admin@test.com` | `test` | Administrator | | `admin@test.com` | `test` | Administrator |
⚠️ **Security Recommendations:** ⚠️ **Security recommendations:**
1. Rotate passwords immediately after initial setup 1. Rotate passwords immediately after initial setup
2. Disable default accounts before deploying to production 2. Disable default accounts before deploying to production
3. Implement proper user management endpoints 3. Implement proper user management endpoints
## ⚙️ Configuration ### ⚙️ Configuration
Create a .env file in the root of the project or configure the application using environment variables.
```env ```env
# ============================== # ==============================
# 📌 DATABASE CONFIGURATION # 📌 DATABASE CONFIGURATION
@ -145,24 +204,3 @@ SERVER_COMPRESSION_LEVEL=6
# Argon2 salt for password hashing (must be kept secret!) # Argon2 salt for password hashing (must be kept secret!)
AUTHENTICATION_ARGON2_SALT="dMjQgtSmoQIH3Imi" AUTHENTICATION_ARGON2_SALT="dMjQgtSmoQIH3Imi"
``` ```
## 📂 Project Structure
```
rustapi/
├── migrations/ # SQL schema versions
├── src/
│ ├── core/ # Config, TLS, server setup
│ ├── database/ # Query handling
│ ├── middlewares/ # Auth system
│ ├── models/ # Data structures
│ └── routes/ # API endpoints
└── Dockerfile # Containerization
```
## 🛠️ Technology Stack
| Category | Key Technologies |
|-----------------------|---------------------------------|
| Web Framework | Axum 0.8 + Tower |
| Database | PostgreSQL + SQLx 0.8 |
| Security | JWT + Argon2 + Rustls |
| Monitoring | Tracing + Sysinfo |

View File

@ -1,10 +1,11 @@
-- Create the roles table -- Create the roles table
CREATE TABLE IF NOT EXISTS roles ( CREATE TABLE IF NOT EXISTS roles (
id SERIAL PRIMARY KEY, id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
level INT NOT NULL, level INT NOT NULL,
role VARCHAR(255) NOT NULL, role VARCHAR(255) NOT NULL,
name VARCHAR(255) NOT NULL, name VARCHAR(255) NOT NULL,
description VARCHAR(255), description VARCHAR(255),
creation_date DATE NOT NULL DEFAULT CURRENT_DATE, -- Default to the current date
CONSTRAINT unique_role UNIQUE (role) -- Add a unique constraint to the 'role' column CONSTRAINT unique_role UNIQUE (role) -- Add a unique constraint to the 'role' column
); );

View File

@ -0,0 +1,22 @@
-- Create the tiers table
CREATE TABLE IF NOT EXISTS tiers (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
level INT NOT NULL,
name VARCHAR(255) NOT NULL,
description VARCHAR(255),
requests_per_day INT NOT NULL,
creation_date DATE NOT NULL DEFAULT CURRENT_DATE, -- Default to the current date
CONSTRAINT unique_name UNIQUE (name) -- Add a unique constraint to the 'role' column
);
INSERT INTO tiers (level, name, description, requests_per_day)
VALUES (1, 'Low', 'Lowest amount of requests.', 1000)
ON CONFLICT (name) DO NOTHING; -- Prevent duplicate insertions if role already exists
INSERT INTO tiers (level, name, description, requests_per_day)
VALUES (2, 'Medium', 'Medium amount of requests.', 5000)
ON CONFLICT (name) DO NOTHING; -- Prevent duplicate insertions if role already exists
INSERT INTO tiers (level, name, description, requests_per_day)
VALUES (3, 'Max', 'Max amount of requests.', 10000)
ON CONFLICT (name) DO NOTHING; -- Prevent duplicate insertions if role already exists

View File

@ -1,21 +1,24 @@
CREATE TABLE users ( CREATE TABLE users (
id SERIAL PRIMARY KEY, id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
username VARCHAR(255) NOT NULL UNIQUE, username VARCHAR(255) NOT NULL UNIQUE,
email VARCHAR(255) NOT NULL UNIQUE, email VARCHAR(255) NOT NULL UNIQUE,
password_hash VARCHAR(255) NOT NULL, password_hash VARCHAR(255) NOT NULL,
totp_secret VARCHAR(255), totp_secret VARCHAR(255),
role_id INT NOT NULL DEFAULT 1 REFERENCES roles(id), -- Default role_id is set to 1 role_level INT NOT NULL DEFAULT 1, -- Default role_id is set to 1
tier_level INT NOT NULL DEFAULT 1, -- Default role_id is set to 1
creation_date DATE NOT NULL DEFAULT CURRENT_DATE, -- Default to the current date
disabled BOOLEAN NOT NULL DEFAULT FALSE, -- Default to false
CONSTRAINT unique_username UNIQUE (username) -- Ensure that username is unique CONSTRAINT unique_username UNIQUE (username) -- Ensure that username is unique
); );
-- Insert the example 'user' into the users table with a conflict check for username -- Insert the example 'user' into the users table with a conflict check for username
INSERT INTO users (username, email, password_hash, role_id) INSERT INTO users (username, email, password_hash, role_level)
VALUES VALUES
('user', 'user@test.com', '$argon2i$v=19$m=16,t=2,p=1$ZE1qUWd0U21vUUlIM0ltaQ$dowBmjU4oHtoPd355dXypQ', 1) ('user', 'user@test.com', '$argon2i$v=19$m=16,t=2,p=1$ZE1qUWd0U21vUUlIM0ltaQ$dowBmjU4oHtoPd355dXypQ', 1)
ON CONFLICT (username) DO NOTHING; -- Prevent duplicate insertions if username already exists ON CONFLICT (username) DO NOTHING;
-- Insert the example 'admin' into the users table with a conflict check for username -- Insert the example 'admin' into the users table with a conflict check for username
INSERT INTO users (username, email, password_hash, role_id) INSERT INTO users (username, email, password_hash, role_level)
VALUES VALUES
('admin', 'admin@test.com', '$argon2i$v=19$m=16,t=2,p=1$ZE1qUWd0U21vUUlIM0ltaQ$dowBmjU4oHtoPd355dXypQ', 2) ('admin', 'admin@test.com', '$argon2i$v=19$m=16,t=2,p=1$ZE1qUWd0U21vUUlIM0ltaQ$dowBmjU4oHtoPd355dXypQ', 2)
ON CONFLICT (username) DO NOTHING; -- Prevent duplicate insertions if username already exists ON CONFLICT (username) DO NOTHING;

View File

@ -0,0 +1,12 @@
CREATE TABLE apikeys (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
key_hash VARCHAR(255) NOT NULL,
user_id UUID NOT NULL REFERENCES users(id),
description VARCHAR(255),
creation_date DATE NOT NULL DEFAULT CURRENT_DATE, -- Default to the current date
expiration_date DATE,
disabled BOOLEAN NOT NULL DEFAULT FALSE, -- Default to false
access_read BOOLEAN NOT NULL DEFAULT TRUE, -- Default to
access_modify BOOLEAN NOT NULL DEFAULT FALSE, -- Default to false
CONSTRAINT unique_key_hash UNIQUE (key_hash) -- Add a unique constraint to the 'key_hash' column
);

View File

@ -0,0 +1,6 @@
CREATE TABLE usage (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
endpoint VARCHAR(255) NOT NULL,
user_id UUID NOT NULL REFERENCES users(id),
creation_date DATE NOT NULL DEFAULT CURRENT_DATE -- Default to the current date
);

View File

@ -1,6 +1,9 @@
CREATE TABLE todos ( CREATE TABLE todos (
id SERIAL PRIMARY KEY, -- Auto-incrementing primary key id UUID PRIMARY KEY DEFAULT gen_random_uuid(), -- Auto-incrementing primary key
task TEXT NOT NULL, -- Task description, cannot be null task TEXT NOT NULL, -- Task description, cannot be null
description TEXT, -- Optional detailed description description TEXT, -- Optional detailed description
user_id INT NOT NULL REFERENCES users(id) -- Foreign key to link to users table user_id UUID NOT NULL REFERENCES users(id), -- Foreign key to link to users table
creation_date DATE NOT NULL DEFAULT CURRENT_DATE, -- Default to the current date
completion_date DATE, -- Date the task was completed
completed BOOLEAN DEFAULT FALSE -- Default to false
); );

View File

@ -2,17 +2,20 @@
use axum::Router; use axum::Router;
// Middleware layers from tower_http // Middleware layers from tower_http
use tower_http::compression::{CompressionLayer, CompressionLevel}; // For HTTP response compression use tower_http::compression::{CompressionLayer, CompressionLevel}; // For HTTP response compression.
use tower_http::trace::TraceLayer; // For HTTP request/response tracing use tower_http::trace::TraceLayer; // For HTTP request/response tracing.
// Local crate imports for database connection and configuration // Local crate imports for database connection and configuration
use crate::database::connect::connect_to_database; // Function to connect to the database use crate::database::connect::connect_to_database; // Function to connect to the database.
use crate::database::connect::run_database_migrations; // Function to run database migrations.
use crate::config; // Environment configuration helper use crate::config; // Environment configuration helper
/// Function to create and configure the Axum server. /// Function to create and configure the Axum server.
pub async fn create_server() -> Router { pub async fn create_server() -> Router {
// Establish a connection to the database // Establish a connection to the database
let db = connect_to_database().await.expect("Failed to connect to database."); let db = connect_to_database().await.expect("❌ Failed to connect to database.");
run_database_migrations(&db).await.expect("❌ Failed to run database migrations.");
// Initialize the routes for the server // Initialize the routes for the server
let mut app = crate::routes::create_routes(db); let mut app = crate::routes::create_routes(db);

View File

@ -53,13 +53,13 @@ pub fn load_tls_config() -> ServerConfig {
Item::Sec1Key(key) => Some(PrivateKeyDer::from(key)), Item::Sec1Key(key) => Some(PrivateKeyDer::from(key)),
_ => None, _ => None,
}) })
.expect("Failed to read a valid private key."); .expect("Failed to read a valid private key.");
// Build and return the TLS server configuration // Build and return the TLS server configuration
ServerConfig::builder() ServerConfig::builder()
.with_no_client_auth() // No client authentication .with_no_client_auth() // No client authentication
.with_single_cert(cert_chain, key) // Use the provided cert and key .with_single_cert(cert_chain, key) // Use the provided cert and key
.expect("Failed to create TLS configuration") .expect("Failed to create TLS configuration.")
} }
// Custom listener that implements axum::serve::Listener // Custom listener that implements axum::serve::Listener
@ -91,11 +91,11 @@ impl Listener for TlsListener {
// Perform TLS handshake // Perform TLS handshake
match acceptor.accept(stream).await { match acceptor.accept(stream).await {
Ok(tls_stream) => { Ok(tls_stream) => {
tracing::info!("Successful TLS handshake with {}", addr); tracing::info!("Successful TLS handshake with {}.", addr);
return (TlsStreamWrapper(tls_stream), addr); // Return TLS stream and address return (TlsStreamWrapper(tls_stream), addr); // Return TLS stream and address
}, },
Err(e) => { Err(e) => {
tracing::warn!("TLS handshake failed: {} (Client may not trust certificate)", e); tracing::warn!("TLS handshake failed: {} (Client may not trust certificate).", e);
continue; // Retry on error continue; // Retry on error
} }
} }

View File

@ -0,0 +1,16 @@
use sqlx::postgres::PgPool;
use uuid::Uuid;
use crate::models::apikey::*;
pub async fn get_active_apikeys_by_user_id(pool: &PgPool, user_id: Uuid) -> Result<Vec<ApiKeyByUserIDResponse>, sqlx::Error> {
sqlx::query_as!(ApiKeyByUserIDResponse,
r#"
SELECT id, key_hash, expiration_date::DATE
FROM apikeys
WHERE user_id = $1 AND (expiration_date IS NULL OR expiration_date > NOW()::DATE)
"#,
user_id
)
.fetch_all(pool)
.await
}

View File

@ -7,7 +7,7 @@ pub async fn get_user_by_email(pool: &PgPool, email: String) -> Result<User, Str
let user = sqlx::query_as!( let user = sqlx::query_as!(
User, // Struct type to map the query result User, // Struct type to map the query result
r#" r#"
SELECT id, username, email, password_hash, totp_secret, role_id SELECT id, username, email, password_hash, totp_secret, role_level, tier_level, creation_date
FROM users FROM users
WHERE email = $1 WHERE email = $1
"#, "#,

View File

@ -0,0 +1,16 @@
use sqlx::postgres::PgPool;
use uuid::Uuid;
pub async fn insert_usage(pool: &PgPool, user_id: Uuid, endpoint: String) -> Result<(), sqlx::Error> {
sqlx::query!(
r#"INSERT INTO usage
(endpoint, user_id)
VALUES ($1, $2)"#,
endpoint,
user_id
)
.execute(pool)
.await?;
Ok(())
}

View File

@ -1,3 +1,5 @@
// Module declarations // Module declarations
pub mod connect; pub mod connect;
pub mod get_users; pub mod get_users;
pub mod get_apikeys;
pub mod insert_usage;

View File

@ -1,6 +1,2 @@
// Module declarations // Module declarations
// pub mod auth; pub mod validate;
// Re-exporting modules
// pub use auth::*;

60
src/handlers/validate.rs Normal file
View File

@ -0,0 +1,60 @@
use chrono::{NaiveDate, Utc};
use validator::ValidationError;
use regex::Regex;
pub fn validate_future_date(date_str: &str) -> Result<(), ValidationError> {
// Attempt to parse the date string into a NaiveDate (date only, no time)
if let Ok(parsed_date) = NaiveDate::parse_from_str(date_str, "%Y-%m-%d") {
// Convert NaiveDate to DateTime<Utc> at the start of the day (00:00:00)
let datetime_utc = match parsed_date.and_hms_opt(0, 0, 0) {
Some(dt) => dt.and_utc(),
None => return Err(ValidationError::new("Invalid time components.")),
};
// Get the current time in UTC
let now_utc = Utc::now();
// Check if the parsed date is in the future
if datetime_utc > now_utc {
Ok(())
} else {
Err(ValidationError::new("The date should be in the future."))
}
} else {
Err(ValidationError::new("Invalid date format. Use YYYY-MM-DD."))
}
}
pub fn validate_username(username: &str) -> Result<(), ValidationError> {
let re = Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap();
if !re.is_match(username) {
return Err(ValidationError::new("Invalid username format. Only alphanumeric characters, dashes, and underscores are allowed."));
}
Ok(())
}
pub fn validate_password(password: &str) -> Result<(), ValidationError> {
if password.len() < 8 {
return Err(ValidationError::new("Password too short. Minimum length is 8 characters."));
}
let re = Regex::new(r"^[a-zA-Z0-9!@#$%^&*()_+\-=\[\]{};:'\'|,.<>/?]+$").unwrap();
if !re.is_match(password) {
return Err(ValidationError::new("Password contains invalid characters. Only alphanumeric characters and special characters are allowed."));
}
if !password.chars().any(|c| c.is_uppercase()) {
return Err(ValidationError::new("Password must contain an uppercase letter."));
}
if !password.chars().any(|c| c.is_lowercase()) {
return Err(ValidationError::new("Password must contain a lowercase letter."));
}
if !password.chars().any(|c| c.is_numeric()) {
return Err(ValidationError::new("Password must contain a number."));
}
if !password.chars().any(|c| "!@#$%^&*()_+-=[]{};:'\"\\|,.<>/?".contains(c)) {
return Err(ValidationError::new("Password must contain at least one special character."));
}
Ok(())
}

View File

@ -10,6 +10,7 @@ mod database;
mod routes; mod routes;
mod models; mod models;
mod middlewares; mod middlewares;
mod handlers;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::sync::Arc; use std::sync::Arc;
@ -22,25 +23,27 @@ async fn main() {
dotenvy::dotenv().ok(); // Load environment variables from a .env file dotenvy::dotenv().ok(); // Load environment variables from a .env file
tracing_subscriber::fmt::init(); // Initialize the logging system tracing_subscriber::fmt::init(); // Initialize the logging system
// Print a cool startup message with ASCII art and emojis // Print a cool startup message with ASCII art and emojis/
println!("{}", r#" println!("{}", r#"
##### ## ##
## ## ## db 88
## ## ## ## ##### ####### ##### #### #### d88b ""
##### ## ## ## ## ## ## ## ## ## d8'`8b
#### ## ## ## ## ## ## ## ## ## d8' `8b 8b, ,d8 88 88 88 88,dPYba,,adPYba,
## ## ## ### ## ## ## ### ## ## ## d8YaaaaY8b `Y8, ,8P' 88 88 88 88P' "88" "8a
## ## ### ## ##### ### ### ## ##### ###### d8""""""""8b )888( 88 88 88 88 88 88
## d8' `8b ,d8" "8b, 88 "8a, ,a88 88 88 88
d8' `8b 8P' `Y8 88 `"YbbdP'Y8 88 88 88
Axium - An example API built with Rust, Axum, SQLx, and PostgreSQL
- GitHub: https://github.com/Riktastic/Axium
Rustapi - An example API built with Rust, Axum, SQLx, and PostgreSQL
GitHub: https://github.com/Riktastic/rustapi
"#); "#);
println!("🚀 Starting Rustapi..."); println!("🚀 Starting Axium...");
// Retrieve server IP and port from the environment, default to 0.0.0.0:3000 // Retrieve server IP and port from the environment, default to 127.0.0.1:3000
let ip: IpAddr = config::get_env_with_default("SERVER_IP", "0.0.0.0") let ip: IpAddr = config::get_env_with_default("SERVER_IP", "127.0.0.1")
.parse() .parse()
.expect("❌ Invalid IP address format. Please provide a valid IPv4 address. For example 0.0.0.0 or 127.0.0.1."); .expect("❌ Invalid IP address format. Please provide a valid IPv4 address. For example 0.0.0.0 or 127.0.0.1.");
let port: u16 = config::get_env_u16("SERVER_PORT", 3000); let port: u16 = config::get_env_u16("SERVER_PORT", 3000);

View File

@ -1,3 +1,5 @@
use std::{collections::HashSet, env};
// Standard library imports for working with HTTP, environment variables, and other necessary utilities // Standard library imports for working with HTTP, environment variables, and other necessary utilities
use axum::{ use axum::{
body::Body, body::Body,
@ -11,7 +13,6 @@ use axum::{
use axum::extract::State; use axum::extract::State;
// Importing necessary libraries for password hashing, JWT handling, and date/time management // Importing necessary libraries for password hashing, JWT handling, and date/time management
use std::env; // For accessing environment variables
use argon2::{ use argon2::{
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, // For password hashing and verification password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, // For password hashing and verification
Argon2, Argon2,
@ -19,19 +20,28 @@ use argon2::{
use chrono::{Duration, Utc}; // For working with time (JWT expiration, etc.) use chrono::{Duration, Utc}; // For working with time (JWT expiration, etc.)
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, TokenData, Validation}; // For encoding and decoding JWT tokens use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, TokenData, Validation}; // For encoding and decoding JWT tokens
use serde::{Deserialize, Serialize}; // For serializing and deserializing JSON data use serde::Deserialize; // For serializing and deserializing JSON data
use serde_json::json; // For constructing JSON data use serde_json::json; // For constructing JSON data
use sqlx::PgPool; // For interacting with PostgreSQL databases asynchronously use sqlx::PgPool; // For interacting with PostgreSQL databases asynchronously
use totp_rs::{Algorithm, Secret, TOTP}; // For generating TOTP secrets and tokens
use rand::rngs::OsRng; // For generating random numbers
use uuid::Uuid; // For working with UUIDs
use rand::Rng;
use tracing::{info, warn, error, instrument}; // For logging
use utoipa::ToSchema; // Import ToSchema for OpenAPI documentation
// Importing custom database query functions // Importing custom database query functions
use crate::database::get_users::get_user_by_email; use crate::database::{get_users::get_user_by_email, get_apikeys::get_active_apikeys_by_user_id, insert_usage::insert_usage};
// Define the structure for JWT claims to be included in the token payload // Define the structure for JWT claims to be included in the token payload
#[derive(Serialize, Deserialize)] #[derive(serde::Serialize, serde::Deserialize)]
pub struct Claims { struct Claims {
pub exp: usize, // Expiration timestamp (in seconds) sub: String, // Subject (e.g., user ID or email)
pub iat: usize, // Issued-at timestamp (in seconds) iat: usize, // Issued At (timestamp)
pub email: String, // User's email exp: usize, // Expiration (timestamp)
iss: String, // Issuer (optional)
aud: String, // Audience (optional)
} }
// Custom error type for handling authentication errors // Custom error type for handling authentication errors
@ -41,23 +51,52 @@ pub struct AuthError {
} }
// Function to verify a password against a stored hash using the Argon2 algorithm // Function to verify a password against a stored hash using the Argon2 algorithm
pub fn verify_password(password: &str, hash: &str) -> Result<bool, argon2::password_hash::Error> { #[instrument]
pub fn verify_hash(password: &str, hash: &str) -> Result<bool, argon2::password_hash::Error> {
let parsed_hash = PasswordHash::new(hash)?; // Parse the hash let parsed_hash = PasswordHash::new(hash)?; // Parse the hash
// Verify the password using Argon2 // Verify the password using Argon2
Ok(Argon2::default().verify_password(password.as_bytes(), &parsed_hash).is_ok()) Ok(Argon2::default().verify_password(password.as_bytes(), &parsed_hash).is_ok())
} }
// Function to hash a password using Argon2 and a salt retrieved from the environment variables // Function to hash a password using Argon2 and a salt retrieved from the environment variables
#[instrument]
pub fn hash_password(password: &str) -> Result<String, argon2::password_hash::Error> { pub fn hash_password(password: &str) -> Result<String, argon2::password_hash::Error> {
// Get the salt from environment variables (must be set) // Get the salt from environment variables (must be set)
let salt = env::var("AUTHENTICATION_ARGON2_SALT").expect("AUTHENTICATION_ARGON2_SALT must be set"); let salt = SaltString::generate(&mut OsRng);
let salt = SaltString::from_b64(&salt).unwrap(); // Convert base64 string to SaltString
let argon2 = Argon2::default(); // Create an Argon2 instance let argon2 = Argon2::default(); // Create an Argon2 instance
// Hash the password with the salt // Hash the password with the salt
let password_hash = argon2.hash_password(password.as_bytes(), &salt)?.to_string(); let password_hash = argon2.hash_password(password.as_bytes(), &salt)?.to_string();
Ok(password_hash) Ok(password_hash)
} }
#[instrument]
pub fn generate_totp_secret() -> String {
let totp = TOTP::new(
Algorithm::SHA512,
8,
1,
30,
Secret::generate_secret().to_bytes().unwrap(),
).expect("Failed to create TOTP.");
let token = totp.generate_current().unwrap();
token
}
#[instrument]
pub fn generate_api_key() -> String {
let mut rng = rand::thread_rng();
(0..5)
.map(|_| {
(0..8)
.map(|_| format!("{:02x}", rng.gen::<u8>()))
.collect::<String>()
})
.collect::<Vec<String>>()
.join("-")
}
// Implement the IntoResponse trait for AuthError to allow it to be returned as a response from the handler // Implement the IntoResponse trait for AuthError to allow it to be returned as a response from the handler
impl IntoResponse for AuthError { impl IntoResponse for AuthError {
fn into_response(self) -> Response<Body> { fn into_response(self) -> Response<Body> {
@ -69,41 +108,79 @@ impl IntoResponse for AuthError {
} }
// Function to encode a JWT token for the given email address // Function to encode a JWT token for the given email address
#[instrument]
pub fn encode_jwt(email: String) -> Result<String, StatusCode> { pub fn encode_jwt(email: String) -> Result<String, StatusCode> {
let jwt_token: String = "randomstring".to_string(); // Secret key for JWT (should be more secure in production) // Load secret key from environment variable for better security
let secret_key = env::var("JWT_SECRET_KEY")
.map_err(|_| {
error!("JWT_SECRET_KEY not set in environment variables");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let now = Utc::now(); // Get current time let now = Utc::now();
let expire = Duration::hours(24); // Set token expiration to 24 hours let expire = Duration::hours(24);
let exp: usize = (now + expire).timestamp() as usize; // Expiration timestamp let exp: usize = (now + expire).timestamp() as usize;
let iat: usize = now.timestamp() as usize; // Issued-at timestamp let iat: usize = now.timestamp() as usize;
let claim = Claims { iat, exp, email }; // Create JWT claims with timestamps and user email let claim = Claims {
let secret = jwt_token.clone(); // Secret key to sign the token sub: email.clone(),
iat,
exp,
iss: "your_issuer".to_string(), // Add issuer if needed
aud: "your_audience".to_string(), // Add audience if needed
};
// Encode the claims into a JWT token // Use a secure HMAC algorithm (e.g., HS256) for signing the token
encode( encode(
&Header::default(), &Header::new(jsonwebtoken::Algorithm::HS256),
&claim, &claim,
&EncodingKey::from_secret(secret.as_ref()), &EncodingKey::from_secret(secret_key.as_ref()),
) )
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) // Return error if encoding fails .map_err(|e| {
error!("Failed to encode JWT: {:?}", e);
StatusCode::INTERNAL_SERVER_ERROR
})
} }
// Function to decode a JWT token and extract the claims // Function to decode a JWT token and extract the claims
#[instrument]
pub fn decode_jwt(jwt: String) -> Result<TokenData<Claims>, StatusCode> { pub fn decode_jwt(jwt: String) -> Result<TokenData<Claims>, StatusCode> {
let secret = "randomstring".to_string(); // Secret key to verify the JWT (should be more secure in production) // Load secret key from environment variable for better security
let secret_key = env::var("JWT_SECRET_KEY")
.map_err(|_| {
error!("JWT_SECRET_KEY not set in environment variables");
StatusCode::INTERNAL_SERVER_ERROR
})?;
// Decode the JWT token using the secret key and extract the claims // Set up validation rules (e.g., check if token has expired, is from a valid issuer, etc.)
decode( let mut validation = Validation::default();
// Use a HashSet for the audience and issuer validation
let mut audience_set = HashSet::new();
audience_set.insert("your_audience".to_string());
let mut issuer_set = HashSet::new();
issuer_set.insert("your_issuer".to_string());
// Set up the validation with the HashSet for audience and issuer
validation.aud = Some(audience_set);
validation.iss = Some(issuer_set);
// Decode the JWT and extract the claims
decode::<Claims>(
&jwt, &jwt,
&DecodingKey::from_secret(secret.as_ref()), &DecodingKey::from_secret(secret_key.as_ref()),
&Validation::default(), &validation,
) )
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) // Return error if decoding fails .map_err(|e| {
warn!("Failed to decode JWT: {:?}", e);
StatusCode::UNAUTHORIZED
})
} }
// Middleware for role-based access control (RBAC) // Middleware for role-based access control (RBAC)
// Ensures that only users with specific roles are authorized to access certain resources // Ensures that only users with specific roles are authorized to access certain resources
#[instrument(skip(req, next))]
pub async fn authorize( pub async fn authorize(
mut req: Request<Body>, mut req: Request<Body>,
next: Next, next: Next,
@ -141,7 +218,7 @@ pub async fn authorize(
}; };
// Fetch the user from the database using the email from the decoded token // Fetch the user from the database using the email from the decoded token
let current_user = match get_user_by_email(&pool, token_data.claims.email).await { let current_user = match get_user_by_email(&pool, token_data.claims.sub).await {
Ok(user) => user, Ok(user) => user,
Err(_) => return Err(AuthError { Err(_) => return Err(AuthError {
message: "Unauthorized user.".to_string(), message: "Unauthorized user.".to_string(),
@ -150,34 +227,66 @@ pub async fn authorize(
}; };
// Check if the user's role is in the list of allowed roles // Check if the user's role is in the list of allowed roles
if !allowed_roles.contains(&current_user.role_id) { if !allowed_roles.contains(&current_user.role_level) {
return Err(AuthError { return Err(AuthError {
message: "Forbidden: insufficient role.".to_string(), message: "Forbidden: insufficient role.".to_string(),
status_code: StatusCode::FORBIDDEN, status_code: StatusCode::FORBIDDEN,
}); });
} }
// Check rate limit.
check_rate_limit(&pool, current_user.id, current_user.tier_level).await?;
// Insert the usage record into the database
insert_usage(&pool, current_user.id, req.uri().path().to_string()).await
.map_err(|_| AuthError {
message: "Failed to insert usage record.".to_string(),
status_code: StatusCode::INTERNAL_SERVER_ERROR,
})?;
// Insert the current user into the request extensions for use in subsequent handlers // Insert the current user into the request extensions for use in subsequent handlers
req.extensions_mut().insert(current_user); req.extensions_mut().insert(current_user);
// Proceed to the next middleware or handler // Proceed to the next middleware or handler
Ok(next.run(req).await) Ok(next.run(req).await)
} }
// Structure to hold the data from the sign-in request // Handler for user sign-in (authentication)
#[derive(Deserialize)] #[derive(Deserialize, ToSchema)]
pub struct SignInData { pub struct SignInData {
pub email: String, pub email: String,
pub password: String, pub password: String,
pub totp: Option<String>,
} }
// Handler for user sign-in (authentication) /// User sign-in endpoint
///
/// This endpoint allows users to sign in using their email, password, and optionally a TOTP code.
///
/// # Parameters
/// - `State(pool)`: The shared database connection pool.
/// - `Json(user_data)`: The user sign-in data (email, password, and optional TOTP code).
///
/// # Returns
/// - `Ok(Json(serde_json::Value))`: A JSON response containing the JWT token if sign-in is successful.
/// - `Err((StatusCode, Json(serde_json::Value)))`: An error response if sign-in fails.
#[utoipa::path(
post,
path = "/sign_in",
request_body = SignInData,
responses(
(status = 200, description = "Successful sign-in", body = serde_json::Value),
(status = 400, description = "Bad request", body = serde_json::Value),
(status = 401, description = "Unauthorized", body = serde_json::Value),
(status = 500, description = "Internal server error", body = serde_json::Value)
)
)]
#[instrument(skip(pool, user_data))]
pub async fn sign_in( pub async fn sign_in(
State(pool): State<PgPool>, // Database connection pool injected as state State(pool): State<PgPool>,
Json(user_data): Json<SignInData>, // Deserialize the JSON body into SignInData Json(user_data): Json<SignInData>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> { ) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
// 1. Retrieve user from the database using the provided email
let user = match get_user_by_email(&pool, user_data.email).await { let user = match get_user_by_email(&pool, user_data.email).await {
Ok(user) => user, Ok(user) => user,
Err(_) => return Err(( Err(_) => return Err((
@ -186,26 +295,105 @@ pub async fn sign_in(
)), )),
}; };
// 2. Verify the password using the stored hash let api_key_hashes = match get_active_apikeys_by_user_id(&pool, user.id).await {
if !verify_password(&user_data.password, &user.password_hash) Ok(hashes) => hashes,
Err(_) => return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error." }))
)),
};
// Check API key first, then password
let credentials_valid = api_key_hashes.iter().any(|api_key| {
verify_hash(&user_data.password, &api_key.key_hash).unwrap_or(false)
}) || verify_hash(&user_data.password, &user.password_hash)
.map_err(|_| ( .map_err(|_| (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error." })) Json(json!({ "error": "Internal server error." }))
))? ))?;
{
if !credentials_valid {
return Err(( return Err((
StatusCode::UNAUTHORIZED, StatusCode::UNAUTHORIZED,
Json(json!({ "error": "Incorrect credentials." })) Json(json!({ "error": "Incorrect credentials." }))
)); ));
} }
// 3. Generate a JWT token for the authenticated user // Check TOTP if it's set up for the user
if let Some(totp_secret) = user.totp_secret {
match user_data.totp {
Some(totp_code) => {
let totp = TOTP::new(
Algorithm::SHA512,
8,
1,
30,
totp_secret.into_bytes(),
).map_err(|_| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error." }))
))?;
if !totp.check_current(&totp_code).unwrap_or(false) {
return Err((
StatusCode::UNAUTHORIZED,
Json(json!({ "error": "Invalid 2FA code." }))
));
}
},
None => return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": "2FA code required for this account." }))
)),
}
}
let email = user.email.clone();
let token = encode_jwt(user.email) let token = encode_jwt(user.email)
.map_err(|_| ( .map_err(|_| (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error." })) Json(json!({ "error": "Internal server error." }))
))?; ))?;
// 4. Return the JWT token to the client info!("User signed in: {}", email);
Ok(Json(json!({ "token": token }))) Ok(Json(json!({ "token": token })))
} }
#[instrument(skip(pool))]
async fn check_rate_limit(pool: &PgPool, user_id: Uuid, tier_level: i32) -> Result<(), AuthError> {
// Get the user's tier requests_per_day
let tier_limit = sqlx::query!(
"SELECT requests_per_day FROM tiers WHERE level = $1",
tier_level
)
.fetch_one(pool)
.await
.map_err(|_| AuthError {
message: "Failed to fetch tier information".to_string(),
status_code: StatusCode::INTERNAL_SERVER_ERROR,
})?
.requests_per_day;
// Count user's requests for today
let request_count = sqlx::query!(
"SELECT COUNT(*) as count FROM usage WHERE user_id = $1 AND creation_date > NOW() - INTERVAL '24 hours'",
user_id
)
.fetch_one(pool)
.await
.map_err(|_| AuthError {
message: "Failed to count user requests".to_string(),
status_code: StatusCode::INTERNAL_SERVER_ERROR,
})?
.count
.unwrap_or(0); // Use 0 if count is NULL
if request_count >= tier_limit as i64 {
return Err(AuthError {
message: "Rate limit exceeded".to_string(),
status_code: StatusCode::TOO_MANY_REQUESTS,
});
}
Ok(())
}

68
src/models/apikey.rs Normal file
View File

@ -0,0 +1,68 @@
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use uuid::Uuid;
use chrono::NaiveDate;
use utoipa::ToSchema;
/// Represents an API key in the system.
#[derive(Deserialize, Debug, Serialize, FromRow, Clone, ToSchema)]
#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL
pub struct ApiKey {
/// The unique id of the API key.
pub id: Uuid,
/// The hashed value of the API key.
pub key_hash: String,
/// The id of the user who owns the API key.
pub user_id: Uuid,
/// The description/name of the API key.
pub description: Option<String>,
/// The expiration date of the API key.
pub expiration_date: Option<NaiveDate>,
/// The creation date of the API key (default is the current date).
pub creation_date: NaiveDate,
/// Whether the API key is disabled (default is false).
pub disabled: bool,
/// Whether the API key has read access (default is true).
pub access_read: bool,
/// Whether the API key has modify access (default is false).
pub access_modify: bool,
}
#[derive(serde::Serialize, ToSchema)]
pub struct ApiKeyNewBody {
pub description: Option<String>,
pub expiration_date: Option<NaiveDate>
}
#[derive(serde::Serialize, ToSchema)]
pub struct ApiKeyResponse {
pub id: Uuid,
pub user_id: Uuid,
pub description: Option<String>,
pub expiration_date: Option<NaiveDate>,
pub creation_date: NaiveDate,
}
#[derive(serde::Serialize, ToSchema)]
pub struct ApiKeyByIDResponse {
pub id: Uuid,
pub description: Option<String>,
pub expiration_date: Option<NaiveDate>,
pub creation_date: NaiveDate,
}
#[derive(sqlx::FromRow, ToSchema)]
pub struct ApiKeyByUserIDResponse {
pub id: Uuid,
pub key_hash: String,
pub expiration_date: Option<NaiveDate>,
}

View File

@ -0,0 +1,11 @@
use utoipa::ToSchema;
#[derive(ToSchema)]
pub struct SuccessResponse {
message: String,
}
#[derive(ToSchema)]
pub struct ErrorResponse {
error: String,
}

View File

@ -1,5 +1,10 @@
/// Module for to-do related models.
pub mod todo;
/// Module for user related models. /// Module for user related models.
pub mod user; pub mod user;
/// Module for API key related models.
pub mod apikey;
/// Module for userrole related models.
pub mod role;
/// Module for to-do related models.
pub mod todo;
/// Module for documentation related models.
pub mod documentation;

View File

@ -1,18 +1,28 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::FromRow; use sqlx::FromRow;
use uuid::Uuid;
use chrono::NaiveDate;
use utoipa::ToSchema;
/// Represents a user role in the system. /// Represents a user role in the system.
#[derive(Deserialize, Debug, Serialize, FromRow, Clone)] #[derive(Deserialize, Debug, Serialize, FromRow, Clone, ToSchema)]
#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL #[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL
pub struct Role { pub struct Role {
/// ID of the role. /// ID of the role.
pub id: i32, pub id: Uuid,
/// Level of the role. /// Level of the role.
pub level: i32, pub level: i32,
/// System name of the role. /// System name of the role.
pub role: String, pub role: String,
/// The name of the role. /// The name of the role.
pub name: String, pub name: String,
/// Description of the role
pub Description: String, /// Description of the role.
pub description: Option<String>,
/// Date when the role was created.
pub creation_date: Option<NaiveDate>,
} }

View File

@ -1,16 +1,31 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::FromRow; use sqlx::FromRow;
use uuid::Uuid;
use chrono::NaiveDate;
use utoipa::ToSchema;
/// Represents a to-do item. /// Represents a to-do item.
#[derive(Deserialize, Debug, Serialize, FromRow)] #[derive(Deserialize, Debug, Serialize, FromRow, ToSchema)]
#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL #[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL
pub struct Todo { pub struct Todo {
/// The unique identifier for the to-do item. /// The unique identifier for the to-do item.
pub id: i32, pub id: Uuid,
/// The task description. /// The task description.
pub task: String, pub task: String,
/// An optional detailed description of the task. /// An optional detailed description of the task.
pub description: Option<String>, pub description: Option<String>,
/// The unique identifier of the user who created the to-do item. /// The unique identifier of the user who created the to-do item.
pub user_id: i32, pub user_id: Uuid,
}
/// The date the task was created.
pub creation_date: NaiveDate,
/// The date the task was completed (if any).
pub completion_date: Option<NaiveDate>,
/// Whether the task is completed.
pub completed: Option<bool>,
}

View File

@ -1,20 +1,59 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::FromRow; use sqlx::FromRow;
use uuid::Uuid;
use chrono::NaiveDate;
use utoipa::ToSchema;
/// Represents a user in the system. /// Represents a user in the system.
#[derive(Deserialize, Debug, Serialize, FromRow, Clone)] #[derive(Deserialize, Debug, Serialize, FromRow, Clone, ToSchema)]
#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL #[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL
pub struct User { pub struct User {
/// The unique identifier for the user. /// The unique identifier for the user.
pub id: i32, pub id: Uuid,
/// The username of the user. /// The username of the user.
pub username: String, pub username: String,
/// The email of the user. /// The email of the user.
pub email: String, pub email: String,
/// The hashed password for the user. /// The hashed password for the user.
pub password_hash: String, pub password_hash: String,
/// The TOTP secret for the user. /// The TOTP secret for the user.
pub totp_secret: Option<String>, pub totp_secret: Option<String>,
/// Current role of the user..
pub role_id: i32, /// Current role of the user.
} pub role_level: i32,
/// Current tier level of the user.
pub tier_level: i32,
/// Date when the user was created.
pub creation_date: Option<NaiveDate>, // Nullable, default value in SQL is CURRENT_DATE
}
/// Represents a user in the system.
#[derive(Deserialize, Debug, Serialize, FromRow, Clone, ToSchema)]
#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL
pub struct UserResponse {
/// The unique identifier for the user.
pub id: Uuid,
/// The username of the user.
pub username: String,
/// The email of the user.
pub email: String,
/// Current role of the user.
pub role_level: i32,
/// Current tier level of the user.
pub tier_level: i32,
/// Date when the user was created.
pub creation_date: Option<NaiveDate>, // Nullable, default value in SQL is CURRENT_DATE
}

View File

@ -0,0 +1,57 @@
use axum::{
extract::{State, Extension, Path},
Json,
response::IntoResponse,
http::StatusCode
};
use sqlx::postgres::PgPool;
use uuid::Uuid;
use serde_json::json;
use tracing::instrument; // For logging
use crate::models::user::User;
// Delete a API key by id
#[utoipa::path(
delete,
path = "/apikeys/{id}",
tag = "apikey",
params(
("id" = String, Path, description = "API key ID")
),
responses(
(status = 200, description = "API key deleted successfully", body = String),
(status = 400, description = "Invalid UUID format", body = String),
(status = 404, description = "API key not found", body = String),
(status = 500, description = "Internal server error", body = String)
)
)]
#[instrument(skip(pool))]
pub async fn delete_apikey_by_id(
State(pool): State<PgPool>,
Extension(user): Extension<User>,
Path(id): Path<String>, // Use Path extractor here
) -> impl IntoResponse {
// Parse the id string to UUID
let uuid = match Uuid::parse_str(&id) {
Ok(uuid) => uuid,
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({ "error": format!("Invalid UUID format.")}))),
};
let result = sqlx::query!("DELETE FROM apikeys WHERE id = $1 AND user_id = $2", uuid, user.id)
.execute(&pool) // Borrow the connection pool
.await;
match result {
Ok(res) => {
if res.rows_affected() == 0 {
(StatusCode::NOT_FOUND, Json(json!({ "error": format!("API key with ID '{}' not found.", id) })))
} else {
(StatusCode::OK, Json(json!({ "success": format!("API key with ID '{}' deleted.", id)})))
}
}
Err(_err) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Could not delete API key '{}'.", id)}))
),
}
}

View File

@ -0,0 +1,58 @@
use axum::{
extract::{State, Extension, Path},
Json,
response::IntoResponse,
http::StatusCode
};
use sqlx::postgres::PgPool;
use uuid::Uuid;
use serde_json::json;
use tracing::instrument; // For logging
use crate::models::user::User;
use crate::models::documentation::{ErrorResponse, SuccessResponse};
// Delete a todo by id
#[utoipa::path(
delete,
path = "/todos/{id}",
tag = "todo",
responses(
(status = 200, description = "Todo deleted successfully", body = SuccessResponse),
(status = 400, description = "Invalid UUID format", body = ErrorResponse),
(status = 404, description = "Todo not found", body = ErrorResponse),
(status = 500, description = "Internal Server Error", body = ErrorResponse)
),
params(
("id" = Uuid, Path, description = "Todo ID"),
("user_id" = Uuid, Path, description = "User ID")
)
)]
#[instrument(skip(pool))]
pub async fn delete_todo_by_id(
State(pool): State<PgPool>,
Extension(user): Extension<User>,
Path(id): Path<String>, // Use Path extractor here
) -> impl IntoResponse {
let uuid = match Uuid::parse_str(&id) {
Ok(uuid) => uuid,
Err(_) => return Err((StatusCode::BAD_REQUEST, Json(json!({ "error": "Invalid UUID format." })),)),
};
let result = sqlx::query!("DELETE FROM todos WHERE id = $1 AND user_id = $2", uuid, user.id)
.execute(&pool) // Borrow the connection pool
.await;
match result {
Ok(res) => {
if (res.rows_affected() == 0) {
Err((StatusCode::NOT_FOUND, Json(json!({ "error": format!("Todo with ID '{}' not found.", id) })),))
} else {
Ok((StatusCode::OK, Json(json!({ "success": format!("Todo with ID '{}' deleted.", id) })),))
}
}
Err(_err) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Could not delete the todo." })),
)),
}
}

View File

@ -0,0 +1,56 @@
use axum::{
extract::{State, Path},
Json,
response::IntoResponse,
http::StatusCode
};
use sqlx::postgres::PgPool;
use uuid::Uuid;
use serde_json::json;
use tracing::instrument; // For logging
use crate::models::documentation::{ErrorResponse, SuccessResponse};
// Delete a user by id
#[utoipa::path(
delete,
path = "/users/{id}",
tag = "user",
responses(
(status = 200, description = "User deleted successfully", body = SuccessResponse),
(status = 400, description = "Invalid UUID format", body = ErrorResponse),
(status = 404, description = "User not found", body = ErrorResponse),
(status = 500, description = "Internal Server Error", body = ErrorResponse)
),
params(
("id" = Uuid, Path, description = "User ID")
)
)]
#[instrument(skip(pool))]
pub async fn delete_user_by_id(
State(pool): State<PgPool>,
Path(id): Path<String>, // Use Path extractor here
) -> impl IntoResponse {
let uuid = match Uuid::parse_str(&id) {
Ok(uuid) => uuid,
Err(_) => return Err((StatusCode::BAD_REQUEST, Json(json!({ "error": "Invalid UUID format." })),)),
};
let result = sqlx::query_as!(User, "DELETE FROM USERS WHERE id = $1", uuid)
.execute(&pool) // Borrow the connection pool
.await;
match result {
Ok(res) => {
if res.rows_affected() == 0 {
Err((StatusCode::NOT_FOUND, Json(json!({ "error": format!("User with ID '{}' not found.", id) })),))
} else {
Ok((StatusCode::OK, Json(json!({ "success": format!("User with ID '{}' deleted.", id) })),))
}
}
Err(_err) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Could not delete the user."})),
)),
}
}

96
src/routes/get_apikeys.rs Normal file
View File

@ -0,0 +1,96 @@
use axum::{
extract::{State, Extension, Path},
Json,
response::IntoResponse,
http::StatusCode
};
use sqlx::postgres::PgPool;
use uuid::Uuid;
use serde_json::json;
use tracing::instrument; // For logging
use crate::models::apikey::*;
use crate::models::user::*;
use crate::models::documentation::ErrorResponse;
use crate::models::apikey::ApiKeyResponse;
// Get all API keys
#[utoipa::path(
get,
path = "/apikeys",
tag = "apikey",
responses(
(status = 200, description = "Get all API keys", body = [ApiKeyResponse]),
(status = 500, description = "Internal Server Error", body = ErrorResponse)
),
params(
("user_id" = Uuid, Path, description = "User ID")
)
)]
#[instrument(skip(pool))]
pub async fn get_all_apikeys(
State(pool): State<PgPool>,
Extension(user): Extension<User>, // Extract current user from the request extensions
) -> impl IntoResponse {
let apikeys = sqlx::query_as!(ApiKeyResponse,
"SELECT id, user_id, description, expiration_date, creation_date FROM apikeys WHERE user_id = $1",
user.id
)
.fetch_all(&pool) // Borrow the connection pool
.await;
match apikeys {
Ok(apikeys) => Ok(Json(apikeys)), // Return all API keys as JSON
Err(_err) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Could not get the API key."})),
)),
}
}
// Get a single API key by id
#[utoipa::path(
get,
path = "/apikeys/{id}",
tag = "apikey",
responses(
(status = 200, description = "Get API key by ID", body = ApiKeyByIDResponse),
(status = 400, description = "Invalid UUID format", body = ErrorResponse),
(status = 404, description = "API key not found", body = ErrorResponse),
(status = 500, description = "Internal Server Error", body = ErrorResponse)
),
params(
("id" = Uuid, Path, description = "API key ID"),
("user_id" = Uuid, Path, description = "User ID")
)
)]
#[instrument(skip(pool))]
pub async fn get_apikeys_by_id(
State(pool): State<PgPool>,
Extension(user): Extension<User>, // Extract current user from the request extensions
Path(id): Path<String>, // Use Path extractor here
) -> impl IntoResponse {
let uuid = match Uuid::parse_str(&id) {
Ok(uuid) => uuid,
Err(_) => return Err((StatusCode::BAD_REQUEST, Json(json!({ "error": "Invalid UUID format." })))),
};
let apikeys = sqlx::query_as!(ApiKeyByIDResponse,
"SELECT id, description, expiration_date, creation_date FROM apikeys WHERE id = $1 AND user_id = $2",
uuid,
user.id
)
.fetch_optional(&pool) // Borrow the connection pool
.await;
match apikeys {
Ok(Some(apikeys)) => Ok(Json(apikeys)), // Return the API key as JSON if found
Ok(None) => Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": format!("API key with ID '{}' not found.", id) })),
)),
Err(_err) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Could not get the API key."})),
)),
}
}

View File

@ -1,10 +1,63 @@
use axum::{response::IntoResponse, Json, extract::State}; use axum::{
response::IntoResponse,
Json,
extract::State
};
use serde_json::json; use serde_json::json;
use sqlx::PgPool; use sqlx::PgPool;
use sysinfo::{System, RefreshKind, Disks}; use sysinfo::{System, RefreshKind, Disks};
use tokio::{task, join}; use tokio::{task, join};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use tracing::instrument; // For logging
use utoipa::ToSchema;
use serde::{Deserialize, Serialize};
// Struct definitions
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct HealthResponse {
pub cpu_usage: CpuUsage,
pub database: DatabaseStatus,
pub disk_usage: DiskUsage,
pub memory: MemoryStatus,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct CpuUsage {
#[serde(rename = "available_percentage")]
pub available_pct: String,
pub status: String,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct DatabaseStatus {
pub status: String,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct DiskUsage {
pub status: String,
#[serde(rename = "used_percentage")]
pub used_pct: String,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct MemoryStatus {
#[serde(rename = "available_mb")]
pub available_mb: i64,
pub status: String,
}
// Health check endpoint
#[utoipa::path(
get,
path = "/health",
tag = "health",
responses(
(status = 200, description = "Successfully fetched health status", body = HealthResponse),
(status = 500, description = "Internal server error")
)
)]
#[instrument(skip(database_connection))]
pub async fn get_health(State(database_connection): State<PgPool>) -> impl IntoResponse { pub async fn get_health(State(database_connection): State<PgPool>) -> impl IntoResponse {
// Use Arc and Mutex to allow sharing System between tasks // Use Arc and Mutex to allow sharing System between tasks
let system = Arc::new(Mutex::new(System::new_with_specifics(RefreshKind::everything()))); let system = Arc::new(Mutex::new(System::new_with_specifics(RefreshKind::everything())));
@ -51,7 +104,7 @@ pub async fn get_health(State(database_connection): State<PgPool>) -> impl IntoR
status = "degraded"; status = "degraded";
} }
} else { } else {
details["cpu_usage"] = json!({ "status": "error", "error": "Failed to retrieve CPU usage" }); details["cpu_usage"] = json!({ "status": "error", "message": "Failed to retrieve CPU usage" });
status = "degraded"; status = "degraded";
} }
@ -62,7 +115,7 @@ pub async fn get_health(State(database_connection): State<PgPool>) -> impl IntoR
status = "degraded"; status = "degraded";
} }
} else { } else {
details["memory"] = json!({ "status": "error", "error": "Failed to retrieve memory information" }); details["memory"] = json!({ "status": "error", "message": "Failed to retrieve memory information" });
status = "degraded"; status = "degraded";
} }
@ -73,7 +126,7 @@ pub async fn get_health(State(database_connection): State<PgPool>) -> impl IntoR
status = "degraded"; status = "degraded";
} }
} else { } else {
details["disk_usage"] = json!({ "status": "error", "error": "Failed to retrieve disk usage" }); details["disk_usage"] = json!({ "status": "error", "message": "Failed to retrieve disk usage" });
status = "degraded"; status = "degraded";
} }
@ -84,7 +137,7 @@ pub async fn get_health(State(database_connection): State<PgPool>) -> impl IntoR
status = "degraded"; status = "degraded";
} }
} else { } else {
details["important_processes"] = json!({ "status": "error", "error": "Failed to retrieve process information" }); details["important_processes"] = json!({ "status": "error", "message": "Failed to retrieve process information" });
status = "degraded"; status = "degraded";
} }
@ -95,7 +148,7 @@ pub async fn get_health(State(database_connection): State<PgPool>) -> impl IntoR
status = "degraded"; status = "degraded";
} }
} else { } else {
details["database"] = json!({ "status": "error", "error": "Failed to retrieve database status" }); details["database"] = json!({ "status": "error", "message": "Failed to retrieve database status" });
status = "degraded"; status = "degraded";
} }
@ -106,7 +159,7 @@ pub async fn get_health(State(database_connection): State<PgPool>) -> impl IntoR
status = "degraded"; status = "degraded";
} }
} else { } else {
details["network"] = json!({ "status": "error", "error": "Failed to retrieve network status" }); details["network"] = json!({ "status": "error", "message": "Failed to retrieve network status" });
status = "degraded"; status = "degraded";
} }
@ -118,6 +171,7 @@ pub async fn get_health(State(database_connection): State<PgPool>) -> impl IntoR
// Helper functions // Helper functions
#[instrument]
fn check_cpu_usage(system: &mut System) -> Result<serde_json::Value, ()> { fn check_cpu_usage(system: &mut System) -> Result<serde_json::Value, ()> {
system.refresh_cpu_usage(); system.refresh_cpu_usage();
let usage = system.global_cpu_usage(); let usage = system.global_cpu_usage();
@ -129,6 +183,7 @@ fn check_cpu_usage(system: &mut System) -> Result<serde_json::Value, ()> {
})) }))
} }
#[instrument]
fn check_memory(system: &mut System) -> Result<serde_json::Value, ()> { fn check_memory(system: &mut System) -> Result<serde_json::Value, ()> {
system.refresh_memory(); system.refresh_memory();
let available = system.available_memory() / 1024 / 1024; // Convert to MB let available = system.available_memory() / 1024 / 1024; // Convert to MB
@ -138,6 +193,7 @@ fn check_memory(system: &mut System) -> Result<serde_json::Value, ()> {
})) }))
} }
#[instrument]
fn check_disk_usage() -> Result<serde_json::Value, ()> { fn check_disk_usage() -> Result<serde_json::Value, ()> {
// Create a new Disks object and refresh the disk information // Create a new Disks object and refresh the disk information
let mut disks = Disks::new(); let mut disks = Disks::new();
@ -163,6 +219,7 @@ fn check_disk_usage() -> Result<serde_json::Value, ()> {
})) }))
} }
#[instrument]
fn check_processes(system: &mut System, processes: &[&str]) -> Result<Vec<serde_json::Value>, ()> { fn check_processes(system: &mut System, processes: &[&str]) -> Result<Vec<serde_json::Value>, ()> {
system.refresh_processes(sysinfo::ProcessesToUpdate::All, true); system.refresh_processes(sysinfo::ProcessesToUpdate::All, true);

View File

@ -1,43 +1,90 @@
use axum::extract::{State, Path}; use axum::{
use axum::Json; extract::{State, Extension, Path},
use axum::response::IntoResponse; Json,
response::IntoResponse,
http::StatusCode
};
use sqlx::postgres::PgPool; use sqlx::postgres::PgPool;
use uuid::Uuid;
use serde_json::json;
use tracing::instrument; // For logging
use crate::models::todo::*; use crate::models::todo::*;
use crate::models::user::*;
// Get all todos // Get all todos
pub async fn get_all_todos(State(pool): State<PgPool>,) -> impl IntoResponse { #[utoipa::path(
let todos = sqlx::query_as!(Todo, "SELECT * FROM todos") // Your table name get,
.fetch_all(&pool) // Borrow the connection pool path = "/todos/all",
.await; tag = "todo",
responses(
(status = 200, description = "Successfully fetched all todos", body = [Todo]),
(status = 500, description = "Internal server error")
)
)]
#[instrument(skip(pool))]
pub async fn get_all_todos(
State(pool): State<PgPool>,
Extension(user): Extension<User>, // Extract current user from the request extensions
) -> impl IntoResponse {
let todos = sqlx::query_as!(Todo,
"SELECT id, user_id, task, description, creation_date, completion_date, completed FROM todos WHERE user_id = $1",
user.id
)
.fetch_all(&pool) // Borrow the connection pool
.await;
match todos { match todos {
Ok(todos) => Ok(Json(todos)), // Return all todos as JSON Ok(todos) => Ok(Json(todos)), // Return all todos as JSON
Err(err) => Err(( Err(_err) => Err((
axum::http::StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
format!("Error fetching todos: {}", err), Json(json!({ "error": "Could not fetch the details of the todo." })),
)), )),
} }
} }
// Get a single todo by id // Get a single todo by id
#[utoipa::path(
get,
path = "/todos/{id}",
tag = "todo",
params(
("id" = String, Path, description = "Todo ID")
),
responses(
(status = 200, description = "Successfully fetched todo by ID", body = Todo),
(status = 400, description = "Invalid UUID format"),
(status = 404, description = "Todo not found"),
(status = 500, description = "Internal server error")
)
)]
#[instrument(skip(pool))]
pub async fn get_todos_by_id( pub async fn get_todos_by_id(
State(pool): State<PgPool>, State(pool): State<PgPool>,
Path(id): Path<i32>, // Use Path extractor here Extension(user): Extension<User>, // Extract current user from the request extensions
Path(id): Path<String>, // Use Path extractor here
) -> impl IntoResponse { ) -> impl IntoResponse {
let todo = sqlx::query_as!(Todo, "SELECT * FROM todos WHERE id = $1", id) let uuid = match Uuid::parse_str(&id) {
.fetch_optional(&pool) // Borrow the connection pool Ok(uuid) => uuid,
.await; Err(_) => return Err((StatusCode::BAD_REQUEST, Json(json!({ "error": "Invalid UUID format." })))),
};
let todo = sqlx::query_as!(Todo,
"SELECT id, user_id, task, description, creation_date, completion_date, completed FROM todos WHERE id = $1 AND user_id = $2",
uuid,
user.id
)
.fetch_optional(&pool) // Borrow the connection pool
.await;
match todo { match todo {
Ok(Some(todo)) => Ok(Json(todo)), // Return the todo as JSON if found Ok(Some(todo)) => Ok(Json(todo)), // Return the todo as JSON if found
Ok(None) => Err(( Ok(None) => Err((
axum::http::StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
format!("Todo with id {} not found", id), Json(json!({ "error": format!("Todo with ID '{}' not found.", id) })),
)), )),
Err(err) => Err(( Err(_err) => Err((
axum::http::StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
format!("Error fetching todo: {}", err), Json(json!({ "error": "Could not fetch the details of the todo." })),
)), )),
} }
} }

84
src/routes/get_usage.rs Normal file
View File

@ -0,0 +1,84 @@
use axum::{extract::{Extension, State}, Json};
use axum::http::StatusCode;
use axum::response::IntoResponse;
use serde::Serialize;
use serde_json::json;
use sqlx::postgres::PgPool;
use tracing::instrument;
use utoipa::ToSchema;
use crate::models::user::*;
#[derive(Debug, Serialize, ToSchema)]
pub struct UsageResponseLastDay {
#[serde(rename = "requests_last_24_hours")]
pub count: i64 // or usize depending on your count type
}
// Get usage for the last 24 hours
#[utoipa::path(
get,
path = "/usage/lastday",
tag = "usage",
responses(
(status = 200, description = "Successfully fetched usage for the last 24 hours", body = UsageResponseLastDay),
(status = 500, description = "Internal server error")
)
)]
#[instrument(skip(pool))]
pub async fn get_usage_last_day(
State(pool): State<PgPool>,
Extension(user): Extension<User>, // Extract current user from the request extensions
) -> impl IntoResponse {
let result = sqlx::query!("SELECT count(*) FROM usage WHERE user_id = $1 AND creation_date > NOW() - INTERVAL '24 hours';", user.id)
.fetch_one(&pool) // Borrow the connection pool
.await;
match result {
Ok(row) => {
let count = row.count.unwrap_or(0) as i64;
Ok(Json(json!({ "requests_last_24_hours": count })))
},
Err(_err) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Could not fetch the usage data." }))),
),
}
}
#[derive(Debug, Serialize, ToSchema)]
pub struct UsageResponseLastWeek {
#[serde(rename = "requests_last_7_days")]
pub count: i64 // or usize depending on your count type
}
// Get usage for the last 7 days
#[utoipa::path(
get,
path = "/usage/lastweek",
tag = "usage",
responses(
(status = 200, description = "Successfully fetched usage for the last 7 days", body = UsageResponseLastDay),
(status = 500, description = "Internal server error")
)
)]
#[instrument(skip(pool))]
pub async fn get_usage_last_week(
State(pool): State<PgPool>,
Extension(user): Extension<User>, // Extract current user from the request extensions
) -> impl IntoResponse {
let result = sqlx::query!("SELECT count(*) FROM usage WHERE user_id = $1 AND creation_date > NOW() - INTERVAL '7 days';", user.id)
.fetch_one(&pool) // Borrow the connection pool
.await;
match result {
Ok(row) => {
let count = row.count.unwrap_or(0) as i64;
Ok(Json(json!({ "requests_last_7_days": count })))
},
Err(_err) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Could not fetch the usage data." }))),
),
}
}

View File

@ -1,87 +1,121 @@
use axum::extract::{State, Path}; use axum::extract::{Path, State};
use axum::http::StatusCode;
use axum::Json; use axum::Json;
use axum::response::IntoResponse; use axum::response::IntoResponse;
use serde_json::json;
use sqlx::postgres::PgPool; use sqlx::postgres::PgPool;
use crate::models::user::*; // Import the User struct use tracing::instrument;
use uuid::Uuid;
use crate::models::user::*;
// Get all users // Get all users
pub async fn get_all_users(State(pool): State<PgPool>,) -> impl IntoResponse { #[utoipa::path(
let users = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users") // Your table name get,
.fetch_all(&pool) // Borrow the connection pool path = "/users/all",
tag = "user",
responses(
(status = 200, description = "Successfully fetched all users", body = [UserResponse]),
(status = 500, description = "Internal server error")
)
)]
#[instrument(skip(pool))]
pub async fn get_all_users(State(pool): State<PgPool>) -> impl IntoResponse {
let users = sqlx::query_as!(UserResponse, "SELECT id, username, email, role_level, tier_level, creation_date FROM users")
.fetch_all(&pool)
.await; .await;
match users { match users {
Ok(users) => Ok(Json(users)), // Return all users as JSON Ok(users) => Ok(Json(users)),
Err(err) => Err(( Err(_err) => Err((
axum::http::StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
format!("Error fetching users: {}", err), Json(json!({ "error": "Could not fetch the users details." })),
)), )),
} }
} }
// Get a single user by id // Get a single user by ID
#[utoipa::path(
get,
path = "/users/{id}",
tag = "user",
params(
("id" = String, Path, description = "User ID")
),
responses(
(status = 200, description = "Successfully fetched user by ID", body = UserResponse),
(status = 400, description = "Invalid UUID format"),
(status = 404, description = "User not found"),
(status = 500, description = "Internal server error")
)
)]
#[instrument(skip(pool))]
pub async fn get_users_by_id( pub async fn get_users_by_id(
State(pool): State<PgPool>, State(pool): State<PgPool>,
Path(id): Path<i32>, // Use Path extractor here Path(id): Path<String>,
) -> impl IntoResponse { ) -> impl IntoResponse {
let uuid = match Uuid::parse_str(&id) {
Ok(uuid) => uuid,
Err(_) => return Err((StatusCode::BAD_REQUEST, Json(json!({ "error": "Invalid UUID format." })))),
};
let user = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users WHERE id = $1", id) let user = sqlx::query_as!(UserResponse, "SELECT id, username, email, role_level, tier_level, creation_date FROM users WHERE id = $1", uuid)
.fetch_optional(&pool) // Borrow the connection pool .fetch_optional(&pool)
.await; .await;
match user { match user {
Ok(Some(user)) => Ok(Json(user)), // Return the user as JSON if found Ok(Some(user)) => Ok(Json(user)),
Ok(None) => Err(( Ok(None) => Err((
axum::http::StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
format!("User with id {} not found", id), Json(json!({ "error": format!("User with ID '{}' not found", id) })),
)), )),
Err(err) => Err(( Err(_err) => Err((
axum::http::StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
format!("Error fetching user: {}", err), Json(json!({ "error": "Could not fetch the users details." })),
)), )),
} }
} }
// Get a single user by username // Get a single user by username
pub async fn get_user_by_username( // pub async fn get_user_by_username(
State(pool): State<PgPool>, // State(pool): State<PgPool>,
Path(username): Path<String>, // Use Path extractor here for username // Path(username): Path<String>,
) -> impl IntoResponse { // ) -> impl IntoResponse {
let user = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users WHERE username = $1", username) // let user = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_level, tier_level, creation_date FROM users WHERE username = $1", username)
.fetch_optional(&pool) // Borrow the connection pool // .fetch_optional(&pool)
.await; // .await;
match user { // match user {
Ok(Some(user)) => Ok(Json(user)), // Return the user as JSON if found // Ok(Some(user)) => Ok(Json(user)),
Ok(None) => Err(( // Ok(None) => Err((
axum::http::StatusCode::NOT_FOUND, // StatusCode::NOT_FOUND,
format!("User with username {} not found", username), // Json(json!({ "error": format!("User with username '{}' not found", username) })),
)), // )),
Err(err) => Err(( // Err(err) => Err((
axum::http::StatusCode::INTERNAL_SERVER_ERROR, // StatusCode::INTERNAL_SERVER_ERROR,
format!("Error fetching user: {}", err), // Json(json!({ "error": "Could not fetch the users details." })),
)), // )),
} // }
} // }
// Get a single user by email // Get a single user by email
pub async fn get_user_by_email( // pub async fn get_user_by_email(
State(pool): State<PgPool>, // State(pool): State<PgPool>,
Path(email): Path<String>, // Use Path extractor here for email // Path(email): Path<String>,
) -> impl IntoResponse { // ) -> impl IntoResponse {
let user = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users WHERE email = $1", email) // let user = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_level, tier_level, creation_date FROM users WHERE email = $1", email)
.fetch_optional(&pool) // Borrow the connection pool // .fetch_optional(&pool)
.await; // .await;
match user { // match user {
Ok(Some(user)) => Ok(Json(user)), // Return the user as JSON if found // Ok(Some(user)) => Ok(Json(user)),
Ok(None) => Err(( // Ok(None) => Err((
axum::http::StatusCode::NOT_FOUND, // StatusCode::NOT_FOUND,
format!("User with email {} not found", email), // Json(json!({ "error": format!("User with email '{}' not found", email) })),
)), // )),
Err(err) => Err(( // Err(err) => Err((
axum::http::StatusCode::INTERNAL_SERVER_ERROR, // StatusCode::INTERNAL_SERVER_ERROR,
format!("Error fetching user: {}", err), // Json(json!({ "error": "Could not fetch the users details." })),
)), // )),
} // }
} // }

80
src/routes/homepage.rs Normal file
View File

@ -0,0 +1,80 @@
use axum::response::{IntoResponse, Html};
use tracing::instrument; // For logging
#[instrument]
pub async fn homepage() -> impl IntoResponse {
Html(r#"
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Welcome to Axium!</title>
<style>
body {
font-family: 'Arial', sans-serif;
background-color: #1e1e2e;
color: #ffffff;
text-align: center;
padding: 40px;
}
a {
color: #00bcd4;
text-decoration: none;
font-weight: bold;
}
a:hover {
text-decoration: underline;
}
.container {
max-width: 800px;
margin: auto;
padding: 20px;
background: #282a36;
border-radius: 8px;
box-shadow: 0 0 15px rgba(0, 0, 0, 0.2);
text-align: center;
}
h1 {
font-size: 1.2em;
white-space: pre;
font-family: monospace;
}
.motto {
margin-top: 10px;
font-size: 1em;
font-style: italic;
}
ul {
list-style-type: none;
padding: 0;
text-align: left;
display: inline-block;
}
li {
margin: 10px 0;
}
</style>
</head>
<body>
<div class="container">
<h1>
db 88
d88b ""
d8'`8b
d8' `8b 8b, ,d8 88 88 88 88,dPYba,,adPYba,
d8YaaaaY8b `Y8, ,8P' 88 88 88 88P' "88" "8a
d8""""""""8b )888( 88 88 88 88 88 88
d8' `8b ,d8" "8b, 88 "8a, ,a88 88 88 88
d8' `8b 8P' `Y8 88 `"YbbdP'Y8 88 88 88
</h1>
<p class="motto">An example API built with Rust, Axum, SQLx, and PostgreSQL</p>
<ul>
<li>🚀 Check out all endpoints by visiting <a href="/swagger-ui">Swagger</a>, or import the <a href="/openapi.json">OpenAPI</a> file.</li>
<li> Do not forget to update your Docker Compose configuration with a health check. Just point it to: <a href="/health">/health</a></li>
</ul>
</div>
</body>
</html>
"#)
}

View File

@ -1,28 +1,101 @@
// Module declarations for different route handlers // Module declarations for different route handlers
pub mod homepage;
pub mod get_todos; pub mod get_todos;
pub mod get_users; pub mod get_users;
pub mod get_apikeys;
pub mod get_usage;
pub mod post_todos; pub mod post_todos;
pub mod post_users; pub mod post_users;
pub mod post_apikeys;
pub mod rotate_apikeys;
pub mod get_health; pub mod get_health;
pub mod delete_users;
pub mod delete_todos;
pub mod delete_apikeys;
pub mod protected; pub mod protected;
// Re-exporting modules to make their contents available at this level // Re-exporting modules to make their contents available at this level
pub use homepage::*;
pub use get_todos::*; pub use get_todos::*;
pub use get_users::*; pub use get_users::*;
pub use get_apikeys::*;
pub use get_usage::*;
pub use rotate_apikeys::*;
pub use post_todos::*; pub use post_todos::*;
pub use post_users::*; pub use post_users::*;
pub use post_apikeys::*;
pub use get_health::*; pub use get_health::*;
pub use delete_users::*;
pub use delete_todos::*;
pub use delete_apikeys::*;
pub use protected::*; pub use protected::*;
use axum::{ use axum::{
Router, Router,
routing::{get, post}, routing::{get, post, delete}
}; };
use sqlx::PgPool; use sqlx::PgPool;
use tower_http::trace::TraceLayer;
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
use crate::middlewares::auth::{sign_in, authorize}; use crate::middlewares::auth::{sign_in, authorize};
// Define the OpenAPI documentation structure
#[derive(OpenApi)]
#[openapi(
info(
title = "Axium",
description = "An example API built with Rust, Axum, SQLx, and PostgreSQL.",
version = "1.0.0",
contact(
url = "https://github.com/Riktastic/Axium"
),
license(
name = "MIT",
url = "https://opensource.org/licenses/MIT"
)
),
paths(
get_all_users,
get_users_by_id,
get_all_apikeys,
get_apikeys_by_id,
get_usage_last_day,
get_usage_last_week,
get_all_todos,
get_todos_by_id,
get_health,
post_user,
post_apikey,
post_todo,
rotate_apikey,
delete_user_by_id,
delete_apikey_by_id,
delete_todo_by_id,
protected,
//sign_in, // Add sign_in path
),
components(
schemas(
UserResponse,
// ApiKeyResponse,
// ApiKeyByIDResponse,
// Todo,
// SignInData,
// ...add other schemas as needed...
)
),
tags(
(name = "user", description = "User related endpoints."),
(name = "apikey", description = "API key related endpoints."),
(name = "usage", description = "Usage related endpoints."),
(name = "todo", description = "Todo related endpoints."),
(name = "health", description = "Health check endpoint."),
)
)]
struct ApiDoc;
/// Function to create and configure all routes /// Function to create and configure all routes
pub fn create_routes(database_connection: PgPool) -> Router { pub fn create_routes(database_connection: PgPool) -> Router {
// Authentication routes // Authentication routes
@ -35,25 +108,84 @@ pub fn create_routes(database_connection: PgPool) -> Router {
// User-related routes // User-related routes
let user_routes = Router::new() let user_routes = Router::new()
.route("/all", get(get_all_users)) .route("/all", get(get_all_users).layer(axum::middleware::from_fn(|req, next| {
.route("/{id}", get(get_users_by_id)) let allowed_roles = vec![2];
.route("/", post(post_user)); authorize(req, next, allowed_roles)})))
.route("/new", post(post_user).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![2];
authorize(req, next, allowed_roles)
})))
.route("/{id}", get(get_users_by_id).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![2];
authorize(req, next, allowed_roles)})))
.route("/{id}", delete(delete_user_by_id).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![2];
authorize(req, next, allowed_roles)})));
// API key-related routes
let apikey_routes = Router::new()
.route("/all", get(get_all_apikeys).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)})))
.route("/new", post(post_apikey).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)
})))
.route("/{id}", get(get_apikeys_by_id).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)})))
.route("/{id}", delete(delete_apikey_by_id).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)})))
.route("/rotate/{id}", post(rotate_apikey).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)})));
// Usage related routes
let usage_routes = Router::new()
.route("/lastday", get(get_usage_last_day).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)})))
.route("/lastweek", get(get_usage_last_week).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)
})));
// Todo-related routes // Todo-related routes
let todo_routes = Router::new() let todo_routes = Router::new()
.route("/all", get(get_all_todos)) .route("/all", get(get_all_todos).layer(axum::middleware::from_fn(|req, next| {
.route("/", post(post_todo).layer(axum::middleware::from_fn(|req, next| { let allowed_roles = vec![1, 2];
authorize(req, next, allowed_roles)})))
.route("/new", post(post_todo).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1, 2]; let allowed_roles = vec![1, 2];
authorize(req, next, allowed_roles) authorize(req, next, allowed_roles)
}))) })))
.route("/{id}", get(get_todos_by_id)); .route("/{id}", get(get_todos_by_id).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1, 2];
authorize(req, next, allowed_roles)})))
.route("/{id}", delete(delete_todo_by_id).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)})));
// Documentation:
// Create OpenAPI specification
let openapi = ApiDoc::openapi();
// Create Swagger UI
let swagger_ui = SwaggerUi::new("/swagger-ui")
.url("/openapi.json", openapi.clone());
// Combine all routes and add middleware // Combine all routes and add middleware
Router::new() Router::new()
.route("/", get(homepage))
.merge(auth_routes) // Add authentication routes .merge(auth_routes) // Add authentication routes
.merge(swagger_ui)
.nest("/users", user_routes) // Add user routes under /users .nest("/users", user_routes) // Add user routes under /users
.nest("/apikeys", apikey_routes) // Add API key routes under /apikeys
.nest("/usage", usage_routes) // Add usage routes under /usage
.nest("/todos", todo_routes) // Add todo routes under /todos .nest("/todos", todo_routes) // Add todo routes under /todos
.route("/health", get(get_health)) // Add health check route .route("/health", get(get_health)) // Add health check route
.layer(axum::Extension(database_connection.clone())) // Add database connection to all routes .layer(axum::Extension(database_connection.clone())) // Add database connection to all routes
.with_state(database_connection) // Add database connection as state .with_state(database_connection) // Add database connection as state
} .layer(TraceLayer::new_for_http()) // Add tracing middleware
}

133
src/routes/post_apikeys.rs Normal file
View File

@ -0,0 +1,133 @@
use axum::{extract::{Extension, State}, Json};
use axum::http::StatusCode;
use axum::response::IntoResponse;
use chrono::{Duration, Utc};
use serde::{Deserialize, Serialize};
use serde_json::json;
use sqlx::postgres::PgPool;
use tracing::{error, info};
use utoipa::ToSchema;
use uuid::Uuid;
use validator::Validate;
use crate::handlers::validate::validate_future_date;
use crate::middlewares::auth::{generate_api_key, hash_password};
use crate::models::user::User;
// Define the request body structure
#[derive(Deserialize, Validate, ToSchema)]
pub struct ApiKeyBody {
#[validate(length(min = 0, max = 50))]
pub description: Option<String>,
#[validate(custom(function = "validate_future_date"))]
pub expiration_date: Option<String>,
}
// Define the response body structure
#[derive(Serialize, ToSchema)]
pub struct ApiKeyResponse {
pub id: Uuid,
pub api_key: String,
pub description: String,
pub expiration_date: String,
}
// Define the API endpoint
#[utoipa::path(
post,
path = "/apikeys",
tag = "apikey",
request_body = ApiKeyBody,
responses(
(status = 200, description = "API key created successfully", body = ApiKeyResponse),
(status = 400, description = "Validation error", body = String),
(status = 500, description = "Internal server error", body = String)
)
)]
pub async fn post_apikey(
State(pool): State<PgPool>,
Extension(user): Extension<User>,
Json(api_key_request): Json<ApiKeyBody>
) -> impl IntoResponse {
// Validate input
if let Err(errors) = api_key_request.validate() {
let error_messages: Vec<String> = errors
.field_errors()
.iter()
.flat_map(|(_, errors)| errors.iter().map(|e| e.message.clone().unwrap_or_default().to_string()))
.collect();
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": error_messages.join(", ") }))
));
}
info!("Received request to create API key for user: {}", user.id);
// Check if the user already has 5 or more API keys
let existing_keys_count = sqlx::query!(
"SELECT COUNT(*) as count FROM apikeys WHERE user_id = $1 AND expiration_date >= CURRENT_DATE",
user.id
)
.fetch_one(&pool)
.await;
match existing_keys_count {
Ok(row) if row.count.unwrap_or(0) >= 5 => {
info!("User {} already has 5 API keys.", user.id);
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": "You already have 5 API keys. Please delete an existing key before creating a new one." }))
));
}
Err(_err) => {
error!("Failed to check the amount of API keys for user {}.", user.id);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Could not check the amount of API keys registered." }))
));
}
_ => {} // Proceed if the user has fewer than 5 keys
}
let current_date = Utc::now().naive_utc();
let description = api_key_request.description
.unwrap_or_else(|| format!("API key created on {}", current_date.format("%Y-%m-%d")));
let expiration_date = api_key_request.expiration_date
.and_then(|date| date.parse::<chrono::NaiveDate>().ok())
.unwrap_or_else(|| (current_date + Duration::days(365 * 2)).date());
let api_key = generate_api_key();
let key_hash = hash_password(&api_key).expect("Failed to hash password.");
let row = sqlx::query!(
"INSERT INTO apikeys (key_hash, description, expiration_date, user_id) VALUES ($1, $2, $3, $4) RETURNING id, key_hash, description, expiration_date, user_id",
key_hash,
description,
expiration_date,
user.id
)
.fetch_one(&pool)
.await;
match row {
Ok(row) => {
info!("Successfully created API key for user: {}", user.id);
Ok(Json(ApiKeyResponse {
id: row.id,
api_key: api_key,
description: description.to_string(),
expiration_date: expiration_date.to_string()
}))
},
Err(err) => {
error!("Error creating API key for user {}: {}", user.id, err);
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Error creating API key: {}.", err) }))
))
},
}
}

View File

@ -1,39 +1,62 @@
use axum::{extract::{State, Extension}, Json}; use axum::{extract::{Extension, State}, Json, response::IntoResponse};
use axum::response::IntoResponse;
use sqlx::postgres::PgPool;
use crate::models::todo::*;
use crate::models::user::*;
use serde::Deserialize;
use axum::http::StatusCode; use axum::http::StatusCode;
use serde::Deserialize;
use serde_json::json; use serde_json::json;
use sqlx::postgres::PgPool;
use tracing::instrument;
use utoipa::ToSchema;
use validator::Validate;
#[derive(Deserialize)] use crate::models::todo::Todo;
use crate::models::user::User;
// Define the request body structure
#[derive(Deserialize, Validate, ToSchema)]
pub struct TodoBody { pub struct TodoBody {
#[validate(length(min = 3, max = 50))]
pub task: String, pub task: String,
#[validate(length(min = 3, max = 100))]
pub description: Option<String>, pub description: Option<String>,
pub user_id: i32,
} }
// Add a new todo // Define the API endpoint
#[utoipa::path(
post,
path = "/todos",
tag = "todo",
request_body = TodoBody,
responses(
(status = 200, description = "Todo created successfully", body = Todo),
(status = 400, description = "Validation error", body = String),
(status = 500, description = "Internal server error", body = String)
)
)]
#[instrument(skip(pool, user, todo))]
pub async fn post_todo( pub async fn post_todo(
State(pool): State<PgPool>, State(pool): State<PgPool>,
Extension(user): Extension<User>, // Extract current user from the request extensions Extension(user): Extension<User>,
Json(todo): Json<TodoBody> Json(todo): Json<TodoBody>
) -> impl IntoResponse { ) -> impl IntoResponse {
// Ensure the user_id from the request matches the current user's id // Validate input
if todo.user_id != user.id { if let Err(errors) = todo.validate() {
let error_messages: Vec<String> = errors
.field_errors()
.iter()
.flat_map(|(_, errors)| errors.iter().map(|e| e.message.clone().unwrap_or_default().to_string()))
.collect();
return Err(( return Err((
StatusCode::FORBIDDEN, StatusCode::BAD_REQUEST,
Json(json!({ "error": "User is not authorized to create a todo for another user" })) Json(json!({ "error": error_messages.join(", ") }))
)); ));
} }
// Insert the todo into the database
let row = sqlx::query!( let row = sqlx::query!(
"INSERT INTO todos (task, description, user_id) VALUES ($1, $2, $3) RETURNING id, task, description, user_id", "INSERT INTO todos (task, description, user_id)
VALUES ($1, $2, $3)
RETURNING id, task, description, user_id, creation_date, completion_date, completed",
todo.task, todo.task,
todo.description, todo.description,
todo.user_id user.id
) )
.fetch_one(&pool) .fetch_one(&pool)
.await; .await;
@ -44,10 +67,13 @@ pub async fn post_todo(
task: row.task, task: row.task,
description: row.description, description: row.description,
user_id: row.user_id, user_id: row.user_id,
creation_date: row.creation_date,
completion_date: row.completion_date,
completed: row.completed,
})), })),
Err(err) => Err(( Err(_err) => Err((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Error: {}", err) })) Json(json!({ "error": "Could not create a new todo." }))
)), )),
} }
} }

View File

@ -1,44 +1,98 @@
use axum::extract::State; use axum::{extract::State, Json, response::IntoResponse};
use axum::Json; use axum::http::StatusCode;
use axum::response::IntoResponse; use serde::{Deserialize, Serialize};
use serde_json::json;
use sqlx::postgres::PgPool; use sqlx::postgres::PgPool;
use crate::models::user::*; use tracing::instrument;
use serde::Deserialize; use uuid::Uuid;
use utoipa::ToSchema;
use validator::Validate;
#[derive(Deserialize)] use crate::handlers::validate::{validate_password, validate_username};
use crate::middlewares::auth::{hash_password, generate_totp_secret};
// Define the request body structure
#[derive(Deserialize, Validate, ToSchema)]
pub struct UserBody { pub struct UserBody {
#[validate(length(min = 3, max = 50), custom(function = "validate_username"))]
pub username: String, pub username: String,
#[validate(email)]
pub email: String, pub email: String,
pub password_hash: String, #[validate(custom(function = "validate_password"))]
pub totp_secret: String, pub password: String,
pub role_id: i32, pub totp: Option<String>,
} }
// Add a new user // Define the response body structure
pub async fn post_user(State(pool): State<PgPool>, Json(user): Json<UserBody>, ) -> impl IntoResponse { #[derive(Serialize, ToSchema)]
pub struct UserResponse {
pub id: Uuid,
pub username: String,
pub email: String,
pub totp_secret: Option<String>,
pub role_level: i32,
}
// Define the API endpoint
#[utoipa::path(
post,
path = "/users",
tag = "user",
request_body = UserBody,
responses(
(status = 200, description = "User created successfully", body = UserResponse),
(status = 400, description = "Validation error", body = String),
(status = 500, description = "Internal server error", body = String)
)
)]
#[instrument(skip(pool, user))]
pub async fn post_user(
State(pool): State<PgPool>,
Json(user): Json<UserBody>,
) -> Result<impl IntoResponse, (StatusCode, Json<serde_json::Value>)> {
// Validate input
if let Err(errors) = user.validate() {
let error_messages: Vec<String> = errors
.field_errors()
.iter()
.flat_map(|(_, errors)| errors.iter().map(|e| e.message.clone().unwrap_or_default().to_string()))
.collect();
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": error_messages.join(", ") }))
));
}
// Hash the password before saving it
let hashed_password = hash_password(&user.password)
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": "Failed to hash password." }))))?;
// Generate TOTP secret if totp is Some("true")
let totp_secret = if user.totp.as_deref() == Some("true") {
Some(generate_totp_secret())
} else {
None
};
let row = sqlx::query!( let row = sqlx::query!(
"INSERT INTO users (username, email, password_hash, totp_secret, role_id) VALUES ($1, $2, $3, $4, $5) RETURNING id, username, email, password_hash, totp_secret, role_id", "INSERT INTO users (username, email, password_hash, totp_secret, role_level)
VALUES ($1, $2, $3, $4, $5)
RETURNING id, username, email, password_hash, totp_secret, role_level",
user.username, user.username,
user.email, user.email,
user.password_hash, hashed_password,
user.totp_secret, totp_secret,
user.role_id 1, // Default role_level
) )
.fetch_one(&pool) // Use `&pool` to borrow the connection pool .fetch_one(&pool)
.await; .await
.map_err(|_err| (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": "Could not create the user."}))))?;
match row { Ok(Json(UserResponse {
Ok(row) => Ok(Json(User { id: row.id,
id: row.id, username: row.username,
username: row.username, email: row.email,
email: row.email, totp_secret: row.totp_secret,
password_hash: row.password_hash, role_level: row.role_level,
totp_secret: row.totp_secret, }))
role_id: row.role_id, }
})),
Err(err) => Err((
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
format!("Error: {}", err),
)),
}
}

View File

@ -1,14 +1,27 @@
use axum::{Extension, Json, response::IntoResponse}; use axum::{Extension, Json, response::IntoResponse};
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use uuid::Uuid;
use crate::models::user::User; use crate::models::user::User;
use tracing::instrument;
use utoipa::ToSchema;
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize, ToSchema)]
struct UserResponse { struct UserResponse {
id: i32, id: Uuid,
username: String, username: String,
email: String email: String
} }
#[utoipa::path(
get,
path = "/protected",
tag = "protected",
responses(
(status = 200, description = "Protected endpoint accessed successfully", body = UserResponse),
(status = 401, description = "Unauthorized", body = String)
)
)]
#[instrument(skip(user))]
pub async fn protected(Extension(user): Extension<User>) -> impl IntoResponse { pub async fn protected(Extension(user): Extension<User>) -> impl IntoResponse {
Json(UserResponse { Json(UserResponse {
id: user.id, id: user.id,

View File

@ -0,0 +1,190 @@
use axum::{extract::{Extension, Path, State}, Json};
use axum::http::StatusCode;
use axum::response::IntoResponse;
use chrono::{Duration, NaiveDate, Utc};
use serde::{Deserialize, Serialize};
use serde_json::json;
use sqlx::postgres::PgPool;
use tracing::instrument;
use utoipa::ToSchema;
use uuid::Uuid;
use validator::Validate;
use crate::handlers::validate::validate_future_date;
use crate::middlewares::auth::{generate_api_key, hash_password};
use crate::models::user::User;
#[derive(Deserialize, Validate, ToSchema)]
pub struct ApiKeyBody {
#[validate(length(min = 0, max = 50))]
pub description: Option<String>,
#[validate(custom(function = "validate_future_date"))]
pub expiration_date: Option<String>,
}
#[derive(Serialize, ToSchema)]
pub struct ApiKeyResponse {
pub id: Uuid,
pub description: Option<String>,
}
#[utoipa::path(
post,
path = "/apikeys/rotate/{id}",
tag = "apikey",
request_body = ApiKeyBody,
responses(
(status = 200, description = "API key rotated successfully", body = ApiKeyResponse),
(status = 400, description = "Validation error", body = String),
(status = 404, description = "API key not found", body = String),
(status = 500, description = "Internal server error", body = String)
),
params(
("id" = String, Path, description = "API key identifier")
)
)]
#[instrument(skip(pool, user, apikeybody))]
pub async fn rotate_apikey(
State(pool): State<PgPool>,
Extension(user): Extension<User>,
Path(id): Path<String>,
Json(apikeybody): Json<ApiKeyBody>
) -> impl IntoResponse {
// Validate input
if let Err(errors) = apikeybody.validate() {
let error_messages: Vec<String> = errors
.field_errors()
.iter()
.flat_map(|(_, errors)| errors.iter().map(|e| e.message.clone().unwrap_or_default().to_string()))
.collect();
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": error_messages.join(", ") }))
));
}
// Validate UUID format
let uuid = match Uuid::parse_str(&id) {
Ok(uuid) => uuid,
Err(_) => return Err((StatusCode::BAD_REQUEST,
Json(json!({ "error": "Invalid API key identifier format" })))),
};
// Verify ownership of the API key
let existing_key = sqlx::query_as!(ApiKeyResponse,
"SELECT id, description FROM apikeys
WHERE user_id = $1 AND id = $2 AND disabled = FALSE",
user.id,
uuid
)
.fetch_optional(&pool)
.await
.map_err(|e| {
tracing::error!("Database error: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error" })))
})?;
let existing_key = existing_key.ok_or_else(||
(StatusCode::NOT_FOUND,
Json(json!({ "error": "API key not found or already disabled" })))
)?;
// Validate expiration date format
let expiration_date = match &apikeybody.expiration_date {
Some(date_str) => NaiveDate::parse_from_str(date_str, "%Y-%m-%d")
.map_err(|_| (StatusCode::BAD_REQUEST,
Json(json!({ "error": "Invalid expiration date format. Use YYYY-MM-DD" }))))?,
None => (Utc::now() + Duration::days(365 * 2)).naive_utc().date(),
};
// Validate expiration date is in the future
if expiration_date <= Utc::now().naive_utc().date() {
return Err((StatusCode::BAD_REQUEST,
Json(json!({ "error": "Expiration date must be in the future" }))));
}
// Generate new secure API key
let api_key = generate_api_key();
let key_hash = hash_password(&api_key)
.map_err(|e| {
tracing::error!("Hashing error: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error" })))
})?;
// Begin transaction
let mut tx = pool.begin().await.map_err(|e| {
tracing::error!("Transaction error: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error" })))
})?;
// Disable old key
let disable_result = sqlx::query!(
"UPDATE apikeys SET
disabled = TRUE,
expiration_date = CURRENT_DATE + INTERVAL '1 day'
WHERE id = $1 AND user_id = $2",
uuid,
user.id
)
.execute(&mut *tx)
.await
.map_err(|e| {
tracing::error!("Database error: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error" })))
})?;
if disable_result.rows_affected() == 0 {
return Err((StatusCode::NOT_FOUND,
Json(json!({ "error": "API key not found or already disabled" }))));
}
// Create new key with automatic description
let description = apikeybody.description.unwrap_or_else(||
format!("Rotated from key {} - {}",
existing_key.id,
Utc::now().format("%Y-%m-%d"))
);
let new_key = sqlx::query!(
"INSERT INTO apikeys
(key_hash, description, expiration_date, user_id, access_read, access_modify)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, description, expiration_date",
key_hash,
description,
expiration_date,
user.id,
true, // Default read access
false // Default no modify access
)
.fetch_one(&mut *tx)
.await
.map_err(|e| {
tracing::error!("Database error: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error" })))
})?;
tx.commit().await.map_err(|e| {
tracing::error!("Transaction commit error: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error" })))
})?;
Ok(Json(json!({
"id": new_key.id,
"api_key": api_key,
"description": new_key.description,
"expiration_date": new_key.expiration_date,
"warning": "Store this key securely - it won't be shown again",
"rotation_info": {
"original_key": existing_key.id,
"disabled_at": Utc::now().to_rfc3339()
}
})))
}