// src/services/batch/meta.rs use std::sync::Arc; use tokio::sync::Mutex; use serde::Serialize; use tracing::{info, warn, error}; use sqlx::SqlitePool; use crate::clients::ads::AdsClient; use crate::clients::arxiv::ArxivClient; use crate::api::handlers::{convert_ads_doc_to_standard, convert_arxiv_to_standard, save_paper_to_db}; // 批量元数据同步进度状态 #[derive(Debug, Clone, Serialize)] pub struct MetaSyncStatus { pub active: bool, pub query: String, pub source: String, pub synced: i32, pub total: i32, } impl MetaSyncStatus { pub fn new() -> Self { MetaSyncStatus { active: false, query: String::new(), source: String::new(), synced: 0, total: 0, } } } pub struct MetaSync; impl MetaSync { // 预估文献总量 pub async fn get_total_count( query: &str, source: &str, ads: &AdsClient, arxiv: &ArxivClient, ) -> anyhow::Result { let mut total = 0; if source == "all" || source == "ads" { match ads.get_total_count(query).await { Ok(count) => { total += count; info!("ADS 预估文献总量: {} 篇", count); } Err(e) => { warn!("获取 ADS 预估总量失败: {}", e); } } } if source == "all" || source == "arxiv" { match arxiv.get_total_count(query).await { Ok(count) => { total += count; info!("arXiv 预估文献总量: {} 篇", count); } Err(e) => { warn!("获取 arXiv 预估总量失败: {}", e); } } } Ok(total) } // 启动后台元数据同步异步任务 pub fn start_harvest( db: SqlitePool, ads: Arc, arxiv: Arc, query: String, source: String, limit: i32, status: Arc>, ) { let query_clone = query.clone(); let source_clone = source.clone(); tokio::spawn(async move { info!("启动后台批量元数据同步任务: 查询词='{}', 源='{}', 上限={}", query_clone, source_clone, limit); // 自动将检索配置存入/更新至 sync_queries 数据库表中进行去重和时间更新 let _ = sqlx::query( "INSERT INTO sync_queries (query, source, limit_count, last_run) \ VALUES (?, ?, ?, CURRENT_TIMESTAMP) \ ON CONFLICT(query, source, limit_count) DO UPDATE SET last_run=excluded.last_run" ) .bind(&query_clone) .bind(&source_clone) .bind(limit) .execute(&db) .await; // 1. 并行获取两端预估总量 let ads_count_fut = { let ads = ads.clone(); let query = query_clone.clone(); let is_active = source_clone == "all" || source_clone == "ads"; async move { if is_active { ads.get_total_count(&query).await.unwrap_or(0) } else { 0 } } }; let arxiv_count_fut = { let arxiv = arxiv.clone(); let query = query_clone.clone(); let is_active = source_clone == "all" || source_clone == "arxiv"; async move { if is_active { arxiv.get_total_count(&query).await.unwrap_or(0) } else { 0 } } }; let (ads_total, arxiv_total) = tokio::join!(ads_count_fut, arxiv_count_fut); let total_count = ads_total + arxiv_total; { let mut s = status.lock().await; s.total = total_count; } // 计算实际需要元数据同步的总上限,并按比例分配或根据实际匹配量上限控制 let limit_to_harvest = if limit > 0 { std::cmp::min(limit, total_count) } else { total_count }; // 共享的 atomic 计数器,以便两端并行同步时独立累加进度 let synced_counter = Arc::new(std::sync::atomic::AtomicI32::new(0)); // 2. 执行并行的同步子任务 let ads_sync_fut = { let db = db.clone(); let ads = ads.clone(); let query = query_clone.clone(); let synced_counter = synced_counter.clone(); let status = status.clone(); let is_active = source_clone == "all" || source_clone == "ads"; // 如果是 all 模式,各平台按比例分摊 limit 额度,或者直接限制自身的最大可用量 let ads_limit = if source_clone == "all" { if ads_total == 0 { 0 } else { let ratio = ads_total as f32 / total_count as f32; ((limit_to_harvest as f32) * ratio).round() as i32 } } else { limit_to_harvest }; async move { if !is_active || ads_limit <= 0 { return; } let mut local_synced = 0; let mut start_offset = 0; while local_synced < ads_limit { let chunk_size = std::cmp::min(2000, ads_limit - local_synced); if chunk_size <= 0 { break; } info!("正在同步 ADS 分批数据: start={}, rows={}", start_offset, chunk_size); match ads.search(&query, start_offset, chunk_size, "relevance").await { Ok(docs) => { if docs.is_empty() { break; } let count = docs.len() as i32; for doc in docs { let paper = convert_ads_doc_to_standard(&doc); let _ = save_paper_to_db(&db, &paper).await; } local_synced += count; start_offset += count; // 累加全局进度并更新状态 let current_global = synced_counter.fetch_add(count, std::sync::atomic::Ordering::SeqCst) + count; { let mut s = status.lock().await; s.synced = current_global; } } Err(e) => { error!("批量同步 ADS 数据出错: {}", e); break; } } } } }; let arxiv_sync_fut = { let db = db.clone(); let arxiv = arxiv.clone(); let query = query_clone.clone(); let synced_counter = synced_counter.clone(); let status = status.clone(); let is_active = source_clone == "all" || source_clone == "arxiv"; let arxiv_limit = if source_clone == "all" { if arxiv_total == 0 { 0 } else { let ratio = arxiv_total as f32 / total_count as f32; ((limit_to_harvest as f32) * ratio).round() as i32 } } else { limit_to_harvest }; async move { if !is_active || arxiv_limit <= 0 { return; } let mut local_synced = 0; let mut start_offset = 0; while local_synced < arxiv_limit { let chunk_size = std::cmp::min(2000, arxiv_limit - local_synced); if chunk_size <= 0 { break; } info!("正在同步 arXiv 分批数据: start={}, max_results={}", start_offset, chunk_size); match arxiv.search(&query, start_offset, chunk_size, "relevance").await { Ok(papers) => { if papers.is_empty() { break; } let count = papers.len() as i32; for p in papers { let paper = convert_arxiv_to_standard(&p); let _ = save_paper_to_db(&db, &paper).await; } local_synced += count; start_offset += count; // 累加全局进度并更新状态 let current_global = synced_counter.fetch_add(count, std::sync::atomic::Ordering::SeqCst) + count; { let mut s = status.lock().await; s.synced = current_global; } } Err(e) => { error!("批量同步 arXiv 数据出错: {}", e); break; } } // 遵循 arXiv API 3 秒间隔要求 tokio::time::sleep(tokio::time::Duration::from_secs(3)).await; } } }; // 使用 tokio::join! 并行驱动两端同步任务 tokio::join!(ads_sync_fut, arxiv_sync_fut); // 4. 收尾并重置状态 let final_synced = synced_counter.load(std::sync::atomic::Ordering::SeqCst); { let mut s = status.lock().await; s.active = false; s.synced = final_synced; info!("后台批量元数据同步任务已结束。共成功同步 {} 篇文献。", final_synced); } }); } }