mirror of
https://github.com/kristoferssolo/Axium.git
synced 2026-02-04 13:32:02 +00:00
first commit
This commit is contained in:
55
src/core/config.rs
Normal file
55
src/core/config.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
// Import the standard library's environment module
|
||||
use std::env;
|
||||
|
||||
/// Retrieves the value of an environment variable as a `String`.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `key` - The name of the environment variable to retrieve.
|
||||
///
|
||||
/// # Returns
|
||||
/// * The value of the environment variable if it exists.
|
||||
/// * Panics if the environment variable is missing.
|
||||
pub fn get_env(key: &str) -> String {
|
||||
env::var(key).unwrap_or_else(|_| panic!("Missing required environment variable: {}", key))
|
||||
}
|
||||
|
||||
/// Retrieves the value of an environment variable as a `String`, with a default value if not found.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `key` - The name of the environment variable to retrieve.
|
||||
/// * `default` - The value to return if the environment variable is not found.
|
||||
///
|
||||
/// # Returns
|
||||
/// * The value of the environment variable if it exists.
|
||||
/// * The `default` value if the environment variable is missing.
|
||||
pub fn get_env_with_default(key: &str, default: &str) -> String {
|
||||
env::var(key).unwrap_or_else(|_| default.to_string())
|
||||
}
|
||||
|
||||
/// Retrieves the value of an environment variable as a `bool`, with a default value if not found.
|
||||
///
|
||||
/// The environment variable is considered `true` if its value is "true" (case-insensitive), otherwise `false`.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `key` - The name of the environment variable to retrieve.
|
||||
/// * `default` - The value to return if the environment variable is not found.
|
||||
///
|
||||
/// # Returns
|
||||
/// * `true` if the environment variable is "true" (case-insensitive).
|
||||
/// * `false` otherwise, or if the variable is missing, the `default` value is returned.
|
||||
pub fn get_env_bool(key: &str, default: bool) -> bool {
|
||||
env::var(key).map(|v| v.to_lowercase() == "true").unwrap_or(default)
|
||||
}
|
||||
|
||||
/// Retrieves the value of an environment variable as a `u16`, with a default value if not found.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `key` - The name of the environment variable to retrieve.
|
||||
/// * `default` - The value to return if the environment variable is not found or cannot be parsed.
|
||||
///
|
||||
/// # Returns
|
||||
/// * The parsed `u16` value of the environment variable if it exists and is valid.
|
||||
/// * The `default` value if the variable is missing or invalid.
|
||||
pub fn get_env_u16(key: &str, default: u16) -> u16 {
|
||||
env::var(key).unwrap_or_else(|_| default.to_string()).parse().unwrap_or(default)
|
||||
}
|
||||
4
src/core/mod.rs
Normal file
4
src/core/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
pub mod config;
|
||||
pub mod server;
|
||||
pub mod tls;
|
||||
|
||||
38
src/core/server.rs
Normal file
38
src/core/server.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
// Axum for web server and routing
|
||||
use axum::Router;
|
||||
|
||||
// Middleware layers from tower_http
|
||||
use tower_http::compression::{CompressionLayer, CompressionLevel}; // For HTTP response compression
|
||||
use tower_http::trace::TraceLayer; // For HTTP request/response tracing
|
||||
|
||||
// Local crate imports for database connection and configuration
|
||||
use crate::database::connect::connect_to_database; // Function to connect to the database
|
||||
use crate::config; // Environment configuration helper
|
||||
|
||||
/// Function to create and configure the Axum server.
|
||||
pub async fn create_server() -> Router {
|
||||
// Establish a connection to the database
|
||||
let db = connect_to_database().await.expect("Failed to connect to database.");
|
||||
|
||||
// Initialize the routes for the server
|
||||
let mut app = crate::routes::create_routes(db);
|
||||
|
||||
// Enable tracing middleware if configured
|
||||
if config::get_env_bool("SERVER_TRACE_ENABLED", true) {
|
||||
app = app.layer(TraceLayer::new_for_http());
|
||||
println!("✔️ Trace hads been enabled.");
|
||||
}
|
||||
|
||||
// Enable compression middleware if configured
|
||||
if config::get_env_bool("SERVER_COMPRESSION_ENABLED", true) {
|
||||
// Parse compression level from environment or default to level 6
|
||||
let level = config::get_env("SERVER_COMPRESSION_LEVEL").parse().unwrap_or(6);
|
||||
// Apply compression layer with Brotli (br) enabled and the specified compression level
|
||||
app = app.layer(CompressionLayer::new().br(true).quality(CompressionLevel::Precise(level)));
|
||||
println!("✔️ Brotli compression enabled with compression quality level {}.", level);
|
||||
|
||||
}
|
||||
|
||||
// Return the fully configured application
|
||||
app
|
||||
}
|
||||
145
src/core/tls.rs
Normal file
145
src/core/tls.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
// Standard library imports
|
||||
use std::{
|
||||
future::Future,
|
||||
net::SocketAddr,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
io::BufReader,
|
||||
fs::File,
|
||||
iter,
|
||||
};
|
||||
|
||||
// External crate imports
|
||||
use axum::serve::Listener;
|
||||
use rustls::{self, server::ServerConfig, pki_types::{PrivateKeyDer, CertificateDer}};
|
||||
use rustls_pemfile::{Item, read_one, certs};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing;
|
||||
|
||||
// Local crate imports
|
||||
use crate::config; // Import env config helper
|
||||
|
||||
// Function to load TLS configuration from files
|
||||
pub fn load_tls_config() -> ServerConfig {
|
||||
// Get certificate and key file paths from the environment
|
||||
let cert_path = config::get_env("SERVER_HTTPS_CERT_FILE_PATH");
|
||||
let key_path = config::get_env("SERVER_HTTPS_KEY_FILE_PATH");
|
||||
|
||||
// Open the certificate and key files
|
||||
let cert_file = File::open(cert_path).expect("❌ Failed to open certificate file.");
|
||||
let key_file = File::open(key_path).expect("❌ Failed to open private key file.");
|
||||
|
||||
// Read the certificate chain and private key from the files
|
||||
let mut cert_reader = BufReader::new(cert_file);
|
||||
let mut key_reader = BufReader::new(key_file);
|
||||
|
||||
// Read and parse the certificate chain
|
||||
let cert_chain: Vec<CertificateDer> = certs(&mut cert_reader)
|
||||
.map(|cert| cert.expect("❌ Failed to read certificate."))
|
||||
.map(CertificateDer::from)
|
||||
.collect();
|
||||
|
||||
// Ensure certificates are found
|
||||
if cert_chain.is_empty() {
|
||||
panic!("❌ No valid certificates found.");
|
||||
}
|
||||
|
||||
// Read the private key from the file
|
||||
let key = iter::from_fn(|| read_one(&mut key_reader).transpose())
|
||||
.find_map(|item| match item.unwrap() {
|
||||
Item::Pkcs1Key(key) => Some(PrivateKeyDer::from(key)),
|
||||
Item::Pkcs8Key(key) => Some(PrivateKeyDer::from(key)),
|
||||
Item::Sec1Key(key) => Some(PrivateKeyDer::from(key)),
|
||||
_ => None,
|
||||
})
|
||||
.expect("Failed to read a valid private key.");
|
||||
|
||||
// Build and return the TLS server configuration
|
||||
ServerConfig::builder()
|
||||
.with_no_client_auth() // No client authentication
|
||||
.with_single_cert(cert_chain, key) // Use the provided cert and key
|
||||
.expect("Failed to create TLS configuration")
|
||||
}
|
||||
|
||||
// Custom listener that implements axum::serve::Listener
|
||||
#[derive(Clone)]
|
||||
pub struct TlsListener {
|
||||
pub inner: Arc<tokio::net::TcpListener>, // Inner TCP listener
|
||||
pub acceptor: tokio_rustls::TlsAcceptor, // TLS acceptor for handling TLS handshakes
|
||||
}
|
||||
|
||||
impl Listener for TlsListener {
|
||||
type Io = TlsStreamWrapper; // Type of I/O stream
|
||||
type Addr = SocketAddr; // Type of address (Socket address)
|
||||
|
||||
// Method to accept incoming connections and establish a TLS handshake
|
||||
fn accept(&mut self) -> impl Future<Output = (Self::Io, Self::Addr)> + Send {
|
||||
let acceptor = self.acceptor.clone(); // Clone the acceptor for async use
|
||||
|
||||
async move {
|
||||
loop {
|
||||
// Accept a TCP connection
|
||||
let (stream, addr) = match self.inner.accept().await {
|
||||
Ok((stream, addr)) => (stream, addr),
|
||||
Err(e) => {
|
||||
tracing::error!("❌ Error accepting TCP connection: {}", e);
|
||||
continue; // Retry on error
|
||||
}
|
||||
};
|
||||
|
||||
// Perform TLS handshake
|
||||
match acceptor.accept(stream).await {
|
||||
Ok(tls_stream) => {
|
||||
tracing::info!("Successful TLS handshake with {}", addr);
|
||||
return (TlsStreamWrapper(tls_stream), addr); // Return TLS stream and address
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!("TLS handshake failed: {} (Client may not trust certificate)", e);
|
||||
continue; // Retry on error
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Method to retrieve the local address of the listener
|
||||
fn local_addr(&self) -> std::io::Result<Self::Addr> {
|
||||
self.inner.local_addr()
|
||||
}
|
||||
}
|
||||
|
||||
// Wrapper for a TLS stream, implementing AsyncRead and AsyncWrite
|
||||
#[derive(Debug)]
|
||||
pub struct TlsStreamWrapper(tokio_rustls::server::TlsStream<tokio::net::TcpStream>);
|
||||
|
||||
impl AsyncRead for TlsStreamWrapper {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
Pin::new(&mut self.0).poll_read(cx, buf) // Delegate read operation to the underlying TLS stream
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for TlsStreamWrapper {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
Pin::new(&mut self.0).poll_write(cx, buf) // Delegate write operation to the underlying TLS stream
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Pin::new(&mut self.0).poll_flush(cx) // Flush operation for the TLS stream
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Pin::new(&mut self.0).poll_shutdown(cx) // Shutdown operation for the TLS stream
|
||||
}
|
||||
}
|
||||
|
||||
// Allow the TLS stream wrapper to be used in non-blocking contexts (needed for async operations)
|
||||
impl Unpin for TlsStreamWrapper {}
|
||||
51
src/database/connect.rs
Normal file
51
src/database/connect.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
use dotenvy::dotenv;
|
||||
use sqlx::{PgPool, migrate::Migrator, postgres::PgPoolOptions};
|
||||
use std::fs;
|
||||
use std::env;
|
||||
use std::path::Path;
|
||||
|
||||
/// Connects to the database using the DATABASE_URL environment variable.
|
||||
pub async fn connect_to_database() -> Result<PgPool, sqlx::Error> {
|
||||
dotenv().ok();
|
||||
let database_url = &env::var("DATABASE_URL").expect("❌ 'DATABASE_URL' environment variable not fount.");
|
||||
|
||||
// Read max and min connection values from environment variables, with defaults
|
||||
let max_connections: u32 = env::var("DATABASE_MAX_CONNECTIONS")
|
||||
.unwrap_or_else(|_| "10".to_string()) // Default to 10
|
||||
.parse()
|
||||
.expect("❌ Invalid 'DATABASE_MAX_CONNECTIONS' value; must be a number.");
|
||||
|
||||
let min_connections: u32 = env::var("DATABASE_MIN_CONNECTIONS")
|
||||
.unwrap_or_else(|_| "2".to_string()) // Default to 2
|
||||
.parse()
|
||||
.expect("❌ Invalid 'DATABASE_MIN_CONNECTIONS' value; must be a number.");
|
||||
|
||||
// Create and configure the connection pool
|
||||
let pool = PgPoolOptions::new()
|
||||
.max_connections(max_connections)
|
||||
.min_connections(min_connections)
|
||||
.connect(&database_url)
|
||||
.await?;
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
/// Run database migrations
|
||||
pub async fn run_database_migrations(pool: &PgPool) -> Result<(), sqlx::Error> {
|
||||
// Define the path to the migrations folder
|
||||
let migrations_path = Path::new("./migrations");
|
||||
|
||||
// Check if the migrations folder exists, and if not, create it
|
||||
if !migrations_path.exists() {
|
||||
fs::create_dir_all(migrations_path).expect("❌ Failed to create migrations directory. Make sure you have the necessary permissions.");
|
||||
println!("✔️ Created migrations directory: {:?}", migrations_path);
|
||||
}
|
||||
|
||||
// Create a migrator instance that looks for migrations in the `./migrations` folder
|
||||
let migrator = Migrator::new(migrations_path).await?;
|
||||
|
||||
// Run all pending migrations
|
||||
migrator.run(pool).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
25
src/database/get_users.rs
Normal file
25
src/database/get_users.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use sqlx::postgres::PgPool;
|
||||
use crate::models::user::*; // Import the User struct
|
||||
|
||||
// Get all users
|
||||
pub async fn get_user_by_email(pool: &PgPool, email: String) -> Result<User, String> {
|
||||
// Use a string literal directly in the macro
|
||||
let user = sqlx::query_as!(
|
||||
User, // Struct type to map the query result
|
||||
r#"
|
||||
SELECT id, username, email, password_hash, totp_secret, role_id
|
||||
FROM users
|
||||
WHERE email = $1
|
||||
"#,
|
||||
email // Bind the `email` parameter
|
||||
)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.map_err(|e| format!("Database error: {}", e))?; // Handle database errors
|
||||
|
||||
// Handle optional result
|
||||
match user {
|
||||
Some(user) => Ok(user),
|
||||
None => Err(format!("User with email '{}' not found.", email)),
|
||||
}
|
||||
}
|
||||
3
src/database/mod.rs
Normal file
3
src/database/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
// Module declarations
|
||||
pub mod connect;
|
||||
pub mod get_users;
|
||||
6
src/handlers/mod.rs
Normal file
6
src/handlers/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
// Module declarations
|
||||
// pub mod auth;
|
||||
|
||||
|
||||
// Re-exporting modules
|
||||
// pub use auth::*;
|
||||
91
src/main.rs
Normal file
91
src/main.rs
Normal file
@@ -0,0 +1,91 @@
|
||||
#[allow(dead_code)]
|
||||
|
||||
// Core modules for the configuration, TLS setup, and server creation
|
||||
mod core;
|
||||
use core::{config, tls, server};
|
||||
use core::tls::TlsListener;
|
||||
|
||||
// Other modules for database, routes, models, and middlewares
|
||||
mod database;
|
||||
mod routes;
|
||||
mod models;
|
||||
mod middlewares;
|
||||
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use axum::serve;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
dotenvy::dotenv().ok(); // Load environment variables from a .env file
|
||||
tracing_subscriber::fmt::init(); // Initialize the logging system
|
||||
|
||||
// Print a cool startup message with ASCII art and emojis
|
||||
println!("{}", r#"
|
||||
##### ## ##
|
||||
## ## ##
|
||||
## ## ## ## ##### ####### ##### #### ####
|
||||
##### ## ## ## ## ## ## ## ## ##
|
||||
#### ## ## ## ## ## ## ## ## ##
|
||||
## ## ## ### ## ## ## ### ## ## ##
|
||||
## ## ### ## ##### ### ### ## ##### ######
|
||||
##
|
||||
|
||||
Rustapi - An example API built with Rust, Axum, SQLx, and PostgreSQL
|
||||
GitHub: https://github.com/Riktastic/rustapi
|
||||
"#);
|
||||
|
||||
println!("🚀 Starting Rustapi...");
|
||||
|
||||
// Retrieve server IP and port from the environment, default to 0.0.0.0:3000
|
||||
let ip: IpAddr = config::get_env_with_default("SERVER_IP", "0.0.0.0")
|
||||
.parse()
|
||||
.expect("❌ Invalid IP address format. Please provide a valid IPv4 address. For example 0.0.0.0 or 127.0.0.1.");
|
||||
let port: u16 = config::get_env_u16("SERVER_PORT", 3000);
|
||||
let socket_addr = SocketAddr::new(ip, port);
|
||||
|
||||
// Create the Axum app instance using the server configuration
|
||||
let app = server::create_server().await;
|
||||
|
||||
// Check if HTTPS is enabled in the environment configuration
|
||||
if config::get_env_bool("SERVER_HTTPS_ENABLED", false) {
|
||||
// If HTTPS is enabled, start the server with secure HTTPS.
|
||||
|
||||
// Bind TCP listener for incoming connections
|
||||
let tcp_listener = TcpListener::bind(socket_addr)
|
||||
.await
|
||||
.expect("❌ Failed to bind to socket. Port might allready be in use."); // Explicit error handling
|
||||
|
||||
// Load the TLS configuration for secure HTTPS connections
|
||||
let tls_config = tls::load_tls_config();
|
||||
let acceptor = TlsAcceptor::from(Arc::new(tls_config)); // Create a TLS acceptor
|
||||
let listener = TlsListener {
|
||||
inner: Arc::new(tcp_listener), // Wrap TCP listener in TlsListener
|
||||
acceptor: acceptor,
|
||||
};
|
||||
|
||||
println!("🔒 Server started with HTTPS at: https://{}:{}", ip, port);
|
||||
|
||||
// Serve the app using the TLS listener (HTTPS)
|
||||
serve(listener, app.into_make_service())
|
||||
.await
|
||||
.expect("❌ Server failed to start with HTTPS. Did you provide valid certificate and key files?");
|
||||
|
||||
} else {
|
||||
// If HTTPS is not enabled, start the server with non-secure HTTP.
|
||||
|
||||
// Bind TCP listener for non-secure HTTP connections
|
||||
let listener = TcpListener::bind(socket_addr)
|
||||
.await
|
||||
.expect("❌ Failed to bind to socket. Port might allready be in use."); // Explicit error handling
|
||||
|
||||
println!("🔓 Server started with HTTP at: http://{}:{}", ip, port);
|
||||
|
||||
// Serve the app using the non-secure TCP listener (HTTP)
|
||||
serve(listener, app.into_make_service())
|
||||
.await
|
||||
.expect("❌ Server failed to start without HTTPS.");
|
||||
}
|
||||
}
|
||||
211
src/middlewares/auth.rs
Normal file
211
src/middlewares/auth.rs
Normal file
@@ -0,0 +1,211 @@
|
||||
// Standard library imports for working with HTTP, environment variables, and other necessary utilities
|
||||
use axum::{
|
||||
body::Body,
|
||||
response::IntoResponse,
|
||||
extract::{Request, Json}, // Extractor for request and JSON body
|
||||
http::{self, Response, StatusCode}, // HTTP response and status codes
|
||||
middleware::Next, // For adding middleware layers to the request handling pipeline
|
||||
};
|
||||
|
||||
// Importing `State` for sharing application state (such as a database connection) across request handlers
|
||||
use axum::extract::State;
|
||||
|
||||
// Importing necessary libraries for password hashing, JWT handling, and date/time management
|
||||
use std::env; // For accessing environment variables
|
||||
use argon2::{
|
||||
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, // For password hashing and verification
|
||||
Argon2,
|
||||
};
|
||||
|
||||
use chrono::{Duration, Utc}; // For working with time (JWT expiration, etc.)
|
||||
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, TokenData, Validation}; // For encoding and decoding JWT tokens
|
||||
use serde::{Deserialize, Serialize}; // For serializing and deserializing JSON data
|
||||
use serde_json::json; // For constructing JSON data
|
||||
use sqlx::PgPool; // For interacting with PostgreSQL databases asynchronously
|
||||
|
||||
// Importing custom database query functions
|
||||
use crate::database::get_users::get_user_by_email;
|
||||
|
||||
// Define the structure for JWT claims to be included in the token payload
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Claims {
|
||||
pub exp: usize, // Expiration timestamp (in seconds)
|
||||
pub iat: usize, // Issued-at timestamp (in seconds)
|
||||
pub email: String, // User's email
|
||||
}
|
||||
|
||||
// Custom error type for handling authentication errors
|
||||
pub struct AuthError {
|
||||
message: String,
|
||||
status_code: StatusCode, // HTTP status code to be returned with the error
|
||||
}
|
||||
|
||||
// Function to verify a password against a stored hash using the Argon2 algorithm
|
||||
pub fn verify_password(password: &str, hash: &str) -> Result<bool, argon2::password_hash::Error> {
|
||||
let parsed_hash = PasswordHash::new(hash)?; // Parse the hash
|
||||
// Verify the password using Argon2
|
||||
Ok(Argon2::default().verify_password(password.as_bytes(), &parsed_hash).is_ok())
|
||||
}
|
||||
|
||||
// Function to hash a password using Argon2 and a salt retrieved from the environment variables
|
||||
pub fn hash_password(password: &str) -> Result<String, argon2::password_hash::Error> {
|
||||
// Get the salt from environment variables (must be set)
|
||||
let salt = env::var("AUTHENTICATION_ARGON2_SALT").expect("AUTHENTICATION_ARGON2_SALT must be set");
|
||||
let salt = SaltString::from_b64(&salt).unwrap(); // Convert base64 string to SaltString
|
||||
let argon2 = Argon2::default(); // Create an Argon2 instance
|
||||
// Hash the password with the salt
|
||||
let password_hash = argon2.hash_password(password.as_bytes(), &salt)?.to_string();
|
||||
Ok(password_hash)
|
||||
}
|
||||
|
||||
// Implement the IntoResponse trait for AuthError to allow it to be returned as a response from the handler
|
||||
impl IntoResponse for AuthError {
|
||||
fn into_response(self) -> Response<Body> {
|
||||
let body = Json(json!( { "error": self.message } )); // Create a JSON response body with the error message
|
||||
|
||||
// Return a response with the appropriate status code and error message
|
||||
(self.status_code, body).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
// Function to encode a JWT token for the given email address
|
||||
pub fn encode_jwt(email: String) -> Result<String, StatusCode> {
|
||||
let jwt_token: String = "randomstring".to_string(); // Secret key for JWT (should be more secure in production)
|
||||
|
||||
let now = Utc::now(); // Get current time
|
||||
let expire = Duration::hours(24); // Set token expiration to 24 hours
|
||||
let exp: usize = (now + expire).timestamp() as usize; // Expiration timestamp
|
||||
let iat: usize = now.timestamp() as usize; // Issued-at timestamp
|
||||
|
||||
let claim = Claims { iat, exp, email }; // Create JWT claims with timestamps and user email
|
||||
let secret = jwt_token.clone(); // Secret key to sign the token
|
||||
|
||||
// Encode the claims into a JWT token
|
||||
encode(
|
||||
&Header::default(),
|
||||
&claim,
|
||||
&EncodingKey::from_secret(secret.as_ref()),
|
||||
)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) // Return error if encoding fails
|
||||
}
|
||||
|
||||
// Function to decode a JWT token and extract the claims
|
||||
pub fn decode_jwt(jwt: String) -> Result<TokenData<Claims>, StatusCode> {
|
||||
let secret = "randomstring".to_string(); // Secret key to verify the JWT (should be more secure in production)
|
||||
|
||||
// Decode the JWT token using the secret key and extract the claims
|
||||
decode(
|
||||
&jwt,
|
||||
&DecodingKey::from_secret(secret.as_ref()),
|
||||
&Validation::default(),
|
||||
)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) // Return error if decoding fails
|
||||
}
|
||||
|
||||
// Middleware for role-based access control (RBAC)
|
||||
// Ensures that only users with specific roles are authorized to access certain resources
|
||||
pub async fn authorize(
|
||||
mut req: Request<Body>,
|
||||
next: Next,
|
||||
allowed_roles: Vec<i32>, // Accept a vector of allowed roles
|
||||
) -> Result<Response<Body>, AuthError> {
|
||||
// Retrieve the database pool from request extensions (shared application state)
|
||||
let pool = req.extensions().get::<PgPool>().expect("Database pool not found in request extensions");
|
||||
|
||||
// Retrieve the Authorization header from the request
|
||||
let auth_header = req.headers().get(http::header::AUTHORIZATION);
|
||||
|
||||
// Ensure the header exists and is correctly formatted
|
||||
let auth_header = match auth_header {
|
||||
Some(header) => header.to_str().map_err(|_| AuthError {
|
||||
message: "Invalid header format".to_string(),
|
||||
status_code: StatusCode::FORBIDDEN,
|
||||
})?,
|
||||
None => return Err(AuthError {
|
||||
message: "Authorization header missing.".to_string(),
|
||||
status_code: StatusCode::FORBIDDEN,
|
||||
}),
|
||||
};
|
||||
|
||||
// Extract the token from the Authorization header (Bearer token format)
|
||||
let mut header = auth_header.split_whitespace();
|
||||
let (_, token) = (header.next(), header.next());
|
||||
|
||||
// Decode the JWT token
|
||||
let token_data = match decode_jwt(token.unwrap().to_string()) {
|
||||
Ok(data) => data,
|
||||
Err(_) => return Err(AuthError {
|
||||
message: "Unable to decode token.".to_string(),
|
||||
status_code: StatusCode::UNAUTHORIZED,
|
||||
}),
|
||||
};
|
||||
|
||||
// Fetch the user from the database using the email from the decoded token
|
||||
let current_user = match get_user_by_email(&pool, token_data.claims.email).await {
|
||||
Ok(user) => user,
|
||||
Err(_) => return Err(AuthError {
|
||||
message: "Unauthorized user.".to_string(),
|
||||
status_code: StatusCode::UNAUTHORIZED,
|
||||
}),
|
||||
};
|
||||
|
||||
// Check if the user's role is in the list of allowed roles
|
||||
if !allowed_roles.contains(¤t_user.role_id) {
|
||||
return Err(AuthError {
|
||||
message: "Forbidden: insufficient role.".to_string(),
|
||||
status_code: StatusCode::FORBIDDEN,
|
||||
});
|
||||
}
|
||||
|
||||
// Insert the current user into the request extensions for use in subsequent handlers
|
||||
req.extensions_mut().insert(current_user);
|
||||
|
||||
// Proceed to the next middleware or handler
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
|
||||
// Structure to hold the data from the sign-in request
|
||||
#[derive(Deserialize)]
|
||||
pub struct SignInData {
|
||||
pub email: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
// Handler for user sign-in (authentication)
|
||||
pub async fn sign_in(
|
||||
State(pool): State<PgPool>, // Database connection pool injected as state
|
||||
Json(user_data): Json<SignInData>, // Deserialize the JSON body into SignInData
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
|
||||
|
||||
// 1. Retrieve user from the database using the provided email
|
||||
let user = match get_user_by_email(&pool, user_data.email).await {
|
||||
Ok(user) => user,
|
||||
Err(_) => return Err((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({ "error": "Incorrect credentials." }))
|
||||
)),
|
||||
};
|
||||
|
||||
// 2. Verify the password using the stored hash
|
||||
if !verify_password(&user_data.password, &user.password_hash)
|
||||
.map_err(|_| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": "Internal server error." }))
|
||||
))?
|
||||
{
|
||||
return Err((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(json!({ "error": "Incorrect credentials." }))
|
||||
));
|
||||
}
|
||||
|
||||
// 3. Generate a JWT token for the authenticated user
|
||||
let token = encode_jwt(user.email)
|
||||
.map_err(|_| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": "Internal server error." }))
|
||||
))?;
|
||||
|
||||
// 4. Return the JWT token to the client
|
||||
Ok(Json(json!({ "token": token })))
|
||||
}
|
||||
2
src/middlewares/mod.rs
Normal file
2
src/middlewares/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
// Module declarations
|
||||
pub mod auth;
|
||||
5
src/models/mod.rs
Normal file
5
src/models/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
/// Module for to-do related models.
|
||||
pub mod todo;
|
||||
/// Module for user related models.
|
||||
pub mod user;
|
||||
|
||||
18
src/models/role.rs
Normal file
18
src/models/role.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::FromRow;
|
||||
|
||||
/// Represents a user role in the system.
|
||||
#[derive(Deserialize, Debug, Serialize, FromRow, Clone)]
|
||||
#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL
|
||||
pub struct Role {
|
||||
/// ID of the role.
|
||||
pub id: i32,
|
||||
/// Level of the role.
|
||||
pub level: i32,
|
||||
/// System name of the role.
|
||||
pub role: String,
|
||||
/// The name of the role.
|
||||
pub name: String,
|
||||
/// Description of the role
|
||||
pub Description: String,
|
||||
}
|
||||
16
src/models/todo.rs
Normal file
16
src/models/todo.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::FromRow;
|
||||
|
||||
/// Represents a to-do item.
|
||||
#[derive(Deserialize, Debug, Serialize, FromRow)]
|
||||
#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL
|
||||
pub struct Todo {
|
||||
/// The unique identifier for the to-do item.
|
||||
pub id: i32,
|
||||
/// The task description.
|
||||
pub task: String,
|
||||
/// An optional detailed description of the task.
|
||||
pub description: Option<String>,
|
||||
/// The unique identifier of the user who created the to-do item.
|
||||
pub user_id: i32,
|
||||
}
|
||||
20
src/models/user.rs
Normal file
20
src/models/user.rs
Normal file
@@ -0,0 +1,20 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::FromRow;
|
||||
|
||||
/// Represents a user in the system.
|
||||
#[derive(Deserialize, Debug, Serialize, FromRow, Clone)]
|
||||
#[sqlx(rename_all = "snake_case")] // Ensures that field names are mapped to snake_case in SQL
|
||||
pub struct User {
|
||||
/// The unique identifier for the user.
|
||||
pub id: i32,
|
||||
/// The username of the user.
|
||||
pub username: String,
|
||||
/// The email of the user.
|
||||
pub email: String,
|
||||
/// The hashed password for the user.
|
||||
pub password_hash: String,
|
||||
/// The TOTP secret for the user.
|
||||
pub totp_secret: Option<String>,
|
||||
/// Current role of the user..
|
||||
pub role_id: i32,
|
||||
}
|
||||
200
src/routes/get_health.rs
Normal file
200
src/routes/get_health.rs
Normal file
@@ -0,0 +1,200 @@
|
||||
use axum::{response::IntoResponse, Json, extract::State};
|
||||
use serde_json::json;
|
||||
use sqlx::PgPool;
|
||||
use sysinfo::{System, RefreshKind, Disks};
|
||||
use tokio::{task, join};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
pub async fn get_health(State(database_connection): State<PgPool>) -> impl IntoResponse {
|
||||
// Use Arc and Mutex to allow sharing System between tasks
|
||||
let system = Arc::new(Mutex::new(System::new_with_specifics(RefreshKind::everything())));
|
||||
|
||||
// Run checks in parallel
|
||||
let (cpu_result, mem_result, disk_result, process_result, db_result, net_result) = join!(
|
||||
task::spawn_blocking({
|
||||
let system = Arc::clone(&system);
|
||||
move || {
|
||||
let mut system = system.lock().unwrap(); // Lock the mutex and get a mutable reference
|
||||
check_cpu_usage(&mut system) // Pass the mutable reference
|
||||
}
|
||||
}),
|
||||
task::spawn_blocking({
|
||||
let system = Arc::clone(&system);
|
||||
move || {
|
||||
let mut system = system.lock().unwrap(); // Lock the mutex and get a mutable reference
|
||||
check_memory(&mut system) // Pass the mutable reference
|
||||
}
|
||||
}),
|
||||
task::spawn_blocking({
|
||||
move || {
|
||||
check_disk_usage() // Does not need a system reference.
|
||||
}
|
||||
}),
|
||||
task::spawn_blocking({
|
||||
let system = Arc::clone(&system);
|
||||
move || {
|
||||
let mut system = system.lock().unwrap(); // Lock the mutex and get a mutable reference
|
||||
check_processes(&mut system, &["postgres", "Code"]) // Pass the mutable reference
|
||||
}
|
||||
}),
|
||||
check_database_connection(&database_connection), // Async function
|
||||
task::spawn_blocking(check_network_connection) // Blocking, okay in spawn_blocking
|
||||
);
|
||||
|
||||
let mut status = "healthy";
|
||||
let mut details = json!({});
|
||||
|
||||
// Process CPU result
|
||||
if let Ok(Ok(cpu_details)) = cpu_result {
|
||||
details["cpu_usage"] = json!(cpu_details);
|
||||
if cpu_details["status"] == "low" {
|
||||
status = "degraded";
|
||||
}
|
||||
} else {
|
||||
details["cpu_usage"] = json!({ "status": "error", "error": "Failed to retrieve CPU usage" });
|
||||
status = "degraded";
|
||||
}
|
||||
|
||||
// Process Memory result
|
||||
if let Ok(Ok(mem_details)) = mem_result {
|
||||
details["memory"] = json!(mem_details);
|
||||
if mem_details["status"] == "low" {
|
||||
status = "degraded";
|
||||
}
|
||||
} else {
|
||||
details["memory"] = json!({ "status": "error", "error": "Failed to retrieve memory information" });
|
||||
status = "degraded";
|
||||
}
|
||||
|
||||
// Process Disk result
|
||||
if let Ok(Ok(disk_details)) = disk_result {
|
||||
details["disk_usage"] = json!(disk_details);
|
||||
if disk_details["status"] == "critical" {
|
||||
status = "degraded";
|
||||
}
|
||||
} else {
|
||||
details["disk_usage"] = json!({ "status": "error", "error": "Failed to retrieve disk usage" });
|
||||
status = "degraded";
|
||||
}
|
||||
|
||||
// Process Process result
|
||||
if let Ok(Ok(process_details)) = process_result {
|
||||
details["important_processes"] = json!(process_details);
|
||||
if process_details.iter().any(|p| p["status"] == "not running") {
|
||||
status = "degraded";
|
||||
}
|
||||
} else {
|
||||
details["important_processes"] = json!({ "status": "error", "error": "Failed to retrieve process information" });
|
||||
status = "degraded";
|
||||
}
|
||||
|
||||
// Process Database result
|
||||
if let Ok(db_status) = db_result {
|
||||
details["database"] = json!({ "status": if db_status { "ok" } else { "degraded" } });
|
||||
if !db_status {
|
||||
status = "degraded";
|
||||
}
|
||||
} else {
|
||||
details["database"] = json!({ "status": "error", "error": "Failed to retrieve database status" });
|
||||
status = "degraded";
|
||||
}
|
||||
|
||||
// Process Network result
|
||||
if let Ok(Ok(net_status)) = net_result {
|
||||
details["network"] = json!({ "status": if net_status { "ok" } else { "degraded" } });
|
||||
if !net_status {
|
||||
status = "degraded";
|
||||
}
|
||||
} else {
|
||||
details["network"] = json!({ "status": "error", "error": "Failed to retrieve network status" });
|
||||
status = "degraded";
|
||||
}
|
||||
|
||||
Json(json!({
|
||||
"status": status,
|
||||
"details": details,
|
||||
}))
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
fn check_cpu_usage(system: &mut System) -> Result<serde_json::Value, ()> {
|
||||
system.refresh_cpu_usage();
|
||||
let usage = system.global_cpu_usage();
|
||||
let available = 100.0 - usage;
|
||||
Ok(json!( {
|
||||
"usage_percentage": format!("{:.2}", usage),
|
||||
"available_percentage": format!("{:.2}", available),
|
||||
"status": if available < 10.0 { "low" } else { "normal" },
|
||||
}))
|
||||
}
|
||||
|
||||
fn check_memory(system: &mut System) -> Result<serde_json::Value, ()> {
|
||||
system.refresh_memory();
|
||||
let available = system.available_memory() / 1024 / 1024; // Convert to MB
|
||||
Ok(json!( {
|
||||
"available_mb": available,
|
||||
"status": if available < 512 { "low" } else { "normal" },
|
||||
}))
|
||||
}
|
||||
|
||||
fn check_disk_usage() -> Result<serde_json::Value, ()> {
|
||||
// Create a new Disks object and refresh the disk information
|
||||
let mut disks = Disks::new();
|
||||
disks.refresh(false); // Refresh disk information without performing a full refresh
|
||||
|
||||
// Iterate through the list of disks and check the usage for each one
|
||||
let usage: Vec<_> = disks.list().iter().map(|disk| {
|
||||
let total = disk.total_space() as f64;
|
||||
let available = disk.available_space() as f64;
|
||||
let used_percentage = ((total - available) / total) * 100.0;
|
||||
used_percentage
|
||||
}).collect();
|
||||
|
||||
// Get the maximum usage percentage
|
||||
let max_usage = usage.into_iter()
|
||||
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.unwrap_or(0.0);
|
||||
|
||||
// Return the result as a JSON object
|
||||
Ok(json!( {
|
||||
"used_percentage": format!("{:.2}", max_usage),
|
||||
"status": if max_usage > 90.0 { "critical" } else { "ok" },
|
||||
}))
|
||||
}
|
||||
|
||||
fn check_processes(system: &mut System, processes: &[&str]) -> Result<Vec<serde_json::Value>, ()> {
|
||||
system.refresh_processes(sysinfo::ProcessesToUpdate::All, true);
|
||||
|
||||
let process_statuses: Vec<_> = processes.iter().map(|&name| {
|
||||
// Adjust process names based on the platform and check if they are running
|
||||
let adjusted_name = if cfg!(target_os = "windows") {
|
||||
match name {
|
||||
"postgres" => "postgres.exe", // Postgres on Windows
|
||||
"Code" => "Code.exe", // Visual Studio Code on Windows
|
||||
_ => name, // For other platforms, use the name as is
|
||||
}
|
||||
} else {
|
||||
name // For non-Windows platforms, use the name as is
|
||||
};
|
||||
|
||||
// Check if the translated (adjusted) process is running
|
||||
let is_running = system.processes().iter().any(|(_, proc)| proc.name() == adjusted_name);
|
||||
|
||||
// Return a JSON object for each process with its status
|
||||
json!({
|
||||
"name": name,
|
||||
"status": if is_running { "running" } else { "not running" }
|
||||
})
|
||||
}).collect();
|
||||
|
||||
Ok(process_statuses)
|
||||
}
|
||||
|
||||
async fn check_database_connection(pool: &PgPool) -> Result<bool, sqlx::Error> {
|
||||
sqlx::query("SELECT 1").fetch_one(pool).await.map(|_| true).or_else(|_| Ok(false))
|
||||
}
|
||||
|
||||
fn check_network_connection() -> Result<bool, ()> {
|
||||
Ok(std::net::TcpStream::connect("8.8.8.8:53").is_ok())
|
||||
}
|
||||
43
src/routes/get_todos.rs
Normal file
43
src/routes/get_todos.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
use axum::extract::{State, Path};
|
||||
use axum::Json;
|
||||
use axum::response::IntoResponse;
|
||||
use sqlx::postgres::PgPool;
|
||||
use crate::models::todo::*;
|
||||
|
||||
// Get all todos
|
||||
pub async fn get_all_todos(State(pool): State<PgPool>,) -> impl IntoResponse {
|
||||
let todos = sqlx::query_as!(Todo, "SELECT * FROM todos") // Your table name
|
||||
.fetch_all(&pool) // Borrow the connection pool
|
||||
.await;
|
||||
|
||||
match todos {
|
||||
Ok(todos) => Ok(Json(todos)), // Return all todos as JSON
|
||||
Err(err) => Err((
|
||||
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Error fetching todos: {}", err),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Get a single todo by id
|
||||
pub async fn get_todos_by_id(
|
||||
State(pool): State<PgPool>,
|
||||
Path(id): Path<i32>, // Use Path extractor here
|
||||
) -> impl IntoResponse {
|
||||
let todo = sqlx::query_as!(Todo, "SELECT * FROM todos WHERE id = $1", id)
|
||||
.fetch_optional(&pool) // Borrow the connection pool
|
||||
.await;
|
||||
|
||||
match todo {
|
||||
Ok(Some(todo)) => Ok(Json(todo)), // Return the todo as JSON if found
|
||||
Ok(None) => Err((
|
||||
axum::http::StatusCode::NOT_FOUND,
|
||||
format!("Todo with id {} not found", id),
|
||||
)),
|
||||
Err(err) => Err((
|
||||
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Error fetching todo: {}", err),
|
||||
)),
|
||||
}
|
||||
}
|
||||
87
src/routes/get_users.rs
Normal file
87
src/routes/get_users.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
use axum::extract::{State, Path};
|
||||
use axum::Json;
|
||||
use axum::response::IntoResponse;
|
||||
use sqlx::postgres::PgPool;
|
||||
use crate::models::user::*; // Import the User struct
|
||||
|
||||
// Get all users
|
||||
pub async fn get_all_users(State(pool): State<PgPool>,) -> impl IntoResponse {
|
||||
let users = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users") // Your table name
|
||||
.fetch_all(&pool) // Borrow the connection pool
|
||||
.await;
|
||||
|
||||
match users {
|
||||
Ok(users) => Ok(Json(users)), // Return all users as JSON
|
||||
Err(err) => Err((
|
||||
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Error fetching users: {}", err),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
// Get a single user by id
|
||||
pub async fn get_users_by_id(
|
||||
State(pool): State<PgPool>,
|
||||
Path(id): Path<i32>, // Use Path extractor here
|
||||
) -> impl IntoResponse {
|
||||
|
||||
let user = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users WHERE id = $1", id)
|
||||
.fetch_optional(&pool) // Borrow the connection pool
|
||||
.await;
|
||||
|
||||
match user {
|
||||
Ok(Some(user)) => Ok(Json(user)), // Return the user as JSON if found
|
||||
Ok(None) => Err((
|
||||
axum::http::StatusCode::NOT_FOUND,
|
||||
format!("User with id {} not found", id),
|
||||
)),
|
||||
Err(err) => Err((
|
||||
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Error fetching user: {}", err),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
// Get a single user by username
|
||||
pub async fn get_user_by_username(
|
||||
State(pool): State<PgPool>,
|
||||
Path(username): Path<String>, // Use Path extractor here for username
|
||||
) -> impl IntoResponse {
|
||||
let user = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users WHERE username = $1", username)
|
||||
.fetch_optional(&pool) // Borrow the connection pool
|
||||
.await;
|
||||
|
||||
match user {
|
||||
Ok(Some(user)) => Ok(Json(user)), // Return the user as JSON if found
|
||||
Ok(None) => Err((
|
||||
axum::http::StatusCode::NOT_FOUND,
|
||||
format!("User with username {} not found", username),
|
||||
)),
|
||||
Err(err) => Err((
|
||||
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Error fetching user: {}", err),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
// Get a single user by email
|
||||
pub async fn get_user_by_email(
|
||||
State(pool): State<PgPool>,
|
||||
Path(email): Path<String>, // Use Path extractor here for email
|
||||
) -> impl IntoResponse {
|
||||
let user = sqlx::query_as!(User, "SELECT id, username, email, password_hash, totp_secret, role_id FROM users WHERE email = $1", email)
|
||||
.fetch_optional(&pool) // Borrow the connection pool
|
||||
.await;
|
||||
|
||||
match user {
|
||||
Ok(Some(user)) => Ok(Json(user)), // Return the user as JSON if found
|
||||
Ok(None) => Err((
|
||||
axum::http::StatusCode::NOT_FOUND,
|
||||
format!("User with email {} not found", email),
|
||||
)),
|
||||
Err(err) => Err((
|
||||
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Error fetching user: {}", err),
|
||||
)),
|
||||
}
|
||||
}
|
||||
59
src/routes/mod.rs
Normal file
59
src/routes/mod.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
// Module declarations for different route handlers
|
||||
pub mod get_todos;
|
||||
pub mod get_users;
|
||||
pub mod post_todos;
|
||||
pub mod post_users;
|
||||
pub mod get_health;
|
||||
pub mod protected;
|
||||
|
||||
// Re-exporting modules to make their contents available at this level
|
||||
pub use get_todos::*;
|
||||
pub use get_users::*;
|
||||
pub use post_todos::*;
|
||||
pub use post_users::*;
|
||||
pub use get_health::*;
|
||||
pub use protected::*;
|
||||
|
||||
use axum::{
|
||||
Router,
|
||||
routing::{get, post},
|
||||
};
|
||||
|
||||
use sqlx::PgPool;
|
||||
|
||||
use crate::middlewares::auth::{sign_in, authorize};
|
||||
|
||||
/// Function to create and configure all routes
|
||||
pub fn create_routes(database_connection: PgPool) -> Router {
|
||||
// Authentication routes
|
||||
let auth_routes = Router::new()
|
||||
.route("/signin", post(sign_in))
|
||||
.route("/protected", get(protected).route_layer(axum::middleware::from_fn(|req, next| {
|
||||
let allowed_roles = vec![1, 2];
|
||||
authorize(req, next, allowed_roles)
|
||||
})));
|
||||
|
||||
// User-related routes
|
||||
let user_routes = Router::new()
|
||||
.route("/all", get(get_all_users))
|
||||
.route("/{id}", get(get_users_by_id))
|
||||
.route("/", post(post_user));
|
||||
|
||||
// Todo-related routes
|
||||
let todo_routes = Router::new()
|
||||
.route("/all", get(get_all_todos))
|
||||
.route("/", post(post_todo).layer(axum::middleware::from_fn(|req, next| {
|
||||
let allowed_roles = vec![1, 2];
|
||||
authorize(req, next, allowed_roles)
|
||||
})))
|
||||
.route("/{id}", get(get_todos_by_id));
|
||||
|
||||
// Combine all routes and add middleware
|
||||
Router::new()
|
||||
.merge(auth_routes) // Add authentication routes
|
||||
.nest("/users", user_routes) // Add user routes under /users
|
||||
.nest("/todos", todo_routes) // Add todo routes under /todos
|
||||
.route("/health", get(get_health)) // Add health check route
|
||||
.layer(axum::Extension(database_connection.clone())) // Add database connection to all routes
|
||||
.with_state(database_connection) // Add database connection as state
|
||||
}
|
||||
53
src/routes/post_todos.rs
Normal file
53
src/routes/post_todos.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use axum::{extract::{State, Extension}, Json};
|
||||
use axum::response::IntoResponse;
|
||||
use sqlx::postgres::PgPool;
|
||||
use crate::models::todo::*;
|
||||
use crate::models::user::*;
|
||||
use serde::Deserialize;
|
||||
use axum::http::StatusCode;
|
||||
use serde_json::json;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct TodoBody {
|
||||
pub task: String,
|
||||
pub description: Option<String>,
|
||||
pub user_id: i32,
|
||||
}
|
||||
|
||||
// Add a new todo
|
||||
pub async fn post_todo(
|
||||
State(pool): State<PgPool>,
|
||||
Extension(user): Extension<User>, // Extract current user from the request extensions
|
||||
Json(todo): Json<TodoBody>
|
||||
) -> impl IntoResponse {
|
||||
// Ensure the user_id from the request matches the current user's id
|
||||
if todo.user_id != user.id {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(json!({ "error": "User is not authorized to create a todo for another user" }))
|
||||
));
|
||||
}
|
||||
|
||||
// Insert the todo into the database
|
||||
let row = sqlx::query!(
|
||||
"INSERT INTO todos (task, description, user_id) VALUES ($1, $2, $3) RETURNING id, task, description, user_id",
|
||||
todo.task,
|
||||
todo.description,
|
||||
todo.user_id
|
||||
)
|
||||
.fetch_one(&pool)
|
||||
.await;
|
||||
|
||||
match row {
|
||||
Ok(row) => Ok(Json(Todo {
|
||||
id: row.id,
|
||||
task: row.task,
|
||||
description: row.description,
|
||||
user_id: row.user_id,
|
||||
})),
|
||||
Err(err) => Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": format!("Error: {}", err) }))
|
||||
)),
|
||||
}
|
||||
}
|
||||
44
src/routes/post_users.rs
Normal file
44
src/routes/post_users.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
use axum::extract::State;
|
||||
use axum::Json;
|
||||
use axum::response::IntoResponse;
|
||||
use sqlx::postgres::PgPool;
|
||||
use crate::models::user::*;
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct UserBody {
|
||||
pub username: String,
|
||||
pub email: String,
|
||||
pub password_hash: String,
|
||||
pub totp_secret: String,
|
||||
pub role_id: i32,
|
||||
}
|
||||
|
||||
// Add a new user
|
||||
pub async fn post_user(State(pool): State<PgPool>, Json(user): Json<UserBody>, ) -> impl IntoResponse {
|
||||
let row = sqlx::query!(
|
||||
"INSERT INTO users (username, email, password_hash, totp_secret, role_id) VALUES ($1, $2, $3, $4, $5) RETURNING id, username, email, password_hash, totp_secret, role_id",
|
||||
user.username,
|
||||
user.email,
|
||||
user.password_hash,
|
||||
user.totp_secret,
|
||||
user.role_id
|
||||
)
|
||||
.fetch_one(&pool) // Use `&pool` to borrow the connection pool
|
||||
.await;
|
||||
|
||||
match row {
|
||||
Ok(row) => Ok(Json(User {
|
||||
id: row.id,
|
||||
username: row.username,
|
||||
email: row.email,
|
||||
password_hash: row.password_hash,
|
||||
totp_secret: row.totp_secret,
|
||||
role_id: row.role_id,
|
||||
})),
|
||||
Err(err) => Err((
|
||||
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Error: {}", err),
|
||||
)),
|
||||
}
|
||||
}
|
||||
18
src/routes/protected.rs
Normal file
18
src/routes/protected.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
use axum::{Extension, Json, response::IntoResponse};
|
||||
use serde::{Serialize, Deserialize};
|
||||
use crate::models::user::User;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct UserResponse {
|
||||
id: i32,
|
||||
username: String,
|
||||
email: String
|
||||
}
|
||||
|
||||
pub async fn protected(Extension(user): Extension<User>) -> impl IntoResponse {
|
||||
Json(UserResponse {
|
||||
id: user.id,
|
||||
username: user.username,
|
||||
email: user.email
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user