use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; use std::fs; use std::path::PathBuf; use std::process::Command; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tauri::{Emitter, Manager}; use tauri::path::BaseDirectory; #[cfg(windows)] use std::os::windows::process::CommandExt; #[cfg(windows)] const CREATE_NO_WINDOW: u32 = 0x08000000; /// 更新信息结构 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct UpdateInfo { pub version: String, pub notes: Option, pub download_url: String, } /// gh-info API 响应结构 #[derive(Debug, Clone, Serialize, Deserialize)] struct GhInfoApiResponse { repo: String, latest_version: String, changelog: Option, published_at: String, #[serde(default)] prerelease: bool, attachments: serde_json::Value, // 支持两种格式: ["URL1", "URL2"] 或 [["文件名", "URL"], ...] } /// 自定义更新服务器 API 响应结构 #[derive(Debug, Clone, Serialize, Deserialize)] struct CustomUpdateApiResponse { version: String, notes: Option, #[serde(rename = "pub_date")] pub_date: Option, download_url: String, signature: Option, platforms: Option>, } /// 平台特定信息 #[derive(Debug, Clone, Serialize, Deserialize)] struct PlatformInfo { url: String, signature: Option, } /// 检查更新(使用自定义 API 端点) pub async fn check_update( endpoint: Option<&str>, current_version: &str, _use_mirror: bool, github_repo: Option<&str>, include_prerelease: bool, ) -> Result> { println!("[更新检查] 开始检查更新..."); println!("[更新检查] 当前版本: {}", current_version); println!("[更新检查] 包含预发布版本: {}", include_prerelease); // 确定使用的 API 端点 let api_url = if let Some(custom_endpoint) = endpoint { // 如果提供了自定义端点,直接使用 custom_endpoint.to_string() } else { // 否则使用默认的 gh-info API let repo = github_repo.unwrap_or("plsgo/cstb"); if include_prerelease { format!("https://gh-info.okk.cool/repos/{}/releases/latest/pre", repo) } else { format!("https://gh-info.okk.cool/repos/{}/releases/latest", repo) } }; println!("[更新检查] API URL: {}", api_url); let client = reqwest::Client::builder() .timeout(std::time::Duration::from_secs(10)) .user_agent("cstb-updater/1.0") .build()?; let response = client.get(&api_url).send().await?; if !response.status().is_success() { println!("[更新检查] ✗ API 请求失败,HTTP 状态码: {}", response.status()); return Err(anyhow::anyhow!("API 请求失败,HTTP 状态码: {}", response.status())); } // 获取响应文本以便尝试不同的解析方式 let response_text = response.text().await?; // 关闭更新日志的打印 // println!("[更新检查] API 响应: {}", response_text); // 尝试解析为自定义更新服务器格式 let update_info = if let Ok(custom_resp) = serde_json::from_str::(&response_text) { println!("[更新检查] 检测到自定义更新服务器格式"); // 提取版本号(去掉 'v' 前缀) let version = custom_resp.version.trim_start_matches('v').to_string(); println!("[更新检查] 远程版本: {}", version); // 版本比较 let comparison = compare_version(&version, current_version); println!("[更新检查] 版本比较结果: {} (1=有新版本, 0=相同, -1=当前更新)", comparison); if comparison > 0 { println!("[更新检查] ✓ 发现新版本: {}", version); // 获取下载链接 // 优先使用平台特定的链接 let download_url = if let Some(ref platforms) = custom_resp.platforms { // 检测当前平台 #[cfg(target_os = "windows")] let platform_key = "windows-x86_64"; #[cfg(all(target_os = "macos", target_arch = "x86_64"))] let platform_key = "darwin-x86_64"; #[cfg(all(target_os = "macos", target_arch = "aarch64"))] let platform_key = "darwin-aarch64"; #[cfg(target_os = "linux")] let platform_key = "linux-x86_64"; #[cfg(not(any(target_os = "windows", target_os = "macos", target_os = "linux")))] let platform_key = ""; if !platform_key.is_empty() { platforms.get(platform_key) .map(|p| p.url.clone()) .unwrap_or_else(|| custom_resp.download_url.clone()) } else { custom_resp.download_url.clone() } } else { custom_resp.download_url.clone() }; println!("[更新检查] 下载链接: {}", download_url); Some(UpdateInfo { version, notes: custom_resp.notes, download_url, }) } else { println!("[更新检查] ✗ 已是最新版本"); None } } else { // 尝试解析为 gh-info API 格式 println!("[更新检查] 尝试解析为 gh-info API 格式"); let api_resp: GhInfoApiResponse = serde_json::from_str(&response_text) .context("解析更新 API 响应失败,既不是自定义格式也不是 gh-info 格式")?; // 提取版本号(去掉 'v' 前缀) let version = api_resp.latest_version.trim_start_matches('v').to_string(); println!("[更新检查] 远程版本: {}", version); // 版本比较 let comparison = compare_version(&version, current_version); println!("[更新检查] 版本比较结果: {} (1=有新版本, 0=相同, -1=当前更新)", comparison); if comparison > 0 { println!("[更新检查] ✓ 发现新版本: {}", version); // 从 attachments 中获取下载链接 // 支持两种格式: // 1. 字符串数组: ["URL1", "URL2", ...] // 2. 嵌套数组: [["文件名", "URL"], ...] let download_url = extract_download_url(&api_resp.attachments) .ok_or_else(|| anyhow::anyhow!("未找到可下载的安装包"))?; println!("[更新检查] 下载链接: {}", download_url); Some(UpdateInfo { version, notes: api_resp.changelog, download_url, }) } else { println!("[更新检查] ✗ 已是最新版本"); None } }; Ok(update_info) } /// 从 attachments 中提取下载 URL /// 支持两种格式: /// 1. 字符串数组: ["URL1", "URL2", ...] - 优先选择 .exe 或 .msi 文件 /// 2. 嵌套数组: [["文件名", "URL"], ...] - 优先选择 .exe 或 .msi 文件 fn extract_download_url(attachments: &serde_json::Value) -> Option { // 尝试解析为字符串数组格式: ["URL1", "URL2", ...] if let Ok(urls) = serde_json::from_value::>(attachments.clone()) { println!("[更新检查] 检测到字符串数组格式的 attachments"); // 优先选择 .exe 或 .msi 文件 if let Some(url) = urls.iter().find(|url| { url.ends_with(".exe") || url.ends_with(".msi") }) { return Some(url.clone()); } // 如果没有找到 .exe 或 .msi,使用第一个 URL return urls.first().cloned(); } // 尝试解析为嵌套数组格式: [["文件名", "URL"], ...] if let Ok(nested) = serde_json::from_value::>>(attachments.clone()) { println!("[更新检查] 检测到嵌套数组格式的 attachments"); // 优先选择 .exe 或 .msi 文件 if let Some(url) = nested.iter().find_map(|attachment| { if attachment.len() >= 2 { let filename = &attachment[0]; let url = &attachment[1]; if filename.ends_with(".exe") || filename.ends_with(".msi") { Some(url.clone()) } else { None } } else { None } }) { return Some(url); } // 如果没有找到 .exe 或 .msi,使用第一个附件的 URL if let Some(attachment) = nested.first() { if attachment.len() >= 2 { return Some(attachment[1].clone()); } } } None } /// 改进的版本比较函数,支持预发布版本(beta.5, beta.6等) fn compare_version(new: &str, current: &str) -> i32 { println!("[版本比较] 比较版本: '{}' vs '{}'", new, current); // 解析版本号:支持格式如 "0.0.6-beta.5", "beta.6", "0.0.6" 等 let (new_base, new_pre) = parse_version(new); let (current_base, current_pre) = parse_version(current); println!("[版本比较] 新版本 - 基础部分: {:?}, 预发布部分: {:?}", new_base, new_pre); println!("[版本比较] 当前版本 - 基础部分: {:?}, 预发布部分: {:?}", current_base, current_pre); // 先比较基础版本号(数字部分) let base_comparison = compare_version_parts(&new_base, ¤t_base); if base_comparison != 0 { println!("[版本比较] 基础版本不同,返回: {}", base_comparison); return base_comparison; } // 如果基础版本相同(或都为空),比较预发布标识符 // 如果基础版本都为空,说明是纯预发布版本(如 beta.5 vs beta.6) let pre_comparison = compare_prerelease(&new_pre, ¤t_pre); println!("[版本比较] 预发布版本比较结果: {}", pre_comparison); // 如果基础版本都为空且预发布比较结果为0,说明版本完全相同 if new_base.is_empty() && current_base.is_empty() && pre_comparison == 0 { return 0; } pre_comparison } /// 解析版本号,返回(基础版本号数组,预发布标识符) fn parse_version(version: &str) -> (Vec, Option) { // 去掉 'v' 前缀 let version = version.trim_start_matches('v').trim(); // 检查是否有预发布标识符(如 -beta.5, -alpha.1 等) let (base_str, pre_str) = if let Some(dash_pos) = version.find('-') { let (base, pre) = version.split_at(dash_pos); (base, Some(pre[1..].to_string())) // 跳过 '-' 字符 } else { (version, None) }; // 解析基础版本号(数字部分) let base_parts: Vec = base_str .split('.') .filter_map(|s| s.parse().ok()) .collect(); // 如果基础版本号为空且没有预发布标识符,可能是纯预发布版本(如 "beta.5") // 这种情况下,整个字符串作为预发布标识符 if base_parts.is_empty() && pre_str.is_none() { // 检查是否包含非数字字符(可能是预发布版本) if !version.chars().any(|c| c.is_ascii_digit()) { return (vec![], Some(version.to_string())); } } (base_parts, pre_str) } /// 比较版本号数组(数字部分) fn compare_version_parts(new: &[u32], current: &[u32]) -> i32 { let max_len = new.len().max(current.len()); for i in 0..max_len { let new_val = new.get(i).copied().unwrap_or(0); let current_val = current.get(i).copied().unwrap_or(0); if new_val > current_val { return 1; } else if new_val < current_val { return -1; } } 0 } /// 比较预发布标识符 /// 规则: /// - 有预发布标识符的版本 < 没有预发布标识符的版本 /// - 如果都有预发布标识符,按字典序比较 fn compare_prerelease(new: &Option, current: &Option) -> i32 { match (new, current) { // 都没有预发布标识符,版本相同 (None, None) => 0, // 新版本有预发布,当前版本没有 -> 新版本更旧(预发布版本 < 正式版本) (Some(_), None) => -1, // 新版本没有预发布,当前版本有 -> 新版本更新 (None, Some(_)) => 1, // 都有预发布标识符,按字典序比较 (Some(new_pre), Some(current_pre)) => { // 尝试提取数字部分进行比较(如 beta.5 -> 5, beta.6 -> 6) let new_num = extract_number_from_prerelease(new_pre); let current_num = extract_number_from_prerelease(current_pre); if let (Some(new_n), Some(current_n)) = (new_num, current_num) { // 如果都能提取数字,比较数字 if new_n > current_n { 1 } else if new_n < current_n { -1 } else { // 数字相同,按字符串比较 new_pre.cmp(current_pre) as i32 } } else { // 无法提取数字,按字符串比较 new_pre.cmp(current_pre) as i32 } } } } /// 从预发布标识符中提取数字(如 "beta.5" -> 5, "alpha.1" -> 1) fn extract_number_from_prerelease(pre: &str) -> Option { // 尝试从最后一部分提取数字 if let Some(last_part) = pre.split('.').last() { last_part.parse().ok() } else { None } } /// 下载更新文件(带进度追踪和取消支持) pub async fn download_update( app: &tauri::AppHandle, download_url: &str, cancelled: Arc, ) -> Result { println!("[下载更新] 开始下载,下载链接: {}", download_url); let client = reqwest::Client::builder() .timeout(std::time::Duration::from_secs(300)) .build()?; let response = client.get(download_url).send().await?; if !response.status().is_success() { return Err(anyhow::anyhow!("下载失败,HTTP 状态码: {}", response.status())); } // 获取文件总大小 let total_size = response.content_length().unwrap_or(0); // 获取缓存目录 let cache_dir = app .path() .resolve("updates", BaseDirectory::AppCache) .context("无法获取缓存目录")?; fs::create_dir_all(&cache_dir)?; // 从 URL 中提取文件名 let filename = download_url .split('/') .last() .unwrap_or("update") .split('?') .next() .unwrap_or("update"); let file_path = cache_dir.join(filename); println!("[下载更新] 文件名: {}", filename); println!("[下载更新] 文件大小: {} bytes ({:.2} MB)", total_size, total_size as f64 / 1024.0 / 1024.0); println!("[下载更新] 保存路径: {}", file_path.display()); // 下载文件 let mut file = fs::File::create(&file_path)?; let mut stream = response.bytes_stream(); let mut downloaded: u64 = 0; use futures_util::StreamExt; use std::io::Write; while let Some(item) = stream.next().await { // 检查是否取消 if cancelled.load(Ordering::Relaxed) { // 删除部分下载的文件 let _ = fs::remove_file(&file_path); return Err(anyhow::anyhow!("下载已取消")); } let chunk = item?; file.write_all(&chunk)?; downloaded += chunk.len() as u64; // 发送进度事件 if total_size > 0 { let progress = (downloaded * 100) / total_size; let _ = app.emit("update-download-progress", progress); } else { // 如果无法获取总大小,发送已下载的字节数 let _ = app.emit("update-download-progress", downloaded); } } file.sync_all()?; // 发送完成事件 let _ = app.emit("update-download-progress", 100u64); println!("[下载更新] ✓ 下载完成,文件已保存到: {}", file_path.display()); Ok(file_path) } /// 安装更新(Windows) #[cfg(target_os = "windows")] pub fn install_update(installer_path: &str) -> Result<()> { let mut cmd = Command::new(installer_path); cmd.args(&["/S"]); // 静默安装 cmd.creation_flags(CREATE_NO_WINDOW); cmd.spawn()?; Ok(()) } /// 安装更新(macOS) #[cfg(target_os = "macos")] pub fn install_update(installer_path: &str) -> Result<()> { Command::new("open").arg(installer_path).spawn()?; Ok(()) } /// 安装更新(Linux) #[cfg(target_os = "linux")] pub fn install_update(installer_path: &str) -> Result<()> { if installer_path.ends_with(".deb") { Command::new("sudo") .args(&["dpkg", "-i", installer_path]) .spawn()?; } else if installer_path.ends_with(".AppImage") { Command::new("chmod") .args(&["+x", installer_path]) .spawn()?; } Ok(()) }