Files
cstb-next/src-tauri/src/tool/updater.rs

483 lines
17 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::PathBuf;
use std::process::Command;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tauri::{Emitter, Manager};
use tauri::path::BaseDirectory;
#[cfg(windows)]
use std::os::windows::process::CommandExt;
#[cfg(windows)]
const CREATE_NO_WINDOW: u32 = 0x08000000;
/// 更新信息结构
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdateInfo {
pub version: String,
pub notes: Option<String>,
pub download_url: String,
}
/// gh-info API 响应结构
#[derive(Debug, Clone, Serialize, Deserialize)]
struct GhInfoApiResponse {
repo: String,
latest_version: String,
changelog: Option<String>,
published_at: String,
#[serde(default)]
prerelease: bool,
attachments: serde_json::Value, // 支持两种格式: ["URL1", "URL2"] 或 [["文件名", "URL"], ...]
}
/// 自定义更新服务器 API 响应结构
#[derive(Debug, Clone, Serialize, Deserialize)]
struct CustomUpdateApiResponse {
version: String,
notes: Option<String>,
#[serde(rename = "pub_date")]
pub_date: Option<String>,
download_url: String,
signature: Option<String>,
platforms: Option<std::collections::HashMap<String, PlatformInfo>>,
}
/// 平台特定信息
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PlatformInfo {
url: String,
signature: Option<String>,
}
/// 检查更新(使用自定义 API 端点)
pub async fn check_update(
endpoint: Option<&str>,
current_version: &str,
_use_mirror: bool,
github_repo: Option<&str>,
include_prerelease: bool,
) -> Result<Option<UpdateInfo>> {
println!("[更新检查] 开始检查更新...");
println!("[更新检查] 当前版本: {}", current_version);
println!("[更新检查] 包含预发布版本: {}", include_prerelease);
// 确定使用的 API 端点
let api_url = if let Some(custom_endpoint) = endpoint {
// 如果提供了自定义端点,直接使用
custom_endpoint.to_string()
} else {
// 否则使用默认的 gh-info API
let repo = github_repo.unwrap_or("plsgo/cstb");
if include_prerelease {
format!("https://gh-info.okk.cool/repos/{}/releases/latest/pre", repo)
} else {
format!("https://gh-info.okk.cool/repos/{}/releases/latest", repo)
}
};
println!("[更新检查] API URL: {}", api_url);
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.user_agent("cstb-updater/1.0")
.build()?;
let response = client.get(&api_url).send().await?;
if !response.status().is_success() {
println!("[更新检查] ✗ API 请求失败HTTP 状态码: {}", response.status());
return Err(anyhow::anyhow!("API 请求失败HTTP 状态码: {}", response.status()));
}
// 获取响应文本以便尝试不同的解析方式
let response_text = response.text().await?;
// 关闭更新日志的打印
// println!("[更新检查] API 响应: {}", response_text);
// 尝试解析为自定义更新服务器格式
let update_info = if let Ok(custom_resp) = serde_json::from_str::<CustomUpdateApiResponse>(&response_text) {
println!("[更新检查] 检测到自定义更新服务器格式");
// 提取版本号(去掉 'v' 前缀)
let version = custom_resp.version.trim_start_matches('v').to_string();
println!("[更新检查] 远程版本: {}", version);
// 版本比较
let comparison = compare_version(&version, current_version);
println!("[更新检查] 版本比较结果: {} (1=有新版本, 0=相同, -1=当前更新)", comparison);
if comparison > 0 {
println!("[更新检查] ✓ 发现新版本: {}", version);
// 获取下载链接
// 优先使用平台特定的链接
let download_url = if let Some(ref platforms) = custom_resp.platforms {
// 检测当前平台
#[cfg(target_os = "windows")]
let platform_key = "windows-x86_64";
#[cfg(all(target_os = "macos", target_arch = "x86_64"))]
let platform_key = "darwin-x86_64";
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
let platform_key = "darwin-aarch64";
#[cfg(target_os = "linux")]
let platform_key = "linux-x86_64";
#[cfg(not(any(target_os = "windows", target_os = "macos", target_os = "linux")))]
let platform_key = "";
if !platform_key.is_empty() {
platforms.get(platform_key)
.map(|p| p.url.clone())
.unwrap_or_else(|| custom_resp.download_url.clone())
} else {
custom_resp.download_url.clone()
}
} else {
custom_resp.download_url.clone()
};
println!("[更新检查] 下载链接: {}", download_url);
Some(UpdateInfo {
version,
notes: custom_resp.notes,
download_url,
})
} else {
println!("[更新检查] ✗ 已是最新版本");
None
}
} else {
// 尝试解析为 gh-info API 格式
println!("[更新检查] 尝试解析为 gh-info API 格式");
let api_resp: GhInfoApiResponse = serde_json::from_str(&response_text)
.context("解析更新 API 响应失败,既不是自定义格式也不是 gh-info 格式")?;
// 提取版本号(去掉 'v' 前缀)
let version = api_resp.latest_version.trim_start_matches('v').to_string();
println!("[更新检查] 远程版本: {}", version);
// 版本比较
let comparison = compare_version(&version, current_version);
println!("[更新检查] 版本比较结果: {} (1=有新版本, 0=相同, -1=当前更新)", comparison);
if comparison > 0 {
println!("[更新检查] ✓ 发现新版本: {}", version);
// 从 attachments 中获取下载链接
// 支持两种格式:
// 1. 字符串数组: ["URL1", "URL2", ...]
// 2. 嵌套数组: [["文件名", "URL"], ...]
let download_url = extract_download_url(&api_resp.attachments)
.ok_or_else(|| anyhow::anyhow!("未找到可下载的安装包"))?;
println!("[更新检查] 下载链接: {}", download_url);
Some(UpdateInfo {
version,
notes: api_resp.changelog,
download_url,
})
} else {
println!("[更新检查] ✗ 已是最新版本");
None
}
};
Ok(update_info)
}
/// 从 attachments 中提取下载 URL
/// 支持两种格式:
/// 1. 字符串数组: ["URL1", "URL2", ...] - 优先选择 .exe 或 .msi 文件
/// 2. 嵌套数组: [["文件名", "URL"], ...] - 优先选择 .exe 或 .msi 文件
fn extract_download_url(attachments: &serde_json::Value) -> Option<String> {
// 尝试解析为字符串数组格式: ["URL1", "URL2", ...]
if let Ok(urls) = serde_json::from_value::<Vec<String>>(attachments.clone()) {
println!("[更新检查] 检测到字符串数组格式的 attachments");
// 优先选择 .exe 或 .msi 文件
if let Some(url) = urls.iter().find(|url| {
url.ends_with(".exe") || url.ends_with(".msi")
}) {
return Some(url.clone());
}
// 如果没有找到 .exe 或 .msi使用第一个 URL
return urls.first().cloned();
}
// 尝试解析为嵌套数组格式: [["文件名", "URL"], ...]
if let Ok(nested) = serde_json::from_value::<Vec<Vec<String>>>(attachments.clone()) {
println!("[更新检查] 检测到嵌套数组格式的 attachments");
// 优先选择 .exe 或 .msi 文件
if let Some(url) = nested.iter().find_map(|attachment| {
if attachment.len() >= 2 {
let filename = &attachment[0];
let url = &attachment[1];
if filename.ends_with(".exe") || filename.ends_with(".msi") {
Some(url.clone())
} else {
None
}
} else {
None
}
}) {
return Some(url);
}
// 如果没有找到 .exe 或 .msi使用第一个附件的 URL
if let Some(attachment) = nested.first() {
if attachment.len() >= 2 {
return Some(attachment[1].clone());
}
}
}
None
}
/// 改进的版本比较函数支持预发布版本beta.5, beta.6等)
fn compare_version(new: &str, current: &str) -> i32 {
println!("[版本比较] 比较版本: '{}' vs '{}'", new, current);
// 解析版本号:支持格式如 "0.0.6-beta.5", "beta.6", "0.0.6" 等
let (new_base, new_pre) = parse_version(new);
let (current_base, current_pre) = parse_version(current);
println!("[版本比较] 新版本 - 基础部分: {:?}, 预发布部分: {:?}", new_base, new_pre);
println!("[版本比较] 当前版本 - 基础部分: {:?}, 预发布部分: {:?}", current_base, current_pre);
// 先比较基础版本号(数字部分)
let base_comparison = compare_version_parts(&new_base, &current_base);
if base_comparison != 0 {
println!("[版本比较] 基础版本不同,返回: {}", base_comparison);
return base_comparison;
}
// 如果基础版本相同(或都为空),比较预发布标识符
// 如果基础版本都为空,说明是纯预发布版本(如 beta.5 vs beta.6
let pre_comparison = compare_prerelease(&new_pre, &current_pre);
println!("[版本比较] 预发布版本比较结果: {}", pre_comparison);
// 如果基础版本都为空且预发布比较结果为0说明版本完全相同
if new_base.is_empty() && current_base.is_empty() && pre_comparison == 0 {
return 0;
}
pre_comparison
}
/// 解析版本号,返回(基础版本号数组,预发布标识符)
fn parse_version(version: &str) -> (Vec<u32>, Option<String>) {
// 去掉 'v' 前缀
let version = version.trim_start_matches('v').trim();
// 检查是否有预发布标识符(如 -beta.5, -alpha.1 等)
let (base_str, pre_str) = if let Some(dash_pos) = version.find('-') {
let (base, pre) = version.split_at(dash_pos);
(base, Some(pre[1..].to_string())) // 跳过 '-' 字符
} else {
(version, None)
};
// 解析基础版本号(数字部分)
let base_parts: Vec<u32> = base_str
.split('.')
.filter_map(|s| s.parse().ok())
.collect();
// 如果基础版本号为空且没有预发布标识符,可能是纯预发布版本(如 "beta.5"
// 这种情况下,整个字符串作为预发布标识符
if base_parts.is_empty() && pre_str.is_none() {
// 检查是否包含非数字字符(可能是预发布版本)
if !version.chars().any(|c| c.is_ascii_digit()) {
return (vec![], Some(version.to_string()));
}
}
(base_parts, pre_str)
}
/// 比较版本号数组(数字部分)
fn compare_version_parts(new: &[u32], current: &[u32]) -> i32 {
let max_len = new.len().max(current.len());
for i in 0..max_len {
let new_val = new.get(i).copied().unwrap_or(0);
let current_val = current.get(i).copied().unwrap_or(0);
if new_val > current_val {
return 1;
} else if new_val < current_val {
return -1;
}
}
0
}
/// 比较预发布标识符
/// 规则:
/// - 有预发布标识符的版本 < 没有预发布标识符的版本
/// - 如果都有预发布标识符,按字典序比较
fn compare_prerelease(new: &Option<String>, current: &Option<String>) -> i32 {
match (new, current) {
// 都没有预发布标识符,版本相同
(None, None) => 0,
// 新版本有预发布,当前版本没有 -> 新版本更旧(预发布版本 < 正式版本)
(Some(_), None) => -1,
// 新版本没有预发布,当前版本有 -> 新版本更新
(None, Some(_)) => 1,
// 都有预发布标识符,按字典序比较
(Some(new_pre), Some(current_pre)) => {
// 尝试提取数字部分进行比较(如 beta.5 -> 5, beta.6 -> 6
let new_num = extract_number_from_prerelease(new_pre);
let current_num = extract_number_from_prerelease(current_pre);
if let (Some(new_n), Some(current_n)) = (new_num, current_num) {
// 如果都能提取数字,比较数字
if new_n > current_n {
1
} else if new_n < current_n {
-1
} else {
// 数字相同,按字符串比较
new_pre.cmp(current_pre) as i32
}
} else {
// 无法提取数字,按字符串比较
new_pre.cmp(current_pre) as i32
}
}
}
}
/// 从预发布标识符中提取数字(如 "beta.5" -> 5, "alpha.1" -> 1
fn extract_number_from_prerelease(pre: &str) -> Option<u32> {
// 尝试从最后一部分提取数字
if let Some(last_part) = pre.split('.').last() {
last_part.parse().ok()
} else {
None
}
}
/// 下载更新文件(带进度追踪和取消支持)
pub async fn download_update(
app: &tauri::AppHandle,
download_url: &str,
cancelled: Arc<AtomicBool>,
) -> Result<PathBuf> {
println!("[下载更新] 开始下载,下载链接: {}", download_url);
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(300))
.build()?;
let response = client.get(download_url).send().await?;
if !response.status().is_success() {
return Err(anyhow::anyhow!("下载失败HTTP 状态码: {}", response.status()));
}
// 获取文件总大小
let total_size = response.content_length().unwrap_or(0);
// 获取缓存目录
let cache_dir = app
.path()
.resolve("updates", BaseDirectory::AppCache)
.context("无法获取缓存目录")?;
fs::create_dir_all(&cache_dir)?;
// 从 URL 中提取文件名
let filename = download_url
.split('/')
.last()
.unwrap_or("update")
.split('?')
.next()
.unwrap_or("update");
let file_path = cache_dir.join(filename);
println!("[下载更新] 文件名: {}", filename);
println!("[下载更新] 文件大小: {} bytes ({:.2} MB)", total_size, total_size as f64 / 1024.0 / 1024.0);
println!("[下载更新] 保存路径: {}", file_path.display());
// 下载文件
let mut file = fs::File::create(&file_path)?;
let mut stream = response.bytes_stream();
let mut downloaded: u64 = 0;
use futures_util::StreamExt;
use std::io::Write;
while let Some(item) = stream.next().await {
// 检查是否取消
if cancelled.load(Ordering::Relaxed) {
// 删除部分下载的文件
let _ = fs::remove_file(&file_path);
return Err(anyhow::anyhow!("下载已取消"));
}
let chunk = item?;
file.write_all(&chunk)?;
downloaded += chunk.len() as u64;
// 发送进度事件
if total_size > 0 {
let progress = (downloaded * 100) / total_size;
let _ = app.emit("update-download-progress", progress);
} else {
// 如果无法获取总大小,发送已下载的字节数
let _ = app.emit("update-download-progress", downloaded);
}
}
file.sync_all()?;
// 发送完成事件
let _ = app.emit("update-download-progress", 100u64);
println!("[下载更新] ✓ 下载完成,文件已保存到: {}", file_path.display());
Ok(file_path)
}
/// 安装更新Windows
#[cfg(target_os = "windows")]
pub fn install_update(installer_path: &str) -> Result<()> {
let mut cmd = Command::new(installer_path);
cmd.args(&["/S"]); // 静默安装
cmd.creation_flags(CREATE_NO_WINDOW);
cmd.spawn()?;
Ok(())
}
/// 安装更新macOS
#[cfg(target_os = "macos")]
pub fn install_update(installer_path: &str) -> Result<()> {
Command::new("open").arg(installer_path).spawn()?;
Ok(())
}
/// 安装更新Linux
#[cfg(target_os = "linux")]
pub fn install_update(installer_path: &str) -> Result<()> {
if installer_path.ends_with(".deb") {
Command::new("sudo")
.args(&["dpkg", "-i", installer_path])
.spawn()?;
} else if installer_path.ends_with(".AppImage") {
Command::new("chmod")
.args(&["+x", installer_path])
.spawn()?;
}
Ok(())
}