Extended the API security, added api documentation, added delete endpoint, more secure docker image.

This commit is contained in:
Rik Heijmann 2025-02-15 20:57:24 +01:00
parent 40ab25987c
commit 0d908ccfe8
70 changed files with 2945 additions and 2082 deletions

View File

@ -1,15 +1,7 @@
# ==============================
# 📌 DATABASE CONFIGURATION
# ⚙️ GENERAL CONFIGURATION
# ==============================
# PostgreSQL connection URL (format: postgres://user:password@host/database)
DATABASE_URL="postgres://postgres:1234@localhost/database_name"
# Maximum number of connections in the database pool
DATABASE_MAX_CONNECTIONS=20
# Minimum number of connections in the database pool
DATABASE_MIN_CONNECTIONS=5
ENVIRONMENT="development" # "production"
# ==============================
# 🌍 SERVER CONFIGURATION
@ -24,6 +16,24 @@ SERVER_PORT="3000"
# Enable tracing for debugging/logging (true/false)
SERVER_TRACE_ENABLED=true
# Amount of threads used to run the server
SERVER_WORKER_THREADS=2
# ==============================
# 🛢️ DATABASE CONFIGURATION
# ==============================
# PostgreSQL connection URL (format: postgres://user:password@host/database)
DATABASE_URL="postgres://postgres:1234@localhost/database_name"
# Maximum number of connections in the database pool
DATABASE_MAX_CONNECTIONS=20
# Minimum number of connections in the database pool
DATABASE_MIN_CONNECTIONS=5
# ==============================
# 🔒 HTTPS CONFIGURATION
# ==============================
@ -40,6 +50,7 @@ SERVER_HTTPS_CERT_FILE_PATH=cert.pem
# Path to the SSL private key file (only used if SERVER_HTTPS_ENABLED=true)
SERVER_HTTPS_KEY_FILE_PATH=key.pem
# ==============================
# 🚦 RATE LIMIT CONFIGURATION
# ==============================
@ -50,6 +61,7 @@ SERVER_RATE_LIMIT=5
# Time period (in seconds) for rate limiting
SERVER_RATE_LIMIT_PERIOD=1
# ==============================
# 📦 COMPRESSION CONFIGURATION
# ==============================
@ -60,9 +72,10 @@ SERVER_COMPRESSION_ENABLED=true
# Compression level (valid range: 0-11, where 11 is the highest compression)
SERVER_COMPRESSION_LEVEL=6
# ==============================
# 🔑 AUTHENTICATION CONFIGURATION
# ==============================
# Argon2 salt for password hashing (must be kept secret!)
AUTHENTICATION_ARGON2_SALT="dMjQgtSmoQIH3Imi"
# JWT secret key.
JWT_SECRET_KEY="fgr4fe34w2rfTwfe3444234edfewfw4e#f$#wferg23w2DFSdf"

View File

@ -1,316 +0,0 @@
{
"name": "Axium",
"version": "1",
"items": [
{
"type": "http",
"name": "Health",
"seq": 3,
"request": {
"url": "{{base_url}}/health",
"method": "GET",
"headers": [],
"params": [],
"body": {
"mode": "none",
"formUrlEncoded": [],
"multipartForm": []
},
"script": {},
"vars": {},
"assertions": [],
"tests": "",
"docs": "",
"auth": {
"mode": "none"
}
}
},
{
"type": "http",
"name": "Protected",
"seq": 4,
"request": {
"url": "{{base_url}}/protected",
"method": "GET",
"headers": [
{
"name": "Authorization",
"value": "Bearer {{token}}",
"enabled": true
}
],
"params": [],
"body": {
"mode": "json",
"json": "",
"formUrlEncoded": [],
"multipartForm": []
},
"script": {},
"vars": {},
"assertions": [],
"tests": "",
"docs": "",
"auth": {
"mode": "none"
}
}
},
{
"type": "http",
"name": "Sign-in",
"seq": 1,
"request": {
"url": "{{base_url}}/signin",
"method": "POST",
"headers": [],
"params": [],
"body": {
"mode": "json",
"json": "{\n \"email\":\"user@test.com\",\n \"password\":\"test\"\n}",
"formUrlEncoded": [],
"multipartForm": []
},
"script": {},
"vars": {},
"assertions": [],
"tests": "",
"docs": "",
"auth": {
"mode": "none"
}
}
},
{
"type": "folder",
"name": "To do's",
"root": {
"request": {
"headers": [
{
"name": "Authorization",
"value": "Bearer {{token}}",
"enabled": true,
"uid": "PRiX2eBEKKPlsc1xxRHeN"
}
]
},
"meta": {
"name": "To do's"
}
},
"items": [
{
"type": "http",
"name": "Get all",
"seq": 1,
"request": {
"url": "{{base_url}}/todos/all",
"method": "GET",
"headers": [],
"params": [],
"body": {
"mode": "none",
"formUrlEncoded": [],
"multipartForm": []
},
"script": {},
"vars": {},
"assertions": [],
"tests": "",
"docs": "",
"auth": {
"mode": "none"
}
}
},
{
"type": "http",
"name": "Get by ID",
"seq": 2,
"request": {
"url": "{{base_url}}/todos/1",
"method": "GET",
"headers": [],
"params": [],
"body": {
"mode": "none",
"formUrlEncoded": [],
"multipartForm": []
},
"script": {},
"vars": {},
"assertions": [],
"tests": "",
"docs": "",
"auth": {
"mode": "none"
}
}
},
{
"type": "http",
"name": "Post new",
"seq": 3,
"request": {
"url": "{{base_url}}/todos/",
"method": "POST",
"headers": [],
"params": [],
"body": {
"mode": "json",
"json": "{\n \"task\": \"Finish Rust project.\",\n \"description\": \"Complete the API endpoints for the todo app.\"\n}",
"formUrlEncoded": [],
"multipartForm": []
},
"script": {},
"vars": {},
"assertions": [],
"tests": "",
"docs": "",
"auth": {
"mode": "none"
}
}
}
]
},
{
"type": "folder",
"name": "Users",
"root": {
"request": {
"headers": [
{
"name": "Authorization",
"value": "Bearer {{token}}",
"enabled": true,
"uid": "Dv1ZS2orRQaKpVNKRBmLf"
}
]
},
"meta": {
"name": "Users"
}
},
"items": [
{
"type": "http",
"name": "Get all",
"seq": 1,
"request": {
"url": "{{base_url}}/users/all",
"method": "GET",
"headers": [
{
"name": "",
"value": "",
"enabled": true
}
],
"params": [],
"body": {
"mode": "none",
"formUrlEncoded": [],
"multipartForm": []
},
"script": {},
"vars": {},
"assertions": [],
"tests": "",
"docs": "",
"auth": {
"mode": "none"
}
}
},
{
"type": "http",
"name": "Post new",
"seq": 3,
"request": {
"url": "{{base_url}}/users/",
"method": "POST",
"headers": [],
"params": [],
"body": {
"mode": "json",
"json": "{\n \"username\": \"MyNewUser\",\n \"email\": \"MyNewUser@test.com\",\n \"password\": \"MyNewUser\",\n \"totp\": \"true\"\n}",
"formUrlEncoded": [],
"multipartForm": []
},
"script": {},
"vars": {},
"assertions": [],
"tests": "",
"docs": "",
"auth": {
"mode": "none"
}
}
},
{
"type": "http",
"name": "Get by ID",
"seq": 2,
"request": {
"url": "{{base_url}}/users/1",
"method": "GET",
"headers": [],
"params": [],
"body": {
"mode": "none",
"formUrlEncoded": [],
"multipartForm": []
},
"script": {},
"vars": {},
"assertions": [],
"tests": "",
"docs": "",
"auth": {
"mode": "none"
}
}
}
]
}
],
"activeEnvironmentUid": "6LVIlBNVHmWamnS5xdrf0",
"environments": [
{
"variables": [
{
"name": "base_url",
"value": "http://127.0.0.1:3000",
"enabled": true,
"secret": false,
"type": "text"
},
{
"name": "token",
"value": "",
"enabled": true,
"secret": true,
"type": "text"
}
],
"name": "Default"
}
],
"root": {
"request": {
"vars": {}
}
},
"brunoConfig": {
"version": "1",
"name": "Axium",
"type": "collection",
"ignore": [
"node_modules",
".git"
]
}
}

View File

@ -7,12 +7,15 @@ edition = "2021"
# Web framework and server
axum = { version = "0.8.1", features = ["json"] }
# hyper = { version = "1.5.2", features = ["full"] }
axum-server = { version = "0.7", features = ["tls-rustls"] }
# Database interaction
sqlx = { version = "0.8.3", features = ["runtime-tokio-rustls", "postgres", "migrate", "uuid", "chrono"] }
uuid = { version = "1.12.1", features = ["serde"] }
rand = "0.8.5"
rand_core = "0.6.4" # 2024-2-3: SQLx 0.8.3 does not support 0.9.
moka = { version = "0.12.10", features = ["future"] }
lazy_static = "1.5"
# Serialization and deserialization
serde = { version = "1.0.217", features = ["derive"] }
@ -24,9 +27,10 @@ argon2 = "0.5.3"
totp-rs = { version = "5.6.0", features = ["gen_secret"] }
base64 = "0.22.1"
bcrypt = "0.17.0"
futures = "0.3.31"
# Asynchronous runtime and traits
tokio = { version = "1.43.0", features = ["rt-multi-thread", "process"] }
tokio = { version = "1.43.0", features = ["rt-multi-thread", "process", "signal"] }
# Configuration and environment
dotenvy = "0.15.7"
@ -53,6 +57,7 @@ rustls-pemfile = "2.2.0"
# Input validation
validator = { version = "0.20.0", features = ["derive"] }
regex = "1.11.1"
thiserror = "1.0"
# Documentation
utoipa = { version = "5.3.1", features = ["axum_extras", "chrono", "uuid"] }

View File

