From 465f9c49e91d946bcc84fb75b6d6865d1e5d8035 Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Sat, 3 Jan 2026 23:27:27 +0200 Subject: [PATCH] refactor: improve idiomaticity and async correctness --- src/config.rs | 14 +++++-- src/download.rs | 107 ++++++++++++++++++++++++++---------------------- src/handler.rs | 10 ++--- src/main.rs | 18 ++++---- src/utils.rs | 85 ++++++++++++++++++++------------------ 5 files changed, 125 insertions(+), 109 deletions(-) diff --git a/src/config.rs b/src/config.rs index defa888..598a96e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -40,9 +40,10 @@ impl Config { /// Load configuration from environment variables. #[must_use] pub fn from_env() -> Self { - let chat_id: Option = env::var("CHAT_ID") + let chat_id = env::var("CHAT_ID") .ok() - .and_then(|id| id.parse::().ok().map(ChatId)); + .and_then(|id| id.parse::().ok()) + .map(ChatId); Self { chat_id, youtube: YoutubeConfig::from_env(), @@ -64,9 +65,14 @@ impl Config { } } /// Get global config (initialized by `Config::init(self)`). +/// +/// # Panics +/// +/// Panics if config has not been initialized. +#[inline] #[must_use] -pub fn global_config() -> Config { - GLOBAL_CONFIG.get().cloned().unwrap_or_default() +pub fn global_config() -> &'static Config { + GLOBAL_CONFIG.get().expect("config not initialized") } impl YoutubeConfig { diff --git a/src/download.rs b/src/download.rs index b8777a2..02bda3d 100644 --- a/src/download.rs +++ b/src/download.rs @@ -10,7 +10,7 @@ use futures::{StreamExt, stream}; use std::{ cmp::min, ffi::OsStr, - fs::{self, metadata}, + fs, path::{Path, PathBuf}, process::Stdio, }; @@ -104,13 +104,14 @@ async fn run_command_in_tempdir(cmd: &str, args: &[&str]) -> Result) -> Result { +pub async fn download_instagram(url: String) -> Result { let config = global_config(); - let args = ["-t", "mp4", "--extractor-args", "instagram:"] - .iter() - .map(ToString::to_string) - .collect(); - run_yt_dlp(args, config.instagram.cookies_path.as_ref(), &url.into()).await + run_yt_dlp( + &["-t", "mp4", "--extractor-args", "instagram:"], + config.instagram.cookies_path.as_ref(), + &url, + ) + .await } /// Download a Tiktok URL with yt-dlp. @@ -119,13 +120,14 @@ pub async fn download_instagram(url: impl Into) -> Result) -> Result { +pub async fn download_tiktok(url: String) -> Result { let config = global_config(); - let args = ["-t", "mp4", "--extractor-args", "tiktok:"] - .iter() - .map(ToString::to_string) - .collect(); - run_yt_dlp(args, config.tiktok.cookies_path.as_ref(), &url.into()).await + run_yt_dlp( + &["-t", "mp4", "--extractor-args", "tiktok:"], + config.tiktok.cookies_path.as_ref(), + &url, + ) + .await } /// Download a Twitter URL with yt-dlp. @@ -134,13 +136,14 @@ pub async fn download_tiktok(url: impl Into) -> Result { /// /// - Propagates `run_command_in_tempdir` errors. #[cfg(feature = "twitter")] -pub async fn download_twitter(url: impl Into) -> Result { +pub async fn download_twitter(url: String) -> Result { let config = global_config(); - let args = ["-t", "mp4", "--extractor-args", "twitter:"] - .iter() - .map(ToString::to_string) - .collect(); - run_yt_dlp(args, config.twitter.cookies_path.as_ref(), &url.into()).await + run_yt_dlp( + &["-t", "mp4", "--extractor-args", "twitter:"], + config.twitter.cookies_path.as_ref(), + &url, + ) + .await } /// Download a URL with yt-dlp. @@ -149,19 +152,20 @@ pub async fn download_twitter(url: impl Into) -> Result /// /// - Propagates `run_command_in_tempdir` errors. #[cfg(feature = "youtube")] -pub async fn download_youtube(url: impl Into) -> Result { +pub async fn download_youtube(url: String) -> Result { let config = global_config(); - let args = [ - "--no-playlist", - "-t", - "mp4", - "--postprocessor-args", - &config.youtube.postprocessor_args, - ] - .iter() - .map(ToString::to_string) - .collect(); - run_yt_dlp(args, config.youtube.cookies_path.as_ref(), &url.into()).await + run_yt_dlp( + &[ + "--no-playlist", + "-t", + "mp4", + "--postprocessor-args", + &config.youtube.postprocessor_args, + ], + config.youtube.cookies_path.as_ref(), + &url, + ) + .await } /// Post-process a `DownloadResult`. @@ -183,9 +187,17 @@ pub async fn process_download_result( return Err(Error::NoMediaFound); } - // Detect kinds in parallel with limiter concurrency + // Detect kinds and validate files in parallel let concurrency = min(8, dr.files.len()); let results = stream::iter(dr.files.drain(..).map(|path| async move { + // Check file metadata asynchronously + let Ok(meta) = tokio::fs::metadata(&path).await else { + return None; + }; + if !meta.is_file() || meta.len() == 0 { + return None; + } + let kind = detect_media_kind_async(&path).await; match kind { MediaKind::Unknown => None, @@ -193,18 +205,10 @@ pub async fn process_download_result( } })) .buffer_unordered(concurrency) - .collect::>>() + .collect::>() .await; - let mut media_items = results - .into_iter() - .flatten() - .filter(|(path, _)| { - metadata(path) - .map(|m| m.is_file() && m.len() > 0) - .unwrap_or(false) - }) - .collect::>(); + let mut media_items = results.into_iter().flatten().collect::>(); if media_items.is_empty() { return Err(Error::NoMediaFound); @@ -254,18 +258,21 @@ fn is_potential_media_file(path: &Path) -> bool { } async fn run_yt_dlp( - mut args: Vec, + base_args: &[&str], cookies_path: Option<&PathBuf>, url: &str, ) -> Result { - if let Some(path) = cookies_path { - args.extend(["--cookies".to_string(), path.to_string_lossy().to_string()]); - } - args.push(url.to_string()); + let cookies_path_str; + let mut args = base_args.to_vec(); - debug!(args = ?args, "downloadting content"); - let args_ref = args.iter().map(String::as_ref).collect::>(); - run_command_in_tempdir("yt-dlp", &args_ref).await + if let Some(path) = cookies_path { + cookies_path_str = path.to_string_lossy(); + args.extend(["--cookies", &cookies_path_str]); + } + args.push(url); + + debug!(args = ?args, "downloading content"); + run_command_in_tempdir("yt-dlp", &args).await } #[cfg(test)] diff --git a/src/handler.rs b/src/handler.rs index 54a2abd..25920ad 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -7,7 +7,7 @@ use std::{pin::Pin, sync::Arc}; use teloxide::{Bot, types::ChatId}; use tracing::info; -type DownloadFn = fn(&str) -> Pin> + Send>>; +type DownloadFn = fn(String) -> Pin> + Send>>; #[derive(Debug, Clone)] pub struct Handler { @@ -52,7 +52,7 @@ impl Handler { /// Returns `Error` if download or media processing fails. pub async fn handle(&self, bot: &Bot, chat_id: ChatId, url: &str) -> Result<()> { info!(handler = %self.name(), url = %url, "handling url"); - let dr = (self.func)(url).await?; + let dr = (self.func)(url.to_owned()).await?; process_download_result(bot, chat_id, dr).await } } @@ -60,10 +60,8 @@ impl Handler { macro_rules! handler { ($feature:expr, $regex:expr, $download_fn:path) => { #[cfg(feature = $feature)] - Handler::new($feature, $regex, |url| { - Box::pin($download_fn(url.to_string())) - }) - .expect(concat!("failed to create ", $feature, " handler")) + Handler::new($feature, $regex, |url: String| Box::pin($download_fn(url))) + .expect(concat!("failed to create ", $feature, " handler")) }; } diff --git a/src/main.rs b/src/main.rs index 39ca7d3..6a2d174 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ use dotenv::dotenv; +use std::sync::Arc; use teloxide::{prelude::*, respond, utils::command::BotCommands}; use tg_relay_rs::{ commands::{Command, answer}, @@ -17,27 +18,26 @@ async fn main() -> color_eyre::Result<()> { Comments::load_from_file("comments.txt") .await - .map_err(|e| { + .unwrap_or_else(|e| { warn!("failed to load comments.txt: {e}; using dummy comments"); - e + Comments::dummy() }) - .unwrap_or_else(|_| Comments::dummy()) .init()?; Config::from_env().init()?; let bot = Bot::from_env(); - let bot_name = bot.get_me().await?.username().to_owned(); + let bot_name: Arc = bot.get_me().await?.username().into(); - info!(name = bot_name, "bot starting"); + info!(name = %bot_name, "bot starting"); let handlers = create_handlers(); teloxide::repl(bot.clone(), move |bot: Bot, msg: Message| { - let handlers = handlers.clone(); - let bot_name_cloned = bot_name.clone(); + let handlers = Arc::clone(&handlers); + let bot_name = Arc::clone(&bot_name); async move { - process_cmd(&bot, &msg, &bot_name_cloned).await; + process_cmd(&bot, &msg, &bot_name).await; process_message(&bot, &msg, &handlers).await; respond(()) } @@ -60,7 +60,7 @@ async fn process_message(bot: &Bot, msg: &Message, handlers: &[Handler]) { .send_message(msg.chat.id, FAILED_FETCH_MEDIA_MESSAGE) .await; if let Some(chat_id) = global_config().chat_id { - let _ = bot.send_message(chat_id, format!("{err}")).await; + let _ = bot.send_message(chat_id, err.to_string()).await; } } return; diff --git a/src/utils.rs b/src/utils.rs index b49728d..a59d633 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -35,26 +35,43 @@ impl MediaKind { } } +/// Check if extension matches any in the given list (case-insensitive). +#[inline] +fn ext_matches(ext: &str, extensions: &[&str]) -> bool { + extensions.iter().any(|e| e.eq_ignore_ascii_case(ext)) +} + +/// Detect media kind from file extension. +fn detect_from_extension(path: &Path) -> Option { + let ext = path.extension().and_then(OsStr::to_str)?; + if ext_matches(ext, VIDEO_EXTSTENSIONS) { + return Some(MediaKind::Video); + } + if ext_matches(ext, IMAGE_EXTSTENSIONS) { + return Some(MediaKind::Image); + } + None +} + +/// Detect media kind from MIME type string. +fn detect_from_mime(mime_type: &str) -> MediaKind { + match mime_type.split('/').next() { + Some("video") => MediaKind::Video, + Some("image") => MediaKind::Image, + _ => MediaKind::Unknown, + } +} + /// Detect media kind first by extension, then by content/magic (sync). +#[must_use] pub fn detect_media_kind(path: &Path) -> MediaKind { - if let Some(ext) = path.extension().and_then(OsStr::to_str) { - let compare = |e: &&str| e.eq_ignore_ascii_case(ext); - if VIDEO_EXTSTENSIONS.iter().any(compare) { - return MediaKind::Video; - } - if IMAGE_EXTSTENSIONS.iter().any(compare) { - return MediaKind::Image; - } + if let Some(kind) = detect_from_extension(path) { + return kind; } // Fallback to MIME type detection if let Ok(Some(kind)) = infer::get_from_path(path) { - let mime_type = kind.mime_type(); - return match mime_type.split('/').next() { - Some("video") => MediaKind::Video, - Some("image") => MediaKind::Image, - _ => MediaKind::Unknown, - }; + return detect_from_mime(kind.mime_type()); } MediaKind::Unknown @@ -63,36 +80,24 @@ pub fn detect_media_kind(path: &Path) -> MediaKind { /// Async/non-blocking detection: check extension first, otherwise read a small /// sample asynchronously and run `infer::get` on the buffer. pub async fn detect_media_kind_async(path: &Path) -> MediaKind { - if let Some(ext) = path.extension().and_then(OsStr::to_str) { - let compare = |e: &&str| e.eq_ignore_ascii_case(ext); - if VIDEO_EXTSTENSIONS.iter().any(compare) { - return MediaKind::Video; - } - if IMAGE_EXTSTENSIONS.iter().any(compare) { - return MediaKind::Image; - } + if let Some(kind) = detect_from_extension(path) { + return kind; } // Read a small prefix (8 KiB) asynchronously and probe - match File::open(path).await { - Ok(mut file) => { - let mut buffer = vec![0u8; 8192]; - if let Ok(n) = file.read(&mut buffer).await - && n > 0 - { - buffer.truncate(n); - if let Some(k) = infer::get(&buffer) { - let mt = k.mime_type(); - if mt.starts_with("video/") { - return MediaKind::Video; - } - if mt.starts_with("image/") { - return MediaKind::Image; - } - } - } + let Ok(mut file) = File::open(path).await else { + warn!(path = ?path.display(), "Failed to open file for media detection"); + return MediaKind::Unknown; + }; + + let mut buffer = vec![0u8; 8192]; + if let Ok(n) = file.read(&mut buffer).await + && n > 0 + { + buffer.truncate(n); + if let Some(k) = infer::get(&buffer) { + return detect_from_mime(k.mime_type()); } - Err(e) => warn!(path = ?path.display(), "Failed to read file for media detection: {e}"), } MediaKind::Unknown