From dd3b2b618b409445136cd7ca489f58ae5e759e1d Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Mon, 22 Sep 2025 11:18:36 +0300 Subject: [PATCH] test: add tests --- Cargo.lock | 2 + Cargo.toml | 2 + src/comments.rs | 112 +++++++++++++++++++++++------ src/download.rs | 159 +++++++++++++++++++++++++++++++++++------- src/error.rs | 14 ++++ src/lib.rs | 1 + src/utils.rs | 93 +++++++++++++----------- src/validate/mod.rs | 23 ++++++ src/validate/utils.rs | 10 +++ 9 files changed, 331 insertions(+), 85 deletions(-) create mode 100644 src/validate/mod.rs create mode 100644 src/validate/utils.rs diff --git a/Cargo.lock b/Cargo.lock index d59ea43..c896e75 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1746,6 +1746,7 @@ dependencies = [ "rand", "regex", "serde", + "shlex", "teloxide", "tempfile", "thiserror 2.0.16", @@ -1755,6 +1756,7 @@ dependencies = [ "tracing-bunyan-formatter", "tracing-log 0.2.0", "tracing-subscriber", + "url", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 2b2b990..9f8ce0c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ infer = "0.19" rand = "0.9" regex = "1.11" serde = { version = "1.0", features = ["derive"] } +shlex = "1.3.0" teloxide = { version = "0.17", features = ["macros"] } tempfile = "3" thiserror = "2.0" @@ -28,6 +29,7 @@ tracing-appender = "0.2" tracing-bunyan-formatter = { version = "0.3", default-features = false } tracing-log = "0.2.0" tracing-subscriber = { version = "0.3", features = ["registry", "env-filter"] } +url = "2.5" [lints.clippy] pedantic = "warn" diff --git a/src/comments.rs b/src/comments.rs index 5774aea..77aaef2 100644 --- a/src/comments.rs +++ b/src/comments.rs @@ -1,11 +1,20 @@ use crate::error::{Error, Result}; use rand::{rng, seq::IndexedRandom}; use std::{ + fmt::Display, path::Path, sync::{Arc, OnceLock}, }; use tokio::fs::read_to_string; -static DISCLAIMER: &str = "(Roleplay — fictional messages for entertainment.)"; + +const DISCLAIMER: &str = "(Roleplay — fictional messages for entertainment.)"; +pub const TELEGRAM_CAPTION_LIMIT: usize = 4096; +const FALLBACK_COMMENTS: &[&str] = &[ + "Oh come on, that's brilliant — and slightly chaotic, like always.", + "That is a proper bit of craftsmanship — then someone presses the red button.", + "Nice shot — looks good on the trailer, not so good on the gearbox.", + "Here you go. Judge for yourself.", +]; #[derive(Debug)] pub struct Comments { @@ -17,12 +26,10 @@ impl Comments { /// Create a small dummy/default Comments instance (useful for tests or fallback). #[must_use] pub fn dummy() -> Self { - let lines = vec![ - "Oh come on, that's brilliant — and slightly chaotic, like always.".into(), - "That is a proper bit of craftsmanship — then someone presses the red button.".into(), - "Nice shot — looks good on the trailer, not so good on the gearbox.".into(), - "Here you go. Judge for yourself.".into(), - ]; + let lines = FALLBACK_COMMENTS + .iter() + .map(ToString::to_string) + .collect::>(); Self { disclaimer: DISCLAIMER.into(), lines: lines.into(), @@ -33,14 +40,12 @@ impl Comments { /// /// # Errors /// - /// - Returns `Error::Io` if reading the file fails (propagated from - /// `tokio::fs::read_to_string`). - /// - Returns `Error::Other` if the file contains no usable lines after - /// filtering (empty or all-comment file). + /// - Returns `Error::Io` if reading the file fails. + /// - Returns `Error::Other` if the file contains no usable lines. pub async fn load_from_file>(path: P) -> Result { - let s = read_to_string(path).await?; + let content = read_to_string(path).await?; - let lines = s + let lines = content .lines() .map(str::trim) .filter(|l| !l.is_empty() && !l.starts_with('#')) @@ -57,20 +62,35 @@ impl Comments { }) } - /// Pick a random comment as &str (no allocation). Falls back to a small static - /// string if the list is unexpectedly empty. + /// Pick a random comment. Falls back to a default if the list is empty. #[must_use] pub fn pick(&self) -> &str { let mut rng = rng(); self.lines .choose(&mut rng) - .map_or("Here you go.", String::as_str) + .map_or(FALLBACK_COMMENTS[0], AsRef::as_ref) } + /// Build a caption by picking a random comment and truncating if necessary. #[must_use] - #[inline] pub fn build_caption(&self) -> String { - self.pick().to_string() + let mut caption = self.pick().to_string(); + + // Trancate if too long for Telegram + if caption.chars().count() > TELEGRAM_CAPTION_LIMIT { + let truncated = caption + .chars() + .take(TELEGRAM_CAPTION_LIMIT.saturating_sub(3)) + .collect::(); + caption = format!("{truncated}..."); + } + caption + } + + /// Get a reference to the underlying lines for debugging or testing. + #[cfg(test)] + pub fn lines(&self) -> &[String] { + &self.lines } } @@ -80,8 +100,7 @@ static GLOBAL_COMMENTS: OnceLock = OnceLock::new(); /// /// # Errors /// -/// - Returns `Error::Other` when the global is already initialized (the -/// underlying `OnceLock::set` fails). +/// - Returns `Error::Other` when the global is already initialized. pub fn init_global_comments(comments: Comments) -> Result<()> { GLOBAL_COMMENTS .set(comments) @@ -89,6 +108,59 @@ pub fn init_global_comments(comments: Comments) -> Result<()> { } /// Get global comments (if initialized). Returns Option<&'static Comments>. +#[inline] +#[must_use] pub fn global_comments() -> Option<&'static Comments> { GLOBAL_COMMENTS.get() } + +impl Display for Comments { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.build_caption()) + } +} + +impl From for String { + fn from(value: Comments) -> Self { + value.to_string() + } +} + +impl From<&Comments> for String { + fn from(value: &Comments) -> Self { + value.to_string() + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn dummy_comments() { + let comments = Comments::dummy(); + assert_eq!(comments.lines.len(), FALLBACK_COMMENTS.len()); + assert!(!comments.lines.is_empty()); + } + + #[test] + fn build_caption_truncation() { + let long_comment = "A".repeat(TELEGRAM_CAPTION_LIMIT + 10); + let comments = Comments { + disclaimer: DISCLAIMER.into(), + lines: Arc::new(vec![long_comment]), + }; + + let caption = comments.build_caption(); + assert_eq!(caption.chars().count(), TELEGRAM_CAPTION_LIMIT); + assert!(caption.ends_with("...")) + } + + #[test] + fn pick_fallbakc() { + let empty_comment = Comments { + disclaimer: DISCLAIMER.into(), + lines: Arc::new(Vec::new()), + }; + assert_eq!(empty_comment.pick(), FALLBACK_COMMENTS[0]); + } +} diff --git a/src/download.rs b/src/download.rs index d9db32d..d22dbd0 100644 --- a/src/download.rs +++ b/src/download.rs @@ -1,12 +1,25 @@ use crate::{ error::{Error, Result}, - utils::{MediaKind, detect_media_kind_async, send_media_from_path}, + utils::{ + IMAGE_EXTSTENSIONS, MediaKind, VIDEO_EXTSTENSIONS, detect_media_kind_async, + send_media_from_path, + }, }; use futures::{StreamExt, stream}; -use std::{path::PathBuf, process::Stdio}; +use std::{ + cmp::min, + env, + ffi::OsStr, + fs::{self, metadata}, + path::{Path, PathBuf}, + process::Stdio, +}; use teloxide::{Bot, types::ChatId}; use tempfile::{TempDir, tempdir}; use tokio::{fs::read_dir, process::Command}; +use tracing::{info, warn}; + +const FORBIDDEN_EXTENSIONS: &[&str] = &["json", "txt", "log"]; /// `TempDir` guard + downloaded files. Keep this value alive until you're /// done sending files so the temporary directory is not deleted. @@ -19,6 +32,8 @@ pub struct DownloadResult { /// Run a command in a freshly created temporary directory and collect /// regular files produced there. /// +/// # Arguments +/// /// `cmd` is the command name (e.g. "yt-dlp" or "instaloader"). /// `args` are the command arguments (owned Strings so callers can build dynamic args). /// @@ -29,7 +44,7 @@ pub struct DownloadResult { /// - `Error::NoMediaFound` if no files were produced. #[allow(clippy::similar_names)] async fn run_command_in_tempdir(cmd: &str, args: &[&str]) -> Result { - let tmp = tempdir().map_err(Error::from)?; + let tmp = tempdir()?; let cwd = tmp.path().to_path_buf(); let output = Command::new(cmd) @@ -39,27 +54,50 @@ async fn run_command_in_tempdir(cmd: &str, args: &[&str]) -> Result Error::instaloader_failed(stderr), + "yt-dlp" => Error::ytdlp_failed(stderr), + _ => Error::Other(format!("{cmd} failed: {stderr}")), + }; + return Err(err); } - // collect files produced in tempdir (async) + // Collect files produced in tempdir (async) let mut rd = read_dir(&cwd).await?; let mut files = Vec::new(); while let Some(entry) = rd.next_entry().await? { - if entry.file_type().await?.is_file() { - files.push(entry.path()); + let path = entry.path(); + // Filter out non-media files (logs, metadata, etc.) + if is_potential_media_file(&path) { + files.push(path); } } + info!(files = files.len(), "Collected files from tempdir"); + if files.is_empty() { + let dir_contents = fs::read_dir(&cwd) + .map(|rd| { + rd.filter_map(std::result::Result::ok) + .map(|e| e.path()) + .collect::>() + }) + .unwrap_or_default(); + warn!(dir_contents = ?dir_contents, "No media files found in tempdir"); return Err(Error::NoMediaFound); } + files.sort(); + Ok(DownloadResult { tempdir: tmp, files, @@ -90,23 +128,32 @@ pub async fn download_instaloader(shortcode: &str) -> Result { /// /// - Propagates `run_command_in_tempdir` errors. pub async fn download_ytdlp(url: &str, cookies: Option<&str>) -> Result { + let default_format = "bestvideo[ext=mp4][vcodec^=avc1]+bestaudio/best"; + let format_selector = env::var("YTDLP_FORMAT").unwrap_or_else(|_| default_format.into()); + let mut args = vec![ "--no-playlist", "--merge-output-format", "mp4", "-f", - "bestvideo[ext=mp4][vcodec^=avc1]+bestaudio/best", + &format_selector, "--restrict-filenames", "-o", "%(id)s.%(ext)s", + "--no-warnings", + "--quiet", ]; - if let Some(c) = cookies { - args.push("--cookies"); - args.push(c); + if let Some(cookie_path) = cookies { + if Path::new(cookie_path).exists() { + args.extend(["--cookies", cookie_path]); + } else { + warn!("Cookies file not found: {cookie_path}"); + } } - args.push(url); + let quoted_url = shlex::try_quote(url)?; + args.push("ed_url); run_command_in_tempdir("yt-dlp", &args).await } @@ -119,10 +166,20 @@ pub async fn download_ytdlp(url: &str, cookies: Option<&str>) -> Result Result<()> { - // detect kinds in parallel - let concurrency = 8; - let results = stream::iter(dr.files.into_iter().map(|path| async move { +pub async fn process_download_result( + bot: &Bot, + chat_id: ChatId, + mut dr: DownloadResult, +) -> Result<()> { + info!(files = dr.files.len(), "Processing download result"); + + if dr.files.is_empty() { + return Err(Error::NoMediaFound); + } + + // Detect kinds in parallel with limiter concurrency + let concurrency = min(8, dr.files.len()); + let results = stream::iter(dr.files.drain(..).map(|path| async move { let kind = detect_media_kind_async(&path).await; match kind { MediaKind::Unknown => None, @@ -133,26 +190,76 @@ pub async fn process_download_result(bot: &Bot, chat_id: ChatId, dr: DownloadRes .collect::>>() .await; - let mut media = results + let mut media_items = results .into_iter() .flatten() - .collect::>(); + .filter(|(path, _)| { + metadata(path) + .map(|m| m.is_file() && m.len() > 0) + .unwrap_or(false) + }) + .collect::>(); - if media.is_empty() { + if media_items.is_empty() { return Err(Error::NoMediaFound); } // deterministic ordering - media.sort_by_key(|(p, _)| p.clone()); + media_items.sort_by(|(p1, _), (p2, _)| p1.cmp(p2)); + + info!(media_items = media_items.len(), "Sending media to chat"); // prefer video over image - if let Some((path, MediaKind::Video)) = media.iter().find(|(_, k)| *k == MediaKind::Video) { - return send_media_from_path(bot, chat_id, path.clone(), Some(MediaKind::Video)).await; + if let Some((path, MediaKind::Video)) = media_items.iter().find(|(_, k)| *k == MediaKind::Video) + { + return send_media_from_path(bot, chat_id, path.clone(), MediaKind::Video).await; } - if let Some((path, MediaKind::Image)) = media.iter().find(|(_, k)| *k == MediaKind::Image) { - return send_media_from_path(bot, chat_id, path.clone(), Some(MediaKind::Image)).await; + if let Some((path, MediaKind::Image)) = media_items.iter().find(|(_, k)| *k == MediaKind::Image) + { + return send_media_from_path(bot, chat_id, path.clone(), MediaKind::Image).await; } Err(Error::NoMediaFound) } + +/// Filter function to determine if a file is potentially media based on name/extension. +fn is_potential_media_file(path: &Path) -> bool { + if let Some(filename) = path.file_name().and_then(OsStr::to_str) { + // Skip common non-media files + if filename.starts_with('.') || filename.to_lowercase().contains("metadata") { + return false; + } + } + + let ext = match path.extension().and_then(OsStr::to_str) { + Some(e) => e.to_lowercase(), + None => return false, + }; + + if FORBIDDEN_EXTENSIONS + .iter() + .any(|forbidden| forbidden.eq_ignore_ascii_case(&ext)) + { + return false; + } + + VIDEO_EXTSTENSIONS + .iter() + .chain(IMAGE_EXTSTENSIONS.iter()) + .any(|allowed| allowed.eq_ignore_ascii_case(&ext)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn is_potential_media_file_() { + assert!(is_potential_media_file(Path::new("video.mp4"))); + assert!(is_potential_media_file(Path::new("image.jpg"))); + assert!(!is_potential_media_file(Path::new(".DS_Store"))); + assert!(!is_potential_media_file(Path::new("metadata.json"))); + assert!(!is_potential_media_file(Path::new("download.log"))); + } +} diff --git a/src/error.rs b/src/error.rs index 36fcafe..d71979b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -17,12 +17,21 @@ pub enum Error { #[error("unknown media kind")] UnknownMediaKind, + #[error("validation failed: {0}")] + ValidationFailed(String), + #[error("teloxide error: {0}")] Teloxide(#[from] teloxide::RequestError), #[error("join error: {0}")] Join(#[from] tokio::task::JoinError), + #[error("rate limit exceeded")] + RateLimit, + + #[error("")] + QuoteError(#[from] shlex::QuoteError), + #[error("other: {0}")] Other(String), } @@ -42,6 +51,11 @@ impl Error { pub fn ytdlp_failed(text: impl Into) -> Self { Self::YTDLPFailed(text.into()) } + + #[inline] + pub fn validation_falied(text: impl Into) -> Self { + Self::ValidationFailed(text.into()) + } } pub type Result = std::result::Result; diff --git a/src/lib.rs b/src/lib.rs index e7f3b95..c4d3cdf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,3 +4,4 @@ pub mod error; pub mod handlers; pub mod telemetry; pub mod utils; +pub mod validate; diff --git a/src/utils.rs b/src/utils.rs index 6f86a0d..b0915c0 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,5 +1,5 @@ use crate::{ - comments::global_comments, + comments::{Comments, global_comments}, error::{Error, Result}, }; use std::{ @@ -13,10 +13,10 @@ use teloxide::{ types::{ChatId, InputFile}, }; use tokio::{fs::File, io::AsyncReadExt}; +use tracing::warn; -const TELEGRAM_CAPTION_LIMIT: usize = 1024; -static VIDEO_EXTS: &[&str] = &["mp4", "webm", "mov", "mkv", "avi"]; -static IMAGE_EXTS: &[&str] = &["jpg", "jpeg", "png", "webp"]; +pub const VIDEO_EXTSTENSIONS: &[&str] = &["mp4", "webm", "mov", "mkv", "avi", "m4v", "3gp"]; +pub const IMAGE_EXTSTENSIONS: &[&str] = &["jpg", "jpeg", "png", "webp", "gif", "bmp"]; /// Simple media kind enum shared by handlers. #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -27,26 +27,25 @@ pub enum MediaKind { } /// Detect media kind first by extension, then by content/magic (sync). -/// NOTE: `infer::get_from_path` is blocking — use `detect_media_kind_async` in -/// async contexts to avoid blocking the Tokio runtime. pub fn detect_media_kind(path: &Path) -> MediaKind { if let Some(ext) = path.extension().and_then(OsStr::to_str) { - if VIDEO_EXTS.iter().any(|e| e.eq_ignore_ascii_case(ext)) { + let compare = |e: &&str| e.eq_ignore_ascii_case(ext); + if VIDEO_EXTSTENSIONS.iter().any(compare) { return MediaKind::Video; } - if IMAGE_EXTS.iter().any(|e| e.eq_ignore_ascii_case(ext)) { + if IMAGE_EXTSTENSIONS.iter().any(compare) { return MediaKind::Image; } } + // Fallback to MIME type detection if let Ok(Some(kind)) = infer::get_from_path(path) { - let mt = kind.mime_type(); - if mt.starts_with("video/") { - return MediaKind::Video; - } - if mt.starts_with("image/") { - return MediaKind::Image; - } + let mime_type = kind.mime_type(); + return match mime_type.split('/').next() { + Some("video") => MediaKind::Video, + Some("image") => MediaKind::Image, + _ => MediaKind::Unknown, + }; } MediaKind::Unknown @@ -56,21 +55,24 @@ pub fn detect_media_kind(path: &Path) -> MediaKind { /// sample asynchronously and run `infer::get` on the buffer. pub async fn detect_media_kind_async(path: &Path) -> MediaKind { if let Some(ext) = path.extension().and_then(OsStr::to_str) { - if VIDEO_EXTS.iter().any(|e| e.eq_ignore_ascii_case(ext)) { + let compare = |e: &&str| e.eq_ignore_ascii_case(ext); + if VIDEO_EXTSTENSIONS.iter().any(compare) { return MediaKind::Video; } - if IMAGE_EXTS.iter().any(|e| e.eq_ignore_ascii_case(ext)) { + if IMAGE_EXTSTENSIONS.iter().any(compare) { return MediaKind::Image; } } // Read a small prefix (8 KiB) asynchronously and probe - if let Ok(mut f) = File::open(path).await { - let mut buf = vec![0u8; 8192]; - match f.read(&mut buf).await { - Ok(n) if n > 0 => { - buf.truncate(n); - if let Some(k) = infer::get(&buf) { + match File::open(path).await { + Ok(mut file) => { + let mut buffer = vec![0u8; 8192]; + if let Ok(n) = file.read(&mut buffer).await + && n > 0 + { + buffer.truncate(n); + if let Some(k) = infer::get(&buffer) { let mt = k.mime_type(); if mt.starts_with("video/") { return MediaKind::Video; @@ -80,8 +82,8 @@ pub async fn detect_media_kind_async(path: &Path) -> MediaKind { } } } - _ => {} } + Err(e) => warn!(path = ?path.display(), "Failed to read file for media detection: {e}"), } MediaKind::Unknown @@ -96,18 +98,11 @@ pub async fn send_media_from_path( bot: &Bot, chat_id: ChatId, path: PathBuf, - kind: Option, + kind: MediaKind, ) -> Result<()> { - let kind = kind.unwrap_or_else(|| detect_media_kind(&path)); - - let caption_opt = global_comments().map(|c| { - let mut caption = c.build_caption(); - if caption.chars().count() > TELEGRAM_CAPTION_LIMIT { - caption = caption.chars().take(TELEGRAM_CAPTION_LIMIT - 1).collect(); - caption.push_str("..."); - } - caption - }); + let caption_opt = global_comments() + .map(Comments::build_caption) + .filter(|caption| !caption.is_empty()); match kind { MediaKind::Video => { @@ -116,7 +111,7 @@ pub async fn send_media_from_path( if let Some(c) = caption_opt { req = req.caption(c); } - req.await.map_err(Error::from)?; + req.await?; } MediaKind::Image => { let photo = InputFile::file(path); @@ -124,14 +119,34 @@ pub async fn send_media_from_path( if let Some(c) = caption_opt { req = req.caption(c); } - req.await.map_err(Error::from)?; + req.await?; } MediaKind::Unknown => { bot.send_message(chat_id, "No supported media found") - .await - .map_err(Error::from)?; + .await?; return Err(Error::UnknownMediaKind); } } Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn detect_media_kind_by_extension() { + assert_eq!(detect_media_kind(Path::new("video.mp4")), MediaKind::Video); + assert_eq!(detect_media_kind(Path::new("image.jpg")), MediaKind::Image); + assert_eq!( + detect_media_kind(Path::new("unknown.txt")), + MediaKind::Unknown + ); + } + + #[test] + fn media_kind_case_insensitive() { + assert_eq!(detect_media_kind(Path::new("VIDEO.MP4")), MediaKind::Video); + assert_eq!(detect_media_kind(Path::new("IMAGE.JPG")), MediaKind::Image); + } +} diff --git a/src/validate/mod.rs b/src/validate/mod.rs new file mode 100644 index 0000000..8e58995 --- /dev/null +++ b/src/validate/mod.rs @@ -0,0 +1,23 @@ +pub mod utils; + +use crate::error::Result; +use regex::Regex; +use std::sync::OnceLock; + +/// Trait for validating platform-specific identifiers (e.g., shortcodes, URLs) +/// extracted from user input. +/// +/// Implementors should: +/// - Check format (e.g., length, characters). +/// - Canonicalize if needed (e.g., trim query params from a URL). +/// - Return `Ok(canonical_id)` on success or `Err(Error::Other(...))` on failure. +pub trait Validate { + /// Validate the input and return a canonicalized String (e.g., cleaned shortcode or URL). + fn validate(&self, input: &str) -> Result; +} + +/// Helper function to create a lazy static Regex (reused across impls). +pub fn lazy_regex(pattern: &str) -> &'static Regex { + static RE: OnceLock = OnceLock::new(); + RE.get_or_init(|| Regex::new(pattern).expect("failed to compile validation regex")) +} diff --git a/src/validate/utils.rs b/src/validate/utils.rs new file mode 100644 index 0000000..0eda9d8 --- /dev/null +++ b/src/validate/utils.rs @@ -0,0 +1,10 @@ +use crate::error::{Error, Result}; + +/// Trims whitespace and rejects empty strings. +pub fn validate_non_empty(input: &str) -> Result<&str> { + let trimmed = input.trim(); + if trimmed.is_empty() { + return Err(Error::validation_falied("input cannot be empty")); + } + Ok(trimmed) +}