@ -1,73 +1,56 @@
# --- Stage 1: Builder Stage ---
FROM rust:1.75-slim-bookworm AS builder
# syntax=docker/dockerfile:1
# Comments are provided throughout this file to help you get started.
# If you need more help, visit the Dockerfile reference guide at
# https://docs.docker.com/engine/reference/builder/
################################################################################
# Create a stage for building the application.
ARG RUST_VERSION=1.78.0
ARG APP_NAME=backend
FROM rust:${RUST_VERSION}-slim-bullseye AS build
ARG APP_NAME
WORKDIR /app
# Build the application.
# Leverage a cache mount to /usr/local/cargo/registry/
# for downloaded dependencies and a cache mount to /app/target/ for
# compiled dependencies which will speed up subsequent builds.
# Leverage a bind mount to the src directory to avoid having to copy the
# source code into the container. Once built, copy the executable to an
# output directory before the cache mounted /app/target is unmounted.
RUN --mount=type=bind,source=src,target=src \
# --mount=type=bind,source=configuration.yaml,target=configuration.yaml \
--mount=type=bind,source=Cargo.toml,target=Cargo.toml \
--mount=type=bind,source=Cargo.lock,target=Cargo.lock \
--mount=type=cache,target=/app/target/ \
--mount=type=cache,target=/usr/local/cargo/registry/ \
<<EOF
set -e
cargo build --locked --release
cp ./target/release/$APP_NAME /bin/server
EOF
# COPY /src/web /bin/web
################################################################################
# Create a new stage for running the application that contains the minimal
# runtime dependencies for the application. This often uses a different base
# image from the build stage where the necessary files are copied from the build
# stage.
#
# The example below uses the debian bullseye image as the foundation for running the app.
# By specifying the "bullseye-slim" tag, it will also use whatever happens to be the
# most recent version of that tag when you build your Dockerfile. If
# reproducability is important, consider using a digest
# (e.g., debian@sha256:ac707220fbd7b67fc19b112cee8170b41a9e97f703f588b2cdbbcdcecdd8af57).
FROM debian:bullseye-slim AS final
# Create a non-privileged user that the app will run under.
# See https://docs.docker.com/develop/develop-images/dockerfile_best-practices/#user
ARG UID=10001
RUN adduser \
--disabled-password \
--gecos "" \
--home "/nonexistent" \
--shell "/sbin/nologin" \
--no-create-home \
--uid "${UID}" \
appuser
USER appuser
# Copy the executable from the "build" stage.
COPY --from=build /bin/server /bin/
# Expose the port that the application listens on.
EXPOSE 80
# What the container should run when it is started.
CMD ["/bin/server"]
WORKDIR /app
# Install required build dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
pkg-config \
libssl-dev \
&& rm -rf /var/lib/apt/lists/*
# Cache dependencies
COPY Cargo.toml Cargo.lock ./
RUN cargo fetch --locked
# Copy source code
COPY src src/
COPY build.rs build.rs
# Build the application in release mode
RUN cargo build --release --locked
# Strip debug symbols to reduce binary size
RUN strip /app/target/release/Axium
# --- Stage 2: Runtime Stage ---
FROM debian:bookworm-slim
# Install runtime dependencies only
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
openssl \
&& rm -rf /var/lib/apt/lists/*
# Create non-root user
RUN useradd --no-log-init -r -m -u 1001 appuser
WORKDIR /app
# Copy built binary from builder stage
COPY --from=builder /app/target/release/Axium .
# Copy environment file (consider secrets management for production)
COPY .env .env
# Change ownership to non-root user
RUN chown -R appuser:appuser /app
USER appuser
# Expose the application port
EXPOSE 3000
# Run the application
CMD ["./Axium"]

297
README.md
View File

@ -1,49 +1,94 @@
# 🦀 Axium
**An example API built with Rust, Axum, SQLx, and PostgreSQL**
# 🦖 Axium
**An example API built with Rust, Axum, SQLx, and PostgreSQL.**
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
## 🚀 Core Features
- **Rust API template** - Production-ready starter template with modern practices,
- **PostgreSQL integration** - Full database support with SQLx migrations,
- **Easy to secure** - HTTP/2 with secure TLS defaults (AWS-LC, FIPS 140-3),
- **Easy to configure** - `.env` and environment variables,
- **JWT authentication** - Secure token-based auth with Argon2 password hashing,
- **Optimized for performance** - Brotli compression,
- **Comprehensive health monitoring**
Docker-compatible endpoint with system metrics:
```json
{
"details": {
"cpu_usage": {"available_percentage": "9.85", "status": "low"},
"database": {"status": "ok"},
"disk_usage": {"status": "ok", "used_percentage": "74.00"},
"memory": {"available_mb": 21613, "status": "normal"}
},
"status": "degraded"
### **Production-Grade Foundation**
_Jumpstart secure API development_
- Battle-tested Rust template following industry best practices
- Built-in scalability patterns for high-traffic environments
### **Effortless Deployment**
_From zero to production in minutes_
- 🐳 Docker Compose stack with pre-configured services
- 20-minute setup timeline with `docker-compose up` simplicity
### **Developer-First API Experience**
_Spec-driven development workflow_
- Auto-generated OpenAPI 3.1 specifications
- Interactive Swagger UI endpoint at `/docs`
```rust
// Endpoint registration example
.route("/docs", get(serve_swagger_ui))
```
### **Enterprise-Grade Security**
_Security by design architecture_
- JWT authentication with Argon2id password hashing (OWASP recommended)
- TLS 1.3/HTTP2 via AWS-LC (FIPS 140-3 compliant cryptography)
- Role-Based Access Control (RBAC) implementation:
```rust
.layer(middleware::from_fn(|req, next|
authorize(req, next, vec![1, 2]) // Admin+Mod roles
))
```
### **PostgreSQL Integration**
_Relational data made simple_
- SQLx-powered async database operations
- Migration system with transactional safety
- Connection pooling for high concurrency
### **Performance Optimizations**
_Engineered for speed at scale_
- Brotli compression (11-level optimization)
- Intelligent request caching strategies
- Zero-copy deserialization pipelines
### **Operational Visibility**
_Production monitoring made easy_
- Docker-healthcheck compatible endpoint:
```json
{
"status": "degraded",
"details": {
"database": {"status": "ok"},
"memory": {"available_mb": 21613, "status": "normal"},
"cpu_usage": {"available_percentage": "9.85", "status": "low"},
"disk_usage": {"used_percentage": "74.00", "status": "ok"}
}
```
- **Granular access control** - Role-based endpoint protection:
```rust
.route("/", post(post_todo).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1, 2];
authorize(req, next, allowed_roles)
})))
```
- **User context injection** - Automatic user profile handling in endpoints:
```rust
pub async fn post_todo(
Extension(user): Extension<User>, // Injected user
Json(todo): Json<TodoBody>
) -> impl IntoResponse {
if todo.user_id != user.id {
return Err((StatusCode::FORBIDDEN, Json(json!({
"error": "Cannot create todos for others"
}))));
}
```
- **Observability** - Integrated tracing,
- **Documented codebase** - Extensive inline comments for easy modification and readability,
- **Latest dependencies** - Regularly updated Rust ecosystem crates,
}
```
### **Developer Ergonomics**
_Code with confidence_
- Context-aware user injection system:
```rust
async fn create_todo(
Extension(User { id, role, .. }): Extension<User>, // Auto-injected
Json(payload): Json<TodoRequest>
) -> Result<impl IntoResponse> {
// Business logic with direct user context
}
```
- Structured logging with OpenTelemetry integration
- Compile-time configuration validation
### **Maintenance & Compliance**
_Future-proof codebase management_
- Automated dependency updates via Dependabot
- Security-focused dependency tree (cargo-audit compliant)
- Comprehensive inline documentation:
```rust
/// JWT middleware - Validates Authorization header
/// # Arguments
/// * `req` - Incoming request
/// * `next` - Next middleware layer
/// # Security
/// - Validates Bearer token format
/// - Checks token expiration
/// - Verifies cryptographic signature
```
## 🛠️ Technology stack
| Category | Key Technologies |
@ -55,53 +100,64 @@
## 📂 Project structure
```
Axium/
├── migrations/ # SQL schema migrations. Creates the required tables and inserts demo data.
├── src/
│ ├── core/ # Core modules: for reading configuration files, starting the server and configuring HTTPS/
│ ├── database/ # Database connectivity, getters and setters for the database.
│ ├── middlewares/ # Currently just the authentication system.
│ ├── models/ # Data structures
│ └── routes/ # API endpoints
│ └── mod.rs # API endpoint router.
│ └── .env # Configuration file.
└── Dockerfile # Builds a docker container for the application.
└── compose.yaml # Docker-compose.yaml. Runs container for the application (also includes a PostgreSQL-container).
axium-api/ # Root project directory
├── 📁 migrations/ # Database schema migrations (SQLx)
├── 📁 src/ # Application source code
│ ├── 📁 core/ # Core application infrastructure
│ │ ├── config.rs # Configuration loader (.env, env vars)
│ │ └── server.rs # HTTP/HTTPS server initialization
│ │
│ ├── 📁 database/ # Database access layer
│ │ ├── connection.rs # Connection pool management
│ │ ├── queries/ # SQL query modules
│ │ └── models.rs # Database entity definitions
│ │
│ ├── 📁 middlewares/ # Axum middleware components
│ ├── 📁 routes/ # API endpoint routing
│ │ └── mod.rs # Route aggregator
│ │
│ ├── 📁 handlers/ # Request handlers
│ │
│ ├── 📁 utils/ # Common utilities
│ │
│ └── main.rs # Application entry point
├── 📄 .env # Environment configuration
├── 📄 .env.example # Environment template
├── 📄 Dockerfile # Production container build
├── 📄 docker-compose.yml # Local development stack
└── 📄 Cargo.toml # Rust dependencies & metadata
```
Each folder has a detailed README.md file which explains the folder in more detail.
## 🌐 Default API endpoints
| Method | Endpoint | Auth Required | Allowed Roles | Description |
|--------|------------------------|---------------|---------------|--------------------------------------|
| POST | `/signin` | No | | Authenticate user and get JWT token |
| GET | `/protected` | Yes | 1, 2 | Test endpoint for authenticated users |
| GET | `/health` | No | | System health check with metrics |
| | | | | |
| **User routes** | | | | |
| GET | `/users/all` | No* | | Get all users |
| GET | `/users/{id}` | No* | | Get user by ID |
| POST | `/users/` | No* | | Create new user |
| | | | | |
| **Todo routes** | | | | |
| GET | `/todos/all` | No* | | Get all todos |
| POST | `/todos/` | Yes | 1, 2 | Create new todo |
| GET | `/todos/{id}` | No* | | Get todo by ID |
**Key:**
🔒 = Requires JWT in `Authorization: Bearer <token>` header
\* Currently unprotected - recommend adding authentication for production
**Roles:** 1 = User, 2 = Administrator
**Security notes:**
- All POST endpoints expect JSON payloads
- User creation endpoint should be protected in production
- Consider adding rate limiting to authentication endpoints
**Notes:**
- 🔒 = Requires JWT in `Authorization: Bearer <token>` header
- Roles: `1` = Regular User, `2` = Administrator
- *Marked endpoints currently unprotected - recommend adding middleware for production use
- All POST endpoints expect JSON payloads
| Method | Endpoint | Auth Required | Administrator only | Description |
|--------|------------------------|---------------|-------------------|--------------------------------------|
| POST | `/signin` | 🚫 | 🚫 | Authenticate user and get JWT token |
| GET | `/protected` | ✅ | 🚫 | Test endpoint for authenticated users |
| GET | `/health` | 🚫 | 🚫 | System health check with metrics |
| | | | | |
| **Apikey routes** | | | | |
| GET | `/apikeys/all` | ✅ | ✅ | Get all apikeys of the current user. |
| POST | `/apikeys/` | ✅ | ✅ | Create a new apikey. |
| GET | `/apikeys/{id}` | ✅ | ✅ | Get an apikey by ID. |
| DELETE | `/apikeys/{id}` | ✅ | 🚫 | Delete an apikey by ID. |
| POST | `/apikeys/rotate/{id}` | ✅ | 🚫 | Rotates an API key, disables the old one (grace period 24 hours), returns a new one. |
| | | | | |
| **User routes** | | | | |
| GET | `/users/all` | ✅ | ✅ | Get all users. |
| POST | `/users/` | ✅ | ✅ | Create a new user. |
| GET | `/users/{id}` | ✅ | ✅ | Get a user by ID. |
| DELETE | `/users/{id}` | ✅ | ✅ | Delete a user by ID. |
| | | | | |
| **Todo routes** | | | | |
| GET | `/todos/all` | ✅ | 🚫 | Get all todos of the current user. |
| POST | `/todos/` | ✅ | 🚫 | Create a new todo. |
| GET | `/todos/{id}` | ✅ | 🚫 | Get a todo by ID. |
| DELETE | `/todos/{id}` | ✅ | 🚫 | Delete a todo by ID. |
## 📦 Installation & Usage
```bash
@ -126,27 +182,47 @@ cargo run --release
| `admin@test.com` | `test` | Administrator |
⚠️ **Security recommendations:**
1. Rotate passwords immediately after initial setup
2. Disable default accounts before deploying to production
3. Implement proper user management endpoints
1. Rotate passwords immediately after initial setup.
2. Disable default accounts before deploying to production.
3. Implement proper user management endpoints.
#### Administrative password resets
*For emergency access recovery only*
1. **Database Access**
Connect to PostgreSQL using privileged credentials:
```bash
psql -U admin_user -d axium_db -h localhost
```
2. **Secure Hash Generation**
Use the integrated CLI tool (never online generators):
```bash
cargo run --bin argon2-cli -- "new_password"
# Output: $argon2id$v=19$m=19456,t=2,p=1$b2JqZWN0X2lkXzEyMzQ1$R7Zx7Y4W...
```
3. **Database Update**
```sql
UPDATE users
SET
password_hash = '$argon2id...',
updated_at = NOW()
WHERE email = 'user@example.com';
```
4. **Verification**
- Immediately test new credentials
- Force user password change on next login
### ⚙️ Configuration
Create a .env file in the root of the project or configure the application using environment variables.
```env
# ==============================
# 📌 DATABASE CONFIGURATION
# ⚙️ GENERAL CONFIGURATION
# ==============================
# PostgreSQL connection URL (format: postgres://user:password@host/database)
DATABASE_URL="postgres://postgres:1234@localhost/database_name"
# Maximum number of connections in the database pool
DATABASE_MAX_CONNECTIONS=20
# Minimum number of connections in the database pool
DATABASE_MIN_CONNECTIONS=5
ENVIRONMENT="development" # "production"
# ==============================
# 🌍 SERVER CONFIGURATION
@ -161,6 +237,24 @@ SERVER_PORT="3000"
# Enable tracing for debugging/logging (true/false)
SERVER_TRACE_ENABLED=true
# Amount of threads used to run the server
SERVER_WORKER_THREADS=2
# ==============================
# 🛢️ DATABASE CONFIGURATION
# ==============================
# PostgreSQL connection URL (format: postgres://user:password@host/database)
DATABASE_URL="postgres://postgres:1234@localhost/database_name"
# Maximum number of connections in the database pool
DATABASE_MAX_CONNECTIONS=20
# Minimum number of connections in the database pool
DATABASE_MIN_CONNECTIONS=5
# ==============================
# 🔒 HTTPS CONFIGURATION
# ==============================
@ -177,6 +271,7 @@ SERVER_HTTPS_CERT_FILE_PATH=cert.pem
# Path to the SSL private key file (only used if SERVER_HTTPS_ENABLED=true)
SERVER_HTTPS_KEY_FILE_PATH=key.pem
# ==============================
# 🚦 RATE LIMIT CONFIGURATION
# ==============================
@ -187,6 +282,7 @@ SERVER_RATE_LIMIT=5
# Time period (in seconds) for rate limiting
SERVER_RATE_LIMIT_PERIOD=1
# ==============================
# 📦 COMPRESSION CONFIGURATION
# ==============================
@ -197,10 +293,11 @@ SERVER_COMPRESSION_ENABLED=true
# Compression level (valid range: 0-11, where 11 is the highest compression)
SERVER_COMPRESSION_LEVEL=6
# ==============================
# 🔑 AUTHENTICATION CONFIGURATION
# ==============================
# Argon2 salt for password hashing (must be kept secret!)
AUTHENTICATION_ARGON2_SALT="dMjQgtSmoQIH3Imi"
# JWT secret key.
JWT_SECRET_KEY="fgr4fe34w2rfTwfe3444234edfewfw4e#f$#wferg23w2DFSdf"
```

View File

@ -1,32 +0,0 @@
services:
server:
build:
context: .
target: final
ports:
- 80:80
depends_on:
- db_image
networks:
- common-net
db_image:
image: postgres:latest
environment:
POSTGRES_PORT: 3306
POSTGRES_DATABASE: database_name
POSTGRES_USER: user
POSTGRES_PASSWORD: database_password
POSTGRES_ROOT_PASSWORD: strong_database_password
expose:
- 3306
ports:
- "3307:3306"
networks:
- common-net
networks:
common-net: {}

53
docker-compose.yml Normal file
View File

@ -0,0 +1,53 @@
version: "3.9"
services:
axium:
build:
context: .
dockerfile: Dockerfile
ports:
- "3000:3000"
environment:
- ENVIRONMENT=${ENVIRONMENT:-development} #default value if not defined.
- SERVER_IP=${SERVER_IP:-0.0.0.0}
- SERVER_PORT=${SERVER_PORT:-3000}
- SERVER_TRACE_ENABLED=${SERVER_TRACE_ENABLED:-true}
- SERVER_WORKER_THREADS=${SERVER_WORKER_THREADS:-2}
- DATABASE_URL=${DATABASE_URL:-postgres://postgres:1234@db/database_name}
- DATABASE_MAX_CONNECTIONS=${DATABASE_MAX_CONNECTIONS:-20}
- DATABASE_MIN_CONNECTIONS=${DATABASE_MIN_CONNECTIONS:-5}
- SERVER_HTTPS_ENABLED=${SERVER_HTTPS_ENABLED:-false}
- SERVER_HTTPS_HTTP2_ENABLED=${SERVER_HTTPS_HTTP2_ENABLED:-true}
# Mount volume for certs for HTTPS
- SERVER_HTTPS_CERT_FILE_PATH=/app/certs/cert.pem # Changed to /app/certs
- SERVER_HTTPS_KEY_FILE_PATH=/app/certs/key.pem # Changed to /app/certs
- SERVER_RATE_LIMIT=${SERVER_RATE_LIMIT:-5}
- SERVER_RATE_LIMIT_PERIOD=${SERVER_RATE_LIMIT_PERIOD:-1}
- SERVER_COMPRESSION_ENABLED=${SERVER_COMPRESSION_ENABLED:-true}
- SERVER_COMPRESSION_LEVEL=${SERVER_COMPRESSION_LEVEL:-6}
- JWT_SECRET_KEY=${JWT_SECRET_KEY:-fgr4fe34w2rfTwfe3444234edfewfw4e#f$#wferg23w2DFSdf} #VERY important to change this!
depends_on:
- db # Ensure the database is up before the app
volumes:
- ./certs:/app/certs # Mount volume for certs
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:3000/health"]
interval: 10s
timeout: 5s
retries: 3
start_period: 15s
db:
image: postgres:16-alpine
restart: always
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: "1234" # Change this in production!
POSTGRES_DB: database_name # Matches the DB name in .env
ports:
- "5432:5432"
volumes:
- db_data:/var/lib/postgresql/data
volumes:
db_data:

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

File diff suppressed because one or more lines are too long

View File

@ -1 +0,0 @@
openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes -subj "/CN=localhost" -addext "subjectAltName = DNS:localhost, IP:127.0.0.1"

9
src/core/README.md Normal file
View File

@ -0,0 +1,9 @@
# Core
The `core` module contains the fundamental components for setting up and configuring the API backend. It includes server creation, environment configuration management, and middleware layers that enhance the overall performance and observability of the API.
## Contributing
Ensure new middleware is well-documented, includes error handling, and integrates with the existing architecture.
## License
This project is licensed under the MIT License.

View File

@ -1,4 +1,3 @@
pub mod config;
pub mod server;
pub mod tls;

View File

@ -23,7 +23,7 @@ pub async fn create_server() -> Router {
// Enable tracing middleware if configured
if config::get_env_bool("SERVER_TRACE_ENABLED", true) {
app = app.layer(TraceLayer::new_for_http());
println!("✔️ Trace hads been enabled.");
println!("✔️ Trace hads been enabled.");
}
// Enable compression middleware if configured
@ -32,7 +32,7 @@ pub async fn create_server() -> Router {
let level = config::get_env("SERVER_COMPRESSION_LEVEL").parse().unwrap_or(6);
// Apply compression layer with Brotli (br) enabled and the specified compression level
app = app.layer(CompressionLayer::new().br(true).quality(CompressionLevel::Precise(level)));
println!("✔️ Brotli compression enabled with compression quality level {}.", level);
println!("✔️ Brotli compression enabled with compression quality level {}.", level);
}

View File

@ -1,145 +0,0 @@
// Standard library imports
use std::{
future::Future,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
io::BufReader,
fs::File,
iter,
};
// External crate imports
use axum::serve::Listener;
use rustls::{self, server::ServerConfig, pki_types::{PrivateKeyDer, CertificateDer}};
use rustls_pemfile::{Item, read_one, certs};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing;
// Local crate imports
use crate::config; // Import env config helper
// Function to load TLS configuration from files
pub fn load_tls_config() -> ServerConfig {
// Get certificate and key file paths from the environment
let cert_path = config::get_env("SERVER_HTTPS_CERT_FILE_PATH");
let key_path = config::get_env("SERVER_HTTPS_KEY_FILE_PATH");
// Open the certificate and key files
let cert_file = File::open(cert_path).expect("❌ Failed to open certificate file.");
let key_file = File::open(key_path).expect("❌ Failed to open private key file.");
// Read the certificate chain and private key from the files
let mut cert_reader = BufReader::new(cert_file);
let mut key_reader = BufReader::new(key_file);
// Read and parse the certificate chain
let cert_chain: Vec<CertificateDer> = certs(&mut cert_reader)
.map(|cert| cert.expect("❌ Failed to read certificate."))
.map(CertificateDer::from)
.collect();
// Ensure certificates are found
if cert_chain.is_empty() {
panic!("❌ No valid certificates found.");
}
// Read the private key from the file
let key = iter::from_fn(|| read_one(&mut key_reader).transpose())
.find_map(|item| match item.unwrap() {
Item::Pkcs1Key(key) => Some(PrivateKeyDer::from(key)),
Item::Pkcs8Key(key) => Some(PrivateKeyDer::from(key)),
Item::Sec1Key(key) => Some(PrivateKeyDer::from(key)),
_ => None,
})
.expect("❌ Failed to read a valid private key.");
// Build and return the TLS server configuration
ServerConfig::builder()
.with_no_client_auth() // No client authentication
.with_single_cert(cert_chain, key) // Use the provided cert and key
.expect("❌ Failed to create TLS configuration.")
}
// Custom listener that implements axum::serve::Listener
#[derive(Clone)]
pub struct TlsListener {
pub inner: Arc<tokio::net::TcpListener>, // Inner TCP listener
pub acceptor: tokio_rustls::TlsAcceptor, // TLS acceptor for handling TLS handshakes
}
impl Listener for TlsListener {
type Io = TlsStreamWrapper; // Type of I/O stream
type Addr = SocketAddr; // Type of address (Socket address)
// Method to accept incoming connections and establish a TLS handshake
fn accept(&mut self) -> impl Future<Output = (Self::Io, Self::Addr)> + Send {
let acceptor = self.acceptor.clone(); // Clone the acceptor for async use
async move {
loop {
// Accept a TCP connection
let (stream, addr) = match self.inner.accept().await {
Ok((stream, addr)) => (stream, addr),
Err(e) => {
tracing::error!("❌ Error accepting TCP connection: {}", e);
continue; // Retry on error
}
};
// Perform TLS handshake
match acceptor.accept(stream).await {
Ok(tls_stream) => {
tracing::info!("Successful TLS handshake with {}.", addr);
return (TlsStreamWrapper(tls_stream), addr); // Return TLS stream and address
},
Err(e) => {
tracing::warn!("TLS handshake failed: {} (Client may not trust certificate).", e);
continue; // Retry on error
}
}
}
}
}
// Method to retrieve the local address of the listener
fn local_addr(&self) -> std::io::Result<Self::Addr> {
self.inner.local_addr()
}
}
// Wrapper for a TLS stream, implementing AsyncRead and AsyncWrite
#[derive(Debug)]
pub struct TlsStreamWrapper(tokio_rustls::server::TlsStream<tokio::net::TcpStream>);
impl AsyncRead for TlsStreamWrapper {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf) // Delegate read operation to the underlying TLS stream
}
}
impl AsyncWrite for TlsStreamWrapper {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf) // Delegate write operation to the underlying TLS stream
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx) // Flush operation for the TLS stream
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx) // Shutdown operation for the TLS stream
}
}
// Allow the TLS stream wrapper to be used in non-blocking contexts (needed for async operations)
impl Unpin for TlsStreamWrapper {}

29
src/database/README.md Normal file
View File

@ -0,0 +1,29 @@
# Database
This folder contains the database interaction layer for Axium, handling database connections, migrations, and queries related to API keys and usage metrics.
## Overview
The `/src/database` folder includes functions for inserting, retrieving, modifying, and deleting API keys, along with usage tracking and database connection management.
### Key Components
- **SQLx:** Asynchronous database operations for PostgreSQL.
- **Chrono:** Date and time manipulation.
- **UUID:** Handling unique identifiers for users and keys.
- **Dotenvy:** Securely loads environment variables.
- **ThisError:** Provides structured error handling.
## Usage
Database functions are called by route handlers for secure data operations. Ensure environment variables like `DATABASE_URL` are properly configured before running the API.
## Dependencies
- [SQLx](https://docs.rs/sqlx/latest/sqlx/)
- [Chrono](https://docs.rs/chrono/latest/chrono/)
- [UUID](https://docs.rs/uuid/latest/uuid/)
- [Dotenvy](https://docs.rs/dotenvy/latest/dotenvy/)
- [ThisError](https://docs.rs/thiserror/latest/thiserror/)
## Contributing
Ensure database queries are secure, optimized, and well-documented. Validate all user inputs before performing database operations.
## License
This project is licensed under the MIT License.

232
src/database/apikeys.rs Normal file
View File

@ -0,0 +1,232 @@
use chrono::NaiveDate;
use sqlx::postgres::PgPool;
use uuid::Uuid;
use crate::models::apikey::{ApiKeyResponse, ApiKeyByIDResponse, ApiKeyByUserIDResponse, ApiKeyInsertResponse, ApiKeyGetActiveForUserResponse};
// ---------------------------
// Key Creation Functions
// ---------------------------
/// Inserts a new API key into the database for the specified user.
///
/// # Parameters
/// - `pool`: PostgreSQL connection pool
/// - `key_hash`: SHA-256 hash of the generated API key
/// - `description`: Human-readable key description
/// - `expiration_date`: Optional key expiration date
/// - `user_id`: Owner's user ID
///
/// # Returns
/// `ApiKeyInsertResponse` with metadata (actual key not stored in DB)
///
/// # Security
/// - Uses parameterized queries to prevent SQL injection
/// - Caller must validate inputs before invocation
pub async fn insert_api_key_into_db(
pool: &PgPool,
key_hash: String,
description: String,
expiration_date: NaiveDate,
user_id: Uuid,
) -> Result<ApiKeyInsertResponse, sqlx::Error> {
let row = sqlx::query!(
r#"
INSERT INTO apikeys (key_hash, description, expiration_date, user_id)
VALUES ($1, $2, $3, $4)
RETURNING id, description, expiration_date
"#,
key_hash,
description,
expiration_date,
user_id
)
.fetch_one(pool)
.await?;
Ok(ApiKeyInsertResponse {
id: row.id,
api_key: "".to_string(), // Placeholder for post-processing
description: row.description.unwrap_or_default(),
expiration_date: row.expiration_date
.map(|d| d.to_string())
.unwrap_or_else(|| "Never".to_string()),
})
}
// ---------------------------
// Key Retrieval Functions
// ---------------------------
/// Retrieves all API keys (including revoked/expired) for a user
///
/// # Security
/// - Always filters by user_id to prevent cross-user access
pub async fn fetch_all_apikeys_from_db(
pool: &PgPool,
user_id: Uuid
) -> Result<Vec<ApiKeyResponse>, sqlx::Error> {
sqlx::query_as!(
ApiKeyResponse,
r#"
SELECT id, user_id, description, expiration_date, creation_date
FROM apikeys
WHERE user_id = $1
"#,
user_id
)
.fetch_all(pool)
.await
}
/// Gets detailed metadata for a specific API key
///
/// # Security
/// - Verifies both key ID and user_id ownership
pub async fn fetch_apikey_by_id_from_db(
pool: &PgPool,
id: Uuid,
user_id: Uuid
) -> Result<Option<ApiKeyByIDResponse>, sqlx::Error> {
sqlx::query_as!(
ApiKeyByIDResponse,
r#"
SELECT id, description, expiration_date, creation_date
FROM apikeys
WHERE id = $1 AND user_id = $2
"#,
id,
user_id
)
.fetch_optional(pool)
.await
}
/// Retrieves active keys for user with security checks
///
/// # Security
/// - Excludes disabled keys and expired keys
pub async fn fetch_active_apikeys_by_user_id_from_db(
pool: &PgPool,
user_id: Uuid
) -> Result<Vec<ApiKeyByUserIDResponse>, sqlx::Error> {
sqlx::query_as!(
ApiKeyByUserIDResponse,
r#"
SELECT id, key_hash, expiration_date
FROM apikeys
WHERE
user_id = $1
AND disabled = FALSE
AND (expiration_date IS NULL OR expiration_date > CURRENT_DATE)
"#,
user_id
)
.fetch_all(pool)
.await
}
// ---------------------------
// Key Modification Functions
// ---------------------------
/// Disables an API key and sets short expiration grace period
///
/// # Security
/// - Requires matching user_id to prevent unauthorized revocation
pub async fn disable_apikey_in_db(
pool: &PgPool,
apikey_id: Uuid,
user_id: Uuid
) -> Result<u64, sqlx::Error> {
let result = sqlx::query!(
r#"
UPDATE apikeys
SET
disabled = TRUE,
expiration_date = CURRENT_DATE + INTERVAL '1 day'
WHERE id = $1 AND user_id = $2
"#,
apikey_id,
user_id
)
.execute(pool)
.await?;
Ok(result.rows_affected())
}
// ---------------------------
// Key Deletion Functions
// ---------------------------
/// Permanently removes an API key from the system
///
/// # Security
/// - Requires matching user_id to prevent unauthorized deletion
pub async fn delete_apikey_from_db(
pool: &PgPool,
id: Uuid,
user_id: Uuid
) -> Result<u64, sqlx::Error> {
let result = sqlx::query!(
r#"
DELETE FROM apikeys
WHERE id = $1 AND user_id = $2
"#,
id,
user_id
)
.execute(pool)
.await?;
Ok(result.rows_affected())
}
// ---------------------------
// Validation Functions
// ---------------------------
/// Checks active key count against rate limits
///
/// # Security
/// - Used to enforce business logic limits
pub async fn check_existing_api_key_count(
pool: &PgPool,
user_id: Uuid
) -> Result<i64, sqlx::Error> {
let row = sqlx::query!(
r#"
SELECT COUNT(*) as count
FROM apikeys
WHERE
user_id = $1
AND disabled = FALSE
AND (expiration_date IS NULL OR expiration_date >= CURRENT_DATE)
"#,
user_id
)
.fetch_one(pool)
.await?;
Ok(row.count.unwrap_or(0))
}
/// Validates key existence and ownership before operations
pub async fn fetch_existing_apikey(
pool: &PgPool,
user_id: Uuid,
apikey_id: Uuid
) -> Result<Option<ApiKeyGetActiveForUserResponse>, sqlx::Error> {
sqlx::query_as!(
ApiKeyGetActiveForUserResponse,
r#"
SELECT id, description
FROM apikeys
WHERE user_id = $1 AND id = $2 AND disabled = FALSE
"#,
user_id,
apikey_id
)
.fetch_optional(pool)
.await
}

View File

@ -1,51 +1,140 @@
use dotenvy::dotenv;
use sqlx::{PgPool, migrate::Migrator, postgres::PgPoolOptions};
use std::fs;
use std::env;
use std::path::Path;
use sqlx::{PgPool, migrate::Migrator, migrate::MigrateError, postgres::PgPoolOptions};
use std::{env, fs, path::Path, time::Duration};
use thiserror::Error;
/// Connects to the database using the DATABASE_URL environment variable.
pub async fn connect_to_database() -> Result<PgPool, sqlx::Error> {
dotenv().ok();
let database_url = &env::var("DATABASE_URL").expect("❌ 'DATABASE_URL' environment variable not fount.");
// ---------------------------
// Error Handling
// ---------------------------
// Read max and min connection values from environment variables, with defaults
let max_connections: u32 = env::var("DATABASE_MAX_CONNECTIONS")
.unwrap_or_else(|_| "10".to_string()) // Default to 10
.parse()
.expect("❌ Invalid 'DATABASE_MAX_CONNECTIONS' value; must be a number.");
#[derive(Debug, Error)]
pub enum DatabaseError {
#[error("❌ Environment error: {0}")]
EnvError(String),
let min_connections: u32 = env::var("DATABASE_MIN_CONNECTIONS")
.unwrap_or_else(|_| "2".to_string()) // Default to 2
.parse()
.expect("❌ Invalid 'DATABASE_MIN_CONNECTIONS' value; must be a number.");
#[error("❌ Connection error: {0}")]
ConnectionError(#[from] sqlx::Error),
#[error("❌ File system error: {0}")]
FileSystemError(String),
#[error("❌ Configuration error: {0}")]
ConfigError(String),
#[error("❌ Migration error: {0}")]
MigrationError(#[from] MigrateError),
}
// ---------------------------
// Database Connection
// ---------------------------
/// Establishes a secure connection to PostgreSQL with connection pooling
///
/// # Security Features
/// - Validates database URL format
/// - Enforces connection limits
/// - Uses environment variables securely
/// - Implements connection timeouts
///
/// # Returns
/// `Result<PgPool, DatabaseError>` - Connection pool or detailed error
pub async fn connect_to_database() -> Result<PgPool, DatabaseError> {
// Load environment variables securely
dotenv().ok();
// Validate database URL presence and format
let database_url = env::var("DATABASE_URL")
.map_err(|_| DatabaseError::EnvError("DATABASE_URL not found".to_string()))?;
if !database_url.starts_with("postgres://") {
return Err(DatabaseError::ConfigError(
"❌ Invalid DATABASE_URL format - must start with postgres://".to_string()
));
}
// Configure connection pool with safety defaults
let max_connections = parse_env_var("DATABASE_MAX_CONNECTIONS", 10)?;
let min_connections = parse_env_var("DATABASE_MIN_CONNECTIONS", 2)?;
// Create and configure the connection pool
let pool = PgPoolOptions::new()
.max_connections(max_connections)
.min_connections(min_connections)
.acquire_timeout(Duration::from_secs(5)) // Prevent hanging connections
.idle_timeout(Duration::from_secs(300)) // Clean up idle connections
.test_before_acquire(true) // Validate connections
.connect(&database_url)
.await?;
.await
.map_err(|e| DatabaseError::ConnectionError(e))?;
Ok(pool)
}
/// Run database migrations
pub async fn run_database_migrations(pool: &PgPool) -> Result<(), sqlx::Error> {
// Define the path to the migrations folder
let migrations_path = Path::new("./migrations");
/// Helper function to safely parse environment variables
fn parse_env_var<T: std::str::FromStr>(name: &str, default: T) -> Result<T, DatabaseError>
where
T::Err: std::fmt::Display,
{
match env::var(name) {
Ok(val) => val.parse().map_err(|e| DatabaseError::ConfigError(
format!("❌ Invalid {} value: {}", name, e)
)),
Err(_) => Ok(default),
}
}
// Check if the migrations folder exists, and if not, create it
// ---------------------------
// Database Migrations
// ---------------------------
/// Executes database migrations with safety checks
///
/// # Security Features
/// - Validates migrations directory existence
/// - Limits migration execution to development/staging environments
/// - Uses transactional migrations where supported
///
/// # Returns
/// `Result<(), DatabaseError>` - Success or detailed error
pub async fn run_database_migrations(pool: &PgPool) -> Result<(), DatabaseError> {
let migrations_path = Path::new("./migrations");
// Validate migrations directory
if !migrations_path.exists() {
fs::create_dir_all(migrations_path).expect("❌ Failed to create migrations directory. Make sure you have the necessary permissions.");
println!("✔️ Created migrations directory: {:?}", migrations_path);
fs::create_dir_all(migrations_path)
.map_err(|e| DatabaseError::FileSystemError(
format!("❌ Failed to create migrations directory: {}", e)
))?;
}
// Create a migrator instance that looks for migrations in the `./migrations` folder
let migrator = Migrator::new(migrations_path).await?;
// Verify directory permissions
let metadata = fs::metadata(migrations_path)
.map_err(|e| DatabaseError::FileSystemError(
format!("❌ Cannot access migrations directory: {}", e)
))?;
if metadata.permissions().readonly() {
return Err(DatabaseError::FileSystemError(
"❌ Migrations directory is read-only".to_string()
));
}
// Run all pending migrations
migrator.run(pool).await?;
// Initialize migrator with production safety checks
let migrator = Migrator::new(migrations_path)
.await
.map_err(|e| DatabaseError::MigrationError(e))?;
// Execute migrations in transaction if supported
if env::var("ENVIRONMENT").unwrap_or_else(|_| "development".into()) == "production" {
println!("🛑 Migration execution blocked in production.");
return Err(DatabaseError::ConfigError(
"🛑 Direct migrations disabled in production.".to_string()
));
}
migrator.run(pool)
.await
.map_err(DatabaseError::MigrationError)?;
Ok(())
}

View File

@ -1,16 +0,0 @@
use sqlx::postgres::PgPool;
use uuid::Uuid;
use crate::models::apikey::*;
pub async fn get_active_apikeys_by_user_id(pool: &PgPool, user_id: Uuid) -> Result<Vec<ApiKeyByUserIDResponse>, sqlx::Error> {
sqlx::query_as!(ApiKeyByUserIDResponse,
r#"
SELECT id, key_hash, expiration_date::DATE
FROM apikeys
WHERE user_id = $1 AND (expiration_date IS NULL OR expiration_date > NOW()::DATE)
"#,
user_id
)
.fetch_all(pool)
.await
}

View File

@ -1,25 +0,0 @@
use sqlx::postgres::PgPool;
use crate::models::user::*; // Import the User struct
// Get all users
pub async fn get_user_by_email(pool: &PgPool, email: String) -> Result<User, String> {
// Use a string literal directly in the macro
let user = sqlx::query_as!(
User, // Struct type to map the query result
r#"
SELECT id, username, email, password_hash, totp_secret, role_level, tier_level, creation_date
FROM users
WHERE email = $1
"#,
email // Bind the `email` parameter
)
.fetch_optional(pool)
.await
.map_err(|e| format!("Database error: {}", e))?; // Handle database errors
// Handle optional result
match user {
Some(user) => Ok(user),
None => Err(format!("User with email '{}' not found.", email)),
}
}

View File

@ -1,16 +0,0 @@
use sqlx::postgres::PgPool;
use uuid::Uuid;
pub async fn insert_usage(pool: &PgPool, user_id: Uuid, endpoint: String) -> Result<(), sqlx::Error> {
sqlx::query!(
r#"INSERT INTO usage
(endpoint, user_id)
VALUES ($1, $2)"#,
endpoint,
user_id
)
.execute(pool)
.await?;
Ok(())
}

View File

@ -1,5 +1,6 @@
// Module declarations
pub mod connect;
pub mod get_users;
pub mod get_apikeys;
pub mod insert_usage;
pub mod users;
pub mod apikeys;
pub mod usage;
pub mod todos;

107
src/database/todos.rs Normal file
View File

@ -0,0 +1,107 @@
use sqlx::postgres::PgPool;
use uuid::Uuid;
use crate::models::todo::*;
/// Inserts a new Todo into the database with robust input validation and ownership enforcement
///
/// # Validation
/// - Task must be 1-100 characters after trimming
/// - Description (if provided) must be ≤500 characters after trimming
/// - Automatically associates todo with the requesting user
///
/// # Security
/// - Uses parameterized queries to prevent SQL injection
/// - Trims input to prevent whitespace abuse
pub async fn insert_todo_into_db(
pool: &PgPool,
task: String,
description: Option<String>,
user_id: Uuid,
) -> Result<Todo, sqlx::Error> {
// Sanitize and validate task
let task = task.trim();
if task.is_empty() {
return Err(sqlx::Error::Protocol("Task cannot be empty".into()));
}
if task.len() > 100 {
return Err(sqlx::Error::Protocol("Task exceeds maximum length of 100 characters".into()));
}
// Sanitize and validate optional description
let description = description.map(|d| d.trim().to_string())
.filter(|d| !d.is_empty());
if let Some(desc) = &description {
if desc.len() > 500 {
return Err(sqlx::Error::Protocol("Description exceeds maximum length of 500 characters".into()));
}
}
// Insert with ownership enforcement
let row = sqlx::query_as!(
Todo,
"INSERT INTO todos (task, description, user_id)
VALUES ($1, $2, $3)
RETURNING id, user_id, task, description, creation_date, completion_date, completed",
task,
description,
user_id
)
.fetch_one(pool)
.await?;
Ok(row)
}
/// Retrieves all Todos for a specific user with strict ownership filtering
///
/// # Security
/// - Uses WHERE clause with user_id to ensure data isolation
/// - Parameterized query prevents SQL injection
pub async fn fetch_all_todos_from_db(pool: &PgPool, user_id: Uuid) -> Result<Vec<Todo>, sqlx::Error> {
let todos = sqlx::query_as!(
Todo,
"SELECT id, user_id, task, description, creation_date, completion_date, completed
FROM todos WHERE user_id = $1",
user_id
)
.fetch_all(pool)
.await?;
Ok(todos)
}
/// Safely retrieves a single Todo by ID with ownership verification
///
/// # Security
/// - Combines ID and user_id in WHERE clause to prevent unauthorized access
/// - Returns Option<Todo> to avoid exposing existence of other users' todos
pub async fn fetch_todo_by_id_from_db(pool: &PgPool, id: Uuid, user_id: Uuid) -> Result<Option<Todo>, sqlx::Error> {
let todo = sqlx::query_as!(
Todo,
"SELECT id, user_id, task, description, creation_date, completion_date, completed
FROM todos WHERE id = $1 AND user_id = $2",
id,
user_id
)
.fetch_optional(pool)
.await?;
Ok(todo)
}
/// Securely deletes a Todo by ID with ownership confirmation
///
/// # Security
/// - Requires both ID and user_id for deletion
/// - Returns affected row count without exposing existence of other users' todos
pub async fn delete_todo_from_db(pool: &PgPool, id: Uuid, user_id: Uuid) -> Result<u64, sqlx::Error> {
let result = sqlx::query!(
"DELETE FROM todos WHERE id = $1 AND user_id = $2",
id,
user_id
)
.execute(pool)
.await?;
Ok(result.rows_affected())
}

68
src/database/usage.rs Normal file
View File

@ -0,0 +1,68 @@
use sqlx::postgres::PgPool;
use uuid::Uuid;
/// Records API usage with validation and security protections
///
/// # Validation
/// - Endpoint must be 1-100 characters after trimming
/// - Rejects empty or whitespace-only endpoints
///
/// # Security
/// - Uses parameterized queries to prevent SQL injection
/// - Automatically trims and sanitizes endpoint input
/// - Enforces user ownership through database constraints
// pub async fn insert_usage_into_db(
// pool: &PgPool,
// user_id: Uuid,
// endpoint: String,
// ) -> Result<(), sqlx::Error> {
// // Sanitize and validate endpoint
// let endpoint = endpoint.trim();
// if endpoint.is_empty() {
// return Err(sqlx::Error::Protocol("Endpoint cannot be empty".into()));
// }
// if endpoint.len() > 100 {
// return Err(sqlx::Error::Protocol("Endpoint exceeds maximum length of 100 characters".into()));
// }
// sqlx::query!(
// r#"INSERT INTO usage (endpoint, user_id)
// VALUES ($1, $2)"#,
// endpoint,
// user_id
// )
// .execute(pool)
// .await?;
// Ok(())
// }
/// Safely retrieves usage count for a user within a specified time period
///
/// # Security
/// - Uses parameterized query with interval casting to prevent SQL injection
/// - Explicit user ownership check
/// - COALESCE ensures always returns a number (0 if no usage)
///
/// # Example Interval Formats
/// - '1 hour'
/// - '7 days'
/// - '30 minutes'
pub async fn fetch_usage_count_from_db(
pool: &PgPool,
user_id: Uuid,
interval: &str,
) -> Result<i64, sqlx::Error> {
let count: i64 = sqlx::query_scalar(
r#"SELECT COALESCE(COUNT(*), 0)
FROM usage
WHERE user_id = $1
AND creation_date > NOW() - CAST($2 AS INTERVAL)"#
)
.bind(user_id)
.bind(interval)
.fetch_one(pool)
.await?;
Ok(count)
}

141
src/database/users.rs Normal file
View File

@ -0,0 +1,141 @@
use sqlx::postgres::PgPool;
use uuid::Uuid;
use crate::models::user::*;
use regex::Regex;
use sqlx::Error;
/// Retrieves all users with security considerations
///
/// # Security
/// - Requires admin privileges (enforced at application layer)
/// - Excludes sensitive fields like password_hash and totp_secret
/// - Limits maximum results in production (enforced at application layer)
pub async fn fetch_all_users_from_db(pool: &PgPool) -> Result<Vec<UserGetResponse>, sqlx::Error> {
sqlx::query_as!(
UserGetResponse,
"SELECT id, username, email, role_level, tier_level, creation_date
FROM users"
)
.fetch_all(pool)
.await
}
/// Safely retrieves user by allowed fields using whitelist validation
///
/// # Allowed Fields
/// - id (UUID)
/// - email (valid email format)
/// - username (valid username format)
///
/// # Security
/// - Field whitelisting prevents SQL injection
/// - Parameterized query for value
pub async fn fetch_user_by_field_from_db(
pool: &PgPool,
field: &str,
value: &str,
) -> Result<Option<User>, sqlx::Error> {
let query = match field {
"id" => "SELECT * FROM users WHERE id = $1",
"email" => "SELECT * FROM users WHERE email = $1",
"username" => "SELECT * FROM users WHERE username = $1",
_ => return Err(sqlx::Error::ColumnNotFound(field.to_string())),
};
sqlx::query_as::<_, User>(query)
.bind(value)
.fetch_optional(pool)
.await
}
/// Retrieves user by email with validation
///
/// # Security
/// - Parameterized query prevents SQL injection
/// - Returns Option to avoid user enumeration risks
pub async fn fetch_user_by_email_from_db(
pool: &PgPool,
email: &str,
) -> Result<Option<User>, sqlx::Error> {
sqlx::query_as!(
User,
r#"SELECT id, username, email, password_hash, totp_secret,
role_level, tier_level, creation_date
FROM users WHERE email = $1"#,
email
)
.fetch_optional(pool)
.await
}
/// Securely deletes a user by ID
///
/// # Security
/// - Requires authentication and authorization
/// - Parameterized query prevents SQL injection
/// - Returns affected rows without sensitive data
pub async fn delete_user_from_db(pool: &PgPool, id: Uuid) -> Result<u64, sqlx::Error> {
let result = sqlx::query!("DELETE FROM users WHERE id = $1", id)
.execute(pool)
.await?;
Ok(result.rows_affected())
}
/// Creates new user with comprehensive validation
///
/// # Validation
/// - Username: 3-30 alphanumeric characters
/// - Email: Valid format with domain verification
/// - Password: Minimum strength requirements (enforced at application layer)
pub async fn insert_user_into_db(
pool: &PgPool,
username: &str,
email: &str,
password_hash: &str,
totp_secret: &str,
role_level: i32,
tier_level: i32,
) -> Result<UserInsertResponse, Error> {
// Validate username
let username = username.trim();
if username.len() < 3 || username.len() > 30 {
return Err(Error::Protocol("Username must be between 3 and 30 characters.".into()));
}
if !username.chars().all(|c| c.is_alphanumeric() || c == '_') {
return Err(Error::Protocol("Invalid username format: only alphanumeric and underscores allowed.".into()));
}
// Validate email
let email = email.trim().to_lowercase();
if !is_valid_email(&email) {
return Err(Error::Protocol("Invalid email format.".into()));
}
// Insert user into database
let row = sqlx::query_as!(
UserInsertResponse,
r#"INSERT INTO users
(username, email, password_hash, totp_secret, role_level, tier_level, creation_date)
VALUES ($1, $2, $3, $4, $5, $6, NOW()::timestamp)
RETURNING id, username, email, totp_secret, role_level, tier_level, creation_date"#,
username,
email,
password_hash,
totp_secret,
role_level,
tier_level,
)
.fetch_one(pool)
.await?;
Ok(row)
}
/// Email validation helper function
fn is_valid_email(email: &str) -> bool {
let email_regex = Regex::new(
r"^[a-z0-9_+]+([a-z0-9_.-]*[a-z0-9_+])?@[a-z0-9]+([-.][a-z0-9]+)*\.[a-z]{2,6}$"
).unwrap();
email_regex.is_match(email)
}

32
src/handlers/README.md Normal file
View File

@ -0,0 +1,32 @@
# Handlers Module for Rust API
This folder contains the route handlers used in the Rust API, responsible for processing incoming HTTP requests and generating responses.
## Overview
The `/src/handlers` folder includes implementations of route handlers for API keys, usage metrics, and the homepage.
### Key Components
- **Axum Handlers:** Built using Axum's handler utilities for routing and extracting request data.
- **SQLx:** Manages database operations like fetching usage and deleting API keys.
- **UUID and Serde:** Handles unique IDs and JSON serialization.
- **Tracing:** Provides structured logging for monitoring and debugging.
## Usage
Handlers are linked to Axum routes using `route` and `handler` methods:
```rust
route("/apikeys/:id", delete(delete_apikey_by_id))
.route("/usage/lastday", get(get_usage_last_day))
```
## Dependencies
- [Axum](https://docs.rs/axum/latest/axum/)
- [SQLx](https://docs.rs/sqlx/latest/sqlx/)
- [UUID](https://docs.rs/uuid/latest/uuid/)
- [Serde](https://docs.rs/serde/latest/serde/)
- [Tracing](https://docs.rs/tracing/latest/tracing/)
## Contributing
Ensure new handlers are well-documented, include proper error handling, and maintain compatibility with existing routes.
## License
This project is licensed under the MIT License.

View File

@ -1,26 +1,32 @@
use axum::{
extract::{State, Extension, Path},
Json,
response::IntoResponse,
http::StatusCode
extract::{State, Extension, Path},
Json,
http::StatusCode,
};
use sqlx::postgres::PgPool;
use uuid::Uuid;
use serde_json::json;
use tracing::instrument; // For logging
use crate::models::user::User;
use crate::database::apikeys::delete_apikey_from_db;
// --- Route Handler ---
// Delete a API key by id
#[utoipa::path(
delete,
path = "/apikeys/{id}",
tag = "apikey",
security(
("jwt_token" = [])
),
params(
("id" = String, Path, description = "API key ID")
),
responses(
(status = 200, description = "API key deleted successfully", body = String),
(status = 400, description = "Invalid UUID format", body = String),
(status = 401, description = "Unauthorized", body = serde_json::Value),
(status = 404, description = "API key not found", body = String),
(status = 500, description = "Internal server error", body = String)
)
@ -30,28 +36,35 @@ pub async fn delete_apikey_by_id(
State(pool): State<PgPool>,
Extension(user): Extension<User>,
Path(id): Path<String>, // Use Path extractor here
) -> impl IntoResponse {
) -> Result<(StatusCode, Json<serde_json::Value>), (StatusCode, Json<serde_json::Value>)> {
// Parse the id string to UUID
let uuid = match Uuid::parse_str(&id) {
Ok(uuid) => uuid,
Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({ "error": format!("Invalid UUID format.")}))),
Err(_) => {
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": format!("Invalid UUID format.") })),
));
}
};
let result = sqlx::query!("DELETE FROM apikeys WHERE id = $1 AND user_id = $2", uuid, user.id)
.execute(&pool) // Borrow the connection pool
.await;
match result {
Ok(res) => {
if res.rows_affected() == 0 {
(StatusCode::NOT_FOUND, Json(json!({ "error": format!("API key with ID '{}' not found.", id) })))
match delete_apikey_from_db(&pool, uuid, user.id).await {
Ok(rows_affected) => {
if rows_affected == 0 {
Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": format!("API key with ID '{}' not found.", id) })),
))
} else {
(StatusCode::OK, Json(json!({ "success": format!("API key with ID '{}' deleted.", id)})))
Ok((
StatusCode::OK,
Json(json!({ "success": format!("API key with ID '{}' deleted.", id) })),
))
}
}
Err(_err) => (
Err(_err) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Could not delete API key '{}'.", id)}))
),
Json(json!({ "error": format!("Could not delete API key '{}'.", id) }))
)),
}
}
}

View File

@ -1,8 +1,7 @@
use axum::{
extract::{State, Extension, Path},
Json,
response::IntoResponse,
http::StatusCode
extract::{State, Extension, Path},
Json,
http::StatusCode,
};
use sqlx::postgres::PgPool;
use uuid::Uuid;
@ -10,15 +9,22 @@ use serde_json::json;
use tracing::instrument; // For logging
use crate::models::user::User;
use crate::models::documentation::{ErrorResponse, SuccessResponse};
use crate::database::todos::delete_todo_from_db;
// --- Route Handler ---
// Delete a todo by id
#[utoipa::path(
delete,
path = "/todos/{id}",
tag = "todo",
security(
("jwt_token" = [])
),
responses(
(status = 200, description = "Todo deleted successfully", body = SuccessResponse),
(status = 400, description = "Invalid UUID format", body = ErrorResponse),
(status = 401, description = "Unauthorized", body = serde_json::Value),
(status = 404, description = "Todo not found", body = ErrorResponse),
(status = 500, description = "Internal Server Error", body = ErrorResponse)
),
@ -32,22 +38,29 @@ pub async fn delete_todo_by_id(
State(pool): State<PgPool>,
Extension(user): Extension<User>,
Path(id): Path<String>, // Use Path extractor here
) -> impl IntoResponse {
) -> Result<(StatusCode, Json<serde_json::Value>), (StatusCode, Json<serde_json::Value>)> {
let uuid = match Uuid::parse_str(&id) {
Ok(uuid) => uuid,
Err(_) => return Err((StatusCode::BAD_REQUEST, Json(json!({ "error": "Invalid UUID format." })),)),
Err(_) => {
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": "Invalid UUID format." })),
));
}
};
let result = sqlx::query!("DELETE FROM todos WHERE id = $1 AND user_id = $2", uuid, user.id)
.execute(&pool) // Borrow the connection pool
.await;
match result {
Ok(res) => {
if (res.rows_affected() == 0) {
Err((StatusCode::NOT_FOUND, Json(json!({ "error": format!("Todo with ID '{}' not found.", id) })),))
match delete_todo_from_db(&pool, uuid, user.id).await {
Ok(rows_affected) => {
if rows_affected == 0 {
Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": format!("Todo with ID '{}' not found.", id) })),
))
} else {
Ok((StatusCode::OK, Json(json!({ "success": format!("Todo with ID '{}' deleted.", id) })),))
Ok((
StatusCode::OK,
Json(json!({ "success": format!("Todo with ID '{}' deleted.", id) })),
))
}
}
Err(_err) => Err((

View File

@ -1,24 +1,30 @@
use axum::{
extract::{State, Path},
Json,
response::IntoResponse,
http::StatusCode
extract::{State, Path},
Json,
http::StatusCode,
};
use sqlx::postgres::PgPool;
use uuid::Uuid;
use serde_json::json;
use tracing::instrument; // For logging
use crate::models::documentation::{ErrorResponse, SuccessResponse};
use crate::database::users::delete_user_from_db;
// --- Route Handler ---
// Delete a user by id
#[utoipa::path(
delete,
path = "/users/{id}",
tag = "user",
security(
("jwt_token" = [])
),
responses(
(status = 200, description = "User deleted successfully", body = SuccessResponse),
(status = 400, description = "Invalid UUID format", body = ErrorResponse),
(status = 401, description = "Unauthorized", body = serde_json::Value),
(status = 404, description = "User not found", body = ErrorResponse),
(status = 500, description = "Internal Server Error", body = ErrorResponse)
),
@ -30,27 +36,34 @@ use crate::models::documentation::{ErrorResponse, SuccessResponse};
pub async fn delete_user_by_id(
State(pool): State<PgPool>,
Path(id): Path<String>, // Use Path extractor here
) -> impl IntoResponse {
) -> Result<(StatusCode, Json<serde_json::Value>), (StatusCode, Json<serde_json::Value>)> {
let uuid = match Uuid::parse_str(&id) {
Ok(uuid) => uuid,
Err(_) => return Err((StatusCode::BAD_REQUEST, Json(json!({ "error": "Invalid UUID format." })),)),
Err(_) => {
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": "Invalid UUID format." })),
));
}
};
let result = sqlx::query_as!(User, "DELETE FROM USERS WHERE id = $1", uuid)
.execute(&pool) // Borrow the connection pool
.await;
match result {
Ok(res) => {
if res.rows_affected() == 0 {
Err((StatusCode::NOT_FOUND, Json(json!({ "error": format!("User with ID '{}' not found.", id) })),))
match delete_user_from_db(&pool, uuid).await {
Ok(rows_affected) => {
if rows_affected == 0 {
Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": format!("User with ID '{}' not found.", id) })),
))
} else {
Ok((StatusCode::OK, Json(json!({ "success": format!("User with ID '{}' deleted.", id) })),))
Ok((
StatusCode::OK,
Json(json!({ "success": format!("User with ID '{}' deleted.", id) })),
))
}
}
Err(_err) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Could not delete the user."})),
Json(json!({ "error": "Could not delete the user." })),
)),
}
}

View File

@ -1,7 +1,6 @@
use axum::{
extract::{State, Extension, Path},
Json,
response::IntoResponse,
Json,
http::StatusCode
};
use sqlx::postgres::PgPool;
@ -12,14 +11,21 @@ use crate::models::apikey::*;
use crate::models::user::*;
use crate::models::documentation::ErrorResponse;
use crate::models::apikey::ApiKeyResponse;
use crate::database::apikeys::{fetch_all_apikeys_from_db, fetch_apikey_by_id_from_db};
// --- Route Handlers ---
// Get all API keys
#[utoipa::path(
get,
path = "/apikeys",
tag = "apikey",
security(
("jwt_token" = [])
),
responses(
(status = 200, description = "Get all API keys", body = [ApiKeyResponse]),
(status = 401, description = "Unauthorized", body = serde_json::Value),
(status = 500, description = "Internal Server Error", body = ErrorResponse)
),
params(
@ -30,19 +36,12 @@ use crate::models::apikey::ApiKeyResponse;
pub async fn get_all_apikeys(
State(pool): State<PgPool>,
Extension(user): Extension<User>, // Extract current user from the request extensions
) -> impl IntoResponse {
let apikeys = sqlx::query_as!(ApiKeyResponse,
"SELECT id, user_id, description, expiration_date, creation_date FROM apikeys WHERE user_id = $1",
user.id
)
.fetch_all(&pool) // Borrow the connection pool
.await;
match apikeys {
) -> Result<Json<Vec<ApiKeyResponse>>, (StatusCode, Json<serde_json::Value>)> {
match fetch_all_apikeys_from_db(&pool, user.id).await {
Ok(apikeys) => Ok(Json(apikeys)), // Return all API keys as JSON
Err(_err) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Could not get the API key."})),
Json(json!({ "error": "Could not get the API keys."})),
)),
}
}
@ -68,22 +67,14 @@ pub async fn get_apikeys_by_id(
State(pool): State<PgPool>,
Extension(user): Extension<User>, // Extract current user from the request extensions
Path(id): Path<String>, // Use Path extractor here
) -> impl IntoResponse {
) -> Result<Json<ApiKeyByIDResponse>, (StatusCode, Json<serde_json::Value>)> {
let uuid = match Uuid::parse_str(&id) {
Ok(uuid) => uuid,
Err(_) => return Err((StatusCode::BAD_REQUEST, Json(json!({ "error": "Invalid UUID format." })))),
};
let apikeys = sqlx::query_as!(ApiKeyByIDResponse,
"SELECT id, description, expiration_date, creation_date FROM apikeys WHERE id = $1 AND user_id = $2",
uuid,
user.id
)
.fetch_optional(&pool) // Borrow the connection pool
.await;
match apikeys {
Ok(Some(apikeys)) => Ok(Json(apikeys)), // Return the API key as JSON if found
match fetch_apikey_by_id_from_db(&pool, uuid, user.id).await {
Ok(Some(apikey)) => Ok(Json(apikey)), // Return the API key as JSON if found
Ok(None) => Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": format!("API key with ID '{}' not found.", id) })),
@ -93,4 +84,4 @@ pub async fn get_apikeys_by_id(
Json(json!({ "error": "Could not get the API key."})),
)),
}
}
}

View File

@ -9,43 +9,7 @@ use sysinfo::{System, RefreshKind, Disks};
use tokio::{task, join};
use std::sync::{Arc, Mutex};
use tracing::instrument; // For logging
use utoipa::ToSchema;
use serde::{Deserialize, Serialize};
// Struct definitions
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct HealthResponse {
pub cpu_usage: CpuUsage,
pub database: DatabaseStatus,
pub disk_usage: DiskUsage,
pub memory: MemoryStatus,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct CpuUsage {
#[serde(rename = "available_percentage")]
pub available_pct: String,
pub status: String,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct DatabaseStatus {
pub status: String,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct DiskUsage {
pub status: String,
#[serde(rename = "used_percentage")]
pub used_pct: String,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct MemoryStatus {
#[serde(rename = "available_mb")]
pub available_mb: i64,
pub status: String,
}
use crate::models::health::HealthResponse;
// Health check endpoint
#[utoipa::path(

View File

@ -1,8 +1,7 @@
use axum::{
extract::{State, Extension, Path},
Json,
response::IntoResponse,
http::StatusCode
extract::{State, Extension, Path},
Json,
http::StatusCode,
};
use sqlx::postgres::PgPool;
use uuid::Uuid;
@ -10,14 +9,21 @@ use serde_json::json;
use tracing::instrument; // For logging
use crate::models::todo::*;
use crate::models::user::*;
use crate::database::todos::{fetch_all_todos_from_db, fetch_todo_by_id_from_db};
// --- Route Handlers ---
// Get all todos
#[utoipa::path(
get,
path = "/todos/all",
tag = "todo",
security(
("jwt_token" = [])
),
responses(
(status = 200, description = "Successfully fetched all todos", body = [Todo]),
(status = 401, description = "Unauthorized", body = serde_json::Value),
(status = 500, description = "Internal server error")
)
)]
@ -25,16 +31,9 @@ use crate::models::user::*;
pub async fn get_all_todos(
State(pool): State<PgPool>,
Extension(user): Extension<User>, // Extract current user from the request extensions
) -> impl IntoResponse {
let todos = sqlx::query_as!(Todo,
"SELECT id, user_id, task, description, creation_date, completion_date, completed FROM todos WHERE user_id = $1",
user.id
)
.fetch_all(&pool) // Borrow the connection pool
.await;
match todos {
Ok(todos) => Ok(Json(todos)), // Return all todos as JSON
) -> Result<Json<Vec<Todo>>, (StatusCode, Json<serde_json::Value>)> {
match fetch_all_todos_from_db(&pool, user.id).await {
Ok(todos) => Ok(Json(todos)),
Err(_err) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Could not fetch the details of the todo." })),
@ -62,22 +61,19 @@ pub async fn get_todos_by_id(
State(pool): State<PgPool>,
Extension(user): Extension<User>, // Extract current user from the request extensions
Path(id): Path<String>, // Use Path extractor here
) -> impl IntoResponse {
) -> Result<Json<Todo>, (StatusCode, Json<serde_json::Value>)> {
let uuid = match Uuid::parse_str(&id) {
Ok(uuid) => uuid,
Err(_) => return Err((StatusCode::BAD_REQUEST, Json(json!({ "error": "Invalid UUID format." })))),
Err(_) => {
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": "Invalid UUID format." })),
));
}
};
let todo = sqlx::query_as!(Todo,
"SELECT id, user_id, task, description, creation_date, completion_date, completed FROM todos WHERE id = $1 AND user_id = $2",
uuid,
user.id
)
.fetch_optional(&pool) // Borrow the connection pool
.await;
match todo {
Ok(Some(todo)) => Ok(Json(todo)), // Return the todo as JSON if found
match fetch_todo_by_id_from_db(&pool, uuid, user.id).await {
Ok(Some(todo)) => Ok(Json(todo)),
Ok(None) => Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": format!("Todo with ID '{}' not found.", id) })),

62
src/handlers/get_usage.rs Normal file
View File

@ -0,0 +1,62 @@
use axum::{extract::{Extension, State}, Json};
use axum::http::StatusCode;
use axum::response::IntoResponse;
use serde_json::json;
use sqlx::postgres::PgPool;
use tracing::instrument;
use crate::models::user::*;
use crate::models::usage::*;
use crate::database::usage::fetch_usage_count_from_db;
// Get usage for the last 24 hours
#[utoipa::path(
get,
path = "/usage/lastday",
tag = "usage",
security(
("jwt_token" = [])
),
responses(
(status = 200, description = "Successfully fetched usage for the last 24 hours", body = UsageResponseLastDay),
(status = 401, description = "Unauthorized", body = serde_json::Value),
(status = 500, description = "Internal server error")
)
)]
#[instrument(skip(pool))]
pub async fn get_usage_last_day(
State(pool): State<PgPool>,
Extension(user): Extension<User>,
) -> impl IntoResponse {
match fetch_usage_count_from_db(&pool, user.id, "24 hours").await {
Ok(count) => Ok(Json(json!({ "requests_last_24_hours": count }))),
Err(_) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Could not fetch the usage data." }))
)),
}
}
// Get usage for the last 7 days
#[utoipa::path(
get,
path = "/usage/lastweek",
tag = "usage",
responses(
(status = 200, description = "Successfully fetched usage for the last 7 days", body = UsageResponseLastDay),
(status = 500, description = "Internal server error")
)
)]
#[instrument(skip(pool))]
pub async fn get_usage_last_week(
State(pool): State<PgPool>,
Extension(user): Extension<User>,
) -> impl IntoResponse {
match fetch_usage_count_from_db(&pool, user.id, "7 days").await {
Ok(count) => Ok(Json(json!({ "requests_last_7_days": count }))),
Err(_) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Could not fetch the usage data." }))
)),
}
}

74
src/handlers/get_users.rs Normal file
View File

@ -0,0 +1,74 @@
use axum::extract::{Path, State};
use axum::http::StatusCode;
use axum::Json;
use axum::response::IntoResponse;
use serde_json::json;
use sqlx::postgres::PgPool;
use tracing::instrument;
use uuid::Uuid;
use crate::models::user::*;
use crate::database::users::{fetch_all_users_from_db, fetch_user_by_field_from_db};
// Get all users
#[utoipa::path(
get,
path = "/users/all",
tag = "user",
security(
("jwt_token" = [])
),
responses(
(status = 200, description = "Successfully fetched all users", body = [UserGetResponse]),
(status = 401, description = "Unauthorized", body = serde_json::Value),
(status = 500, description = "Internal server error")
)
)]
#[instrument(skip(pool))]
pub async fn get_all_users(State(pool): State<PgPool>) -> impl IntoResponse {
match fetch_all_users_from_db(&pool).await {
Ok(users) => Ok(Json(users)),
Err(_) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Could not fetch the users details." })),
)),
}
}
// Get a single user by ID
#[utoipa::path(
get,
path = "/users/{id}",
tag = "user",
params(
("id" = String, Path, description = "User ID")
),
responses(
(status = 200, description = "Successfully fetched user by ID", body = UserGetResponse),
(status = 400, description = "Invalid UUID format"),
(status = 404, description = "User not found"),
(status = 500, description = "Internal server error")
)
)]
#[instrument(skip(pool))]
pub async fn get_users_by_id(
State(pool): State<PgPool>,
Path(id): Path<String>,
) -> impl IntoResponse {
let uuid = match Uuid::parse_str(&id) {
Ok(uuid) => uuid,
Err(_) => return Err((StatusCode::BAD_REQUEST, Json(json!({ "error": "Invalid UUID format." })))),
};
match fetch_user_by_field_from_db(&pool, "id", &uuid.to_string()).await {
Ok(Some(user)) => Ok(Json(user)),
Ok(None) => Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": format!("User with ID '{}' not found", id) })),
)),
Err(_) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Could not fetch the users details." })),
)),
}
}

115
src/handlers/homepage.rs Normal file
View File

@ -0,0 +1,115 @@
use axum::response::{IntoResponse, Html};
// Homepage route
pub async fn homepage() -> impl IntoResponse {
Html(r#"
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Axium API</title>
<link rel="icon" href="data:image/svg+xml,<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 80 80'><text x='0' y='60' font-size='64'>🦖</text></svg>">
<style>
:root {
--neon-cyan: #00f3ff;
--dark-space: #0a0e14;
--starry-night: #1a1f2c;
}
body {
font-family: 'Arial', sans-serif;
background: linear-gradient(135deg, var(--dark-space) 0%, var(--starry-night) 100%);
color:#ffffff;
margin: 0;
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
line-height: 1.6;
}
a {
color:#00ffff;
text-decoration: none;
font-weight: 500;
transition: color 0.3s;
}
a:hover {
color: #40ffa0;
}
.container {
background: rgba(25, 28, 36, 0.9);
backdrop-filter: blur(12px);
border-radius: 16px;
padding: 2.5rem;
max-width: 800px;
margin: 2rem;
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.3);
border: 1px solid rgba(255, 255, 255, 0.1);
text-align: center;
}
h1 {
font-size: 1.2em;
white-space: pre;
font-family: monospace;
display: inline-block;
text-align: left;
line-height: normal;
}
ul {
list-style-type: none;
padding: 0;
text-align: left;
display: inline-block;
font-size: 1.1em;
}
li {
margin: 15px 0;
}
.github-link {
margin-top: 25px;
display: inline-flex;
align-items: center;
padding: 12px 25px;
background-color: #00ffff;
color: #0f111a;
border-radius: 8px;
font-weight: bold;
transition: background-color 0.3s;
}
.github-link:hover {
background-color:#40ffa0;
color: #ffffff;
}
.github-link svg {
margin-right: 8px;
position: relative;
}
</style>
</head>
<body>
<div class="container">
<h1>
db 88
d88b ""
d8'`8b
d8' `8b 8b, ,d8 88 88 88 88,dPYba,,adPYba,
d8YaaaaY8b `Y8, ,8P' 88 88 88 88P' "88" "8a
d8""""""""8b )888( 88 88 88 88 88 88
d8' `8b ,d8" "8b, 88 "8a, ,a88 88 88 88
d8' `8b 8P' `Y8 88 `"YbbdP'Y8 88 88 88
</h1>
<ul>
<li>📖 Explore the API using <a href="/swagger">Swagger UI</a> or import the <a href="/openapi.json">OpenAPI spec</a>.</li>
<li>🩺 Ensure your Docker setup is reliable, by pointing its healthcheck too <a href="/health">/health</a>.</li>
</ul>
<a href="https://github.com/Riktastic/Axium" class="github-link" target="_blank">
<svg height="20" aria-hidden="true" viewBox="0 0 16 16" version="1.1" width="20" data-view-component="true" fill="currentColor">
<path d="M8 0c4.42 0 8 3.58 8 8a8.013 8.013 0 0 1-5.45 7.59c-.4.08-.55-.17-.55-.38 0-.27.01-1.13.01-2.2 0-.75-.25-1.23-.54-1.48 1.78-.2 3.65-.88 3.65-3.95 0-.88-.31-1.59-.82-2.15.08-.2.36-1.02-.08-2.12 0 0-.67-.22-2.2.82-.64-.18-1.32-.27-2-.27-.68 0-1.36.09-2 .27-1.53-1.03-2.2-.82-2.2-.82-.44 1.1-.16 1.92-.08 2.12-.51.56-.82 1.28-.82 2.15 0 3.06 1.86 3.75 3.64 3.95-.23.2-.44.55-.51 1.07-.46.21-1.61.55-2.33-.66-.15-.24-.6-.83-1.23-.82-.67.01-.27.38.01.53.34.19.73.9.82 1.13.16.45.68 1.31 2.69.94 0 .67.01 1.3.01 1.49 0 .21-.15.45-.55.38A7.995 7.995 0 0 1 0 8c0-4.42 3.58-8 8-8Z"></path>
</svg>
View on GitHub
</a>
</div>
</body>
</html>
"#)
}

View File

@ -1,2 +1,16 @@
// Module declarations
pub mod validate;
pub mod delete_apikeys;
pub mod delete_todos;
pub mod delete_users;
pub mod get_apikeys;
pub mod get_health;
pub mod get_todos;
pub mod get_usage;
pub mod get_users;
pub mod homepage;
pub mod post_apikeys;
pub mod post_todos;
pub mod post_users;
pub mod protected;
pub mod rotate_apikeys;
pub mod signin;

View File

@ -1,54 +1,39 @@
use axum::{extract::{Extension, State}, Json};
use axum::http::StatusCode;
use axum::response::IntoResponse;
use chrono::{Duration, Utc};
use serde::{Deserialize, Serialize};
use serde_json::json;
use sqlx::postgres::PgPool;
use tracing::{error, info};
use utoipa::ToSchema;
use uuid::Uuid;
use validator::Validate;
use crate::handlers::validate::validate_future_date;
use crate::middlewares::auth::{generate_api_key, hash_password};
use crate::utils::auth::{generate_api_key, hash_password};
use crate::models::user::User;
use crate::database::apikeys::{check_existing_api_key_count, insert_api_key_into_db};
use crate::models::apikey::{ApiKeyInsertBody, ApiKeyInsertResponse};
// Define the request body structure
#[derive(Deserialize, Validate, ToSchema)]
pub struct ApiKeyBody {
#[validate(length(min = 0, max = 50))]
pub description: Option<String>,
#[validate(custom(function = "validate_future_date"))]
pub expiration_date: Option<String>,
}
// Define the response body structure
#[derive(Serialize, ToSchema)]
pub struct ApiKeyResponse {
pub id: Uuid,
pub api_key: String,
pub description: String,
pub expiration_date: String,
}
// --- Route Handler ---
// Define the API endpoint
#[utoipa::path(
post,
path = "/apikeys",
tag = "apikey",
request_body = ApiKeyBody,
security(
("jwt_token" = [])
),
request_body = ApiKeyInsertBody,
responses(
(status = 200, description = "API key created successfully", body = ApiKeyResponse),
(status = 200, description = "API key created successfully", body = ApiKeyInsertResponse),
(status = 400, description = "Validation error", body = String),
(status = 401, description = "Unauthorized", body = serde_json::Value),
(status = 500, description = "Internal server error", body = String)
)
)]
pub async fn post_apikey(
State(pool): State<PgPool>,
State(pool): State<PgPool>,
Extension(user): Extension<User>,
Json(api_key_request): Json<ApiKeyBody>
) -> impl IntoResponse {
Json(api_key_request): Json<ApiKeyInsertBody>
) -> Result<Json<ApiKeyInsertResponse>, (StatusCode, Json<serde_json::Value>)> {
// Validate input
if let Err(errors) = api_key_request.validate() {
let error_messages: Vec<String> = errors
@ -65,69 +50,49 @@ pub async fn post_apikey(
info!("Received request to create API key for user: {}", user.id);
// Check if the user already has 5 or more API keys
let existing_keys_count = sqlx::query!(
"SELECT COUNT(*) as count FROM apikeys WHERE user_id = $1 AND expiration_date >= CURRENT_DATE",
user.id
)
.fetch_one(&pool)
.await;
match existing_keys_count {
Ok(row) if row.count.unwrap_or(0) >= 5 => {
info!("User {} already has 5 API keys.", user.id);
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": "You already have 5 API keys. Please delete an existing key before creating a new one." }))
));
}
Err(_err) => {
error!("Failed to check the amount of API keys for user {}.", user.id);
let existing_keys_count = match check_existing_api_key_count(&pool, user.id).await {
Ok(count) => count,
Err(err) => {
error!("Failed to check the amount of API keys for user {}: {}", user.id, err);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Could not check the amount of API keys registered." }))
));
}
_ => {} // Proceed if the user has fewer than 5 keys
};
if existing_keys_count >= 5 {
info!("User {} already has 5 API keys.", user.id);
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": "You already have 5 API keys. Please delete an existing key before creating a new one." }))
));
}
let current_date = Utc::now().naive_utc();
let description = api_key_request.description
.unwrap_or_else(|| format!("API key created on {}", current_date.format("%Y-%m-%d")));
let expiration_date = api_key_request.expiration_date
.and_then(|date| date.parse::<chrono::NaiveDate>().ok())
.unwrap_or_else(|| (current_date + Duration::days(365 * 2)).date());
let api_key = generate_api_key();
let key_hash = hash_password(&api_key).expect("Failed to hash password.");
let row = sqlx::query!(
"INSERT INTO apikeys (key_hash, description, expiration_date, user_id) VALUES ($1, $2, $3, $4) RETURNING id, key_hash, description, expiration_date, user_id",
key_hash,
description,
expiration_date,
user.id
)
.fetch_one(&pool)
.await;
match row {
Ok(row) => {
match insert_api_key_into_db(&pool, key_hash, description, expiration_date, user.id).await {
Ok(mut api_key_response) => {
info!("Successfully created API key for user: {}", user.id);
Ok(Json(ApiKeyResponse {
id: row.id,
api_key: api_key,
description: description.to_string(),
expiration_date: expiration_date.to_string()
}))
},
// Restore generated api_key to response. It is not stored in database for security reasons.
api_key_response.api_key = api_key;
Ok(Json(api_key_response))
}
Err(err) => {
error!("Error creating API key for user {}: {}", user.id, err);
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": format!("Error creating API key: {}.", err) }))
))
},
}
}
}

View File

@ -1,4 +1,4 @@
use axum::{extract::{Extension, State}, Json, response::IntoResponse};
use axum::{extract::{Extension, State}, Json};
use axum::http::StatusCode;
use serde::Deserialize;
use serde_json::json;
@ -9,6 +9,7 @@ use validator::Validate;
use crate::models::todo::Todo;
use crate::models::user::User;
use crate::database::todos::insert_todo_into_db;
// Define the request body structure
#[derive(Deserialize, Validate, ToSchema)]
@ -19,24 +20,30 @@ pub struct TodoBody {
pub description: Option<String>,
}
// --- Route Handler ---
// Define the API endpoint
#[utoipa::path(
post,
path = "/todos",
tag = "todo",
security(
("jwt_token" = [])
),
request_body = TodoBody,
responses(
(status = 200, description = "Todo created successfully", body = Todo),
(status = 400, description = "Validation error", body = String),
(status = 401, description = "Unauthorized", body = serde_json::Value),
(status = 500, description = "Internal server error", body = String)
)
)]
#[instrument(skip(pool, user, todo))]
pub async fn post_todo(
State(pool): State<PgPool>,
State(pool): State<PgPool>,
Extension(user): Extension<User>,
Json(todo): Json<TodoBody>
) -> impl IntoResponse {
) -> Result<Json<Todo>, (StatusCode, Json<serde_json::Value>)> {
// Validate input
if let Err(errors) = todo.validate() {
let error_messages: Vec<String> = errors
@ -50,27 +57,8 @@ pub async fn post_todo(
));
}
let row = sqlx::query!(
"INSERT INTO todos (task, description, user_id)
VALUES ($1, $2, $3)
RETURNING id, task, description, user_id, creation_date, completion_date, completed",
todo.task,
todo.description,
user.id
)
.fetch_one(&pool)
.await;
match row {
Ok(row) => Ok(Json(Todo {
id: row.id,
task: row.task,
description: row.description,
user_id: row.user_id,
creation_date: row.creation_date,
completion_date: row.completion_date,
completed: row.completed,
})),
match insert_todo_into_db(&pool, todo.task, todo.description, user.id).await {
Ok(new_todo) => Ok(Json(new_todo)),
Err(_err) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Could not create a new todo." }))

View File

@ -0,0 +1,66 @@
use axum::{extract::State, Json};
use axum::http::StatusCode;
use serde_json::json;
use sqlx::postgres::PgPool;
use tracing::instrument;
use validator::Validate;
use crate::utils::auth::{hash_password, generate_totp_secret};
use crate::database::users::insert_user_into_db;
use crate::models::user::{UserInsertResponse, UserInsertBody};
// --- Route Handler ---
// Define the API endpoint
#[utoipa::path(
post,
path = "/users",
tag = "user",
security(
("jwt_token" = [])
),
request_body = UserInsertBody,
responses(
(status = 200, description = "User created successfully", body = UserInsertResponse),
(status = 400, description = "Validation error", body = String),
(status = 401, description = "Unauthorized", body = serde_json::Value),
(status = 500, description = "Internal server error", body = String)
)
)]
#[instrument(skip(pool, user))]
pub async fn post_user(
State(pool): State<PgPool>,
Json(user): Json<UserInsertBody>,
) -> Result<Json<UserInsertResponse>, (StatusCode, Json<serde_json::Value>)> {
// Validate input
if let Err(errors) = user.validate() {
let error_messages: Vec<String> = errors
.field_errors()
.iter()
.flat_map(|(_, errors)| errors.iter().map(|e| e.message.clone().unwrap_or_default().to_string()))
.collect();
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": error_messages.join(", ") }))
));
}
// Hash the password before saving it
let hashed_password = hash_password(&user.password)
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": "Failed to hash password." }))))?;
// Generate TOTP secret if totp is Some("true")
let totp_secret = if user.totp.as_deref() == Some("true") {
generate_totp_secret()
} else {
String::new() // or some other default value
};
match insert_user_into_db(&pool, &user.username, &user.email, &hashed_password, &totp_secret, 1, 1).await {
Ok(new_user) => Ok(Json(new_user)),
Err(_err) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Could not create the user." }))
)),
}
}

21
src/handlers/protected.rs Normal file
View File

@ -0,0 +1,21 @@
use axum::{Extension, Json, response::IntoResponse};
use crate::models::user::{User, UserGetResponse};
use tracing::instrument;
#[utoipa::path(
get,
path = "/protected",
tag = "protected",
security(
("jwt_token" = [])
),
responses(
(status = 200, description = "Protected endpoint accessed successfully", body = UserGetResponse),
(status = 401, description = "Unauthorized", body = String)
)
)]
#[instrument(skip(user))]
pub async fn protected(Extension(user): Extension<User>) -> impl IntoResponse {
Json(UserGetResponse {id:user.id,username:user.username,email:user.email, role_level: user.role_level, tier_level: user.tier_level, creation_date: user.creation_date
})
}

View File

@ -0,0 +1,128 @@
use axum::{extract::{Extension, Path, State}, Json};
use axum::http::StatusCode;
use chrono::{Duration, NaiveDate, Utc};
use serde_json::json;
use sqlx::postgres::PgPool;
use tracing::instrument;
use uuid::Uuid;
use validator::Validate;
use crate::utils::auth::{generate_api_key, hash_password};
use crate::models::user::User;
use crate::database::apikeys::{fetch_existing_apikey, insert_api_key_into_db, disable_apikey_in_db};
use crate::models::apikey::{ApiKeyRotateBody, ApiKeyRotateResponse, ApiKeyRotateResponseInfo};
#[utoipa::path(
post,
path = "/apikeys/rotate/{id}",
tag = "apikey",
security(
("jwt_token" = [])
),
request_body = ApiKeyRotateBody,
responses(
(status = 200, description = "API key rotated successfully", body = ApiKeyRotateResponse),
(status = 400, description = "Validation error", body = String),
(status = 404, description = "API key not found", body = String),
(status = 500, description = "Internal server error", body = String)
),
params(
("id" = String, Path, description = "API key identifier")
)
)]
#[instrument(skip(pool, user, apikeyrotatebody))]
pub async fn rotate_apikey(
State(pool): State<PgPool>,
Extension(user): Extension<User>,
Path(id): Path<String>,
Json(apikeyrotatebody): Json<ApiKeyRotateBody>
) -> Result<Json<ApiKeyRotateResponse>, (StatusCode, Json<serde_json::Value>)> {
// Validate input
if let Err(errors) = apikeyrotatebody.validate() {
let error_messages: Vec<String> = errors
.field_errors()
.iter()
.flat_map(|(_, errors)| errors.iter().map(|e| e.message.clone().unwrap_or_default().to_string()))
.collect();
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": error_messages.join(", ") }))
));
}
// Validate UUID format
let uuid = match Uuid::parse_str(&id) {
Ok(uuid) => uuid,
Err(_) => return Err((StatusCode::BAD_REQUEST, Json(json!({ "error": "Invalid API key identifier format" })))),
};
// Verify ownership of the old API key
let existing_key = fetch_existing_apikey(&pool, user.id, uuid).await.map_err(|e| {
tracing::error!("Database error: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": "Internal server error" })))
})?.ok_or_else(|| (StatusCode::NOT_FOUND, Json(json!({ "error": "API key not found or already disabled" }))))?;
// Validate expiration date format
let expiration_date = match &apikeyrotatebody.expiration_date {
Some(date_str) => NaiveDate::parse_from_str(date_str, "%Y-%m-%d")
.map_err(|_| (StatusCode::BAD_REQUEST, Json(json!({ "error": "Invalid expiration date format. Use YYYY-MM-DD" }))))?,
None => (Utc::now() + Duration::days(365 * 2)).naive_utc().date(),
};
// Validate expiration date is in the future
if expiration_date <= Utc::now().naive_utc().date() {
return Err((StatusCode::BAD_REQUEST, Json(json!({ "error": "Expiration date must be in the future" }))));
}
// Generate new secure API key
let api_key = generate_api_key();
let key_hash = hash_password(&api_key).map_err(|e| {
tracing::error!("Hashing error: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": "Internal server error" })))
})?;
// Create new key FIRST
let description = apikeyrotatebody.description.unwrap_or_else(||
format!("Rotated from key {} - {}", existing_key.id, Utc::now().format("%Y-%m-%d"))
);
let new_key = insert_api_key_into_db(&pool, key_hash, description, expiration_date, user.id).await.map_err(|e| {
tracing::error!("Database error: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": "Internal server error" })))
})?;
// Attempt to disable old key
let disable_result = match disable_apikey_in_db(&pool, uuid, user.id).await {
Ok(res) => res,
Err(e) => {
tracing::error!("Database error: {}", e);
// Rollback: Disable the newly created key
let _ = disable_apikey_in_db(&pool, new_key.id, user.id).await;
return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": "Internal server error" }))));
}
};
// Verify old key was actually disabled
if disable_result == 0 {
// Rollback: Disable new key
let _ = disable_apikey_in_db(&pool, new_key.id, user.id).await;
return Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": "Old API key not found or already disabled" }))
));
}
// Create the ApiKeyRotateResponse
let rotate_response = ApiKeyRotateResponse {
id: new_key.id,
api_key,
description: new_key.description,
expiration_date: expiration_date,
rotation_info: ApiKeyRotateResponseInfo {
original_key: existing_key.id,
disabled_at: Utc::now().date_naive(),
},
};
Ok(Json(rotate_response))
}

138
src/handlers/signin.rs Normal file
View File

@ -0,0 +1,138 @@
use axum::{
extract::State,
http::StatusCode,
Json,
};
use serde::Deserialize;
use serde_json::json;
use sqlx::PgPool;
use totp_rs::{Algorithm, TOTP};
use tracing::{info, instrument};
use utoipa::ToSchema;
use crate::utils::auth::{encode_jwt, verify_hash};
use crate::database::{apikeys::fetch_active_apikeys_by_user_id_from_db, users::fetch_user_by_email_from_db};
#[derive(Deserialize, ToSchema)]
pub struct SignInData {
pub email: String,
pub password: String,
pub totp: Option<String>,
}
/// User sign-in endpoint
///
/// This endpoint allows users to sign in using their email, password, and optionally a TOTP code.
///
/// # Parameters
/// - `State(pool)`: The shared database connection pool.
/// - `Json(user_data)`: The user sign-in data (email, password, and optional TOTP code).
///
/// # Returns
/// - `Ok(Json(serde_json::Value))`: A JSON response containing the JWT token if sign-in is successful.
/// - `Err((StatusCode, Json(serde_json::Value)))`: An error response if sign-in fails.
#[utoipa::path(
post,
path = "/signin",
tag = "auth",
request_body = SignInData,
responses(
(status = 200, description = "Successful sign-in", body = serde_json::Value),
(status = 400, description = "Bad request", body = serde_json::Value),
(status = 401, description = "Unauthorized", body = serde_json::Value),
(status = 500, description = "Internal server error", body = serde_json::Value)
)
)]
#[instrument(skip(pool, user_data))]
pub async fn signin(
State(pool): State<PgPool>,
Json(user_data): Json<SignInData>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
let user = match fetch_user_by_email_from_db(&pool, &user_data.email).await {
Ok(Some(user)) => user,
Ok(None) | Err(_) => return Err((
StatusCode::UNAUTHORIZED,
Json(json!({ "error": "Incorrect credentials." }))
)),
};
let api_key_hashes = match fetch_active_apikeys_by_user_id_from_db(&pool, user.id).await {
Ok(hashes) => hashes,
Err(_) => return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error." }))
)),
};
// Check API key first (async version)
let api_key_futures = api_key_hashes.iter().map(|api_key| {
let password = user_data.password.clone();
let hash = api_key.key_hash.clone();
async move {
verify_hash(&password, &hash)
.await
.unwrap_or(false)
}
});
let any_api_key_valid = futures::future::join_all(api_key_futures)
.await
.into_iter()
.any(|result| result);
// Check password (async version)
let password_valid = verify_hash(&user_data.password, &user.password_hash)
.await
.map_err(|_| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error." }))
))?;
let credentials_valid = any_api_key_valid || password_valid;
if !credentials_valid {
return Err((
StatusCode::UNAUTHORIZED,
Json(json!({ "error": "Incorrect credentials." }))
));
}
// Check TOTP if it's set up for the user
if let Some(totp_secret) = user.totp_secret {
match user_data.totp {
Some(totp_code) => {
let totp = TOTP::new(
Algorithm::SHA512,
8,
1,
30,
totp_secret.into_bytes(),
).map_err(|_| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error." }))
))?;
if !totp.check_current(&totp_code).unwrap_or(false) {
return Err((
StatusCode::UNAUTHORIZED,
Json(json!({ "error": "Invalid 2FA code." }))
));
}
},
None => return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": "2FA code required for this account." }))
)),
}
}
let email = user.email.clone();
let token = encode_jwt(user.email)
.map_err(|_| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error." }))
))?;
info!("User signed in: {}", email);
Ok(Json(json!({ "token": token })))
}

View File

@ -2,8 +2,7 @@
// Core modules for the configuration, TLS setup, and server creation
mod core;
use core::{config, tls, server};
use core::tls::TlsListener;
use core::{config, server};
// Other modules for database, routes, models, and middlewares
mod database;
@ -11,21 +10,53 @@ mod routes;
mod models;
mod middlewares;
mod handlers;
mod utils;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor;
use axum::serve;
use tokio::signal;
use axum_server::tls_rustls::RustlsConfig;
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("❌ Failed to install Ctrl+C handler.");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
println!("\n⏳ Shutdown signal received, starting graceful shutdown.");
}
fn display_additional_info(protocol: &str, ip: IpAddr, port: u16) {
println!("\n📖 Explore the API using Swagger ({0}://{1}:{2}/swagger)\n or import the OpenAPI spec ({0}://{1}:{2}/openapi.json).", protocol, ip, port);
println!("\n🩺 Ensure your Docker setup is reliable,\n by pointing its healthcheck to {0}://{1}:{2}/health", protocol, ip, port);
println!("\nPress [CTRL] + [C] to gracefully shutdown.");
}
#[tokio::main]
async fn main() {
dotenvy::dotenv().ok(); // Load environment variables from a .env file
tracing_subscriber::fmt::init(); // Initialize the logging system
// Print a cool startup message with ASCII art and emojis/
println!("{}", r#"
db 88
d88b ""
d8'`8b
@ -35,60 +66,129 @@ async fn main() {
d8' `8b ,d8" "8b, 88 "8a, ,a88 88 88 88
d8' `8b 8P' `Y8 88 `"YbbdP'Y8 88 88 88
Axium - An example API built with Rust, Axum, SQLx, and PostgreSQL
- GitHub: https://github.com/Riktastic/Axium
- GitHub: https://github.com/Riktastic/Axium
- Version: 1.0
"#);
println!("🚀 Starting Axium...");
println!("🦖 Starting Axium...");
// Retrieve server IP and port from the environment, default to 127.0.0.1:3000
let ip: IpAddr = config::get_env_with_default("SERVER_IP", "127.0.0.1")
.parse()
.expect("Invalid IP address format. Please provide a valid IPv4 address. For example 0.0.0.0 or 127.0.0.1.");
.expect(" Invalid IP address format.");
let port: u16 = config::get_env_u16("SERVER_PORT", 3000);
let socket_addr = SocketAddr::new(ip, port);
// Create the Axum app instance using the server configuration
let addr = SocketAddr::new(ip, port);
let app = server::create_server().await;
// Check if HTTPS is enabled in the environment configuration
if config::get_env_bool("SERVER_HTTPS_ENABLED", false) {
// If HTTPS is enabled, start the server with secure HTTPS.
let is_https = config::get_env_bool("SERVER_HTTPS_ENABLED", false);
let is_http2 = config::get_env_bool("SERVER_HTTPS_HTTP2_ENABLED", false);
let protocol = if is_https { "https" } else { "http" };
// Bind TCP listener for incoming connections
let tcp_listener = TcpListener::bind(socket_addr)
.await
.expect("❌ Failed to bind to socket. Port might allready be in use."); // Explicit error handling
// Load the TLS configuration for secure HTTPS connections
let tls_config = tls::load_tls_config();
let acceptor = TlsAcceptor::from(Arc::new(tls_config)); // Create a TLS acceptor
let listener = TlsListener {
inner: Arc::new(tcp_listener), // Wrap TCP listener in TlsListener
acceptor: acceptor,
if is_https {
// HTTPS
// Ensure that the crypto provider is initialized before using rustls
rustls::crypto::aws_lc_rs::default_provider()
.install_default()
.unwrap_or_else(|e| {
eprintln!("❌ Crypto provider initialization failed: {:?}", e);
std::process::exit(1);
});
// Get certificate and key file paths from environment variables
let cert_path = config::get_env("SERVER_HTTPS_CERT_FILE_PATH");
let key_path = config::get_env("SERVER_HTTPS_KEY_FILE_PATH");
// Set up Rustls config with HTTP/2 support
let (certs, key) = {
// Load certificate chain
let certs = tokio::fs::read(&cert_path)
.await
.unwrap_or_else(|e| {
eprintln!("❌ Failed to read certificate file: {}", e);
std::process::exit(1);
});
// Load private key
let key = tokio::fs::read(&key_path)
.await
.unwrap_or_else(|e| {
eprintln!("❌ Failed to read key file: {}", e);
std::process::exit(1);
});
// Parse certificates and private key
let certs = rustls_pemfile::certs(&mut &*certs)
.collect::<Result<Vec<_>, _>>()
.unwrap_or_else(|e| {
eprintln!("❌ Failed to parse certificates: {}", e);
std::process::exit(1);
});
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &*key)
.collect::<Result<Vec<_>, _>>()
.unwrap_or_else(|e| {
eprintln!("❌ Failed to parse private key: {}", e);
std::process::exit(1);
});
let key = keys.remove(0);
// Wrap the private key in the correct type
let key = rustls::pki_types::PrivateKeyDer::Pkcs8(key);
(certs, key)
};
println!("🔒 Server started with HTTPS at: https://{}:{}", ip, port);
let mut config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.unwrap_or_else(|e| {
eprintln!("❌ Failed to build TLS configuration: {}", e);
std::process::exit(1);
});
// Serve the app using the TLS listener (HTTPS)
serve(listener, app.into_make_service())
.await
.expect("❌ Server failed to start with HTTPS. Did you provide valid certificate and key files?");
if is_http2 {
config.alpn_protocols = vec![b"h2".to_vec()];
}
let rustls_config = RustlsConfig::from_config(Arc::new(config));
println!("🔒 Server started with HTTPS at: {}://{}:{}", protocol, ip, port);
display_additional_info(protocol, ip, port);
// Create the server future but don't await it yet
let server = axum_server::bind_rustls(addr, rustls_config)
.serve(app.into_make_service());
tokio::select! {
result = server => {
if let Err(e) = result {
eprintln!("❌ Server failed to start with HTTPS: {}", e);
}
},
_ = shutdown_signal() => {},
}
} else {
// If HTTPS is not enabled, start the server with non-secure HTTP.
// HTTP
// Bind TCP listener for non-secure HTTP connections
let listener = TcpListener::bind(socket_addr)
.await
.expect("❌ Failed to bind to socket. Port might allready be in use."); // Explicit error handling
println!("🔓 Server started with HTTP at: {}://{}:{}", protocol, ip, port);
println!("🔓 Server started with HTTP at: http://{}:{}", ip, port);
display_additional_info(protocol, ip, port);
// Serve the app using the non-secure TCP listener (HTTP)
serve(listener, app.into_make_service())
.await
.expect("❌ Server failed to start without HTTPS.");
// Create the server future but don't await it yet
let server = axum_server::bind(addr)
.serve(app.into_make_service());
tokio::select! {
result = server => {
if let Err(e) = result {
eprintln!("❌ Server failed to start with HTTP: {}", e);
}
},
_ = shutdown_signal() => {},
}
}
}
println!("\n✔️ Server has shut down gracefully.");
}

40
src/middlewares/README.md Normal file
View File

@ -0,0 +1,40 @@
# Middleware
This folder contains middleware functions used in Axium, providing essential utilities like authentication, authorization, and usage tracking.
## Overview
The `/src/middlewares` folder includes middleware implementations for role-based access control (RBAC), JWT authentication, rate limiting, and batched usage tracking.
### Key Components
- **Axum Middleware:** Utilizes Axum's middleware layer for request handling.
- **Moka Cache:** Provides caching for rate limits.
- **SQLx:** Facilitates database interactions.
- **UUID and Chrono:** Handles unique identifiers and timestamps.
## Middleware Files
This folder includes:
- **authorize:** Middleware to enforce role-based access by validating JWT tokens and checking user roles.
- **usage tracking:** Middleware to count and store usage metrics efficiently through batched database writes.
## Usage
To apply middleware, use Axum's `layer` method:
```rust
.route("/path", get(handler).layer(from_fn(|req, next| {
let allowed_roles = vec![1, 2];
authorize(req, next, allowed_roles)
})))
```
## Extending Middleware
Add new middleware by creating Rust functions that implement Axum's `Next` trait. Ensure proper logging, error handling, and unit tests.
## Dependencies
- [Axum](https://docs.rs/axum/latest/axum/)
- [SQLx](https://docs.rs/sqlx/latest/sqlx/)
- [Moka Cache](https://docs.rs/moka/latest/moka/)
## Contributing
Ensure new middleware is well-documented, includes error handling, and integrates with the existing architecture.
## License
This project is licensed under the MIT License.

View File

@ -1,5 +1,3 @@
use std::{collections::HashSet, env};
// Standard library imports for working with HTTP, environment variables, and other necessary utilities
use axum::{
body::Body,
@ -9,93 +7,24 @@ use axum::{
middleware::Next, // For adding middleware layers to the request handling pipeline
};
// Importing `State` for sharing application state (such as a database connection) across request handlers
use axum::extract::State;
// Importing necessary libraries for password hashing, JWT handling, and date/time management
use argon2::{
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, // For password hashing and verification
Argon2,
};
use chrono::{Duration, Utc}; // For working with time (JWT expiration, etc.)
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, TokenData, Validation}; // For encoding and decoding JWT tokens
use serde::Deserialize; // For serializing and deserializing JSON data
use serde_json::json; // For constructing JSON data
use sqlx::PgPool; // For interacting with PostgreSQL databases asynchronously
use totp_rs::{Algorithm, Secret, TOTP}; // For generating TOTP secrets and tokens
use rand::rngs::OsRng; // For generating random numbers
use sqlx::{PgPool, Postgres, QueryBuilder}; // For interacting with PostgreSQL databases asynchronously
use uuid::Uuid; // For working with UUIDs
use rand::Rng;
use tracing::{info, warn, error, instrument}; // For logging
use tracing::instrument; // For logging
use utoipa::ToSchema; // Import ToSchema for OpenAPI documentation
// New imports for caching and batched writes
use std::sync::Arc;
use std::time::Duration;
use moka::future::Cache;
use tokio::sync::Mutex;
use tokio::time::interval;
use chrono::Utc;
// Importing custom database query functions
use crate::database::{get_users::get_user_by_email, get_apikeys::get_active_apikeys_by_user_id, insert_usage::insert_usage};
use crate::database::users::fetch_user_by_email_from_db;
// Define the structure for JWT claims to be included in the token payload
#[derive(serde::Serialize, serde::Deserialize)]
struct Claims {
sub: String, // Subject (e.g., user ID or email)
iat: usize, // Issued At (timestamp)
exp: usize, // Expiration (timestamp)
iss: String, // Issuer (optional)
aud: String, // Audience (optional)
}
// Custom error type for handling authentication errors
pub struct AuthError {
message: String,
status_code: StatusCode, // HTTP status code to be returned with the error
}
// Function to verify a password against a stored hash using the Argon2 algorithm
#[instrument]
pub fn verify_hash(password: &str, hash: &str) -> Result<bool, argon2::password_hash::Error> {
let parsed_hash = PasswordHash::new(hash)?; // Parse the hash
// Verify the password using Argon2
Ok(Argon2::default().verify_password(password.as_bytes(), &parsed_hash).is_ok())
}
// Function to hash a password using Argon2 and a salt retrieved from the environment variables
#[instrument]
pub fn hash_password(password: &str) -> Result<String, argon2::password_hash::Error> {
// Get the salt from environment variables (must be set)
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default(); // Create an Argon2 instance
// Hash the password with the salt
let password_hash = argon2.hash_password(password.as_bytes(), &salt)?.to_string();
Ok(password_hash)
}
#[instrument]
pub fn generate_totp_secret() -> String {
let totp = TOTP::new(
Algorithm::SHA512,
8,
1,
30,
Secret::generate_secret().to_bytes().unwrap(),
).expect("Failed to create TOTP.");
let token = totp.generate_current().unwrap();
token
}
#[instrument]
pub fn generate_api_key() -> String {
let mut rng = rand::thread_rng();
(0..5)
.map(|_| {
(0..8)
.map(|_| format!("{:02x}", rng.gen::<u8>()))
.collect::<String>()
})
.collect::<Vec<String>>()
.join("-")
}
use crate::models::auth::AuthError; // Import the AuthError struct for error handling
use crate::utils::auth::decode_jwt;
// Implement the IntoResponse trait for AuthError to allow it to be returned as a response from the handler
impl IntoResponse for AuthError {
@ -107,75 +36,71 @@ impl IntoResponse for AuthError {
}
}
// Function to encode a JWT token for the given email address
#[instrument]
pub fn encode_jwt(email: String) -> Result<String, StatusCode> {
// Load secret key from environment variable for better security
let secret_key = env::var("JWT_SECRET_KEY")
.map_err(|_| {
error!("JWT_SECRET_KEY not set in environment variables");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let now = Utc::now();
let expire = Duration::hours(24);
let exp: usize = (now + expire).timestamp() as usize;
let iat: usize = now.timestamp() as usize;
let claim = Claims {
sub: email.clone(),
iat,
exp,
iss: "your_issuer".to_string(), // Add issuer if needed
aud: "your_audience".to_string(), // Add audience if needed
};
// Use a secure HMAC algorithm (e.g., HS256) for signing the token
encode(
&Header::new(jsonwebtoken::Algorithm::HS256),
&claim,
&EncodingKey::from_secret(secret_key.as_ref()),
)
.map_err(|e| {
error!("Failed to encode JWT: {:?}", e);
StatusCode::INTERNAL_SERVER_ERROR
})
// New struct for caching rate limit data
#[derive(Clone)]
struct CachedRateLimit {
tier_limit: i64,
request_count: i64,
}
// Function to decode a JWT token and extract the claims
#[instrument]
pub fn decode_jwt(jwt: String) -> Result<TokenData<Claims>, StatusCode> {
// Load secret key from environment variable for better security
let secret_key = env::var("JWT_SECRET_KEY")
.map_err(|_| {
error!("JWT_SECRET_KEY not set in environment variables");
StatusCode::INTERNAL_SERVER_ERROR
})?;
// New struct for batched usage records
#[derive(Clone, Debug)]
struct UsageRecord {
user_id: Uuid,
path: String,
}
// Set up validation rules (e.g., check if token has expired, is from a valid issuer, etc.)
let mut validation = Validation::default();
// Use a HashSet for the audience and issuer validation
let mut audience_set = HashSet::new();
audience_set.insert("your_audience".to_string());
// Global cache and batched writes queue
lazy_static::lazy_static! {
static ref RATE_LIMIT_CACHE: Cache<(Uuid, i32), CachedRateLimit> = Cache::builder()
.time_to_live(Duration::from_secs(300)) // 5 minutes cache lifetime
.build();
static ref USAGE_QUEUE: Arc<Mutex<Vec<UsageRecord>>> = Arc::new(Mutex::new(Vec::new()));
}
let mut issuer_set = HashSet::new();
issuer_set.insert("your_issuer".to_string());
// Function to start the background task for batched writes
pub fn start_batched_writes(pool: PgPool) {
tokio::spawn(async move {
let mut interval = interval(Duration::from_secs(60)); // Run every minute
loop {
interval.tick().await;
flush_usage_queue(&pool).await;
}
});
}
// Set up the validation with the HashSet for audience and issuer
validation.aud = Some(audience_set);
validation.iss = Some(issuer_set);
// Function to flush the usage queue and perform batch inserts
#[instrument(skip(pool))]
async fn flush_usage_queue(pool: &PgPool) {
let mut queue = USAGE_QUEUE.lock().await;
if queue.is_empty() {
return;
}
// Decode the JWT and extract the claims
decode::<Claims>(
&jwt,
&DecodingKey::from_secret(secret_key.as_ref()),
&validation,
)
.map_err(|e| {
warn!("Failed to decode JWT: {:?}", e);
StatusCode::UNAUTHORIZED
})
// Prepare batch insert
let mut query_builder: QueryBuilder<Postgres> = QueryBuilder::new(
"INSERT INTO usage (user_id, path, creation_date) "
);
query_builder.push_values(queue.iter(), |mut b, record| {
b.push_bind(record.user_id)
.push_bind(&record.path)
.push_bind(Utc::now());
});
// Execute batch insert
let result = query_builder.build().execute(pool).await;
match result {
Ok(_) => {
tracing::info!("Successfully inserted {} usage records in batch.", queue.len());
}
Err(e) => {
tracing::error!("Error inserting batch usage records: {}", e);
}
}
// Clear the queue
queue.clear();
}
// Middleware for role-based access control (RBAC)
@ -218,8 +143,12 @@ pub async fn authorize(
};
// Fetch the user from the database using the email from the decoded token
let current_user = match get_user_by_email(&pool, token_data.claims.sub).await {
Ok(user) => user,
let current_user = match fetch_user_by_email_from_db(&pool, &token_data.claims.sub).await {
Ok(Some(user)) => user,
Ok(None) => return Err(AuthError {
message: "User not found.".to_string(),
status_code: StatusCode::UNAUTHORIZED,
}),
Err(_) => return Err(AuthError {
message: "Unauthorized user.".to_string(),
status_code: StatusCode::UNAUTHORIZED,
@ -234,134 +163,41 @@ pub async fn authorize(
});
}
// Check rate limit.
// Check rate limit using cached data
check_rate_limit(&pool, current_user.id, current_user.tier_level).await?;
// Insert the usage record into the database
insert_usage(&pool, current_user.id, req.uri().path().to_string()).await
.map_err(|_| AuthError {
message: "Failed to insert usage record.".to_string(),
status_code: StatusCode::INTERNAL_SERVER_ERROR,
})?;
// Queue the usage record for batch insert instead of immediate insertion
USAGE_QUEUE.lock().await.push(UsageRecord {
user_id: current_user.id,
path: req.uri().path().to_string(),
});
// Insert the current user into the request extensions for use in subsequent handlers
req.extensions_mut().insert(current_user);
// Proceed to the next middleware or handler
Ok(next.run(req).await)
}
// Handler for user sign-in (authentication)
#[derive(Deserialize, ToSchema)]
pub struct SignInData {
pub email: String,
pub password: String,
pub totp: Option<String>,
}
/// User sign-in endpoint
///
/// This endpoint allows users to sign in using their email, password, and optionally a TOTP code.
///
/// # Parameters
/// - `State(pool)`: The shared database connection pool.
/// - `Json(user_data)`: The user sign-in data (email, password, and optional TOTP code).
///
/// # Returns
/// - `Ok(Json(serde_json::Value))`: A JSON response containing the JWT token if sign-in is successful.
/// - `Err((StatusCode, Json(serde_json::Value)))`: An error response if sign-in fails.
#[utoipa::path(
post,
path = "/sign_in",
request_body = SignInData,
responses(
(status = 200, description = "Successful sign-in", body = serde_json::Value),
(status = 400, description = "Bad request", body = serde_json::Value),
(status = 401, description = "Unauthorized", body = serde_json::Value),
(status = 500, description = "Internal server error", body = serde_json::Value)
)
)]
#[instrument(skip(pool, user_data))]
pub async fn sign_in(
State(pool): State<PgPool>,
Json(user_data): Json<SignInData>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
let user = match get_user_by_email(&pool, user_data.email).await {
Ok(user) => user,
Err(_) => return Err((
StatusCode::UNAUTHORIZED,
Json(json!({ "error": "Incorrect credentials." }))
)),
};
let api_key_hashes = match get_active_apikeys_by_user_id(&pool, user.id).await {
Ok(hashes) => hashes,
Err(_) => return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error." }))
)),
};
// Check API key first, then password
let credentials_valid = api_key_hashes.iter().any(|api_key| {
verify_hash(&user_data.password, &api_key.key_hash).unwrap_or(false)
}) || verify_hash(&user_data.password, &user.password_hash)
.map_err(|_| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error." }))
))?;
if !credentials_valid {
return Err((
StatusCode::UNAUTHORIZED,
Json(json!({ "error": "Incorrect credentials." }))
));
}
// Check TOTP if it's set up for the user
if let Some(totp_secret) = user.totp_secret {
match user_data.totp {
Some(totp_code) => {
let totp = TOTP::new(
Algorithm::SHA512,
8,
1,
30,
totp_secret.into_bytes(),
).map_err(|_| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error." }))
))?;
if !totp.check_current(&totp_code).unwrap_or(false) {
return Err((
StatusCode::UNAUTHORIZED,
Json(json!({ "error": "Invalid 2FA code." }))
));
}
},
None => return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": "2FA code required for this account." }))
)),
}
}
let email = user.email.clone();
let token = encode_jwt(user.email)
.map_err(|_| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error." }))
))?;
info!("User signed in: {}", email);
Ok(Json(json!({ "token": token })))
}
#[instrument(skip(pool))]
async fn check_rate_limit(pool: &PgPool, user_id: Uuid, tier_level: i32) -> Result<(), AuthError> {
// Get the user's tier requests_per_day
async fn check_rate_limit(pool: &PgPool, user_id: Uuid, tier_level: i32) -> Result<(), AuthError> {
// Try to get cached rate limit data
if let Some(cached) = RATE_LIMIT_CACHE.get(&(user_id, tier_level)).await {
if cached.request_count >= cached.tier_limit {
return Err(AuthError {
message: "Rate limit exceeded".to_string(),
status_code: StatusCode::TOO_MANY_REQUESTS,
});
}
// Update cache with incremented request count
RATE_LIMIT_CACHE.insert((user_id, tier_level), CachedRateLimit {
tier_limit: cached.tier_limit,
request_count: cached.request_count + 1,
}).await;
return Ok(());
}
// If not in cache, fetch from database
let tier_limit = sqlx::query!(
"SELECT requests_per_day FROM tiers WHERE level = $1",
tier_level
@ -372,7 +208,7 @@ async fn check_rate_limit(pool: &PgPool, user_id: Uuid, tier_level: i32) -> Resu
message: "Failed to fetch tier information".to_string(),
status_code: StatusCode::INTERNAL_SERVER_ERROR,
})?
.requests_per_day;
.requests_per_day as i64;
// Count user's requests for today
let request_count = sqlx::query!(
@ -386,9 +222,15 @@ async fn check_rate_limit(pool: &PgPool, user_id: Uuid, tier_level: i32) -> Resu
status_code: StatusCode::INTERNAL_SERVER_ERROR,
})?
.count
.unwrap_or(0); // Use 0 if count is NULL
.unwrap_or(0) as i64; // Use 0 if count is NULL
if request_count >= tier_limit as i64 {
// Cache the result
RATE_LIMIT_CACHE.insert((user_id, tier_level), CachedRateLimit {
tier_limit,
request_count,
}).await;
if request_count >= tier_limit {
return Err(AuthError {
message: "Rate limit exceeded".to_string(),
status_code: StatusCode::TOO_MANY_REQUESTS,
@ -396,4 +238,4 @@ async fn check_rate_limit(pool: &PgPool, user_id: Uuid, tier_level: i32) -> Resu
}
Ok(())
}
}

31
src/models/README.md Normal file
View File

@ -0,0 +1,31 @@
# Models
This folder contains data models used in Axium, primarily defined as Rust structs. These models are essential for data serialization, deserialization, and validation within the application.
## Overview
The `/src/models` folder contains various struct definitions that represent key data structures, such as JWT claims and custom error types.
### Key Components
- **Serde:** Provides serialization and deserialization capabilities.
- **Utoipa:** Facilitates API documentation through the `ToSchema` derive macro.
- **Axum StatusCode:** Used for HTTP status management within custom error types.
## Usage
Import and utilize these models across your API routes, handlers, and services. For example:
```rust
use crate::models::Claims;
use crate::models::AuthError;
```
## Extending Models
You can extend the existing models by adding more fields, or create new models as needed for additional functionality. Ensure that any new models are properly documented and derive necessary traits.
## Dependencies
- [Serde](https://docs.rs/serde/latest/serde/)
- [Utoipa](https://docs.rs/utoipa/latest/utoipa/)
- [Axum](https://docs.rs/axum/latest/axum/)
## Contributing
When adding new models, ensure they are well-documented, derive necessary traits, and integrate seamlessly with the existing codebase.
## License
This project is licensed under the MIT License.

View File

@ -3,66 +3,133 @@ use sqlx::FromRow;
use uuid::Uuid;
use chrono::NaiveDate;
use utoipa::ToSchema;
use validator::Validate;
use crate::utils::validate::validate_future_date;
/// Represents an API key in the system.
#[derive(Deserialize, Debug, Serialize, FromRow, Clone, ToSchema)]
#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL
#[sqlx(rename_all = "snake_case")]
pub struct ApiKey {
/// The unique id of the API key.
pub id: Uuid,
/// The hashed value of the API key.
pub key_hash: String,
/// The id of the user who owns the API key.
pub user_id: Uuid,
/// The description/name of the API key.
pub description: Option<String>,
/// The expiration date of the API key.
pub expiration_date: Option<NaiveDate>,
/// The creation date of the API key (default is the current date).
pub creation_date: NaiveDate,
/// Whether the API key is disabled (default is false).
pub disabled: bool,
/// Whether the API key has read access (default is true).
pub access_read: bool,
/// Whether the API key has modify access (default is false).
pub access_modify: bool,
}
/// Request body for creating a new API key.
#[derive(Deserialize, Validate, ToSchema)]
pub struct ApiKeyInsertBody {
/// Optional description of the API key (max 50 characters).
#[validate(length(min = 0, max = 50))]
pub description: Option<String>,
/// Optional expiration date of the API key (must be in the future).
#[validate(custom(function = "validate_future_date"))]
pub expiration_date: Option<String>,
}
/// Response body for creating a new API key.
#[derive(Serialize, ToSchema)]
pub struct ApiKeyInsertResponse {
/// The unique id of the created API key.
pub id: Uuid,
/// The actual API key value.
pub api_key: String,
/// The description of the API key.
pub description: String,
/// The expiration date of the API key.
pub expiration_date: String,
}
/// Response body for retrieving an API key.
#[derive(Serialize, ToSchema)]
pub struct ApiKeyResponse {
/// The unique id of the API key.
pub id: Uuid,
/// The id of the user who owns the API key.
pub user_id: Uuid,
/// The description of the API key.
pub description: Option<String>,
/// The expiration date of the API key.
pub expiration_date: Option<NaiveDate>,
/// The creation date of the API key.
pub creation_date: NaiveDate,
}
/// Response body for retrieving an API key by its ID.
#[derive(Serialize, ToSchema)]
pub struct ApiKeyByIDResponse {
/// The unique id of the API key.
pub id: Uuid,
/// The description of the API key.
pub description: Option<String>,
/// The expiration date of the API key.
pub expiration_date: Option<NaiveDate>,
/// The creation date of the API key.
pub creation_date: NaiveDate,
}
/// Response body for retrieving active API keys for a user.
#[derive(Serialize, ToSchema)]
pub struct ApiKeyGetActiveForUserResponse {
/// The unique id of the API key.
pub id: Uuid,
/// The description of the API key.
pub description: Option<String>,
}
/// Response body for retrieving API keys by user ID.
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct ApiKeyByUserIDResponse {
/// The unique id of the API key.
pub id: Uuid,
/// The hashed value of the API key.
pub key_hash: String,
/// The expiration date of the API key.
pub expiration_date: Option<NaiveDate>,
}
/// Request body for creating a new API key (deprecated).
#[derive(serde::Serialize, ToSchema)]
pub struct ApiKeyNewBody {
/// The description of the API key.
pub description: Option<String>,
/// The expiration date of the API key.
pub expiration_date: Option<NaiveDate>
}
#[derive(serde::Serialize, ToSchema)]
pub struct ApiKeyResponse {
#[derive(Serialize, ToSchema)]
pub struct ApiKeyRotateResponse {
pub id: Uuid,
pub user_id: Uuid,
pub description: Option<String>,
pub expiration_date: Option<NaiveDate>,
pub creation_date: NaiveDate,
pub api_key: String,
pub description: String,
pub expiration_date: NaiveDate,
pub rotation_info: ApiKeyRotateResponseInfo,
}
#[derive(serde::Serialize, ToSchema)]
pub struct ApiKeyByIDResponse {
pub id: Uuid,
pub description: Option<String>,
pub expiration_date: Option<NaiveDate>,
pub creation_date: NaiveDate,
#[derive(Serialize, ToSchema)]
pub struct ApiKeyRotateResponseInfo {
pub original_key: Uuid,
pub disabled_at: NaiveDate,
}
#[derive(sqlx::FromRow, ToSchema)]
pub struct ApiKeyByUserIDResponse {
pub id: Uuid,
pub key_hash: String,
pub expiration_date: Option<NaiveDate>,
#[derive(Deserialize, Validate, ToSchema)]
pub struct ApiKeyRotateBody {
#[validate(length(min = 1, max = 255))]
pub description: Option<String>,
pub expiration_date: Option<String>,
}

41
src/models/auth.rs Normal file
View File

@ -0,0 +1,41 @@
use axum::http::StatusCode;
use utoipa::ToSchema;
use serde::{Serialize, Deserialize};
/// Represents the claims to be included in a JWT payload.
#[derive(Serialize, Deserialize, ToSchema)]
pub struct Claims {
/// Subject of the token (e.g., user ID or email).
pub sub: String,
/// Timestamp when the token was issued.
pub iat: usize,
/// Timestamp when the token will expire.
pub exp: usize,
/// Issuer of the token (optional).
pub iss: String,
/// Intended audience for the token (optional).
pub aud: String,
}
/// Custom error type for handling authentication-related errors.
pub struct AuthError {
/// Descriptive error message.
pub message: String,
/// HTTP status code to be returned with the error.
pub status_code: StatusCode,
}
// Implement Display trait for AuthError if needed
// impl std::fmt::Display for AuthError {
// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// write!(f, "{}", self.message)
// }
// }
// Implement Error trait for AuthError if needed
// impl std::error::Error for AuthError {}

View File

@ -1,11 +1,16 @@
use utoipa::ToSchema;
use serde::Serialize;
#[derive(ToSchema)]
/// Represents a successful response from the API.
#[derive(Serialize, ToSchema)]
pub struct SuccessResponse {
message: String,
/// A message describing the successful operation.
pub message: String,
}
#[derive(ToSchema)]
/// Represents an error response from the API.
#[derive(Serialize, ToSchema)]
pub struct ErrorResponse {
error: String,
/// A description of the error that occurred.
pub error: String,
}

52
src/models/health.rs Normal file
View File

@ -0,0 +1,52 @@
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
/// Represents the overall health status of the system.
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct HealthResponse {
/// CPU usage information.
pub cpu_usage: CpuUsage,
/// Database status information.
pub database: DatabaseStatus,
/// Disk usage information.
pub disk_usage: DiskUsage,
/// Memory status information.
pub memory: MemoryStatus,
}
/// Represents CPU usage information.
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct CpuUsage {
/// Percentage of CPU available, represented as a string.
#[serde(rename = "available_percentage")]
pub available_pct: String,
/// Status of the CPU (e.g., "OK", "Warning", "Critical").
pub status: String,
}
/// Represents database status information.
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct DatabaseStatus {
/// Status of the database (e.g., "Connected", "Disconnected").
pub status: String,
}
/// Represents disk usage information.
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct DiskUsage {
/// Status of the disk (e.g., "OK", "Warning", "Critical").
pub status: String,
/// Percentage of disk space used, represented as a string.
#[serde(rename = "used_percentage")]
pub used_pct: String,
}
/// Represents memory status information.
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct MemoryStatus {
/// Amount of available memory in megabytes.
#[serde(rename = "available_mb")]
pub available_mb: i64,
/// Status of the memory (e.g., "OK", "Warning", "Critical").
pub status: String,
}

View File

@ -7,4 +7,10 @@ pub mod role;
/// Module for to-do related models.
pub mod todo;
/// Module for documentation related models.
pub mod documentation;
pub mod documentation;
/// Module for authentication related models.
pub mod auth;
/// Module for the health endpoint related models.
pub mod health;
/// Module for the health endpoint related models.
pub mod usage;

18
src/models/usage.rs Normal file
View File

@ -0,0 +1,18 @@
use serde::Serialize;
use utoipa::ToSchema;
/// Represents the usage statistics for the last 24 hours.
#[derive(Debug, Serialize, ToSchema)]
pub struct UsageResponseLastDay {
/// The number of requests made in the last 24 hours.
#[serde(rename = "requests_last_24_hours")]
pub count: i64
}
/// Represents the usage statistics for the last 7 days.
#[derive(Debug, Serialize, ToSchema)]
pub struct UsageResponseLastWeek {
/// The number of requests made in the last 7 days.
#[serde(rename = "requests_last_7_days")]
pub count: i64
}

View File

@ -1,59 +1,83 @@
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use uuid::Uuid;
use chrono::NaiveDate;
use chrono::{NaiveDate, NaiveDateTime};
use utoipa::ToSchema;
use validator::Validate;
use crate::utils::validate::{validate_password, validate_username};
/// Represents a user in the system.
#[derive(Deserialize, Debug, Serialize, FromRow, Clone, ToSchema)]
#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL
#[sqlx(rename_all = "snake_case")]
pub struct User {
/// The unique identifier for the user.
pub id: Uuid,
/// The username of the user.
pub username: String,
/// The email of the user.
pub email: String,
/// The hashed password for the user.
pub password_hash: String,
/// The TOTP secret for the user.
pub totp_secret: Option<String>,
/// Current role of the user.
pub role_level: i32,
/// Current tier level of the user.
pub tier_level: i32,
/// Date when the user was created.
pub creation_date: Option<NaiveDate>, // Nullable, default value in SQL is CURRENT_DATE
pub creation_date: Option<NaiveDate>,
}
/// Represents a user in the system.
/// Represents a user response for GET requests.
#[derive(Deserialize, Debug, Serialize, FromRow, Clone, ToSchema)]
#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL
pub struct UserResponse {
#[sqlx(rename_all = "snake_case")]
pub struct UserGetResponse {
/// The unique identifier for the user.
pub id: Uuid,
/// The username of the user.
pub username: String,
/// The email of the user.
pub email: String,
/// Current role of the user.
pub role_level: i32,
/// Current tier level of the user.
pub tier_level: i32,
/// Date when the user was created.
pub creation_date: Option<NaiveDate>, // Nullable, default value in SQL is CURRENT_DATE
pub creation_date: Option<NaiveDate>,
}
/// Request body for inserting a new user.
#[derive(Deserialize, Validate, ToSchema)]
pub struct UserInsertBody {
/// The username of the new user.
#[validate(length(min = 3, max = 50), custom(function = "validate_username"))]
pub username: String,
/// The email of the new user.
#[validate(email)]
pub email: String,
/// The password for the new user.
#[validate(custom(function = "validate_password"))]
pub password: String,
/// Optional TOTP secret for the new user.
pub totp: Option<String>,
}
/// Response body for a successful user insertion.
#[derive(Serialize, ToSchema)]
pub struct UserInsertResponse {
/// The unique identifier for the newly created user.
pub id: Uuid,
/// The username of the newly created user.
pub username: String,
/// The email of the newly created user.
pub email: String,
/// The TOTP secret for the newly created user, if provided.
pub totp_secret: Option<String>,
/// The role level assigned to the newly created user.
pub role_level: i32,
/// The tier level assigned to the newly created user.
pub tier_level: i32,
/// The creation date and time of the newly created user.
pub creation_date: NaiveDateTime,
}

43
src/routes/README.md Normal file
View File

@ -0,0 +1,43 @@
# Routes
This folder contains the route definitions for Axium, built using [Axum](https://docs.rs/axum/latest/axum/) and [SQLx](https://docs.rs/sqlx/latest/sqlx/).
## Overview
The `/src/routes` folder manages the routing for various API endpoints, handling operations such as CRUD functionality, usage statistics, and more. Each route is associated with its handler and protected by an authorization middleware.
### Key Components
- **Axum Router:** Sets up API routes and manages HTTP requests, see mod.rs.
- **SQLx PgPool:** Provides database connection pooling.
- **Authorization Middleware:** Ensures secure access based on user roles.
## Middleware
The `authorize` middleware is defined in `src/middlewares/auth.rs`. It takes the request, a next handler, and a vector of allowed roles. It verifies that the user has one of the required roles before forwarding the request. Usage example:
```rust
.route("/path", get(handler).layer(from_fn(|req, next| {
let allowed_roles = vec![1, 2];
authorize(req, next, allowed_roles)
})))
```
Ensure that the `authorize` function is imported and applied to each route that requires restricted access.
The `authorize` middleware ensures users have appropriate roles before accessing certain routes.
## Handlers
Each route delegates its logic to handler functions found in the `src/handlers` folder, ensuring separation of concerns.
## Usage
Integrate these routes into your main application router by nesting them appropriately:
```rust
let app = Router::new()
.nest("/todos", create_todo_routes())
.nest("/usage", create_usage_routes());
```
## Dependencies
- [Axum](https://docs.rs/axum/latest/axum/)
- [SQLx](https://docs.rs/sqlx/latest/sqlx/)
## Contributing
Add new route files, update existing routes, or enhance the middleware and handlers. Document any changes for clarity.
## License
This project is licensed under the MIT License.

29
src/routes/apikey.rs Normal file
View File

@ -0,0 +1,29 @@
use axum::{
Router,
routing::{get, post, delete},
middleware::from_fn,
};
use sqlx::PgPool;
use crate::middlewares::auth::authorize;
use crate::handlers::{get_apikeys::{get_all_apikeys, get_apikeys_by_id}, post_apikeys::post_apikey, rotate_apikeys::rotate_apikey, delete_apikeys::delete_apikey_by_id};
pub fn create_apikey_routes() -> Router<PgPool> {
Router::new()
.route("/all", get(get_all_apikeys).layer(from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)})))
.route("/new", post(post_apikey).layer(from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)
})))
.route("/{id}", get(get_apikeys_by_id).layer(from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)})))
.route("/{id}", delete(delete_apikey_by_id).layer(from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)})))
.route("/rotate/{id}", post(rotate_apikey).layer(from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)})))
}

19
src/routes/auth.rs Normal file
View File

@ -0,0 +1,19 @@
use axum::{
Router,
routing::post,
routing::get,
};
use sqlx::PgPool;
use crate::handlers::{signin::signin, protected::protected};
use crate::middlewares::auth::authorize;
use axum::middleware::from_fn;
pub fn create_auth_routes() -> Router<PgPool> {
Router::new()
.route("/signin", post(signin))
.route("/protected", get(protected).layer(from_fn(|req, next| {
let allowed_roles = vec![1, 2];
authorize(req, next, allowed_roles)
})))
}

View File

@ -1,84 +0,0 @@
use axum::{extract::{Extension, State}, Json};
use axum::http::StatusCode;
use axum::response::IntoResponse;
use serde::Serialize;
use serde_json::json;
use sqlx::postgres::PgPool;
use tracing::instrument;
use utoipa::ToSchema;
use crate::models::user::*;
#[derive(Debug, Serialize, ToSchema)]
pub struct UsageResponseLastDay {
#[serde(rename = "requests_last_24_hours")]
pub count: i64 // or usize depending on your count type
}
// Get usage for the last 24 hours
#[utoipa::path(
get,
path = "/usage/lastday",
tag = "usage",
responses(
(status = 200, description = "Successfully fetched usage for the last 24 hours", body = UsageResponseLastDay),
(status = 500, description = "Internal server error")
)
)]
#[instrument(skip(pool))]
pub async fn get_usage_last_day(
State(pool): State<PgPool>,
Extension(user): Extension<User>, // Extract current user from the request extensions
) -> impl IntoResponse {
let result = sqlx::query!("SELECT count(*) FROM usage WHERE user_id = $1 AND creation_date > NOW() - INTERVAL '24 hours';", user.id)
.fetch_one(&pool) // Borrow the connection pool
.await;
match result {
Ok(row) => {
let count = row.count.unwrap_or(0) as i64;
Ok(Json(json!({ "requests_last_24_hours": count })))
},
Err(_err) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Could not fetch the usage data." }))),
),
}
}
#[derive(Debug, Serialize, ToSchema)]
pub struct UsageResponseLastWeek {
#[serde(rename = "requests_last_7_days")]
pub count: i64 // or usize depending on your count type
}
// Get usage for the last 7 days
#[utoipa::path(
get,
path = "/usage/lastweek",
tag = "usage",
responses(
(status = 200, description = "Successfully fetched usage for the last 7 days", body = UsageResponseLastDay),
(status = 500, description = "Internal server error")
)
)]
#[instrument(skip(pool))]
pub async fn get_usage_last_week(
State(pool): State<PgPool>,
Extension(user): Extension<User>, // Extract current user from the request extensions
) -> impl IntoResponse {
let result = sqlx::query!("SELECT count(*) FROM usage WHERE user_id = $1 AND creation_date > NOW() - INTERVAL '7 days';", user.id)
.fetch_one(&pool) // Borrow the connection pool
.await;
match result {
Ok(row) => {
let count = row.count.unwrap_or(0) as i64;
Ok(Json(json!({ "requests_last_7_days": count })))
},
Err(_err) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Could not fetch the usage data." }))),
),
}
}

View File

@ -1,121 +0,0 @@
use axum::extract::{Path, State};
use axum::http::StatusCode;
use axum::Json;
use axum::response::IntoResponse;
use serde_json::json;
use sqlx::postgres::PgPool;
use tracing::instrument;
use uuid::Uuid;
use crate::models::user::*;
// Get all users
#[utoipa::path(
get,
path = "/users/all",
tag = "user",
responses(
(status = 200, description = "Successfully fetched all users", body = [UserResponse]),
(status = 500, description = "Internal server error")
)
)]
#[instrument(skip(pool))]
pub async fn get_all_users(State(pool): State<PgPool>) -> impl IntoResponse {
let users = sqlx::query_as!(UserResponse, "SELECT id, username, email, role_level, tier_level, creation_date FROM users")
.fetch_all(&pool)
.await;
match users {
Ok(users) => Ok(Json(users)),
Err(_err) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Could not fetch the users details." })),
)),
}
}
// Get a single user by ID
#[utoipa::path(
get,
path = "/users/{id}",
tag = "user",
params(
("id" = String, Path, description = "User ID")
),
responses(
(status = 200, description = "Successfully fetched user by ID", body = UserResponse),
(status = 400, description = "Invalid UUID format"),
(status = 404, description = "User not found"),
(status = 500, description = "Internal server error")
)
)]
#[instrument(skip(pool))]
pub async fn get_users_by_id(
State(pool): State<PgPool>,
Path(id): Path<String>,
) -> impl IntoResponse {
let uuid = match Uuid::parse_str(&id) {
Ok(uuid) => uuid,
Err(_) => return Err((StatusCode::BAD_REQUEST, Json(json!({ "error": "Invalid UUID format." })))),
};
let user = sqlx::query_as!(UserResponse, "SELECT id, username, email, role_level, tier_level, creation_date FROM users WHERE id = $1", uuid)
.fetch_optional(&pool)
.await;
match user {
Ok(Some(user)) => Ok(Json(user)),
Ok(None) => Err((
StatusCode::NOT_FOUND,
Json(json!({ "error": format!("User with ID '{}' not found", id) })),
)),
Err(_err) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Could not fetch the users details." })),
)),
}
}
// Get a single user by username
// pub async fn get_user_by_username(
// State(pool): State<PgPool>,
// Path(username): Path<String>,
// ) -> impl IntoResponse {
// let user = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_level, tier_level, creation_date FROM users WHERE username = $1", username)
// .fetch_optional(&pool)
// .await;
// match user {
// Ok(Some(user)) => Ok(Json(user)),
// Ok(None) => Err((
// StatusCode::NOT_FOUND,
// Json(json!({ "error": format!("User with username '{}' not found", username) })),
// )),
// Err(err) => Err((
// StatusCode::INTERNAL_SERVER_ERROR,
// Json(json!({ "error": "Could not fetch the users details." })),
// )),
// }
// }
// Get a single user by email
// pub async fn get_user_by_email(
// State(pool): State<PgPool>,
// Path(email): Path<String>,
// ) -> impl IntoResponse {
// let user = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_level, tier_level, creation_date FROM users WHERE email = $1", email)
// .fetch_optional(&pool)
// .await;
// match user {
// Ok(Some(user)) => Ok(Json(user)),
// Ok(None) => Err((
// StatusCode::NOT_FOUND,
// Json(json!({ "error": format!("User with email '{}' not found", email) })),
// )),
// Err(err) => Err((
// StatusCode::INTERNAL_SERVER_ERROR,
// Json(json!({ "error": "Could not fetch the users details." })),
// )),
// }
// }

12
src/routes/health.rs Normal file
View File

@ -0,0 +1,12 @@
use axum::{
Router,
routing::get,
};
use sqlx::PgPool;
use crate::handlers::get_health::get_health;
pub fn create_health_route() -> Router<PgPool> {
Router::new()
.route("/health", get(get_health))
}

View File

@ -1,80 +1,12 @@
use axum::response::{IntoResponse, Html};
use tracing::instrument; // For logging
use axum::{
Router,
routing::get,
};
use sqlx::PgPool;
#[instrument]
pub async fn homepage() -> impl IntoResponse {
Html(r#"
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Welcome to Axium!</title>
<style>
body {
font-family: 'Arial', sans-serif;
background-color: #1e1e2e;
color: #ffffff;
text-align: center;
padding: 40px;
}
a {
color: #00bcd4;
text-decoration: none;
font-weight: bold;
}
a:hover {
text-decoration: underline;
}
.container {
max-width: 800px;
margin: auto;
padding: 20px;
background: #282a36;
border-radius: 8px;
box-shadow: 0 0 15px rgba(0, 0, 0, 0.2);
text-align: center;
}
h1 {
font-size: 1.2em;
white-space: pre;
font-family: monospace;
}
.motto {
margin-top: 10px;
font-size: 1em;
font-style: italic;
}
ul {
list-style-type: none;
padding: 0;
text-align: left;
display: inline-block;
}
li {
margin: 10px 0;
}
</style>
</head>
<body>
<div class="container">
<h1>
db 88
d88b ""
d8'`8b
d8' `8b 8b, ,d8 88 88 88 88,dPYba,,adPYba,
d8YaaaaY8b `Y8, ,8P' 88 88 88 88P' "88" "8a
d8""""""""8b )888( 88 88 88 88 88 88
d8' `8b ,d8" "8b, 88 "8a, ,a88 88 88 88
d8' `8b 8P' `Y8 88 `"YbbdP'Y8 88 88 88
</h1>
<p class="motto">An example API built with Rust, Axum, SQLx, and PostgreSQL</p>
<ul>
<li>🚀 Check out all endpoints by visiting <a href="/swagger-ui">Swagger</a>, or import the <a href="/openapi.json">OpenAPI</a> file.</li>
<li> Do not forget to update your Docker Compose configuration with a health check. Just point it to: <a href="/health">/health</a></li>
</ul>
</div>
</body>
</html>
"#)
use crate::handlers::homepage::homepage;
pub fn create_homepage_route() -> Router<PgPool> {
Router::new()
.route("/", get(homepage))
}

View File

@ -1,45 +1,52 @@
// Module declarations for different route handlers
pub mod homepage;
pub mod get_todos;
pub mod get_users;
pub mod get_apikeys;
pub mod get_usage;
pub mod post_todos;
pub mod post_users;
pub mod post_apikeys;
pub mod rotate_apikeys;
pub mod get_health;
pub mod delete_users;
pub mod delete_todos;
pub mod delete_apikeys;
pub mod protected;
pub mod apikey;
pub mod auth;
pub mod health;
pub mod todo;
pub mod usage;
pub mod user;
// Re-exporting modules to make their contents available at this level
pub use homepage::*;
pub use get_todos::*;
pub use get_users::*;
pub use get_apikeys::*;
pub use get_usage::*;
pub use rotate_apikeys::*;
pub use post_todos::*;
pub use post_users::*;
pub use post_apikeys::*;
pub use get_health::*;
pub use delete_users::*;
pub use delete_todos::*;
pub use delete_apikeys::*;
pub use protected::*;
use axum::{
Router,
routing::{get, post, delete}
};
use axum::Router;
use sqlx::PgPool;
use tower_http::trace::TraceLayer;
use utoipa::OpenApi;
use utoipa::openapi::security::{SecurityScheme, HttpBuilder, HttpAuthScheme};
use utoipa::{Modify, OpenApi};
use utoipa_swagger_ui::SwaggerUi;
use crate::middlewares::auth::{sign_in, authorize};
pub mod handlers {
pub use crate::handlers::*;
}
pub mod models {
pub use crate::models::*;
}
use self::{
todo::create_todo_routes,
user::create_user_routes,
apikey::create_apikey_routes,
usage::create_usage_routes,
auth::create_auth_routes,
homepage::create_homepage_route,
health::create_health_route,
};
struct SecurityAddon;
impl Modify for SecurityAddon {
fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
let components = openapi.components.get_or_insert_with(Default::default);
components.add_security_scheme(
"jwt_token",
SecurityScheme::Http(
HttpBuilder::new()
.scheme(HttpAuthScheme::Bearer)
.bearer_format("JWT")
.description(Some("Use JWT token obtained from /signin endpoint"))
.build()
)
);
}
}
// Define the OpenAPI documentation structure
#[derive(OpenApi)]
@ -57,33 +64,54 @@ use crate::middlewares::auth::{sign_in, authorize};
)
),
paths(
get_all_users,
get_users_by_id,
get_all_apikeys,
get_apikeys_by_id,
get_usage_last_day,
get_usage_last_week,
get_all_todos,
get_todos_by_id,
get_health,
post_user,
post_apikey,
post_todo,
rotate_apikey,
delete_user_by_id,
delete_apikey_by_id,
delete_todo_by_id,
protected,
//sign_in, // Add sign_in path
handlers::get_users::get_all_users,
handlers::get_users:: get_users_by_id,
handlers::get_apikeys::get_all_apikeys,
handlers::get_apikeys::get_apikeys_by_id,
handlers::get_usage::get_usage_last_day,
handlers::get_usage::get_usage_last_week,
handlers::get_todos::get_all_todos,
handlers::get_todos::get_todos_by_id,
handlers::get_health::get_health,
handlers::post_users::post_user,
handlers::post_apikeys::post_apikey,
handlers::post_todos::post_todo,
handlers::rotate_apikeys::rotate_apikey,
handlers::delete_users::delete_user_by_id,
handlers::delete_apikeys::delete_apikey_by_id,
handlers::delete_todos::delete_todo_by_id,
handlers::protected::protected,
handlers::signin::signin,
),
components(
schemas(
UserResponse,
// ApiKeyResponse,
// ApiKeyByIDResponse,
// Todo,
// SignInData,
// ...add other schemas as needed...
models::apikey::ApiKey,
models::apikey::ApiKeyInsertBody,
models::apikey::ApiKeyInsertResponse,
models::apikey::ApiKeyResponse,
models::apikey::ApiKeyByIDResponse,
models::apikey::ApiKeyGetActiveForUserResponse,
models::apikey::ApiKeyByUserIDResponse,
models::apikey::ApiKeyNewBody,
models::apikey::ApiKeyRotateResponse,
models::apikey::ApiKeyRotateResponseInfo,
models::apikey::ApiKeyRotateBody,
models::auth::Claims,
models::documentation::SuccessResponse,
models::documentation::ErrorResponse,
models::health::HealthResponse,
models::health::CpuUsage,
models::health::DatabaseStatus,
models::health::DiskUsage,
models::health::MemoryStatus,
models::role::Role,
models::todo::Todo,
models::usage::UsageResponseLastDay,
models::usage::UsageResponseLastWeek,
models::user::User,
models::user::UserGetResponse,
models::user::UserInsertBody,
models::user::UserInsertResponse
)
),
tags(
@ -98,94 +126,24 @@ struct ApiDoc;
/// Function to create and configure all routes
pub fn create_routes(database_connection: PgPool) -> Router {
// Authentication routes
let auth_routes = Router::new()
.route("/signin", post(sign_in))
.route("/protected", get(protected).route_layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1, 2];
authorize(req, next, allowed_roles)
})));
// User-related routes
let user_routes = Router::new()
.route("/all", get(get_all_users).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![2];
authorize(req, next, allowed_roles)})))
.route("/new", post(post_user).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![2];
authorize(req, next, allowed_roles)
})))
.route("/{id}", get(get_users_by_id).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![2];
authorize(req, next, allowed_roles)})))
.route("/{id}", delete(delete_user_by_id).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![2];
authorize(req, next, allowed_roles)})));
// API key-related routes
let apikey_routes = Router::new()
.route("/all", get(get_all_apikeys).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)})))
.route("/new", post(post_apikey).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)
})))
.route("/{id}", get(get_apikeys_by_id).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)})))
.route("/{id}", delete(delete_apikey_by_id).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)})))
.route("/rotate/{id}", post(rotate_apikey).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)})));
// Usage related routes
let usage_routes = Router::new()
.route("/lastday", get(get_usage_last_day).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)})))
.route("/lastweek", get(get_usage_last_week).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)
})));
// Todo-related routes
let todo_routes = Router::new()
.route("/all", get(get_all_todos).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1, 2];
authorize(req, next, allowed_roles)})))
.route("/new", post(post_todo).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1, 2];
authorize(req, next, allowed_roles)
})))
.route("/{id}", get(get_todos_by_id).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1, 2];
authorize(req, next, allowed_roles)})))
.route("/{id}", delete(delete_todo_by_id).layer(axum::middleware::from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)})));
// Documentation:
// Create OpenAPI specification
let openapi = ApiDoc::openapi();
// Create Swagger UI
let swagger_ui = SwaggerUi::new("/swagger-ui")
let swagger_ui = SwaggerUi::new("/swagger")
.url("/openapi.json", openapi.clone());
// Combine all routes and add middleware
Router::new()
.route("/", get(homepage))
.merge(auth_routes) // Add authentication routes
.merge(create_homepage_route())
.merge(create_auth_routes())
.merge(swagger_ui)
.nest("/users", user_routes) // Add user routes under /users
.nest("/apikeys", apikey_routes) // Add API key routes under /apikeys
.nest("/usage", usage_routes) // Add usage routes under /usage
.nest("/todos", todo_routes) // Add todo routes under /todos
.route("/health", get(get_health)) // Add health check route
.layer(axum::Extension(database_connection.clone())) // Add database connection to all routes
.with_state(database_connection) // Add database connection as state
.layer(TraceLayer::new_for_http()) // Add tracing middleware
}
.nest("/users", create_user_routes())
.nest("/apikeys", create_apikey_routes())
.nest("/usage", create_usage_routes())
.nest("/todos", create_todo_routes())
.merge(create_health_route())
.layer(axum::Extension(database_connection.clone()))
.with_state(database_connection)
.layer(TraceLayer::new_for_http())
}

View File

@ -1,98 +0,0 @@
use axum::{extract::State, Json, response::IntoResponse};
use axum::http::StatusCode;
use serde::{Deserialize, Serialize};
use serde_json::json;
use sqlx::postgres::PgPool;
use tracing::instrument;
use uuid::Uuid;
use utoipa::ToSchema;
use validator::Validate;
use crate::handlers::validate::{validate_password, validate_username};
use crate::middlewares::auth::{hash_password, generate_totp_secret};
// Define the request body structure
#[derive(Deserialize, Validate, ToSchema)]
pub struct UserBody {
#[validate(length(min = 3, max = 50), custom(function = "validate_username"))]
pub username: String,
#[validate(email)]
pub email: String,
#[validate(custom(function = "validate_password"))]
pub password: String,
pub totp: Option<String>,
}
// Define the response body structure
#[derive(Serialize, ToSchema)]
pub struct UserResponse {
pub id: Uuid,
pub username: String,
pub email: String,
pub totp_secret: Option<String>,
pub role_level: i32,
}
// Define the API endpoint
#[utoipa::path(
post,
path = "/users",
tag = "user",
request_body = UserBody,
responses(
(status = 200, description = "User created successfully", body = UserResponse),
(status = 400, description = "Validation error", body = String),
(status = 500, description = "Internal server error", body = String)
)
)]
#[instrument(skip(pool, user))]
pub async fn post_user(
State(pool): State<PgPool>,
Json(user): Json<UserBody>,
) -> Result<impl IntoResponse, (StatusCode, Json<serde_json::Value>)> {
// Validate input
if let Err(errors) = user.validate() {
let error_messages: Vec<String> = errors
.field_errors()
.iter()
.flat_map(|(_, errors)| errors.iter().map(|e| e.message.clone().unwrap_or_default().to_string()))
.collect();
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": error_messages.join(", ") }))
));
}
// Hash the password before saving it
let hashed_password = hash_password(&user.password)
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": "Failed to hash password." }))))?;
// Generate TOTP secret if totp is Some("true")
let totp_secret = if user.totp.as_deref() == Some("true") {
Some(generate_totp_secret())
} else {
None
};
let row = sqlx::query!(
"INSERT INTO users (username, email, password_hash, totp_secret, role_level)
VALUES ($1, $2, $3, $4, $5)
RETURNING id, username, email, password_hash, totp_secret, role_level",
user.username,
user.email,
hashed_password,
totp_secret,
1, // Default role_level
)
.fetch_one(&pool)
.await
.map_err(|_err| (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": "Could not create the user."}))))?;
Ok(Json(UserResponse {
id: row.id,
username: row.username,
email: row.email,
totp_secret: row.totp_secret,
role_level: row.role_level,
}))
}

View File

@ -1,190 +0,0 @@
use axum::{extract::{Extension, Path, State}, Json};
use axum::http::StatusCode;
use axum::response::IntoResponse;
use chrono::{Duration, NaiveDate, Utc};
use serde::{Deserialize, Serialize};
use serde_json::json;
use sqlx::postgres::PgPool;
use tracing::instrument;
use utoipa::ToSchema;
use uuid::Uuid;
use validator::Validate;
use crate::handlers::validate::validate_future_date;
use crate::middlewares::auth::{generate_api_key, hash_password};
use crate::models::user::User;
#[derive(Deserialize, Validate, ToSchema)]
pub struct ApiKeyBody {
#[validate(length(min = 0, max = 50))]
pub description: Option<String>,
#[validate(custom(function = "validate_future_date"))]
pub expiration_date: Option<String>,
}
#[derive(Serialize, ToSchema)]
pub struct ApiKeyResponse {
pub id: Uuid,
pub description: Option<String>,
}
#[utoipa::path(
post,
path = "/apikeys/rotate/{id}",
tag = "apikey",
request_body = ApiKeyBody,
responses(
(status = 200, description = "API key rotated successfully", body = ApiKeyResponse),
(status = 400, description = "Validation error", body = String),
(status = 404, description = "API key not found", body = String),
(status = 500, description = "Internal server error", body = String)
),
params(
("id" = String, Path, description = "API key identifier")
)
)]
#[instrument(skip(pool, user, apikeybody))]
pub async fn rotate_apikey(
State(pool): State<PgPool>,
Extension(user): Extension<User>,
Path(id): Path<String>,
Json(apikeybody): Json<ApiKeyBody>
) -> impl IntoResponse {
// Validate input
if let Err(errors) = apikeybody.validate() {
let error_messages: Vec<String> = errors
.field_errors()
.iter()
.flat_map(|(_, errors)| errors.iter().map(|e| e.message.clone().unwrap_or_default().to_string()))
.collect();
return Err((
StatusCode::BAD_REQUEST,
Json(json!({ "error": error_messages.join(", ") }))
));
}
// Validate UUID format
let uuid = match Uuid::parse_str(&id) {
Ok(uuid) => uuid,
Err(_) => return Err((StatusCode::BAD_REQUEST,
Json(json!({ "error": "Invalid API key identifier format" })))),
};
// Verify ownership of the API key
let existing_key = sqlx::query_as!(ApiKeyResponse,
"SELECT id, description FROM apikeys
WHERE user_id = $1 AND id = $2 AND disabled = FALSE",
user.id,
uuid
)
.fetch_optional(&pool)
.await
.map_err(|e| {
tracing::error!("Database error: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error" })))
})?;
let existing_key = existing_key.ok_or_else(||
(StatusCode::NOT_FOUND,
Json(json!({ "error": "API key not found or already disabled" })))
)?;
// Validate expiration date format
let expiration_date = match &apikeybody.expiration_date {
Some(date_str) => NaiveDate::parse_from_str(date_str, "%Y-%m-%d")
.map_err(|_| (StatusCode::BAD_REQUEST,
Json(json!({ "error": "Invalid expiration date format. Use YYYY-MM-DD" }))))?,
None => (Utc::now() + Duration::days(365 * 2)).naive_utc().date(),
};
// Validate expiration date is in the future
if expiration_date <= Utc::now().naive_utc().date() {
return Err((StatusCode::BAD_REQUEST,
Json(json!({ "error": "Expiration date must be in the future" }))));
}
// Generate new secure API key
let api_key = generate_api_key();
let key_hash = hash_password(&api_key)
.map_err(|e| {
tracing::error!("Hashing error: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error" })))
})?;
// Begin transaction
let mut tx = pool.begin().await.map_err(|e| {
tracing::error!("Transaction error: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error" })))
})?;
// Disable old key
let disable_result = sqlx::query!(
"UPDATE apikeys SET
disabled = TRUE,
expiration_date = CURRENT_DATE + INTERVAL '1 day'
WHERE id = $1 AND user_id = $2",
uuid,
user.id
)
.execute(&mut *tx)
.await
.map_err(|e| {
tracing::error!("Database error: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error" })))
})?;
if disable_result.rows_affected() == 0 {
return Err((StatusCode::NOT_FOUND,
Json(json!({ "error": "API key not found or already disabled" }))));
}
// Create new key with automatic description
let description = apikeybody.description.unwrap_or_else(||
format!("Rotated from key {} - {}",
existing_key.id,
Utc::now().format("%Y-%m-%d"))
);
let new_key = sqlx::query!(
"INSERT INTO apikeys
(key_hash, description, expiration_date, user_id, access_read, access_modify)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, description, expiration_date",
key_hash,
description,
expiration_date,
user.id,
true, // Default read access
false // Default no modify access
)
.fetch_one(&mut *tx)
.await
.map_err(|e| {
tracing::error!("Database error: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error" })))
})?;
tx.commit().await.map_err(|e| {
tracing::error!("Transaction commit error: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Internal server error" })))
})?;
Ok(Json(json!({
"id": new_key.id,
"api_key": api_key,
"description": new_key.description,
"expiration_date": new_key.expiration_date,
"warning": "Store this key securely - it won't be shown again",
"rotation_info": {
"original_key": existing_key.id,
"disabled_at": Utc::now().to_rfc3339()
}
})))
}

26
src/routes/todo.rs Normal file
View File

@ -0,0 +1,26 @@
use axum::{
Router,
routing::{get, post, delete},
middleware::from_fn,
};
use sqlx::PgPool;
use crate::middlewares::auth::authorize;
use crate::handlers::{get_todos::{get_all_todos, get_todos_by_id}, post_todos::post_todo, delete_todos::delete_todo_by_id};
pub fn create_todo_routes() -> Router<PgPool> {
Router::new()
.route("/all", get(get_all_todos).layer(from_fn(|req, next| {
let allowed_roles = vec![1, 2];
authorize(req, next, allowed_roles)})))
.route("/new", post(post_todo).layer(from_fn(|req, next| {
let allowed_roles = vec![1, 2];
authorize(req, next, allowed_roles)
})))
.route("/{id}", get(get_todos_by_id).layer(from_fn(|req, next| {
let allowed_roles = vec![1, 2];
authorize(req, next, allowed_roles)})))
.route("/{id}", delete(delete_todo_by_id).layer(from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)})))
}

19
src/routes/usage.rs Normal file
View File

@ -0,0 +1,19 @@
use axum::{
Router,
routing::get,
middleware::from_fn,
};
use sqlx::PgPool;
use crate::middlewares::auth::authorize;
use crate::handlers::get_usage::{get_usage_last_day, get_usage_last_week};
pub fn create_usage_routes() -> Router<PgPool> {
Router::new()
.route("/lastday", get(get_usage_last_day).layer(from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)})))
.route("/lastweek", get(get_usage_last_week).layer(from_fn(|req, next| {
let allowed_roles = vec![1,2];
authorize(req, next, allowed_roles)})))
}

26
src/routes/user.rs Normal file
View File

@ -0,0 +1,26 @@
use axum::{
Router,
routing::{get, post, delete},
middleware::from_fn,
};
use sqlx::PgPool;
use crate::middlewares::auth::authorize;
use crate::handlers::{get_users::{get_all_users, get_users_by_id}, post_users::post_user, delete_users::delete_user_by_id};
pub fn create_user_routes() -> Router<PgPool> {
Router::new()
.route("/all", get(get_all_users).layer(from_fn(|req, next| {
let allowed_roles = vec![2];
authorize(req, next, allowed_roles)})))
.route("/new", post(post_user).layer(from_fn(|req, next| {
let allowed_roles = vec![2];
authorize(req, next, allowed_roles)
})))
.route("/{id}", get(get_users_by_id).layer(from_fn(|req, next| {
let allowed_roles = vec![2];
authorize(req, next, allowed_roles)})))
.route("/{id}", delete(delete_user_by_id).layer(from_fn(|req, next| {
let allowed_roles = vec![2];
authorize(req, next, allowed_roles)})))
}

9
src/utils/README.md Normal file
View File

@ -0,0 +1,9 @@
# Utils
The `/src/utils` folder includes functions for tasks like validating user input, handling JWT authentication, generating secure passwords, and working with time-based one-time passwords (TOTP). These utilities are used throughout the application to streamline common functionality and reduce redundancy.
## Contributing
Add new route files, update existing routes, or enhance the middleware and handlers. Document any changes for clarity.
## License
This project is licensed under the MIT License.

186
src/utils/auth.rs Normal file
View File

@ -0,0 +1,186 @@
use std::{collections::HashSet, env};
use axum::http::StatusCode;
use argon2::{
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString, Error},
Argon2, Params, Version,
};
use chrono::{Duration, Utc};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, TokenData, Validation};
use totp_rs::{Secret, TOTP};
use rand::{rngs::OsRng, Rng};
use tracing::{warn, error, instrument};
use tokio::task;
use moka::future::Cache;
use lazy_static::lazy_static;
use crate::models::auth::Claims;
// Standard library imports for working with HTTP, environment variables, and other necessary utilities
// Importing necessary libraries for password hashing, JWT handling, and date/time management
// Cache for storing successful password verifications
lazy_static! {
static ref PASSWORD_CACHE: Cache<String, bool> = Cache::builder()
.time_to_live(std::time::Duration::from_secs(300)) // 5 minutes
.build();
}
#[instrument(skip(password, hash))]
pub async fn verify_hash(password: &str, hash: &str) -> Result<bool, Error> {
// Check cache first
if let Some(result) = PASSWORD_CACHE.get(password).await {
return Ok(result);
}
let password_owned = password.to_string();
let hash_owned = hash.to_string();
let password_clone = password_owned.clone();
let result = task::spawn_blocking(move || {
let parsed_hash = PasswordHash::new(&hash_owned)?;
Argon2::default()
.verify_password(password_owned.as_bytes(), &parsed_hash)
.map(|_| true) // Remove the map_err conversion
})
.await
.map_err(|_| argon2::Error::AlgorithmInvalid)??; // Keep double question mark
if result {
PASSWORD_CACHE.insert(password_clone, true).await;
}
Ok(result)
}
// Function to hash a password using Argon2 and a salt retrieved from the environment variables
#[instrument(skip(password))]
pub fn hash_password(password: &str) -> Result<String, argon2::password_hash::Error> {
// Generate random salt
let salt = SaltString::generate(&mut OsRng);
// Configure Argon2id with recommended parameters
let argon2 = Argon2::new(
argon2::Algorithm::Argon2id, // Explicitly use Argon2id variant
Version::V0x13, // Latest version
Params::new( // OWASP-recommended parameters
15360, // 15 MiB memory cost
2, // 2 iterations
1, // 1 parallelism
None // Default output length
)?
);
// Hash password with configured parameters
let password_hash = argon2.hash_password(password.as_bytes(), &salt)?.to_string();
Ok(password_hash)
}
#[instrument]
pub fn generate_totp_secret() -> String {
let totp = TOTP::new(
totp_rs::Algorithm::SHA512,
8,
1,
30,
Secret::generate_secret().to_bytes().unwrap(),
).expect("Failed to create TOTP.");
totp.generate_current().unwrap()
}
#[instrument]
pub fn generate_api_key() -> String {
let mut rng = rand::thread_rng();
(0..5)
.map(|_| {
(0..8)
.map(|_| format!("{:02x}", rng.gen::<u8>()))
.collect::<String>()
})
.collect::<Vec<String>>()
.join("-")
}
// Function to encode a JWT token for the given email address
#[instrument(skip(email))]
pub fn encode_jwt(email: String) -> Result<String, StatusCode> {
// Load secret key from environment variable for better security
let secret_key = env::var("JWT_SECRET_KEY")
.map_err(|_| {
error!("JWT_SECRET_KEY not set in environment variables");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let now = Utc::now();
let expire = Duration::hours(24);
let exp: usize = (now + expire).timestamp() as usize;
let iat: usize = now.timestamp() as usize;
let claim = Claims {
sub: email.clone(),
iat,
exp,
iss: "your_issuer".to_string(), // Add issuer if needed
aud: "your_audience".to_string(), // Add audience if needed
};
// Use a secure HMAC algorithm (e.g., HS256) for signing the token
encode(
&Header::new(jsonwebtoken::Algorithm::HS256),
&claim,
&EncodingKey::from_secret(secret_key.as_ref()),
)
.map_err(|e| {
error!("Failed to encode JWT: {:?}", e);
StatusCode::INTERNAL_SERVER_ERROR
})
}
// Function to decode a JWT token and extract the claims
#[instrument(skip(jwt))]
pub fn decode_jwt(jwt: String) -> Result<TokenData<Claims>, StatusCode> {
// Load secret key from environment variable for better security
let secret_key = env::var("JWT_SECRET_KEY")
.map_err(|_| {
error!("JWT_SECRET_KEY not set in environment variables");
StatusCode::INTERNAL_SERVER_ERROR
})?;
// Set up validation rules (e.g., check if token has expired, is from a valid issuer, etc.)
let mut validation = Validation::default();
// Use a HashSet for the audience and issuer validation
let mut audience_set = HashSet::new();
audience_set.insert("your_audience".to_string());
let mut issuer_set = HashSet::new();
issuer_set.insert("your_issuer".to_string());
// Set up the validation with the HashSet for audience and issuer
validation.aud = Some(audience_set);
validation.iss = Some(issuer_set);
// Decode the JWT and extract the claims
decode::<Claims>(
&jwt,
&DecodingKey::from_secret(secret_key.as_ref()),
&validation,
)
.map_err(|e| {
warn!("Failed to decode JWT: {:?}", e);
StatusCode::UNAUTHORIZED
})
}
// Function to verify password asynchronously
#[instrument(skip(password, hash))]
pub async fn verify_password(password: String, hash: String) -> Result<bool, Error> {
verify_hash(&password, &hash).await
}
#[instrument(skip(password, hash))]
pub async fn verify_api_key(password: String, hash: String) -> Result<bool, Error> {
verify_hash(&password, &hash).await
}

2
src/utils/mod.rs Normal file
View File

@ -0,0 +1,2 @@
pub mod validate;
pub mod auth;