mirror of
https://github.com/kristoferssolo/Axium.git
synced 2025-10-21 16:00:34 +00:00
first commit
This commit is contained in:
commit
e20f21bc8b
35
.dockerignore
Normal file
35
.dockerignore
Normal file
@ -0,0 +1,35 @@
|
||||
|
||||
# Include any files or directories that you don't want to be copied to your
|
||||
# container here (e.g., local build artifacts, temporary files, etc.).
|
||||
#
|
||||
# For more help, visit the .dockerignore file reference guide at
|
||||
# https://docs.docker.com/engine/reference/builder/#dockerignore-file
|
||||
|
||||
**/.DS_Store
|
||||
**/.classpath
|
||||
**/.dockerignore
|
||||
**/.env
|
||||
**/.git
|
||||
**/.gitignore
|
||||
**/.project
|
||||
**/.settings
|
||||
**/.toolstarget
|
||||
**/.vs
|
||||
**/.vscode
|
||||
**/*.*proj.user
|
||||
**/*.dbmdl
|
||||
**/*.jfm
|
||||
**/charts
|
||||
**/docker-compose*
|
||||
**/compose*
|
||||
**/Dockerfile*
|
||||
**/node_modules
|
||||
**/npm-debug.log
|
||||
**/secrets.dev.yaml
|
||||
**/values.dev.yaml
|
||||
/bin
|
||||
/target
|
||||
LICENSE
|
||||
README.md
|
||||
|
||||
|
||||
68
.env.example
Normal file
68
.env.example
Normal file
@ -0,0 +1,68 @@
|
||||
# ==============================
|
||||
# 📌 DATABASE CONFIGURATION
|
||||
# ==============================
|
||||
|
||||
# PostgreSQL connection URL (format: postgres://user:password@host/database)
|
||||
DATABASE_URL="postgres://postgres:1234@localhost/database_name"
|
||||
|
||||
# Maximum number of connections in the database pool
|
||||
DATABASE_MAX_CONNECTIONS=20
|
||||
|
||||
# Minimum number of connections in the database pool
|
||||
DATABASE_MIN_CONNECTIONS=5
|
||||
|
||||
# ==============================
|
||||
# 🌍 SERVER CONFIGURATION
|
||||
# ==============================
|
||||
|
||||
# IP address the server will bind to (0.0.0.0 allows all network interfaces)
|
||||
SERVER_IP="0.0.0.0"
|
||||
|
||||
# Port the server will listen on
|
||||
SERVER_PORT="3000"
|
||||
|
||||
# Enable tracing for debugging/logging (true/false)
|
||||
SERVER_TRACE_ENABLED=true
|
||||
|
||||
# ==============================
|
||||
# 🔒 HTTPS CONFIGURATION
|
||||
# ==============================
|
||||
|
||||
# Enable HTTPS (true/false)
|
||||
SERVER_HTTPS_ENABLED=false
|
||||
|
||||
# Enable HTTP/2 when using HTTPS (true/false)
|
||||
SERVER_HTTPS_HTTP2_ENABLED=true
|
||||
|
||||
# Path to the SSL certificate file (only used if SERVER_HTTPS_ENABLED=true)
|
||||
SERVER_HTTPS_CERT_FILE_PATH=cert.pem
|
||||
|
||||
# Path to the SSL private key file (only used if SERVER_HTTPS_ENABLED=true)
|
||||
SERVER_HTTPS_KEY_FILE_PATH=key.pem
|
||||
|
||||
# ==============================
|
||||
# 🚦 RATE LIMIT CONFIGURATION
|
||||
# ==============================
|
||||
|
||||
# Maximum number of requests allowed per period
|
||||
SERVER_RATE_LIMIT=5
|
||||
|
||||
# Time period (in seconds) for rate limiting
|
||||
SERVER_RATE_LIMIT_PERIOD=1
|
||||
|
||||
# ==============================
|
||||
# 📦 COMPRESSION CONFIGURATION
|
||||
# ==============================
|
||||
|
||||
# Enable Brotli compression (true/false)
|
||||
SERVER_COMPRESSION_ENABLED=true
|
||||
|
||||
# Compression level (valid range: 0-11, where 11 is the highest compression)
|
||||
SERVER_COMPRESSION_LEVEL=6
|
||||
|
||||
# ==============================
|
||||
# 🔑 AUTHENTICATION CONFIGURATION
|
||||
# ==============================
|
||||
|
||||
# Argon2 salt for password hashing (must be kept secret!)
|
||||
AUTHENTICATION_ARGON2_SALT="dMjQgtSmoQIH3Imi"
|
||||
5
.gitignore
vendored
Normal file
5
.gitignore
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
.env
|
||||
/target
|
||||
cert.pem
|
||||
key.pem
|
||||
Cargo.lock
|
||||
45
Cargo.toml
Normal file
45
Cargo.toml
Normal file
@ -0,0 +1,45 @@
|
||||
[package]
|
||||
name = "rustapi"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
# Web framework and server
|
||||
axum = { version = "0.8.1", features = ["json"] }
|
||||
# hyper = { version = "1.5.2", features = ["full"] }
|
||||
|
||||
# Database interaction
|
||||
sqlx = { version = "0.8.3", features = ["runtime-tokio-rustls", "postgres", "migrate"] }
|
||||
|
||||
# Serialization and deserialization
|
||||
serde = { version = "1.0.217", features = ["derive"] }
|
||||
serde_json = "1.0.137"
|
||||
|
||||
# Authentication and security
|
||||
jsonwebtoken = "9.3.0"
|
||||
argon2 = "0.5.3"
|
||||
|
||||
# Asynchronous runtime and traits
|
||||
tokio = { version = "1.43.0", features = ["rt-multi-thread", "process"] }
|
||||
|
||||
# Configuration and environment
|
||||
dotenvy = "0.15.7"
|
||||
|
||||
# Middleware and server utilities
|
||||
tower = { version = "0.5.2", features = ["limit"] }
|
||||
tower-http = { version = "0.6.2", features = ["trace", "cors", "compression-br"] }
|
||||
|
||||
# Logging and monitoring
|
||||
tracing = "0.1.41"
|
||||
tracing-subscriber = "0.3.19"
|
||||
|
||||
# System information
|
||||
sysinfo = "0.33.1"
|
||||
|
||||
# Date and time handling
|
||||
chrono = "0.4.39"
|
||||
|
||||
# SSL / TLS
|
||||
rustls = "0.23.21"
|
||||
tokio-rustls = "0.26.1"
|
||||
rustls-pemfile = "2.2.0"
|
||||
73
Dockerfile
Normal file
73
Dockerfile
Normal file
@ -0,0 +1,73 @@
|
||||
|
||||
# syntax=docker/dockerfile:1
|
||||
|
||||
# Comments are provided throughout this file to help you get started.
|
||||
# If you need more help, visit the Dockerfile reference guide at
|
||||
# https://docs.docker.com/engine/reference/builder/
|
||||
|
||||
################################################################################
|
||||
# Create a stage for building the application.
|
||||
|
||||
ARG RUST_VERSION=1.78.0
|
||||
ARG APP_NAME=backend
|
||||
FROM rust:${RUST_VERSION}-slim-bullseye AS build
|
||||
ARG APP_NAME
|
||||
WORKDIR /app
|
||||
|
||||
# Build the application.
|
||||
# Leverage a cache mount to /usr/local/cargo/registry/
|
||||
# for downloaded dependencies and a cache mount to /app/target/ for
|
||||
# compiled dependencies which will speed up subsequent builds.
|
||||
# Leverage a bind mount to the src directory to avoid having to copy the
|
||||
# source code into the container. Once built, copy the executable to an
|
||||
# output directory before the cache mounted /app/target is unmounted.
|
||||
RUN --mount=type=bind,source=src,target=src \
|
||||
# --mount=type=bind,source=configuration.yaml,target=configuration.yaml \
|
||||
--mount=type=bind,source=Cargo.toml,target=Cargo.toml \
|
||||
--mount=type=bind,source=Cargo.lock,target=Cargo.lock \
|
||||
--mount=type=cache,target=/app/target/ \
|
||||
--mount=type=cache,target=/usr/local/cargo/registry/ \
|
||||
<<EOF
|
||||
set -e
|
||||
cargo build --locked --release
|
||||
cp ./target/release/$APP_NAME /bin/server
|
||||
EOF
|
||||
|
||||
# COPY /src/web /bin/web
|
||||
################################################################################
|
||||
# Create a new stage for running the application that contains the minimal
|
||||
# runtime dependencies for the application. This often uses a different base
|
||||
# image from the build stage where the necessary files are copied from the build
|
||||
# stage.
|
||||
#
|
||||
# The example below uses the debian bullseye image as the foundation for running the app.
|
||||
# By specifying the "bullseye-slim" tag, it will also use whatever happens to be the
|
||||
# most recent version of that tag when you build your Dockerfile. If
|
||||
# reproducability is important, consider using a digest
|
||||
# (e.g., debian@sha256:ac707220fbd7b67fc19b112cee8170b41a9e97f703f588b2cdbbcdcecdd8af57).
|
||||
FROM debian:bullseye-slim AS final
|
||||
|
||||
# Create a non-privileged user that the app will run under.
|
||||
# See https://docs.docker.com/develop/develop-images/dockerfile_best-practices/#user
|
||||
ARG UID=10001
|
||||
RUN adduser \
|
||||
--disabled-password \
|
||||
--gecos "" \
|
||||
--home "/nonexistent" \
|
||||
--shell "/sbin/nologin" \
|
||||
--no-create-home \
|
||||
--uid "${UID}" \
|
||||
appuser
|
||||
USER appuser
|
||||
|
||||
# Copy the executable from the "build" stage.
|
||||
COPY --from=build /bin/server /bin/
|
||||
|
||||
# Expose the port that the application listens on.
|
||||
EXPOSE 80
|
||||
|
||||
# What the container should run when it is started.
|
||||
CMD ["/bin/server"]
|
||||
|
||||
|
||||
|
||||
168
README.md
Normal file
168
README.md
Normal file
@ -0,0 +1,168 @@
|
||||
```markdown
|
||||
# 🦀 RustAPI
|
||||
**An example API built with Rust, Axum, SQLx, and PostgreSQL**
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
|
||||
## 🚀 Core Features
|
||||
- **Rust API Template** - Production-ready starter template with modern practices
|
||||
- **PostgreSQL Integration** - Full database support with SQLx migrations
|
||||
- **Comprehensive Health Monitoring**
|
||||
Docker-compatible endpoint with system metrics:
|
||||
```json
|
||||
{
|
||||
"details": {
|
||||
"cpu_usage": {"available_percentage": "9.85", "status": "low"},
|
||||
"database": {"status": "ok"},
|
||||
"disk_usage": {"status": "ok", "used_percentage": "74.00"},
|
||||
"memory": {"available_mb": 21613, "status": "normal"}
|
||||
},
|
||||
"status": "degraded"
|
||||
}
|
||||
```
|
||||
- **JWT Authentication** - Secure token-based auth with Argon2 password hashing
|
||||
- **Granular Access Control** - Role-based endpoint protection:
|
||||
```rust
|
||||
.route("/", post(post_todo).layer(axum::middleware::from_fn(|req, next| {
|
||||
let allowed_roles = vec![1, 2];
|
||||
authorize(req, next, allowed_roles)
|
||||
})))
|
||||
```
|
||||
- **User Context Injection** - Automatic user profile handling in endpoints:
|
||||
```rust
|
||||
pub async fn post_todo(
|
||||
Extension(user): Extension<User>, // Injected user
|
||||
Json(todo): Json<TodoBody>
|
||||
) -> impl IntoResponse {
|
||||
if todo.user_id != user.id {
|
||||
return Err((StatusCode::FORBIDDEN, Json(json!({
|
||||
"error": "Cannot create todos for others"
|
||||
}))));
|
||||
}
|
||||
```
|
||||
- **Modern protocols ** - HTTP/2 with secure TLS defaults
|
||||
- **Observability** - Integrated tracing
|
||||
- **Optimized for performance** - Brotli compression
|
||||
- **Easy configuration** - `.env` and environment variables
|
||||
- **Documented codebase** - Extensive inline comments for easy modification and readability
|
||||
- **Latest dependencies** - Regularly updated Rust ecosystem crates
|
||||
|
||||
## 📦 Installation & Usage
|
||||
```bash
|
||||
# Clone and setup
|
||||
git clone https://github.com/Riktastic/rustapi.git
|
||||
cd rustapi && cp .env.example .env
|
||||
|
||||
# Database setup
|
||||
sqlx database create && sqlx migrate run
|
||||
|
||||
# Start server
|
||||
cargo run --release
|
||||
```
|
||||
|
||||
### 🔐 Default Accounts
|
||||
|
||||
**Warning:** These accounts should only be used for initial testing. Always change or disable them in production environments.
|
||||
|
||||
| Email | Password | Role |
|
||||
|---------------------|----------|----------------|
|
||||
| `user@test.com` | `test` | User |
|
||||
| `admin@test.com` | `test` | Administrator |
|
||||
|
||||
⚠️ **Security Recommendations:**
|
||||
1. Rotate passwords immediately after initial setup
|
||||
2. Disable default accounts before deploying to production
|
||||
3. Implement proper user management endpoints
|
||||
|
||||
|
||||
## ⚙️ Configuration
|
||||
```env
|
||||
# ==============================
|
||||
# 📌 DATABASE CONFIGURATION
|
||||
# ==============================
|
||||
|
||||
# PostgreSQL connection URL (format: postgres://user:password@host/database)
|
||||
DATABASE_URL="postgres://postgres:1234@localhost/database_name"
|
||||
|
||||
# Maximum number of connections in the database pool
|
||||
DATABASE_MAX_CONNECTIONS=20
|
||||
|
||||
# Minimum number of connections in the database pool
|
||||
DATABASE_MIN_CONNECTIONS=5
|
||||
|
||||
# ==============================
|
||||
# 🌍 SERVER CONFIGURATION
|
||||
# ==============================
|
||||
|
||||
# IP address the server will bind to (0.0.0.0 allows all network interfaces)
|
||||
SERVER_IP="0.0.0.0"
|
||||
|
||||
# Port the server will listen on
|
||||
SERVER_PORT="3000"
|
||||
|
||||
# Enable tracing for debugging/logging (true/false)
|
||||
SERVER_TRACE_ENABLED=true
|
||||
|
||||
# ==============================
|
||||
# 🔒 HTTPS CONFIGURATION
|
||||
# ==============================
|
||||
|
||||
# Enable HTTPS (true/false)
|
||||
SERVER_HTTPS_ENABLED=false
|
||||
|
||||
# Enable HTTP/2 when using HTTPS (true/false)
|
||||
SERVER_HTTPS_HTTP2_ENABLED=true
|
||||
|
||||
# Path to the SSL certificate file (only used if SERVER_HTTPS_ENABLED=true)
|
||||
SERVER_HTTPS_CERT_FILE_PATH=cert.pem
|
||||
|
||||
# Path to the SSL private key file (only used if SERVER_HTTPS_ENABLED=true)
|
||||
SERVER_HTTPS_KEY_FILE_PATH=key.pem
|
||||
|
||||
# ==============================
|
||||
# 🚦 RATE LIMIT CONFIGURATION
|
||||
# ==============================
|
||||
|
||||
# Maximum number of requests allowed per period
|
||||
SERVER_RATE_LIMIT=5
|
||||
|
||||
# Time period (in seconds) for rate limiting
|
||||
SERVER_RATE_LIMIT_PERIOD=1
|
||||
|
||||
# ==============================
|
||||
# 📦 COMPRESSION CONFIGURATION
|
||||
# ==============================
|
||||
|
||||
# Enable Brotli compression (true/false)
|
||||
SERVER_COMPRESSION_ENABLED=true
|
||||
|
||||
# Compression level (valid range: 0-11, where 11 is the highest compression)
|
||||
SERVER_COMPRESSION_LEVEL=6
|
||||
|
||||
# ==============================
|
||||
# 🔑 AUTHENTICATION CONFIGURATION
|
||||
# ==============================
|
||||
|
||||
# Argon2 salt for password hashing (must be kept secret!)
|
||||
AUTHENTICATION_ARGON2_SALT="dMjQgtSmoQIH3Imi"
|
||||
```
|
||||
|
||||
## 📂 Project Structure
|
||||
```
|
||||
rustapi/
|
||||
├── migrations/ # SQL schema versions
|
||||
├── src/
|
||||
│ ├── core/ # Config, TLS, server setup
|
||||
│ ├── database/ # Query handling
|
||||
│ ├── middlewares/ # Auth system
|
||||
│ ├── models/ # Data structures
|
||||
│ └── routes/ # API endpoints
|
||||
└── Dockerfile # Containerization
|
||||
```
|
||||
|
||||
## 🛠️ Technology Stack
|
||||
| Category | Key Technologies |
|
||||
|-----------------------|---------------------------------|
|
||||
| Web Framework | Axum 0.8 + Tower |
|
||||
| Database | PostgreSQL + SQLx 0.8 |
|
||||
| Security | JWT + Argon2 + Rustls |
|
||||
| Monitoring | Tracing + Sysinfo |
|
||||
32
compose.yaml
Normal file
32
compose.yaml
Normal file
@ -0,0 +1,32 @@
|
||||
|
||||
services:
|
||||
server:
|
||||
build:
|
||||
context: .
|
||||
target: final
|
||||
ports:
|
||||
- 80:80
|
||||
depends_on:
|
||||
- db_image
|
||||
networks:
|
||||
- common-net
|
||||
|
||||
db_image:
|
||||
image: postgres:latest
|
||||
environment:
|
||||
POSTGRES_PORT: 3306
|
||||
POSTGRES_DATABASE: database_name
|
||||
POSTGRES_USER: user
|
||||
POSTGRES_PASSWORD: database_password
|
||||
POSTGRES_ROOT_PASSWORD: strong_database_password
|
||||
expose:
|
||||
- 3306
|
||||
ports:
|
||||
- "3307:3306"
|
||||
networks:
|
||||
- common-net
|
||||
|
||||
networks:
|
||||
common-net: {}
|
||||
|
||||
|
||||
1
generate_ssl_key.bat
Normal file
1
generate_ssl_key.bat
Normal file
@ -0,0 +1 @@
|
||||
openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes -subj "/CN=localhost" -addext "subjectAltName = DNS:localhost, IP:127.0.0.1"
|
||||
19
migrations/20250128160043_create_roles_table.sql
Normal file
19
migrations/20250128160043_create_roles_table.sql
Normal file
@ -0,0 +1,19 @@
|
||||
-- Create the roles table
|
||||
CREATE TABLE IF NOT EXISTS roles (
|
||||
id SERIAL PRIMARY KEY,
|
||||
level INT NOT NULL,
|
||||
role VARCHAR(255) NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description VARCHAR(255),
|
||||
CONSTRAINT unique_role UNIQUE (role) -- Add a unique constraint to the 'role' column
|
||||
);
|
||||
|
||||
-- Insert a role into the roles table (this assumes role_id=1 is for 'user')
|
||||
INSERT INTO roles (level, role, name, description)
|
||||
VALUES (1, 'user', 'User', 'A regular user with basic access.')
|
||||
ON CONFLICT (role) DO NOTHING; -- Prevent duplicate insertions if role already exists
|
||||
|
||||
-- Insert a role into the roles table (this assumes role_id=2 is for 'admin')
|
||||
INSERT INTO roles (level, role, name, description)
|
||||
VALUES (2, 'admin', 'Administrator', 'An administrator.')
|
||||
ON CONFLICT (role) DO NOTHING; -- Prevent duplicate insertions if role already exists
|
||||
21
migrations/20250128160119_create_users_table.sql
Normal file
21
migrations/20250128160119_create_users_table.sql
Normal file
@ -0,0 +1,21 @@
|
||||
CREATE TABLE users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username VARCHAR(255) NOT NULL UNIQUE,
|
||||
email VARCHAR(255) NOT NULL UNIQUE,
|
||||
password_hash VARCHAR(255) NOT NULL,
|
||||
totp_secret VARCHAR(255),
|
||||
role_id INT NOT NULL DEFAULT 1 REFERENCES roles(id), -- Default role_id is set to 1
|
||||
CONSTRAINT unique_username UNIQUE (username) -- Ensure that username is unique
|
||||
);
|
||||
|
||||
-- Insert the example 'user' into the users table with a conflict check for username
|
||||
INSERT INTO users (username, email, password_hash, role_id)
|
||||
VALUES
|
||||
('user', 'user@test.com', '$argon2i$v=19$m=16,t=2,p=1$ZE1qUWd0U21vUUlIM0ltaQ$dowBmjU4oHtoPd355dXypQ', 1)
|
||||
ON CONFLICT (username) DO NOTHING; -- Prevent duplicate insertions if username already exists
|
||||
|
||||
-- Insert the example 'admin' into the users table with a conflict check for username
|
||||
INSERT INTO users (username, email, password_hash, role_id)
|
||||
VALUES
|
||||
('admin', 'admin@test.com', '$argon2i$v=19$m=16,t=2,p=1$ZE1qUWd0U21vUUlIM0ltaQ$dowBmjU4oHtoPd355dXypQ', 2)
|
||||
ON CONFLICT (username) DO NOTHING; -- Prevent duplicate insertions if username already exists
|
||||
6
migrations/20250128160204_create_todos_table.sql
Normal file
6
migrations/20250128160204_create_todos_table.sql
Normal file
@ -0,0 +1,6 @@
|
||||
CREATE TABLE todos (
|
||||
id SERIAL PRIMARY KEY, -- Auto-incrementing primary key
|
||||
task TEXT NOT NULL, -- Task description, cannot be null
|
||||
description TEXT, -- Optional detailed description
|
||||
user_id INT NOT NULL REFERENCES users(id) -- Foreign key to link to users table
|
||||
);
|
||||
55
src/core/config.rs
Normal file
55
src/core/config.rs
Normal file
@ -0,0 +1,55 @@
|
||||
// Import the standard library's environment module
|
||||
use std::env;
|
||||
|
||||
/// Retrieves the value of an environment variable as a `String`.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `key` - The name of the environment variable to retrieve.
|
||||
///
|
||||
/// # Returns
|
||||
/// * The value of the environment variable if it exists.
|
||||
/// * Panics if the environment variable is missing.
|
||||
pub fn get_env(key: &str) -> String {
|
||||
env::var(key).unwrap_or_else(|_| panic!("Missing required environment variable: {}", key))
|
||||
}
|
||||
|
||||
/// Retrieves the value of an environment variable as a `String`, with a default value if not found.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `key` - The name of the environment variable to retrieve.
|
||||
/// * `default` - The value to return if the environment variable is not found.
|
||||
///
|
||||
/// # Returns
|
||||
/// * The value of the environment variable if it exists.
|
||||
/// * The `default` value if the environment variable is missing.
|
||||
pub fn get_env_with_default(key: &str, default: &str) -> String {
|
||||
env::var(key).unwrap_or_else(|_| default.to_string())
|
||||
}
|
||||
|
||||
/// Retrieves the value of an environment variable as a `bool`, with a default value if not found.
|
||||
///
|
||||
/// The environment variable is considered `true` if its value is "true" (case-insensitive), otherwise `false`.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `key` - The name of the environment variable to retrieve.
|
||||
/// * `default` - The value to return if the environment variable is not found.
|
||||
///
|
||||
/// # Returns
|
||||
/// * `true` if the environment variable is "true" (case-insensitive).
|
||||
/// * `false` otherwise, or if the variable is missing, the `default` value is returned.
|
||||
pub fn get_env_bool(key: &str, default: bool) -> bool {
|
||||
env::var(key).map(|v| v.to_lowercase() == "true").unwrap_or(default)
|
||||
}
|
||||
|
||||
/// Retrieves the value of an environment variable as a `u16`, with a default value if not found.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `key` - The name of the environment variable to retrieve.
|
||||
/// * `default` - The value to return if the environment variable is not found or cannot be parsed.
|
||||
///
|
||||
/// # Returns
|
||||
/// * The parsed `u16` value of the environment variable if it exists and is valid.
|
||||
/// * The `default` value if the variable is missing or invalid.
|
||||
pub fn get_env_u16(key: &str, default: u16) -> u16 {
|
||||
env::var(key).unwrap_or_else(|_| default.to_string()).parse().unwrap_or(default)
|
||||
}
|
||||
4
src/core/mod.rs
Normal file
4
src/core/mod.rs
Normal file
@ -0,0 +1,4 @@
|
||||
pub mod config;
|
||||
pub mod server;
|
||||
pub mod tls;
|
||||
|
||||
38
src/core/server.rs
Normal file
38
src/core/server.rs
Normal file
@ -0,0 +1,38 @@
|
||||
// Axum for web server and routing
|
||||
use axum::Router;
|
||||
|
||||
// Middleware layers from tower_http
|
||||
use tower_http::compression::{CompressionLayer, CompressionLevel}; // For HTTP response compression
|
||||
use tower_http::trace::TraceLayer; // For HTTP request/response tracing
|
||||
|
||||
// Local crate imports for database connection and configuration
|
||||
use crate::database::connect::connect_to_database; // Function to connect to the database
|
||||
use crate::config; // Environment configuration helper
|
||||
|
||||
/// Function to create and configure the Axum server.
|
||||
pub async fn create_server() -> Router {
|
||||
// Establish a connection to the database
|
||||
let db = connect_to_database().await.expect("Failed to connect to database.");
|
||||
|
||||
// Initialize the routes for the server
|
||||
let mut app = crate::routes::create_routes(db);
|
||||
|
||||
// Enable tracing middleware if configured
|
||||
if config::get_env_bool("SERVER_TRACE_ENABLED", true) {
|
||||
app = app.layer(TraceLayer::new_for_http());
|
||||
println!("✔️ Trace hads been enabled.");
|
||||
}
|
||||
|
||||
// Enable compression middleware if configured
|
||||
if config::get_env_bool("SERVER_COMPRESSION_ENABLED", true) {
|
||||
// Parse compression level from environment or default to level 6
|
||||
let level = config::get_env("SERVER_COMPRESSION_LEVEL").parse().unwrap_or(6);
|
||||
// Apply compression layer with Brotli (br) enabled and the specified compression level
|
||||
app = app.layer(CompressionLayer::new().br(true).quality(CompressionLevel::Precise(level)));
|
||||
println!("✔️ Brotli compression enabled with compression quality level {}.", level);
|
||||
|
||||
}
|
||||
|
||||
// Return the fully configured application
|
||||
app
|
||||
}
|
||||
145
src/core/tls.rs
Normal file
145
src/core/tls.rs
Normal file
@ -0,0 +1,145 @@
|
||||
// Standard library imports
|
||||
use std::{
|
||||
future::Future,
|
||||
net::SocketAddr,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
io::BufReader,
|
||||
fs::File,
|
||||
iter,
|
||||
};
|
||||
|
||||
// External crate imports
|
||||
use axum::serve::Listener;
|
||||
use rustls::{self, server::ServerConfig, pki_types::{PrivateKeyDer, CertificateDer}};
|
||||
use rustls_pemfile::{Item, read_one, certs};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing;
|
||||
|
||||
// Local crate imports
|
||||
use crate::config; // Import env config helper
|
||||
|
||||
// Function to load TLS configuration from files
|
||||
pub fn load_tls_config() -> ServerConfig {
|
||||
// Get certificate and key file paths from the environment
|
||||
let cert_path = config::get_env("SERVER_HTTPS_CERT_FILE_PATH");
|
||||
let key_path = config::get_env("SERVER_HTTPS_KEY_FILE_PATH");
|
||||
|
||||
// Open the certificate and key files
|
||||
let cert_file = File::open(cert_path).expect("❌ Failed to open certificate file.");
|
||||
let key_file = File::open(key_path).expect("❌ Failed to open private key file.");
|
||||
|
||||
// Read the certificate chain and private key from the files
|
||||
let mut cert_reader = BufReader::new(cert_file);
|
||||
let mut key_reader = BufReader::new(key_file);
|
||||
|
||||
// Read and parse the certificate chain
|
||||
let cert_chain: Vec<CertificateDer> = certs(&mut cert_reader)
|
||||
.map(|cert| cert.expect("❌ Failed to read certificate."))
|
||||
.map(CertificateDer::from)
|
||||
.collect();
|
||||
|
||||
// Ensure certificates are found
|
||||
if cert_chain.is_empty() {
|
||||
panic!("❌ No valid certificates found.");
|
||||
}
|
||||
|
||||
// Read the private key from the file
|
||||
let key = iter::from_fn(|| read_one(&mut key_reader).transpose())
|
||||
.find_map(|item| match item.unwrap() {
|
||||
Item::Pkcs1Key(key) => Some(PrivateKeyDer::from(key)),
|
||||
Item::Pkcs8Key(key) => Some(PrivateKeyDer::from(key)),
|
||||
Item::Sec1Key(key) => Some(PrivateKeyDer::from(key)),
|
||||
_ => None,
|
||||
})
|
||||
.expect("Failed to read a valid private key.");
|
||||
|
||||
// Build and return the TLS server configuration
|
||||
ServerConfig::builder()
|
||||
.with_no_client_auth() // No client authentication
|
||||
.with_single_cert(cert_chain, key) // Use the provided cert and key
|
||||
.expect("Failed to create TLS configuration")
|
||||
}
|
||||
|
||||
// Custom listener that implements axum::serve::Listener
|
||||
#[derive(Clone)]
|
||||
pub struct TlsListener {
|
||||
pub inner: Arc<tokio::net::TcpListener>, // Inner TCP listener
|
||||
pub acceptor: tokio_rustls::TlsAcceptor, // TLS acceptor for handling TLS handshakes
|
||||
}
|
||||
|
||||
impl Listener for TlsListener {
|
||||
type Io = TlsStreamWrapper; // Type of I/O stream
|
||||
type Addr = SocketAddr; // Type of address (Socket address)
|
||||
|
||||
// Method to accept incoming connections and establish a TLS handshake
|
||||
fn accept(&mut self) -> impl Future<Output = (Self::Io, Self::Addr)> + Send {
|
||||
let acceptor = self.acceptor.clone(); // Clone the acceptor for async use
|
||||
|
||||
async move {
|
||||
loop {
|
||||
// Accept a TCP connection
|
||||
let (stream, addr) = match self.inner.accept().await {
|
||||
Ok((stream, addr)) => (stream, addr),
|
||||
Err(e) => {
|
||||
tracing::error!("❌ Error accepting TCP connection: {}", e);
|
||||
continue; // Retry on error
|
||||
}
|
||||
};
|
||||
|
||||
// Perform TLS handshake
|
||||
match acceptor.accept(stream).await {
|
||||
Ok(tls_stream) => {
|
||||
tracing::info!("Successful TLS handshake with {}", addr);
|
||||
return (TlsStreamWrapper(tls_stream), addr); // Return TLS stream and address
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!("TLS handshake failed: {} (Client may not trust certificate)", e);
|
||||
continue; // Retry on error
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Method to retrieve the local address of the listener
|
||||
fn local_addr(&self) -> std::io::Result<Self::Addr> {
|
||||
self.inner.local_addr()
|
||||
}
|
||||
}
|
||||
|
||||
// Wrapper for a TLS stream, implementing AsyncRead and AsyncWrite
|
||||
#[derive(Debug)]
|
||||
pub struct TlsStreamWrapper(tokio_rustls::server::TlsStream<tokio::net::TcpStream>);
|
||||
|
||||
impl AsyncRead for TlsStreamWrapper {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
Pin::new(&mut self.0).poll_read(cx, buf) // Delegate read operation to the underlying TLS stream
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for TlsStreamWrapper {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
Pin::new(&mut self.0).poll_write(cx, buf) // Delegate write operation to the underlying TLS stream
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Pin::new(&mut self.0).poll_flush(cx) // Flush operation for the TLS stream
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Pin::new(&mut self.0).poll_shutdown(cx) // Shutdown operation for the TLS stream
|
||||
}
|
||||
}
|
||||
|
||||
// Allow the TLS stream wrapper to be used in non-blocking contexts (needed for async operations)
|
||||
impl Unpin for TlsStreamWrapper {}
|
||||
51
src/database/connect.rs
Normal file
51
src/database/connect.rs
Normal file
@ -0,0 +1,51 @@
|
||||
use dotenvy::dotenv;
|
||||
use sqlx::{PgPool, migrate::Migrator, postgres::PgPoolOptions};
|
||||
use std::fs;
|
||||
use std::env;
|
||||
use std::path::Path;
|
||||
|
||||
/// Connects to the database using the DATABASE_URL environment variable.
|
||||
pub async fn connect_to_database() -> Result<PgPool, sqlx::Error> {
|
||||
dotenv().ok();
|
||||
let database_url = &env::var("DATABASE_URL").expect("❌ 'DATABASE_URL' environment variable not fount.");
|
||||
|
||||
// Read max and min connection values from environment variables, with defaults
|
||||
let max_connections: u32 = env::var("DATABASE_MAX_CONNECTIONS")
|
||||
.unwrap_or_else(|_| "10".to_string()) // Default to 10
|
||||
.parse()
|
||||
.expect("❌ Invalid 'DATABASE_MAX_CONNECTIONS' value; must be a number.");
|
||||
|
||||
let min_connections: u32 = env::var("DATABASE_MIN_CONNECTIONS")
|
||||
.unwrap_or_else(|_| "2".to_string()) // Default to 2
|
||||
.parse()
|
||||
.expect("❌ Invalid 'DATABASE_MIN_CONNECTIONS' value; must be a number.");
|
||||
|
||||
// Create and configure the connection pool
|
||||
let pool = PgPoolOptions::new()
|
||||
.max_connections(max_connections)
|
||||
.min_connections(min_connections)
|
||||
.connect(&database_url)
|
||||
.await?;
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
/// Run database migrations
|
||||
pub async fn run_database_migrations(pool: &PgPool) -> Result<(), sqlx::Error> {
|
||||
// Define the path to the migrations folder
|
||||
let migrations_path = Path::new("./migrations");
|
||||
|
||||
// Check if the migrations folder exists, and if not, create it
|
||||
if !migrations_path.exists() {
|
||||
fs::create_dir_all(migrations_path).expect("❌ Failed to create migrations directory. Make sure you have the necessary permissions.");
|
||||
println!("✔️ Created migrations directory: {:?}", migrations_path);
|
||||
}
|
||||
|
||||
// Create a migrator instance that looks for migrations in the `./migrations` folder
|
||||
let migrator = Migrator::new(migrations_path).await?;
|
||||
|
||||
// Run all pending migrations
|
||||
migrator.run(pool).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
25
src/database/get_users.rs
Normal file
25
src/database/get_users.rs
Normal file
@ -0,0 +1,25 @@
|
||||
use sqlx::postgres::PgPool;
|
||||
use crate::models::user::*; // Import the User struct
|
||||
|
||||
// Get all users
|
||||
pub async fn get_user_by_email(pool: &PgPool, email: String) -> Result<User, String> {
|
||||
// Use a string literal directly in the macro
|
||||
let user = sqlx::query_as!(
|
||||
User, // Struct type to map the query result
|
||||
r#"
|
||||
SELECT id, username, email, password_hash, totp_secret, role_id
|
||||
FROM users
|
||||
WHERE email = $1
|
||||
"#,
|
||||
email // Bind the `email` parameter
|
||||
)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.map_err(|e| format!("Database error: {}", e))?; // Handle database errors
|
||||
|
||||
// Handle optional result
|
||||
match user {
|
||||
Some(user) => Ok(user),
|
||||
None => Err(format!("User with email '{}' not found.", email)),
|
||||
}
|
||||
}
|
||||
3
src/database/mod.rs
Normal file
3
src/database/mod.rs
Normal file
@ -0,0 +1,3 @@
|
||||
// Module declarations
|
||||
pub mod connect;
|
||||
pub mod get_users;
|
||||
6
src/handlers/mod.rs
Normal file
6
src/handlers/mod.rs
Normal file
@ -0,0 +1,6 @@
|
||||
// Module declarations
|
||||
// pub mod auth;
|
||||
|
||||
|
||||
// Re-exporting modules
|
||||
// pub use auth::*;
|
||||
91
src/main.rs
Normal file
91
src/main.rs
Normal file
@ -0,0 +1,91 @@
|
||||
#[allow(dead_code)]
|
||||
|
||||
// Core modules for the configuration, TLS setup, and server creation
|
||||
mod core;
|
||||
use core::{config, tls, server};
|
||||
use core::tls::TlsListener;
|
||||
|
||||
// Other modules for database, routes, models, and middlewares
|
||||
mod database;
|
||||
mod routes;
|
||||
mod models;
|
||||
mod middlewares;
|
||||
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use axum::serve;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
dotenvy::dotenv().ok(); // Load environment variables from a .env file
|
||||
tracing_subscriber::fmt::init(); // Initialize the logging system
|
||||
|
||||
// Print a cool startup message with ASCII art and emojis
|
||||
println!("{}", r#"
|
||||
##### ## ##
|
||||
## ## ##
|
||||
## ## ## ## ##### ####### ##### #### ####
|
||||
##### ## ## ## ## ## ## ## ## ##
|
||||
#### ## ## ## ## ## ## ## ## ##
|
||||
## ## ## ### ## ## ## ### ## ## ##
|
||||
## ## ### ## ##### ### ### ## ##### ######
|
||||
##
|
||||
|
||||
Rustapi - An example API built with Rust, Axum, SQLx, and PostgreSQL
|
||||
GitHub: https://github.com/Riktastic/rustapi
|
||||
"#);
|
||||
|
||||
println!("🚀 Starting Rustapi...");
|
||||
|
||||
// Retrieve server IP and port from the environment, default to 0.0.0.0:3000
|
||||
let ip: IpAddr = config::get_env_with_default("SERVER_IP", "0.0.0.0")
|
||||
.parse()
|
||||
.expect("❌ Invalid IP address format. Please provide a valid IPv4 address. For example 0.0.0.0 or 127.0.0.1.");
|
||||
let port: u16 = config::get_env_u16("SERVER_PORT", 3000);
|
||||
let socket_addr = SocketAddr::new(ip, port);
|
||||
|
||||
// Create the Axum app instance using the server configuration
|
||||
let app = server::create_server().await;
|
||||
|
||||
// Check if HTTPS is enabled in the environment configuration
|
||||
if config::get_env_bool("SERVER_HTTPS_ENABLED", false) {
|
||||
// If HTTPS is enabled, start the server with secure HTTPS.
|
||||
|
||||
// Bind TCP listener for incoming connections
|
||||
let tcp_listener = TcpListener::bind(socket_addr)
|
||||
.await
|
||||
.expect("❌ Failed to bind to socket. Port might allready be in use."); // Explicit error handling
|
||||
|
||||
// Load the TLS configuration for secure HTTPS connections
|
||||
let tls_config = tls::load_tls_config();
|
||||
let acceptor = TlsAcceptor::from(Arc::new(tls_config)); // Create a TLS acceptor
|
||||
let listener = TlsListener {
|
||||
inner: Arc::new(tcp_listener), // Wrap TCP listener in TlsListener
|
||||
acceptor: acceptor,
|
||||
};
|
||||
|
||||
println!("🔒 Server started with HTTPS at: https://{}:{}", ip, port);
|
||||
|
||||
// Serve the app using the TLS listener (HTTPS)
|
||||
serve(listener, app.into_make_service())
|
||||
.await
|
||||
.expect("❌ Server failed to start with HTTPS. Did you provide valid certificate and key files?");
|
||||
|
||||
} else {
|
||||
// If HTTPS is not enabled, start the server with non-secure HTTP.
|
||||
|
||||
// Bind TCP listener for non-secure HTTP connections
|
||||
let listener = TcpListener::bind(socket_addr)
|
||||
.await
|
||||
.expect("❌ Failed to bind to socket. Port might allready be in use."); // Explicit error handling
|
||||
|
||||
println!("🔓 Server started with HTTP at: http://{}:{}", ip, port);
|
||||
|
||||
// Serve the app using the non-secure TCP listener (HTTP)
|
||||
serve(listener, app.into_make_service())
|
||||
.await
|
||||
.expect("❌ Server failed to start without HTTPS.");
|
||||
}
|
||||
}
|
||||
211
src/middlewares/auth.rs
Normal file
211
src/middlewares/auth.rs
Normal file
@ -0,0 +1,211 @@
|
||||
// Standard library imports for working with HTTP, environment variables, and other necessary utilities
|
||||
use axum::{
|
||||
body::Body,
|
||||
response::IntoResponse,
|
||||
extract::{Request, Json}, // Extractor for request and JSON body
|
||||
http::{self, Response, StatusCode}, // HTTP response and status codes
|
||||
middleware::Next, // For adding middleware layers to the request handling pipeline
|
||||
};
|
||||
|
||||
// Importing `State` for sharing application state (such as a database connection) across request handlers
|
||||
use axum::extract::State;
|
||||
|
||||
// Importing necessary libraries for password hashing, JWT handling, and date/time management
|
||||
use std::env; // For accessing environment variables
|
||||
use argon2::{
|
||||
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, // For password hashing and verification
|
||||
Argon2,
|
||||
};
|
||||
|
||||
use chrono::{Duration, Utc}; // For working with time (JWT expiration, etc.)
|
||||
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, TokenData, Validation}; // For encoding and decoding JWT tokens
|
||||
use serde::{Deserialize, Serialize}; // For serializing and deserializing JSON data
|
||||
use serde_json::json; // For constructing JSON data
|
||||
use sqlx::PgPool; // For interacting with PostgreSQL databases asynchronously
|
||||
|
||||
// Importing custom database query functions
|
||||
use crate::database::get_users::get_user_by_email;
|
||||
|
||||
// Define the structure for JWT claims to be included in the token payload
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Claims {
|
||||
pub exp: usize, // Expiration timestamp (in seconds)
|
||||
pub iat: usize, // Issued-at timestamp (in seconds)
|
||||
pub email: String, // User's email
|
||||
}
|
||||
|
||||
// Custom error type for handling authentication errors
|
||||
pub struct AuthError {
|
||||
message: String,
|
||||
status_code: StatusCode, // HTTP status code to be returned with the error
|
||||
}
|
||||
|
||||
// Function to verify a password against a stored hash using the Argon2 algorithm
|
||||
pub fn verify_password(password: &str, hash: &str) -> Result<bool, argon2::password_hash::Error> {
|
||||
let parsed_hash = PasswordHash::new(hash)?; // Parse the hash
|
||||
// Verify the password using Argon2
|
||||
Ok(Argon2::default().verify_password(password.as_bytes(), &parsed_hash).is_ok())
|
||||
}
|
||||
|
||||
// Function to hash a password using Argon2 and a salt retrieved from the environment variables
|
||||
pub fn hash_password(password: &str) -> Result<String, argon2::password_hash::Error> {
|
||||
// Get the salt from environment variables (must be set)
|
||||
let salt = env::var("AUTHENTICATION_ARGON2_SALT").expect("AUTHENTICATION_ARGON2_SALT must be set");
|
||||
let salt = SaltString::from_b64(&salt).unwrap(); // Convert base64 string to SaltString
|
||||
let argon2 = Argon2::default(); // Create an Argon2 instance
|
||||
// Hash the password with the salt
|
||||
let password_hash = argon2.hash_password(password.as_bytes(), &salt)?.to_string();
|
||||
Ok(password_hash)
|
||||
}
|
||||
|
||||
// Implement the IntoResponse trait for AuthError to allow it to be returned as a response from the handler
|
||||
impl IntoResponse for AuthError {
|
||||
fn into_response(self) -> Response<Body> {
|
||||
let body = Json(json!( { "error": self.message } )); // Create a JSON response body with the error message
|
||||
|
||||
// Return a response with the appropriate status code and error message
|
||||
(self.status_code, body).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
// Function to encode a JWT token for the given email address
|
||||
pub fn encode_jwt(email: String) -> Result<String, StatusCode> {
|
||||
let jwt_token: String = "randomstring".to_string(); // Secret key for JWT (should be more secure in production)
|
||||
|
||||
let now = Utc::now(); // Get current time
|
||||
let expire = Duration::hours(24); // Set token expiration to 24 hours
|
||||
let exp: usize = (now + expire).timestamp() as usize; // Expiration timestamp
|
||||
let iat: usize = now.timestamp() as usize; // Issued-at timestamp
|
||||
|
||||
let claim = Claims { iat, exp, email }; // Create JWT claims with timestamps and user email
|
||||
let secret = jwt_token.clone(); // Secret key to sign the token
|
||||
|
||||
// Encode the claims into a JWT token
|
||||
encode(
|
||||
&Header::default(),
|
||||
&claim,
|
||||
&EncodingKey::from_secret(secret.as_ref()),
|
||||
)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) // Return error if encoding fails
|
||||
}
|
||||
|
||||
// Function to decode a JWT token and extract the claims
|
||||
pub fn decode_jwt(jwt: String) -> Result<TokenData<Claims>, StatusCode> {
|
||||
let secret = "randomstring".to_string(); // Secret key to verify the JWT (should be more secure in production)
|
||||
|
||||
// Decode the JWT token using the secret key and extract the claims
|
||||
decode(
|
||||
&jwt,
|
||||
&DecodingKey::from_secret(secret.as_ref()),
|
||||
&Validation::default(),
|
||||
)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) // Return error if decoding fails
|
||||
}
|
||||
|
||||
// Middleware for role-based access control (RBAC)
|
||||
// Ensures that only users with specific roles are authorized to access certain resources
|
||||
pub async fn authorize(
|
||||
mut req: Request<Body>,
|
||||
next: Next,
|
||||
allowed_roles: Vec<i32>, // Accept a vector of allowed roles
|
||||
) -> Result<Response<Body>, AuthError> {
|
||||
// Retrieve the database pool from request extensions (shared application state)
|
||||
let pool = req.extensions().get::<PgPool>().expect("Database pool not found in request extensions");
|
||||
|
||||
// Retrieve the Authorization header from the request
|
||||
let auth_header = req.headers().get(http::header::AUTHORIZATION);
|
||||
|
||||
// Ensure the header exists and is correctly formatted
|
||||
let auth_header = match auth_header {
|
||||
Some(header) => header.to_str().map_err(|_| AuthError {
|
||||
message: "Invalid header format".to_string(),
|
||||
status_code: StatusCode::FORBIDDEN,
|
||||
})?,
|
||||
None => return Err(AuthError {
|
||||
message: "Authorization header missing.".to_string(),
|
||||
status_code: StatusCode::FORBIDDEN,
|
||||
}),
|
||||
};
|
||||
|
||||
// Extract the token from the Authorization header (Bearer token format)
|
||||
let mut header = auth_header.split_whitespace();
|
||||
let (_, token) = (header.next(), header.next());
|
||||
|
||||
// Decode the JWT token
|
||||
let token_data = match decode_jwt(token.unwrap().to_string()) {
|
||||
Ok(data) => data,
|
||||
Err(_) => return Err(AuthError {
|
||||
message: "Unable to decode token.".to_string(),
|
||||
status_code: StatusCode::UNAUTHORIZED,
|
||||
}),
|
||||
};
|
||||
|
||||
// Fetch the user from the database using the email from the decoded token
|
||||
let current_user = match get_user_by_email(&pool, token_data.claims.email).await {
|
||||
Ok(user) => user,
|
||||
Err(_) => return Err(AuthError {
|
||||
message: "Unauthorized user.".to_string(),
|
||||
status_code: StatusCode::UNAUTHORIZED,
|
||||
}),
|
||||
};
|
||||
|
||||
// Check if the user's role is in the list of allowed roles
|
||||
if !allowed_roles.contains(¤t_user.role_id) {
|
||||
return Err(AuthError {
|
||||
message: "Forbidden: insufficient role.".to_string(),
|
||||
status_code: StatusCode::FORBIDDEN,
|
||||
});
|
||||
}
|
||||
|
||||
// Insert the current user into the request extensions for use in subsequent handlers
|
||||
req.extensions_mut().insert(current_user);
|
||||
|
||||
// Proceed to the next middleware or handler
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
|
||||
// Structure to hold the data from the sign-in request
|
||||
#[derive(Deserialize)]
|
||||
pub struct SignInData {
|
||||
pub email: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
// Handler for user sign-in (authentication)
|
||||
pub async fn sign_in(
|
||||
State(pool): State<PgPool>, // Database connection pool injected as state
|
||||
Json(user_data): Json<SignInData>, // Deserialize the JSON body into SignInData
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
|
||||
|
||||
// 1. Retrieve user from the database using the provided email
|
||||
let user = match get_user_by_email(&pool, user_data.email).await {
|
||||
Ok(user) => user,
|
||||
Err(_) => return Err((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({ "error": "Incorrect credentials." }))
|
||||
)),
|
||||
};
|
||||
|
||||
// 2. Verify the password using the stored hash
|
||||
if !verify_password(&user_data.password, &user.password_hash)
|
||||
.map_err(|_| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": "Internal server error." }))
|
||||
))?
|
||||
{
|
||||
return Err((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({ "error": "Incorrect credentials." }))
|
||||
));
|
||||
}
|
||||
|
||||
// 3. Generate a JWT token for the authenticated user
|
||||
let token = encode_jwt(user.email)
|
||||
.map_err(|_| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": "Internal server error." }))
|
||||
))?;
|
||||
|
||||
// 4. Return the JWT token to the client
|
||||
Ok(Json(json!({ "token": token })))
|
||||
}
|
||||
2
src/middlewares/mod.rs
Normal file
2
src/middlewares/mod.rs
Normal file
@ -0,0 +1,2 @@
|
||||
// Module declarations
|
||||
pub mod auth;
|
||||
5
src/models/mod.rs
Normal file
5
src/models/mod.rs
Normal file
@ -0,0 +1,5 @@
|
||||
/// Module for to-do related models.
|
||||
pub mod todo;
|
||||
/// Module for user related models.
|
||||
pub mod user;
|
||||
|
||||
18
src/models/role.rs
Normal file
18
src/models/role.rs
Normal file
@ -0,0 +1,18 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::FromRow;
|
||||
|
||||
/// Represents a user role in the system.
|
||||
#[derive(Deserialize, Debug, Serialize, FromRow, Clone)]
|
||||
#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL
|
||||
pub struct Role {
|
||||
/// ID of the role.
|
||||
pub id: i32,
|
||||
/// Level of the role.
|
||||
pub level: i32,
|
||||
/// System name of the role.
|
||||
pub role: String,
|
||||
/// The name of the role.
|
||||
pub name: String,
|
||||
/// Description of the role
|
||||
pub Description: String,
|
||||
}
|
||||
16
src/models/todo.rs
Normal file
16
src/models/todo.rs
Normal file
@ -0,0 +1,16 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::FromRow;
|
||||
|
||||
/// Represents a to-do item.
|
||||
#[derive(Deserialize, Debug, Serialize, FromRow)]
|
||||
#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL
|
||||
pub struct Todo {
|
||||
/// The unique identifier for the to-do item.
|
||||
pub id: i32,
|
||||
/// The task description.
|
||||
pub task: String,
|
||||
/// An optional detailed description of the task.
|
||||
pub description: Option<String>,
|
||||
/// The unique identifier of the user who created the to-do item.
|
||||
pub user_id: i32,
|
||||
}
|
||||
20
src/models/user.rs
Normal file
20
src/models/user.rs
Normal file
@ -0,0 +1,20 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::FromRow;
|
||||
|
||||
/// Represents a user in the system.
|
||||
#[derive(Deserialize, Debug, Serialize, FromRow, Clone)]
|
||||
#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL
|
||||
pub struct User {
|
||||
/// The unique identifier for the user.
|
||||
pub id: i32,
|
||||
/// The username of the user.
|
||||
pub username: String,
|
||||
/// The email of the user.
|
||||
pub email: String,
|
||||
/// The hashed password for the user.
|
||||
pub password_hash: String,
|
||||
/// The TOTP secret for the user.
|
||||
pub totp_secret: Option<String>,
|
||||
/// Current role of the user..
|
||||
pub role_id: i32,
|
||||
}
|
||||
200
src/routes/get_health.rs
Normal file
200
src/routes/get_health.rs
Normal file
@ -0,0 +1,200 @@
|
||||
use axum::{response::IntoResponse, Json, extract::State};
|
||||
use serde_json::json;
|
||||
use sqlx::PgPool;
|
||||
use sysinfo::{System, RefreshKind, Disks};
|
||||
use tokio::{task, join};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
pub async fn get_health(State(database_connection): State<PgPool>) -> impl IntoResponse {
|
||||
// Use Arc and Mutex to allow sharing System between tasks
|
||||
let system = Arc::new(Mutex::new(System::new_with_specifics(RefreshKind::everything())));
|
||||
|
||||
// Run checks in parallel
|
||||
let (cpu_result, mem_result, disk_result, process_result, db_result, net_result) = join!(
|
||||
task::spawn_blocking({
|
||||
let system = Arc::clone(&system);
|
||||
move || {
|
||||
let mut system = system.lock().unwrap(); // Lock the mutex and get a mutable reference
|
||||
check_cpu_usage(&mut system) // Pass the mutable reference
|
||||
}
|
||||
}),
|
||||
task::spawn_blocking({
|
||||
let system = Arc::clone(&system);
|
||||
move || {
|
||||
let mut system = system.lock().unwrap(); // Lock the mutex and get a mutable reference
|
||||
check_memory(&mut system) // Pass the mutable reference
|
||||
}
|
||||
}),
|
||||
task::spawn_blocking({
|
||||
move || {
|
||||
check_disk_usage() // Does not need a system reference.
|
||||
}
|
||||
}),
|
||||
task::spawn_blocking({
|
||||
let system = Arc::clone(&system);
|
||||
move || {
|
||||
let mut system = system.lock().unwrap(); // Lock the mutex and get a mutable reference
|
||||
check_processes(&mut system, &["postgres", "Code"]) // Pass the mutable reference
|
||||
}
|
||||
}),
|
||||
check_database_connection(&database_connection), // Async function
|
||||
task::spawn_blocking(check_network_connection) // Blocking, okay in spawn_blocking
|
||||
);
|
||||
|
||||
let mut status = "healthy";
|
||||
let mut details = json!({});
|
||||
|
||||
// Process CPU result
|
||||
if let Ok(Ok(cpu_details)) = cpu_result {
|
||||
details["cpu_usage"] = json!(cpu_details);
|
||||
if cpu_details["status"] == "low" {
|
||||
status = "degraded";
|
||||
}
|
||||
} else {
|
||||
details["cpu_usage"] = json!({ "status": "error", "error": "Failed to retrieve CPU usage" });
|
||||
status = "degraded";
|
||||
}
|
||||
|
||||
// Process Memory result
|
||||
if let Ok(Ok(mem_details)) = mem_result {
|
||||
details["memory"] = json!(mem_details);
|
||||
if mem_details["status"] == "low" {
|
||||
status = "degraded";
|
||||
}
|
||||
} else {
|
||||
details["memory"] = json!({ "status": "error", "error": "Failed to retrieve memory information" });
|
||||
status = "degraded";
|
||||
}
|
||||
|
||||
// Process Disk result
|
||||
if let Ok(Ok(disk_details)) = disk_result {
|
||||
details["disk_usage"] = json!(disk_details);
|
||||
if disk_details["status"] == "critical" {
|
||||
status = "degraded";
|
||||
}
|
||||
} else {
|
||||
details["disk_usage"] = json!({ "status": "error", "error": "Failed to retrieve disk usage" });
|
||||
status = "degraded";
|
||||
}
|
||||
|
||||
// Process Process result
|
||||
if let Ok(Ok(process_details)) = process_result {
|
||||
details["important_processes"] = json!(process_details);
|
||||
if process_details.iter().any(|p| p["status"] == "not running") {
|
||||
status = "degraded";
|
||||
}
|
||||
} else {
|
||||
details["important_processes"] = json!({ "status": "error", "error": "Failed to retrieve process information" });
|
||||
status = "degraded";
|
||||
}
|
||||
|
||||
// Process Database result
|
||||
if let Ok(db_status) = db_result {
|
||||
details["database"] = json!({ "status": if db_status { "ok" } else { "degraded" } });
|
||||
if !db_status {
|
||||
status = "degraded";
|
||||
}
|
||||
} else {
|
||||
details["database"] = json!({ "status": "error", "error": "Failed to retrieve database status" });
|
||||
status = "degraded";
|
||||
}
|
||||
|
||||
// Process Network result
|
||||
if let Ok(Ok(net_status)) = net_result {
|
||||
details["network"] = json!({ "status": if net_status { "ok" } else { "degraded" } });
|
||||
if !net_status {
|
||||
status = "degraded";
|
||||
}
|
||||
} else {
|
||||
details["network"] = json!({ "status": "error", "error": "Failed to retrieve network status" });
|
||||
status = "degraded";
|
||||
}
|
||||
|
||||
Json(json!({
|
||||
"status": status,
|
||||
"details": details,
|
||||
}))
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
fn check_cpu_usage(system: &mut System) -> Result<serde_json::Value, ()> {
|
||||
system.refresh_cpu_usage();
|
||||
let usage = system.global_cpu_usage();
|
||||
let available = 100.0 - usage;
|
||||
Ok(json!( {
|
||||
"usage_percentage": format!("{:.2}", usage),
|
||||
"available_percentage": format!("{:.2}", available),
|
||||
"status": if available < 10.0 { "low" } else { "normal" },
|
||||
}))
|
||||
}
|
||||
|
||||
fn check_memory(system: &mut System) -> Result<serde_json::Value, ()> {
|
||||
system.refresh_memory();
|
||||
let available = system.available_memory() / 1024 / 1024; // Convert to MB
|
||||
Ok(json!( {
|
||||
"available_mb": available,
|
||||
"status": if available < 512 { "low" } else { "normal" },
|
||||
}))
|
||||
}
|
||||
|
||||
fn check_disk_usage() -> Result<serde_json::Value, ()> {
|
||||
// Create a new Disks object and refresh the disk information
|
||||
let mut disks = Disks::new();
|
||||
disks.refresh(false); // Refresh disk information without performing a full refresh
|
||||
|
||||
// Iterate through the list of disks and check the usage for each one
|
||||
let usage: Vec<_> = disks.list().iter().map(|disk| {
|
||||
let total = disk.total_space() as f64;
|
||||
let available = disk.available_space() as f64;
|
||||
let used_percentage = ((total - available) / total) * 100.0;
|
||||
used_percentage
|
||||
}).collect();
|
||||
|
||||
// Get the maximum usage percentage
|
||||
let max_usage = usage.into_iter()
|
||||
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.unwrap_or(0.0);
|
||||
|
||||
// Return the result as a JSON object
|
||||
Ok(json!( {
|
||||
"used_percentage": format!("{:.2}", max_usage),
|
||||
"status": if max_usage > 90.0 { "critical" } else { "ok" },
|
||||
}))
|
||||
}
|
||||
|
||||
fn check_processes(system: &mut System, processes: &[&str]) -> Result<Vec<serde_json::Value>, ()> {
|
||||
system.refresh_processes(sysinfo::ProcessesToUpdate::All, true);
|
||||
|
||||
let process_statuses: Vec<_> = processes.iter().map(|&name| {
|
||||
// Adjust process names based on the platform and check if they are running
|
||||
let adjusted_name = if cfg!(target_os = "windows") {
|
||||
match name {
|
||||
"postgres" => "postgres.exe", // Postgres on Windows
|
||||
"Code" => "Code.exe", // Visual Studio Code on Windows
|
||||
_ => name, // For other platforms, use the name as is
|
||||
}
|
||||
} else {
|
||||
name // For non-Windows platforms, use the name as is
|
||||
};
|
||||
|
||||
// Check if the translated (adjusted) process is running
|
||||
let is_running = system.processes().iter().any(|(_, proc)| proc.name() == adjusted_name);
|
||||
|
||||
// Return a JSON object for each process with its status
|
||||
json!({
|
||||
"name": name,
|
||||
"status": if is_running { "running" } else { "not running" }
|
||||
})
|
||||
}).collect();
|
||||
|
||||
Ok(process_statuses)
|
||||
}
|
||||
|
||||
async fn check_database_connection(pool: &PgPool) -> Result<bool, sqlx::Error> {
|
||||
sqlx::query("SELECT 1").fetch_one(pool).await.map(|_| true).or_else(|_| Ok(false))
|
||||
}
|
||||
|
||||
fn check_network_connection() -> Result<bool, ()> {
|
||||
Ok(std::net::TcpStream::connect("8.8.8.8:53").is_ok())
|
||||
}
|
||||
43
src/routes/get_todos.rs
Normal file
43
src/routes/get_todos.rs
Normal file
@ -0,0 +1,43 @@
|
||||
use axum::extract::{State, Path};
|
||||
use axum::Json;
|
||||
use axum::response::IntoResponse;
|
||||
use sqlx::postgres::PgPool;
|
||||
use crate::models::todo::*;
|
||||
|
||||
// Get all todos
|
||||
pub async fn get_all_todos(State(pool): State<PgPool>,) -> impl IntoResponse {
|
||||
let todos = sqlx::query_as!(Todo, "SELECT * FROM todos") // Your table name
|
||||
.fetch_all(&pool) // Borrow the connection pool
|
||||
.await;
|
||||
|
||||
match todos {
|
||||
Ok(todos) => Ok(Json(todos)), // Return all todos as JSON
|
||||
Err(err) => Err((
|
||||
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Error fetching todos: {}", err),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Get a single todo by id
|
||||
pub async fn get_todos_by_id(
|
||||
State(pool): State<PgPool>,
|
||||
Path(id): Path<i32>, // Use Path extractor here
|
||||
) -> impl IntoResponse {
|
||||
let todo = sqlx::query_as!(Todo, "SELECT * FROM todos WHERE id = $1", id)
|
||||
.fetch_optional(&pool) // Borrow the connection pool
|
||||
.await;
|
||||
|
||||
match todo {
|
||||
Ok(Some(todo)) => Ok(Json(todo)), // Return the todo as JSON if found
|
||||
Ok(None) => Err((
|
||||
axum::http::StatusCode::NOT_FOUND,
|
||||
format!("Todo with id {} not found", id),
|
||||
)),
|
||||
Err(err) => Err((
|
||||
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Error fetching todo: {}", err),
|
||||
)),
|
||||
}
|
||||
}
|
||||
87
src/routes/get_users.rs
Normal file
87
src/routes/get_users.rs
Normal file
@ -0,0 +1,87 @@
|
||||
use axum::extract::{State, Path};
|
||||
use axum::Json;
|
||||
use axum::response::IntoResponse;
|
||||
use sqlx::postgres::PgPool;
|
||||
use crate::models::user::*; // Import the User struct
|
||||
|
||||
// Get all users
|
||||
pub async fn get_all_users(State(pool): State<PgPool>,) -> impl IntoResponse {
|
||||
let users = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users") // Your table name
|
||||
.fetch_all(&pool) // Borrow the connection pool
|
||||
.await;
|
||||
|
||||
match users {
|
||||
Ok(users) => Ok(Json(users)), // Return all users as JSON
|
||||
Err(err) => Err((
|
||||
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Error fetching users: {}", err),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
// Get a single user by id
|
||||
pub async fn get_users_by_id(
|
||||
State(pool): State<PgPool>,
|
||||
Path(id): Path<i32>, // Use Path extractor here
|
||||
) -> impl IntoResponse {
|
||||
|
||||
let user = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users WHERE id = $1", id)
|
||||
.fetch_optional(&pool) // Borrow the connection pool
|
||||
.await;
|
||||
|
||||
match user {
|
||||
Ok(Some(user)) => Ok(Json(user)), // Return the user as JSON if found
|
||||
Ok(None) => Err((
|
||||
axum::http::StatusCode::NOT_FOUND,
|
||||
format!("User with id {} not found", id),
|
||||
)),
|
||||
Err(err) => Err((
|
||||
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Error fetching user: {}", err),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
// Get a single user by username
|
||||
pub async fn get_user_by_username(
|
||||
State(pool): State<PgPool>,
|
||||
Path(username): Path<String>, // Use Path extractor here for username
|
||||
) -> impl IntoResponse {
|
||||
let user = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users WHERE username = $1", username)
|
||||
.fetch_optional(&pool) // Borrow the connection pool
|
||||
.await;
|
||||
|
||||
match user {
|
||||
Ok(Some(user)) => Ok(Json(user)), // Return the user as JSON if found
|
||||
Ok(None) => Err((
|
||||
axum::http::StatusCode::NOT_FOUND,
|
||||
format!("User with username {} not found", username),
|
||||
)),
|
||||
Err(err) => Err((
|
||||
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Error fetching user: {}", err),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
// Get a single user by email
|
||||
pub async fn get_user_by_email(
|
||||
State(pool): State<PgPool>,
|
||||
Path(email): Path<String>, // Use Path extractor here for email
|
||||
) -> impl IntoResponse {
|
||||
let user = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users WHERE email = $1", email)
|
||||
.fetch_optional(&pool) // Borrow the connection pool
|
||||
.await;
|
||||
|
||||
match user {
|
||||
Ok(Some(user)) => Ok(Json(user)), // Return the user as JSON if found
|
||||
Ok(None) => Err((
|
||||
axum::http::StatusCode::NOT_FOUND,
|
||||
format!("User with email {} not found", email),
|
||||
)),
|
||||
Err(err) => Err((
|
||||
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Error fetching user: {}", err),
|
||||
)),
|
||||
}
|
||||
}
|
||||
59
src/routes/mod.rs
Normal file
59
src/routes/mod.rs
Normal file
@ -0,0 +1,59 @@
|
||||
// Module declarations for different route handlers
|
||||
pub mod get_todos;
|
||||
pub mod get_users;
|
||||
pub mod post_todos;
|
||||
pub mod post_users;
|
||||
pub mod get_health;
|
||||
pub mod protected;
|
||||
|
||||
// Re-exporting modules to make their contents available at this level
|
||||
pub use get_todos::*;
|
||||
pub use get_users::*;
|
||||
pub use post_todos::*;
|
||||
pub use post_users::*;
|
||||
pub use get_health::*;
|
||||
pub use protected::*;
|
||||
|
||||
use axum::{
|
||||
Router,
|
||||
routing::{get, post},
|
||||
};
|
||||
|
||||
use sqlx::PgPool;
|
||||
|
||||
use crate::middlewares::auth::{sign_in, authorize};
|
||||
|
||||
/// Function to create and configure all routes
|
||||
pub fn create_routes(database_connection: PgPool) -> Router {
|
||||
// Authentication routes
|
||||
let auth_routes = Router::new()
|
||||
.route("/signin", post(sign_in))
|
||||
.route("/protected", get(protected).route_layer(axum::middleware::from_fn(|req, next| {
|
||||
let allowed_roles = vec![1, 2];
|
||||
authorize(req, next, allowed_roles)
|
||||
})));
|
||||
|
||||
// User-related routes
|
||||
let user_routes = Router::new()
|
||||
.route("/all", get(get_all_users))
|
||||
.route("/{id}", get(get_users_by_id))
|
||||
.route("/", post(post_user));
|
||||
|
||||
// Todo-related routes
|
||||
let todo_routes = Router::new()
|
||||
.route("/all", get(get_all_todos))
|
||||
.route("/", post(post_todo).layer(axum::middleware::from_fn(|req, next| {
|
||||
let allowed_roles = vec![1, 2];
|
||||
authorize(req, next, allowed_roles)
|
||||
})))
|
||||
.route("/{id}", get(get_todos_by_id));
|
||||
|
||||
// Combine all routes and add middleware
|
||||
Router::new()
|
||||
.merge(auth_routes) // Add authentication routes
|
||||
.nest("/users", user_routes) // Add user routes under /users
|
||||
.nest("/todos", todo_routes) // Add todo routes under /todos
|
||||
.route("/health", get(get_health)) // Add health check route
|
||||
.layer(axum::Extension(database_connection.clone())) // Add database connection to all routes
|
||||
.with_state(database_connection) // Add database connection as state
|
||||
}
|
||||
53
src/routes/post_todos.rs
Normal file
53
src/routes/post_todos.rs
Normal file
@ -0,0 +1,53 @@
|
||||
use axum::{extract::{State, Extension}, Json};
|
||||
use axum::response::IntoResponse;
|
||||
use sqlx::postgres::PgPool;
|
||||
use crate::models::todo::*;
|
||||
use crate::models::user::*;
|
||||
use serde::Deserialize;
|
||||
use axum::http::StatusCode;
|
||||
use serde_json::json;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct TodoBody {
|
||||
pub task: String,
|
||||
pub description: Option<String>,
|
||||
pub user_id: i32,
|
||||
}
|
||||
|
||||
// Add a new todo
|
||||
pub async fn post_todo(
|
||||
State(pool): State<PgPool>,
|
||||
Extension(user): Extension<User>, // Extract current user from the request extensions
|
||||
Json(todo): Json<TodoBody>
|
||||
) -> impl IntoResponse {
|
||||
// Ensure the user_id from the request matches the current user's id
|
||||
if todo.user_id != user.id {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": "User is not authorized to create a todo for another user" }))
|
||||
));
|
||||
}
|
||||
|
||||
// Insert the todo into the database
|
||||
let row = sqlx::query!(
|
||||
"INSERT INTO todos (task, description, user_id) VALUES ($1, $2, $3) RETURNING id, task, description, user_id",
|
||||
todo.task,
|
||||
todo.description,
|
||||
todo.user_id
|
||||
)
|
||||
.fetch_one(&pool)
|
||||
.await;
|
||||
|
||||
match row {
|
||||
Ok(row) => Ok(Json(Todo {
|
||||
id: row.id,
|
||||
task: row.task,
|
||||
description: row.description,
|
||||
user_id: row.user_id,
|
||||
})),
|
||||
Err(err) => Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": format!("Error: {}", err) }))
|
||||
)),
|
||||
}
|
||||
}
|
||||
44
src/routes/post_users.rs
Normal file
44
src/routes/post_users.rs
Normal file
@ -0,0 +1,44 @@
|
||||
use axum::extract::State;
|
||||
use axum::Json;
|
||||
use axum::response::IntoResponse;
|
||||
use sqlx::postgres::PgPool;
|
||||
use crate::models::user::*;
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct UserBody {
|
||||
pub username: String,
|
||||
pub email: String,
|
||||
pub password_hash: String,
|
||||
pub totp_secret: String,
|
||||
pub role_id: i32,
|
||||
}
|
||||
|
||||
// Add a new user
|
||||
pub async fn post_user(State(pool): State<PgPool>, Json(user): Json<UserBody>, ) -> impl IntoResponse {
|
||||
let row = sqlx::query!(
|
||||
"INSERT INTO users (username, email, password_hash, totp_secret, role_id) VALUES ($1, $2, $3, $4, $5) RETURNING id, username, email, password_hash, totp_secret, role_id",
|
||||
user.username,
|
||||
user.email,
|
||||
user.password_hash,
|
||||
user.totp_secret,
|
||||
user.role_id
|
||||
)
|
||||
.fetch_one(&pool) // Use `&pool` to borrow the connection pool
|
||||
.await;
|
||||
|
||||
match row {
|
||||
Ok(row) => Ok(Json(User {
|
||||
id: row.id,
|
||||
username: row.username,
|
||||
email: row.email,
|
||||
password_hash: row.password_hash,
|
||||
totp_secret: row.totp_secret,
|
||||
role_id: row.role_id,
|
||||
})),
|
||||
Err(err) => Err((
|
||||
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Error: {}", err),
|
||||
)),
|
||||
}
|
||||
}
|
||||
18
src/routes/protected.rs
Normal file
18
src/routes/protected.rs
Normal file
@ -0,0 +1,18 @@
|
||||
use axum::{Extension, Json, response::IntoResponse};
|
||||
use serde::{Serialize, Deserialize};
|
||||
use crate::models::user::User;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct UserResponse {
|
||||
id: i32,
|
||||
username: String,
|
||||
email: String
|
||||
}
|
||||
|
||||
pub async fn protected(Extension(user): Extension<User>) -> impl IntoResponse {
|
||||
Json(UserResponse {
|
||||
id: user.id,
|
||||
username: user.username,
|
||||
email: user.email
|
||||
})
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user