mirror of
https://github.com/kristoferssolo/tg-relay-rs.git
synced 2026-01-14 12:46:04 +00:00
refactor: improve idiomaticity and async correctness
This commit is contained in:
parent
da38ec8d69
commit
465f9c49e9
@ -40,9 +40,10 @@ impl Config {
|
|||||||
/// Load configuration from environment variables.
|
/// Load configuration from environment variables.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn from_env() -> Self {
|
pub fn from_env() -> Self {
|
||||||
let chat_id: Option<ChatId> = env::var("CHAT_ID")
|
let chat_id = env::var("CHAT_ID")
|
||||||
.ok()
|
.ok()
|
||||||
.and_then(|id| id.parse::<i64>().ok().map(ChatId));
|
.and_then(|id| id.parse::<i64>().ok())
|
||||||
|
.map(ChatId);
|
||||||
Self {
|
Self {
|
||||||
chat_id,
|
chat_id,
|
||||||
youtube: YoutubeConfig::from_env(),
|
youtube: YoutubeConfig::from_env(),
|
||||||
@ -64,9 +65,14 @@ impl Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
/// Get global config (initialized by `Config::init(self)`).
|
/// Get global config (initialized by `Config::init(self)`).
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if config has not been initialized.
|
||||||
|
#[inline]
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn global_config() -> Config {
|
pub fn global_config() -> &'static Config {
|
||||||
GLOBAL_CONFIG.get().cloned().unwrap_or_default()
|
GLOBAL_CONFIG.get().expect("config not initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
impl YoutubeConfig {
|
impl YoutubeConfig {
|
||||||
|
|||||||
107
src/download.rs
107
src/download.rs
@ -10,7 +10,7 @@ use futures::{StreamExt, stream};
|
|||||||
use std::{
|
use std::{
|
||||||
cmp::min,
|
cmp::min,
|
||||||
ffi::OsStr,
|
ffi::OsStr,
|
||||||
fs::{self, metadata},
|
fs,
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
process::Stdio,
|
process::Stdio,
|
||||||
};
|
};
|
||||||
@ -104,13 +104,14 @@ async fn run_command_in_tempdir(cmd: &str, args: &[&str]) -> Result<DownloadResu
|
|||||||
///
|
///
|
||||||
/// - Propagates `run_command_in_tempdir` errors.
|
/// - Propagates `run_command_in_tempdir` errors.
|
||||||
#[cfg(feature = "instagram")]
|
#[cfg(feature = "instagram")]
|
||||||
pub async fn download_instagram(url: impl Into<String>) -> Result<DownloadResult> {
|
pub async fn download_instagram(url: String) -> Result<DownloadResult> {
|
||||||
let config = global_config();
|
let config = global_config();
|
||||||
let args = ["-t", "mp4", "--extractor-args", "instagram:"]
|
run_yt_dlp(
|
||||||
.iter()
|
&["-t", "mp4", "--extractor-args", "instagram:"],
|
||||||
.map(ToString::to_string)
|
config.instagram.cookies_path.as_ref(),
|
||||||
.collect();
|
&url,
|
||||||
run_yt_dlp(args, config.instagram.cookies_path.as_ref(), &url.into()).await
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Download a Tiktok URL with yt-dlp.
|
/// Download a Tiktok URL with yt-dlp.
|
||||||
@ -119,13 +120,14 @@ pub async fn download_instagram(url: impl Into<String>) -> Result<DownloadResult
|
|||||||
///
|
///
|
||||||
/// - Propagates `run_command_in_tempdir` errors.
|
/// - Propagates `run_command_in_tempdir` errors.
|
||||||
#[cfg(feature = "tiktok")]
|
#[cfg(feature = "tiktok")]
|
||||||
pub async fn download_tiktok(url: impl Into<String>) -> Result<DownloadResult> {
|
pub async fn download_tiktok(url: String) -> Result<DownloadResult> {
|
||||||
let config = global_config();
|
let config = global_config();
|
||||||
let args = ["-t", "mp4", "--extractor-args", "tiktok:"]
|
run_yt_dlp(
|
||||||
.iter()
|
&["-t", "mp4", "--extractor-args", "tiktok:"],
|
||||||
.map(ToString::to_string)
|
config.tiktok.cookies_path.as_ref(),
|
||||||
.collect();
|
&url,
|
||||||
run_yt_dlp(args, config.tiktok.cookies_path.as_ref(), &url.into()).await
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Download a Twitter URL with yt-dlp.
|
/// Download a Twitter URL with yt-dlp.
|
||||||
@ -134,13 +136,14 @@ pub async fn download_tiktok(url: impl Into<String>) -> Result<DownloadResult> {
|
|||||||
///
|
///
|
||||||
/// - Propagates `run_command_in_tempdir` errors.
|
/// - Propagates `run_command_in_tempdir` errors.
|
||||||
#[cfg(feature = "twitter")]
|
#[cfg(feature = "twitter")]
|
||||||
pub async fn download_twitter(url: impl Into<String>) -> Result<DownloadResult> {
|
pub async fn download_twitter(url: String) -> Result<DownloadResult> {
|
||||||
let config = global_config();
|
let config = global_config();
|
||||||
let args = ["-t", "mp4", "--extractor-args", "twitter:"]
|
run_yt_dlp(
|
||||||
.iter()
|
&["-t", "mp4", "--extractor-args", "twitter:"],
|
||||||
.map(ToString::to_string)
|
config.twitter.cookies_path.as_ref(),
|
||||||
.collect();
|
&url,
|
||||||
run_yt_dlp(args, config.twitter.cookies_path.as_ref(), &url.into()).await
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Download a URL with yt-dlp.
|
/// Download a URL with yt-dlp.
|
||||||
@ -149,19 +152,20 @@ pub async fn download_twitter(url: impl Into<String>) -> Result<DownloadResult>
|
|||||||
///
|
///
|
||||||
/// - Propagates `run_command_in_tempdir` errors.
|
/// - Propagates `run_command_in_tempdir` errors.
|
||||||
#[cfg(feature = "youtube")]
|
#[cfg(feature = "youtube")]
|
||||||
pub async fn download_youtube(url: impl Into<String>) -> Result<DownloadResult> {
|
pub async fn download_youtube(url: String) -> Result<DownloadResult> {
|
||||||
let config = global_config();
|
let config = global_config();
|
||||||
let args = [
|
run_yt_dlp(
|
||||||
"--no-playlist",
|
&[
|
||||||
"-t",
|
"--no-playlist",
|
||||||
"mp4",
|
"-t",
|
||||||
"--postprocessor-args",
|
"mp4",
|
||||||
&config.youtube.postprocessor_args,
|
"--postprocessor-args",
|
||||||
]
|
&config.youtube.postprocessor_args,
|
||||||
.iter()
|
],
|
||||||
.map(ToString::to_string)
|
config.youtube.cookies_path.as_ref(),
|
||||||
.collect();
|
&url,
|
||||||
run_yt_dlp(args, config.youtube.cookies_path.as_ref(), &url.into()).await
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Post-process a `DownloadResult`.
|
/// Post-process a `DownloadResult`.
|
||||||
@ -183,9 +187,17 @@ pub async fn process_download_result(
|
|||||||
return Err(Error::NoMediaFound);
|
return Err(Error::NoMediaFound);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Detect kinds in parallel with limiter concurrency
|
// Detect kinds and validate files in parallel
|
||||||
let concurrency = min(8, dr.files.len());
|
let concurrency = min(8, dr.files.len());
|
||||||
let results = stream::iter(dr.files.drain(..).map(|path| async move {
|
let results = stream::iter(dr.files.drain(..).map(|path| async move {
|
||||||
|
// Check file metadata asynchronously
|
||||||
|
let Ok(meta) = tokio::fs::metadata(&path).await else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
if !meta.is_file() || meta.len() == 0 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
let kind = detect_media_kind_async(&path).await;
|
let kind = detect_media_kind_async(&path).await;
|
||||||
match kind {
|
match kind {
|
||||||
MediaKind::Unknown => None,
|
MediaKind::Unknown => None,
|
||||||
@ -193,18 +205,10 @@ pub async fn process_download_result(
|
|||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
.buffer_unordered(concurrency)
|
.buffer_unordered(concurrency)
|
||||||
.collect::<Vec<Option<(PathBuf, MediaKind)>>>()
|
.collect::<Vec<_>>()
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let mut media_items = results
|
let mut media_items = results.into_iter().flatten().collect::<Vec<_>>();
|
||||||
.into_iter()
|
|
||||||
.flatten()
|
|
||||||
.filter(|(path, _)| {
|
|
||||||
metadata(path)
|
|
||||||
.map(|m| m.is_file() && m.len() > 0)
|
|
||||||
.unwrap_or(false)
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
if media_items.is_empty() {
|
if media_items.is_empty() {
|
||||||
return Err(Error::NoMediaFound);
|
return Err(Error::NoMediaFound);
|
||||||
@ -254,18 +258,21 @@ fn is_potential_media_file(path: &Path) -> bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn run_yt_dlp(
|
async fn run_yt_dlp(
|
||||||
mut args: Vec<String>,
|
base_args: &[&str],
|
||||||
cookies_path: Option<&PathBuf>,
|
cookies_path: Option<&PathBuf>,
|
||||||
url: &str,
|
url: &str,
|
||||||
) -> Result<DownloadResult> {
|
) -> Result<DownloadResult> {
|
||||||
if let Some(path) = cookies_path {
|
let cookies_path_str;
|
||||||
args.extend(["--cookies".to_string(), path.to_string_lossy().to_string()]);
|
let mut args = base_args.to_vec();
|
||||||
}
|
|
||||||
args.push(url.to_string());
|
|
||||||
|
|
||||||
debug!(args = ?args, "downloadting content");
|
if let Some(path) = cookies_path {
|
||||||
let args_ref = args.iter().map(String::as_ref).collect::<Vec<_>>();
|
cookies_path_str = path.to_string_lossy();
|
||||||
run_command_in_tempdir("yt-dlp", &args_ref).await
|
args.extend(["--cookies", &cookies_path_str]);
|
||||||
|
}
|
||||||
|
args.push(url);
|
||||||
|
|
||||||
|
debug!(args = ?args, "downloading content");
|
||||||
|
run_command_in_tempdir("yt-dlp", &args).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@ -7,7 +7,7 @@ use std::{pin::Pin, sync::Arc};
|
|||||||
use teloxide::{Bot, types::ChatId};
|
use teloxide::{Bot, types::ChatId};
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
type DownloadFn = fn(&str) -> Pin<Box<dyn Future<Output = Result<DownloadResult>> + Send>>;
|
type DownloadFn = fn(String) -> Pin<Box<dyn Future<Output = Result<DownloadResult>> + Send>>;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Handler {
|
pub struct Handler {
|
||||||
@ -52,7 +52,7 @@ impl Handler {
|
|||||||
/// Returns `Error` if download or media processing fails.
|
/// Returns `Error` if download or media processing fails.
|
||||||
pub async fn handle(&self, bot: &Bot, chat_id: ChatId, url: &str) -> Result<()> {
|
pub async fn handle(&self, bot: &Bot, chat_id: ChatId, url: &str) -> Result<()> {
|
||||||
info!(handler = %self.name(), url = %url, "handling url");
|
info!(handler = %self.name(), url = %url, "handling url");
|
||||||
let dr = (self.func)(url).await?;
|
let dr = (self.func)(url.to_owned()).await?;
|
||||||
process_download_result(bot, chat_id, dr).await
|
process_download_result(bot, chat_id, dr).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -60,10 +60,8 @@ impl Handler {
|
|||||||
macro_rules! handler {
|
macro_rules! handler {
|
||||||
($feature:expr, $regex:expr, $download_fn:path) => {
|
($feature:expr, $regex:expr, $download_fn:path) => {
|
||||||
#[cfg(feature = $feature)]
|
#[cfg(feature = $feature)]
|
||||||
Handler::new($feature, $regex, |url| {
|
Handler::new($feature, $regex, |url: String| Box::pin($download_fn(url)))
|
||||||
Box::pin($download_fn(url.to_string()))
|
.expect(concat!("failed to create ", $feature, " handler"))
|
||||||
})
|
|
||||||
.expect(concat!("failed to create ", $feature, " handler"))
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
18
src/main.rs
18
src/main.rs
@ -1,4 +1,5 @@
|
|||||||
use dotenv::dotenv;
|
use dotenv::dotenv;
|
||||||
|
use std::sync::Arc;
|
||||||
use teloxide::{prelude::*, respond, utils::command::BotCommands};
|
use teloxide::{prelude::*, respond, utils::command::BotCommands};
|
||||||
use tg_relay_rs::{
|
use tg_relay_rs::{
|
||||||
commands::{Command, answer},
|
commands::{Command, answer},
|
||||||
@ -17,27 +18,26 @@ async fn main() -> color_eyre::Result<()> {
|
|||||||
|
|
||||||
Comments::load_from_file("comments.txt")
|
Comments::load_from_file("comments.txt")
|
||||||
.await
|
.await
|
||||||
.map_err(|e| {
|
.unwrap_or_else(|e| {
|
||||||
warn!("failed to load comments.txt: {e}; using dummy comments");
|
warn!("failed to load comments.txt: {e}; using dummy comments");
|
||||||
e
|
Comments::dummy()
|
||||||
})
|
})
|
||||||
.unwrap_or_else(|_| Comments::dummy())
|
|
||||||
.init()?;
|
.init()?;
|
||||||
|
|
||||||
Config::from_env().init()?;
|
Config::from_env().init()?;
|
||||||
|
|
||||||
let bot = Bot::from_env();
|
let bot = Bot::from_env();
|
||||||
let bot_name = bot.get_me().await?.username().to_owned();
|
let bot_name: Arc<str> = bot.get_me().await?.username().into();
|
||||||
|
|
||||||
info!(name = bot_name, "bot starting");
|
info!(name = %bot_name, "bot starting");
|
||||||
|
|
||||||
let handlers = create_handlers();
|
let handlers = create_handlers();
|
||||||
|
|
||||||
teloxide::repl(bot.clone(), move |bot: Bot, msg: Message| {
|
teloxide::repl(bot.clone(), move |bot: Bot, msg: Message| {
|
||||||
let handlers = handlers.clone();
|
let handlers = Arc::clone(&handlers);
|
||||||
let bot_name_cloned = bot_name.clone();
|
let bot_name = Arc::clone(&bot_name);
|
||||||
async move {
|
async move {
|
||||||
process_cmd(&bot, &msg, &bot_name_cloned).await;
|
process_cmd(&bot, &msg, &bot_name).await;
|
||||||
process_message(&bot, &msg, &handlers).await;
|
process_message(&bot, &msg, &handlers).await;
|
||||||
respond(())
|
respond(())
|
||||||
}
|
}
|
||||||
@ -60,7 +60,7 @@ async fn process_message(bot: &Bot, msg: &Message, handlers: &[Handler]) {
|
|||||||
.send_message(msg.chat.id, FAILED_FETCH_MEDIA_MESSAGE)
|
.send_message(msg.chat.id, FAILED_FETCH_MEDIA_MESSAGE)
|
||||||
.await;
|
.await;
|
||||||
if let Some(chat_id) = global_config().chat_id {
|
if let Some(chat_id) = global_config().chat_id {
|
||||||
let _ = bot.send_message(chat_id, format!("{err}")).await;
|
let _ = bot.send_message(chat_id, err.to_string()).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
|
|||||||
85
src/utils.rs
85
src/utils.rs
@ -35,26 +35,43 @@ impl MediaKind {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Check if extension matches any in the given list (case-insensitive).
|
||||||
|
#[inline]
|
||||||
|
fn ext_matches(ext: &str, extensions: &[&str]) -> bool {
|
||||||
|
extensions.iter().any(|e| e.eq_ignore_ascii_case(ext))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Detect media kind from file extension.
|
||||||
|
fn detect_from_extension(path: &Path) -> Option<MediaKind> {
|
||||||
|
let ext = path.extension().and_then(OsStr::to_str)?;
|
||||||
|
if ext_matches(ext, VIDEO_EXTSTENSIONS) {
|
||||||
|
return Some(MediaKind::Video);
|
||||||
|
}
|
||||||
|
if ext_matches(ext, IMAGE_EXTSTENSIONS) {
|
||||||
|
return Some(MediaKind::Image);
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Detect media kind from MIME type string.
|
||||||
|
fn detect_from_mime(mime_type: &str) -> MediaKind {
|
||||||
|
match mime_type.split('/').next() {
|
||||||
|
Some("video") => MediaKind::Video,
|
||||||
|
Some("image") => MediaKind::Image,
|
||||||
|
_ => MediaKind::Unknown,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Detect media kind first by extension, then by content/magic (sync).
|
/// Detect media kind first by extension, then by content/magic (sync).
|
||||||
|
#[must_use]
|
||||||
pub fn detect_media_kind(path: &Path) -> MediaKind {
|
pub fn detect_media_kind(path: &Path) -> MediaKind {
|
||||||
if let Some(ext) = path.extension().and_then(OsStr::to_str) {
|
if let Some(kind) = detect_from_extension(path) {
|
||||||
let compare = |e: &&str| e.eq_ignore_ascii_case(ext);
|
return kind;
|
||||||
if VIDEO_EXTSTENSIONS.iter().any(compare) {
|
|
||||||
return MediaKind::Video;
|
|
||||||
}
|
|
||||||
if IMAGE_EXTSTENSIONS.iter().any(compare) {
|
|
||||||
return MediaKind::Image;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to MIME type detection
|
// Fallback to MIME type detection
|
||||||
if let Ok(Some(kind)) = infer::get_from_path(path) {
|
if let Ok(Some(kind)) = infer::get_from_path(path) {
|
||||||
let mime_type = kind.mime_type();
|
return detect_from_mime(kind.mime_type());
|
||||||
return match mime_type.split('/').next() {
|
|
||||||
Some("video") => MediaKind::Video,
|
|
||||||
Some("image") => MediaKind::Image,
|
|
||||||
_ => MediaKind::Unknown,
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MediaKind::Unknown
|
MediaKind::Unknown
|
||||||
@ -63,36 +80,24 @@ pub fn detect_media_kind(path: &Path) -> MediaKind {
|
|||||||
/// Async/non-blocking detection: check extension first, otherwise read a small
|
/// Async/non-blocking detection: check extension first, otherwise read a small
|
||||||
/// sample asynchronously and run `infer::get` on the buffer.
|
/// sample asynchronously and run `infer::get` on the buffer.
|
||||||
pub async fn detect_media_kind_async(path: &Path) -> MediaKind {
|
pub async fn detect_media_kind_async(path: &Path) -> MediaKind {
|
||||||
if let Some(ext) = path.extension().and_then(OsStr::to_str) {
|
if let Some(kind) = detect_from_extension(path) {
|
||||||
let compare = |e: &&str| e.eq_ignore_ascii_case(ext);
|
return kind;
|
||||||
if VIDEO_EXTSTENSIONS.iter().any(compare) {
|
|
||||||
return MediaKind::Video;
|
|
||||||
}
|
|
||||||
if IMAGE_EXTSTENSIONS.iter().any(compare) {
|
|
||||||
return MediaKind::Image;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read a small prefix (8 KiB) asynchronously and probe
|
// Read a small prefix (8 KiB) asynchronously and probe
|
||||||
match File::open(path).await {
|
let Ok(mut file) = File::open(path).await else {
|
||||||
Ok(mut file) => {
|
warn!(path = ?path.display(), "Failed to open file for media detection");
|
||||||
let mut buffer = vec![0u8; 8192];
|
return MediaKind::Unknown;
|
||||||
if let Ok(n) = file.read(&mut buffer).await
|
};
|
||||||
&& n > 0
|
|
||||||
{
|
let mut buffer = vec![0u8; 8192];
|
||||||
buffer.truncate(n);
|
if let Ok(n) = file.read(&mut buffer).await
|
||||||
if let Some(k) = infer::get(&buffer) {
|
&& n > 0
|
||||||
let mt = k.mime_type();
|
{
|
||||||
if mt.starts_with("video/") {
|
buffer.truncate(n);
|
||||||
return MediaKind::Video;
|
if let Some(k) = infer::get(&buffer) {
|
||||||
}
|
return detect_from_mime(k.mime_type());
|
||||||
if mt.starts_with("image/") {
|
|
||||||
return MediaKind::Image;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Err(e) => warn!(path = ?path.display(), "Failed to read file for media detection: {e}"),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MediaKind::Unknown
|
MediaKind::Unknown
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user