refactor: improve idiomaticity and async correctness

This commit is contained in:
Kristofers Solo 2026-01-03 23:27:27 +02:00
parent da38ec8d69
commit 465f9c49e9
Signed by: kristoferssolo
GPG Key ID: 8687F2D3EEE6F0ED
5 changed files with 125 additions and 109 deletions

View File

@ -40,9 +40,10 @@ impl Config {
/// Load configuration from environment variables.
#[must_use]
pub fn from_env() -> Self {
let chat_id: Option<ChatId> = env::var("CHAT_ID")
let chat_id = env::var("CHAT_ID")
.ok()
.and_then(|id| id.parse::<i64>().ok().map(ChatId));
.and_then(|id| id.parse::<i64>().ok())
.map(ChatId);
Self {
chat_id,
youtube: YoutubeConfig::from_env(),
@ -64,9 +65,14 @@ impl Config {
}
}
/// Get global config (initialized by `Config::init(self)`).
///
/// # Panics
///
/// Panics if config has not been initialized.
#[inline]
#[must_use]
pub fn global_config() -> Config {
GLOBAL_CONFIG.get().cloned().unwrap_or_default()
pub fn global_config() -> &'static Config {
GLOBAL_CONFIG.get().expect("config not initialized")
}
impl YoutubeConfig {

View File

@ -10,7 +10,7 @@ use futures::{StreamExt, stream};
use std::{
cmp::min,
ffi::OsStr,
fs::{self, metadata},
fs,
path::{Path, PathBuf},
process::Stdio,
};
@ -104,13 +104,14 @@ async fn run_command_in_tempdir(cmd: &str, args: &[&str]) -> Result<DownloadResu
///
/// - Propagates `run_command_in_tempdir` errors.
#[cfg(feature = "instagram")]
pub async fn download_instagram(url: impl Into<String>) -> Result<DownloadResult> {
pub async fn download_instagram(url: String) -> Result<DownloadResult> {
let config = global_config();
let args = ["-t", "mp4", "--extractor-args", "instagram:"]
.iter()
.map(ToString::to_string)
.collect();
run_yt_dlp(args, config.instagram.cookies_path.as_ref(), &url.into()).await
run_yt_dlp(
&["-t", "mp4", "--extractor-args", "instagram:"],
config.instagram.cookies_path.as_ref(),
&url,
)
.await
}
/// Download a Tiktok URL with yt-dlp.
@ -119,13 +120,14 @@ pub async fn download_instagram(url: impl Into<String>) -> Result<DownloadResult
///
/// - Propagates `run_command_in_tempdir` errors.
#[cfg(feature = "tiktok")]
pub async fn download_tiktok(url: impl Into<String>) -> Result<DownloadResult> {
pub async fn download_tiktok(url: String) -> Result<DownloadResult> {
let config = global_config();
let args = ["-t", "mp4", "--extractor-args", "tiktok:"]
.iter()
.map(ToString::to_string)
.collect();
run_yt_dlp(args, config.tiktok.cookies_path.as_ref(), &url.into()).await
run_yt_dlp(
&["-t", "mp4", "--extractor-args", "tiktok:"],
config.tiktok.cookies_path.as_ref(),
&url,
)
.await
}
/// Download a Twitter URL with yt-dlp.
@ -134,13 +136,14 @@ pub async fn download_tiktok(url: impl Into<String>) -> Result<DownloadResult> {
///
/// - Propagates `run_command_in_tempdir` errors.
#[cfg(feature = "twitter")]
pub async fn download_twitter(url: impl Into<String>) -> Result<DownloadResult> {
pub async fn download_twitter(url: String) -> Result<DownloadResult> {
let config = global_config();
let args = ["-t", "mp4", "--extractor-args", "twitter:"]
.iter()
.map(ToString::to_string)
.collect();
run_yt_dlp(args, config.twitter.cookies_path.as_ref(), &url.into()).await
run_yt_dlp(
&["-t", "mp4", "--extractor-args", "twitter:"],
config.twitter.cookies_path.as_ref(),
&url,
)
.await
}
/// Download a URL with yt-dlp.
@ -149,19 +152,20 @@ pub async fn download_twitter(url: impl Into<String>) -> Result<DownloadResult>
///
/// - Propagates `run_command_in_tempdir` errors.
#[cfg(feature = "youtube")]
pub async fn download_youtube(url: impl Into<String>) -> Result<DownloadResult> {
pub async fn download_youtube(url: String) -> Result<DownloadResult> {
let config = global_config();
let args = [
"--no-playlist",
"-t",
"mp4",
"--postprocessor-args",
&config.youtube.postprocessor_args,
]
.iter()
.map(ToString::to_string)
.collect();
run_yt_dlp(args, config.youtube.cookies_path.as_ref(), &url.into()).await
run_yt_dlp(
&[
"--no-playlist",
"-t",
"mp4",
"--postprocessor-args",
&config.youtube.postprocessor_args,
],
config.youtube.cookies_path.as_ref(),
&url,
)
.await
}
/// Post-process a `DownloadResult`.
@ -183,9 +187,17 @@ pub async fn process_download_result(
return Err(Error::NoMediaFound);
}
// Detect kinds in parallel with limiter concurrency
// Detect kinds and validate files in parallel
let concurrency = min(8, dr.files.len());
let results = stream::iter(dr.files.drain(..).map(|path| async move {
// Check file metadata asynchronously
let Ok(meta) = tokio::fs::metadata(&path).await else {
return None;
};
if !meta.is_file() || meta.len() == 0 {
return None;
}
let kind = detect_media_kind_async(&path).await;
match kind {
MediaKind::Unknown => None,
@ -193,18 +205,10 @@ pub async fn process_download_result(
}
}))
.buffer_unordered(concurrency)
.collect::<Vec<Option<(PathBuf, MediaKind)>>>()
.collect::<Vec<_>>()
.await;
let mut media_items = results
.into_iter()
.flatten()
.filter(|(path, _)| {
metadata(path)
.map(|m| m.is_file() && m.len() > 0)
.unwrap_or(false)
})
.collect::<Vec<_>>();
let mut media_items = results.into_iter().flatten().collect::<Vec<_>>();
if media_items.is_empty() {
return Err(Error::NoMediaFound);
@ -254,18 +258,21 @@ fn is_potential_media_file(path: &Path) -> bool {
}
async fn run_yt_dlp(
mut args: Vec<String>,
base_args: &[&str],
cookies_path: Option<&PathBuf>,
url: &str,
) -> Result<DownloadResult> {
if let Some(path) = cookies_path {
args.extend(["--cookies".to_string(), path.to_string_lossy().to_string()]);
}
args.push(url.to_string());
let cookies_path_str;
let mut args = base_args.to_vec();
debug!(args = ?args, "downloadting content");
let args_ref = args.iter().map(String::as_ref).collect::<Vec<_>>();
run_command_in_tempdir("yt-dlp", &args_ref).await
if let Some(path) = cookies_path {
cookies_path_str = path.to_string_lossy();
args.extend(["--cookies", &cookies_path_str]);
}
args.push(url);
debug!(args = ?args, "downloading content");
run_command_in_tempdir("yt-dlp", &args).await
}
#[cfg(test)]

View File

@ -7,7 +7,7 @@ use std::{pin::Pin, sync::Arc};
use teloxide::{Bot, types::ChatId};
use tracing::info;
type DownloadFn = fn(&str) -> Pin<Box<dyn Future<Output = Result<DownloadResult>> + Send>>;
type DownloadFn = fn(String) -> Pin<Box<dyn Future<Output = Result<DownloadResult>> + Send>>;
#[derive(Debug, Clone)]
pub struct Handler {
@ -52,7 +52,7 @@ impl Handler {
/// Returns `Error` if download or media processing fails.
pub async fn handle(&self, bot: &Bot, chat_id: ChatId, url: &str) -> Result<()> {
info!(handler = %self.name(), url = %url, "handling url");
let dr = (self.func)(url).await?;
let dr = (self.func)(url.to_owned()).await?;
process_download_result(bot, chat_id, dr).await
}
}
@ -60,10 +60,8 @@ impl Handler {
macro_rules! handler {
($feature:expr, $regex:expr, $download_fn:path) => {
#[cfg(feature = $feature)]
Handler::new($feature, $regex, |url| {
Box::pin($download_fn(url.to_string()))
})
.expect(concat!("failed to create ", $feature, " handler"))
Handler::new($feature, $regex, |url: String| Box::pin($download_fn(url)))
.expect(concat!("failed to create ", $feature, " handler"))
};
}

View File

@ -1,4 +1,5 @@
use dotenv::dotenv;
use std::sync::Arc;
use teloxide::{prelude::*, respond, utils::command::BotCommands};
use tg_relay_rs::{
commands::{Command, answer},
@ -17,27 +18,26 @@ async fn main() -> color_eyre::Result<()> {
Comments::load_from_file("comments.txt")
.await
.map_err(|e| {
.unwrap_or_else(|e| {
warn!("failed to load comments.txt: {e}; using dummy comments");
e
Comments::dummy()
})
.unwrap_or_else(|_| Comments::dummy())
.init()?;
Config::from_env().init()?;
let bot = Bot::from_env();
let bot_name = bot.get_me().await?.username().to_owned();
let bot_name: Arc<str> = bot.get_me().await?.username().into();
info!(name = bot_name, "bot starting");
info!(name = %bot_name, "bot starting");
let handlers = create_handlers();
teloxide::repl(bot.clone(), move |bot: Bot, msg: Message| {
let handlers = handlers.clone();
let bot_name_cloned = bot_name.clone();
let handlers = Arc::clone(&handlers);
let bot_name = Arc::clone(&bot_name);
async move {
process_cmd(&bot, &msg, &bot_name_cloned).await;
process_cmd(&bot, &msg, &bot_name).await;
process_message(&bot, &msg, &handlers).await;
respond(())
}
@ -60,7 +60,7 @@ async fn process_message(bot: &Bot, msg: &Message, handlers: &[Handler]) {
.send_message(msg.chat.id, FAILED_FETCH_MEDIA_MESSAGE)
.await;
if let Some(chat_id) = global_config().chat_id {
let _ = bot.send_message(chat_id, format!("{err}")).await;
let _ = bot.send_message(chat_id, err.to_string()).await;
}
}
return;

View File

@ -35,26 +35,43 @@ impl MediaKind {
}
}
/// Check if extension matches any in the given list (case-insensitive).
#[inline]
fn ext_matches(ext: &str, extensions: &[&str]) -> bool {
extensions.iter().any(|e| e.eq_ignore_ascii_case(ext))
}
/// Detect media kind from file extension.
fn detect_from_extension(path: &Path) -> Option<MediaKind> {
let ext = path.extension().and_then(OsStr::to_str)?;
if ext_matches(ext, VIDEO_EXTSTENSIONS) {
return Some(MediaKind::Video);
}
if ext_matches(ext, IMAGE_EXTSTENSIONS) {
return Some(MediaKind::Image);
}
None
}
/// Detect media kind from MIME type string.
fn detect_from_mime(mime_type: &str) -> MediaKind {
match mime_type.split('/').next() {
Some("video") => MediaKind::Video,
Some("image") => MediaKind::Image,
_ => MediaKind::Unknown,
}
}
/// Detect media kind first by extension, then by content/magic (sync).
#[must_use]
pub fn detect_media_kind(path: &Path) -> MediaKind {
if let Some(ext) = path.extension().and_then(OsStr::to_str) {
let compare = |e: &&str| e.eq_ignore_ascii_case(ext);
if VIDEO_EXTSTENSIONS.iter().any(compare) {
return MediaKind::Video;
}
if IMAGE_EXTSTENSIONS.iter().any(compare) {
return MediaKind::Image;
}
if let Some(kind) = detect_from_extension(path) {
return kind;
}
// Fallback to MIME type detection
if let Ok(Some(kind)) = infer::get_from_path(path) {
let mime_type = kind.mime_type();
return match mime_type.split('/').next() {
Some("video") => MediaKind::Video,
Some("image") => MediaKind::Image,
_ => MediaKind::Unknown,
};
return detect_from_mime(kind.mime_type());
}
MediaKind::Unknown
@ -63,36 +80,24 @@ pub fn detect_media_kind(path: &Path) -> MediaKind {
/// Async/non-blocking detection: check extension first, otherwise read a small
/// sample asynchronously and run `infer::get` on the buffer.
pub async fn detect_media_kind_async(path: &Path) -> MediaKind {
if let Some(ext) = path.extension().and_then(OsStr::to_str) {
let compare = |e: &&str| e.eq_ignore_ascii_case(ext);
if VIDEO_EXTSTENSIONS.iter().any(compare) {
return MediaKind::Video;
}
if IMAGE_EXTSTENSIONS.iter().any(compare) {
return MediaKind::Image;
}
if let Some(kind) = detect_from_extension(path) {
return kind;
}
// Read a small prefix (8 KiB) asynchronously and probe
match File::open(path).await {
Ok(mut file) => {
let mut buffer = vec![0u8; 8192];
if let Ok(n) = file.read(&mut buffer).await
&& n > 0
{
buffer.truncate(n);
if let Some(k) = infer::get(&buffer) {
let mt = k.mime_type();
if mt.starts_with("video/") {
return MediaKind::Video;
}
if mt.starts_with("image/") {
return MediaKind::Image;
}
}
}
let Ok(mut file) = File::open(path).await else {
warn!(path = ?path.display(), "Failed to open file for media detection");
return MediaKind::Unknown;
};
let mut buffer = vec![0u8; 8192];
if let Ok(n) = file.read(&mut buffer).await
&& n > 0
{
buffer.truncate(n);
if let Some(k) = infer::get(&buffer) {
return detect_from_mime(k.mime_type());
}
Err(e) => warn!(path = ?path.display(), "Failed to read file for media detection: {e}"),
}
MediaKind::Unknown