refactor: improve idiomaticity and async correctness

This commit is contained in:
Kristofers Solo 2026-01-03 23:27:27 +02:00
parent da38ec8d69
commit 465f9c49e9
Signed by: kristoferssolo
GPG Key ID: 8687F2D3EEE6F0ED
5 changed files with 125 additions and 109 deletions

View File

@ -40,9 +40,10 @@ impl Config {
/// Load configuration from environment variables. /// Load configuration from environment variables.
#[must_use] #[must_use]
pub fn from_env() -> Self { pub fn from_env() -> Self {
let chat_id: Option<ChatId> = env::var("CHAT_ID") let chat_id = env::var("CHAT_ID")
.ok() .ok()
.and_then(|id| id.parse::<i64>().ok().map(ChatId)); .and_then(|id| id.parse::<i64>().ok())
.map(ChatId);
Self { Self {
chat_id, chat_id,
youtube: YoutubeConfig::from_env(), youtube: YoutubeConfig::from_env(),
@ -64,9 +65,14 @@ impl Config {
} }
} }
/// Get global config (initialized by `Config::init(self)`). /// Get global config (initialized by `Config::init(self)`).
///
/// # Panics
///
/// Panics if config has not been initialized.
#[inline]
#[must_use] #[must_use]
pub fn global_config() -> Config { pub fn global_config() -> &'static Config {
GLOBAL_CONFIG.get().cloned().unwrap_or_default() GLOBAL_CONFIG.get().expect("config not initialized")
} }
impl YoutubeConfig { impl YoutubeConfig {

View File

@ -10,7 +10,7 @@ use futures::{StreamExt, stream};
use std::{ use std::{
cmp::min, cmp::min,
ffi::OsStr, ffi::OsStr,
fs::{self, metadata}, fs,
path::{Path, PathBuf}, path::{Path, PathBuf},
process::Stdio, process::Stdio,
}; };
@ -104,13 +104,14 @@ async fn run_command_in_tempdir(cmd: &str, args: &[&str]) -> Result<DownloadResu
/// ///
/// - Propagates `run_command_in_tempdir` errors. /// - Propagates `run_command_in_tempdir` errors.
#[cfg(feature = "instagram")] #[cfg(feature = "instagram")]
pub async fn download_instagram(url: impl Into<String>) -> Result<DownloadResult> { pub async fn download_instagram(url: String) -> Result<DownloadResult> {
let config = global_config(); let config = global_config();
let args = ["-t", "mp4", "--extractor-args", "instagram:"] run_yt_dlp(
.iter() &["-t", "mp4", "--extractor-args", "instagram:"],
.map(ToString::to_string) config.instagram.cookies_path.as_ref(),
.collect(); &url,
run_yt_dlp(args, config.instagram.cookies_path.as_ref(), &url.into()).await )
.await
} }
/// Download a Tiktok URL with yt-dlp. /// Download a Tiktok URL with yt-dlp.
@ -119,13 +120,14 @@ pub async fn download_instagram(url: impl Into<String>) -> Result<DownloadResult
/// ///
/// - Propagates `run_command_in_tempdir` errors. /// - Propagates `run_command_in_tempdir` errors.
#[cfg(feature = "tiktok")] #[cfg(feature = "tiktok")]
pub async fn download_tiktok(url: impl Into<String>) -> Result<DownloadResult> { pub async fn download_tiktok(url: String) -> Result<DownloadResult> {
let config = global_config(); let config = global_config();
let args = ["-t", "mp4", "--extractor-args", "tiktok:"] run_yt_dlp(
.iter() &["-t", "mp4", "--extractor-args", "tiktok:"],
.map(ToString::to_string) config.tiktok.cookies_path.as_ref(),
.collect(); &url,
run_yt_dlp(args, config.tiktok.cookies_path.as_ref(), &url.into()).await )
.await
} }
/// Download a Twitter URL with yt-dlp. /// Download a Twitter URL with yt-dlp.
@ -134,13 +136,14 @@ pub async fn download_tiktok(url: impl Into<String>) -> Result<DownloadResult> {
/// ///
/// - Propagates `run_command_in_tempdir` errors. /// - Propagates `run_command_in_tempdir` errors.
#[cfg(feature = "twitter")] #[cfg(feature = "twitter")]
pub async fn download_twitter(url: impl Into<String>) -> Result<DownloadResult> { pub async fn download_twitter(url: String) -> Result<DownloadResult> {
let config = global_config(); let config = global_config();
let args = ["-t", "mp4", "--extractor-args", "twitter:"] run_yt_dlp(
.iter() &["-t", "mp4", "--extractor-args", "twitter:"],
.map(ToString::to_string) config.twitter.cookies_path.as_ref(),
.collect(); &url,
run_yt_dlp(args, config.twitter.cookies_path.as_ref(), &url.into()).await )
.await
} }
/// Download a URL with yt-dlp. /// Download a URL with yt-dlp.
@ -149,19 +152,20 @@ pub async fn download_twitter(url: impl Into<String>) -> Result<DownloadResult>
/// ///
/// - Propagates `run_command_in_tempdir` errors. /// - Propagates `run_command_in_tempdir` errors.
#[cfg(feature = "youtube")] #[cfg(feature = "youtube")]
pub async fn download_youtube(url: impl Into<String>) -> Result<DownloadResult> { pub async fn download_youtube(url: String) -> Result<DownloadResult> {
let config = global_config(); let config = global_config();
let args = [ run_yt_dlp(
"--no-playlist", &[
"-t", "--no-playlist",
"mp4", "-t",
"--postprocessor-args", "mp4",
&config.youtube.postprocessor_args, "--postprocessor-args",
] &config.youtube.postprocessor_args,
.iter() ],
.map(ToString::to_string) config.youtube.cookies_path.as_ref(),
.collect(); &url,
run_yt_dlp(args, config.youtube.cookies_path.as_ref(), &url.into()).await )
.await
} }
/// Post-process a `DownloadResult`. /// Post-process a `DownloadResult`.
@ -183,9 +187,17 @@ pub async fn process_download_result(
return Err(Error::NoMediaFound); return Err(Error::NoMediaFound);
} }
// Detect kinds in parallel with limiter concurrency // Detect kinds and validate files in parallel
let concurrency = min(8, dr.files.len()); let concurrency = min(8, dr.files.len());
let results = stream::iter(dr.files.drain(..).map(|path| async move { let results = stream::iter(dr.files.drain(..).map(|path| async move {
// Check file metadata asynchronously
let Ok(meta) = tokio::fs::metadata(&path).await else {
return None;
};
if !meta.is_file() || meta.len() == 0 {
return None;
}
let kind = detect_media_kind_async(&path).await; let kind = detect_media_kind_async(&path).await;
match kind { match kind {
MediaKind::Unknown => None, MediaKind::Unknown => None,
@ -193,18 +205,10 @@ pub async fn process_download_result(
} }
})) }))
.buffer_unordered(concurrency) .buffer_unordered(concurrency)
.collect::<Vec<Option<(PathBuf, MediaKind)>>>() .collect::<Vec<_>>()
.await; .await;
let mut media_items = results let mut media_items = results.into_iter().flatten().collect::<Vec<_>>();
.into_iter()
.flatten()
.filter(|(path, _)| {
metadata(path)
.map(|m| m.is_file() && m.len() > 0)
.unwrap_or(false)
})
.collect::<Vec<_>>();
if media_items.is_empty() { if media_items.is_empty() {
return Err(Error::NoMediaFound); return Err(Error::NoMediaFound);
@ -254,18 +258,21 @@ fn is_potential_media_file(path: &Path) -> bool {
} }
async fn run_yt_dlp( async fn run_yt_dlp(
mut args: Vec<String>, base_args: &[&str],
cookies_path: Option<&PathBuf>, cookies_path: Option<&PathBuf>,
url: &str, url: &str,
) -> Result<DownloadResult> { ) -> Result<DownloadResult> {
if let Some(path) = cookies_path { let cookies_path_str;
args.extend(["--cookies".to_string(), path.to_string_lossy().to_string()]); let mut args = base_args.to_vec();
}
args.push(url.to_string());
debug!(args = ?args, "downloadting content"); if let Some(path) = cookies_path {
let args_ref = args.iter().map(String::as_ref).collect::<Vec<_>>(); cookies_path_str = path.to_string_lossy();
run_command_in_tempdir("yt-dlp", &args_ref).await args.extend(["--cookies", &cookies_path_str]);
}
args.push(url);
debug!(args = ?args, "downloading content");
run_command_in_tempdir("yt-dlp", &args).await
} }
#[cfg(test)] #[cfg(test)]

View File

@ -7,7 +7,7 @@ use std::{pin::Pin, sync::Arc};
use teloxide::{Bot, types::ChatId}; use teloxide::{Bot, types::ChatId};
use tracing::info; use tracing::info;
type DownloadFn = fn(&str) -> Pin<Box<dyn Future<Output = Result<DownloadResult>> + Send>>; type DownloadFn = fn(String) -> Pin<Box<dyn Future<Output = Result<DownloadResult>> + Send>>;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Handler { pub struct Handler {
@ -52,7 +52,7 @@ impl Handler {
/// Returns `Error` if download or media processing fails. /// Returns `Error` if download or media processing fails.
pub async fn handle(&self, bot: &Bot, chat_id: ChatId, url: &str) -> Result<()> { pub async fn handle(&self, bot: &Bot, chat_id: ChatId, url: &str) -> Result<()> {
info!(handler = %self.name(), url = %url, "handling url"); info!(handler = %self.name(), url = %url, "handling url");
let dr = (self.func)(url).await?; let dr = (self.func)(url.to_owned()).await?;
process_download_result(bot, chat_id, dr).await process_download_result(bot, chat_id, dr).await
} }
} }
@ -60,10 +60,8 @@ impl Handler {
macro_rules! handler { macro_rules! handler {
($feature:expr, $regex:expr, $download_fn:path) => { ($feature:expr, $regex:expr, $download_fn:path) => {
#[cfg(feature = $feature)] #[cfg(feature = $feature)]
Handler::new($feature, $regex, |url| { Handler::new($feature, $regex, |url: String| Box::pin($download_fn(url)))
Box::pin($download_fn(url.to_string())) .expect(concat!("failed to create ", $feature, " handler"))
})
.expect(concat!("failed to create ", $feature, " handler"))
}; };
} }

View File

@ -1,4 +1,5 @@
use dotenv::dotenv; use dotenv::dotenv;
use std::sync::Arc;
use teloxide::{prelude::*, respond, utils::command::BotCommands}; use teloxide::{prelude::*, respond, utils::command::BotCommands};
use tg_relay_rs::{ use tg_relay_rs::{
commands::{Command, answer}, commands::{Command, answer},
@ -17,27 +18,26 @@ async fn main() -> color_eyre::Result<()> {
Comments::load_from_file("comments.txt") Comments::load_from_file("comments.txt")
.await .await
.map_err(|e| { .unwrap_or_else(|e| {
warn!("failed to load comments.txt: {e}; using dummy comments"); warn!("failed to load comments.txt: {e}; using dummy comments");
e Comments::dummy()
}) })
.unwrap_or_else(|_| Comments::dummy())
.init()?; .init()?;
Config::from_env().init()?; Config::from_env().init()?;
let bot = Bot::from_env(); let bot = Bot::from_env();
let bot_name = bot.get_me().await?.username().to_owned(); let bot_name: Arc<str> = bot.get_me().await?.username().into();
info!(name = bot_name, "bot starting"); info!(name = %bot_name, "bot starting");
let handlers = create_handlers(); let handlers = create_handlers();
teloxide::repl(bot.clone(), move |bot: Bot, msg: Message| { teloxide::repl(bot.clone(), move |bot: Bot, msg: Message| {
let handlers = handlers.clone(); let handlers = Arc::clone(&handlers);
let bot_name_cloned = bot_name.clone(); let bot_name = Arc::clone(&bot_name);
async move { async move {
process_cmd(&bot, &msg, &bot_name_cloned).await; process_cmd(&bot, &msg, &bot_name).await;
process_message(&bot, &msg, &handlers).await; process_message(&bot, &msg, &handlers).await;
respond(()) respond(())
} }
@ -60,7 +60,7 @@ async fn process_message(bot: &Bot, msg: &Message, handlers: &[Handler]) {
.send_message(msg.chat.id, FAILED_FETCH_MEDIA_MESSAGE) .send_message(msg.chat.id, FAILED_FETCH_MEDIA_MESSAGE)
.await; .await;
if let Some(chat_id) = global_config().chat_id { if let Some(chat_id) = global_config().chat_id {
let _ = bot.send_message(chat_id, format!("{err}")).await; let _ = bot.send_message(chat_id, err.to_string()).await;
} }
} }
return; return;

View File

@ -35,26 +35,43 @@ impl MediaKind {
} }
} }
/// Check if extension matches any in the given list (case-insensitive).
#[inline]
fn ext_matches(ext: &str, extensions: &[&str]) -> bool {
extensions.iter().any(|e| e.eq_ignore_ascii_case(ext))
}
/// Detect media kind from file extension.
fn detect_from_extension(path: &Path) -> Option<MediaKind> {
let ext = path.extension().and_then(OsStr::to_str)?;
if ext_matches(ext, VIDEO_EXTSTENSIONS) {
return Some(MediaKind::Video);
}
if ext_matches(ext, IMAGE_EXTSTENSIONS) {
return Some(MediaKind::Image);
}
None
}
/// Detect media kind from MIME type string.
fn detect_from_mime(mime_type: &str) -> MediaKind {
match mime_type.split('/').next() {
Some("video") => MediaKind::Video,
Some("image") => MediaKind::Image,
_ => MediaKind::Unknown,
}
}
/// Detect media kind first by extension, then by content/magic (sync). /// Detect media kind first by extension, then by content/magic (sync).
#[must_use]
pub fn detect_media_kind(path: &Path) -> MediaKind { pub fn detect_media_kind(path: &Path) -> MediaKind {
if let Some(ext) = path.extension().and_then(OsStr::to_str) { if let Some(kind) = detect_from_extension(path) {
let compare = |e: &&str| e.eq_ignore_ascii_case(ext); return kind;
if VIDEO_EXTSTENSIONS.iter().any(compare) {
return MediaKind::Video;
}
if IMAGE_EXTSTENSIONS.iter().any(compare) {
return MediaKind::Image;
}
} }
// Fallback to MIME type detection // Fallback to MIME type detection
if let Ok(Some(kind)) = infer::get_from_path(path) { if let Ok(Some(kind)) = infer::get_from_path(path) {
let mime_type = kind.mime_type(); return detect_from_mime(kind.mime_type());
return match mime_type.split('/').next() {
Some("video") => MediaKind::Video,
Some("image") => MediaKind::Image,
_ => MediaKind::Unknown,
};
} }
MediaKind::Unknown MediaKind::Unknown
@ -63,36 +80,24 @@ pub fn detect_media_kind(path: &Path) -> MediaKind {
/// Async/non-blocking detection: check extension first, otherwise read a small /// Async/non-blocking detection: check extension first, otherwise read a small
/// sample asynchronously and run `infer::get` on the buffer. /// sample asynchronously and run `infer::get` on the buffer.
pub async fn detect_media_kind_async(path: &Path) -> MediaKind { pub async fn detect_media_kind_async(path: &Path) -> MediaKind {
if let Some(ext) = path.extension().and_then(OsStr::to_str) { if let Some(kind) = detect_from_extension(path) {
let compare = |e: &&str| e.eq_ignore_ascii_case(ext); return kind;
if VIDEO_EXTSTENSIONS.iter().any(compare) {
return MediaKind::Video;
}
if IMAGE_EXTSTENSIONS.iter().any(compare) {
return MediaKind::Image;
}
} }
// Read a small prefix (8 KiB) asynchronously and probe // Read a small prefix (8 KiB) asynchronously and probe
match File::open(path).await { let Ok(mut file) = File::open(path).await else {
Ok(mut file) => { warn!(path = ?path.display(), "Failed to open file for media detection");
let mut buffer = vec![0u8; 8192]; return MediaKind::Unknown;
if let Ok(n) = file.read(&mut buffer).await };
&& n > 0
{ let mut buffer = vec![0u8; 8192];
buffer.truncate(n); if let Ok(n) = file.read(&mut buffer).await
if let Some(k) = infer::get(&buffer) { && n > 0
let mt = k.mime_type(); {
if mt.starts_with("video/") { buffer.truncate(n);
return MediaKind::Video; if let Some(k) = infer::get(&buffer) {
} return detect_from_mime(k.mime_type());
if mt.starts_with("image/") {
return MediaKind::Image;
}
}
}
} }
Err(e) => warn!(path = ?path.display(), "Failed to read file for media detection: {e}"),
} }
MediaKind::Unknown MediaKind::Unknown