483 lines
17 KiB
Rust
483 lines
17 KiB
Rust
use anyhow::{Context, Result};
|
||
use serde::{Deserialize, Serialize};
|
||
use std::fs;
|
||
use std::path::PathBuf;
|
||
use std::process::Command;
|
||
use std::sync::atomic::{AtomicBool, Ordering};
|
||
use std::sync::Arc;
|
||
use tauri::{Emitter, Manager};
|
||
use tauri::path::BaseDirectory;
|
||
|
||
#[cfg(windows)]
|
||
use std::os::windows::process::CommandExt;
|
||
|
||
#[cfg(windows)]
|
||
const CREATE_NO_WINDOW: u32 = 0x08000000;
|
||
|
||
/// 更新信息结构
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct UpdateInfo {
|
||
pub version: String,
|
||
pub notes: Option<String>,
|
||
pub download_url: String,
|
||
}
|
||
|
||
/// gh-info API 响应结构
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
struct GhInfoApiResponse {
|
||
repo: String,
|
||
latest_version: String,
|
||
changelog: Option<String>,
|
||
published_at: String,
|
||
#[serde(default)]
|
||
prerelease: bool,
|
||
attachments: serde_json::Value, // 支持两种格式: ["URL1", "URL2"] 或 [["文件名", "URL"], ...]
|
||
}
|
||
|
||
/// 自定义更新服务器 API 响应结构
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
struct CustomUpdateApiResponse {
|
||
version: String,
|
||
notes: Option<String>,
|
||
#[serde(rename = "pub_date")]
|
||
pub_date: Option<String>,
|
||
download_url: String,
|
||
signature: Option<String>,
|
||
platforms: Option<std::collections::HashMap<String, PlatformInfo>>,
|
||
}
|
||
|
||
/// 平台特定信息
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
struct PlatformInfo {
|
||
url: String,
|
||
signature: Option<String>,
|
||
}
|
||
|
||
/// 检查更新(使用自定义 API 端点)
|
||
pub async fn check_update(
|
||
endpoint: Option<&str>,
|
||
current_version: &str,
|
||
_use_mirror: bool,
|
||
github_repo: Option<&str>,
|
||
include_prerelease: bool,
|
||
) -> Result<Option<UpdateInfo>> {
|
||
println!("[更新检查] 开始检查更新...");
|
||
println!("[更新检查] 当前版本: {}", current_version);
|
||
println!("[更新检查] 包含预发布版本: {}", include_prerelease);
|
||
|
||
// 确定使用的 API 端点
|
||
let api_url = if let Some(custom_endpoint) = endpoint {
|
||
// 如果提供了自定义端点,直接使用
|
||
custom_endpoint.to_string()
|
||
} else {
|
||
// 否则使用默认的 gh-info API
|
||
let repo = github_repo.unwrap_or("plsgo/cstb");
|
||
if include_prerelease {
|
||
format!("https://gh-info.okk.cool/repos/{}/releases/latest/pre", repo)
|
||
} else {
|
||
format!("https://gh-info.okk.cool/repos/{}/releases/latest", repo)
|
||
}
|
||
};
|
||
|
||
println!("[更新检查] API URL: {}", api_url);
|
||
|
||
let client = reqwest::Client::builder()
|
||
.timeout(std::time::Duration::from_secs(10))
|
||
.user_agent("cstb-updater/1.0")
|
||
.build()?;
|
||
|
||
let response = client.get(&api_url).send().await?;
|
||
|
||
if !response.status().is_success() {
|
||
println!("[更新检查] ✗ API 请求失败,HTTP 状态码: {}", response.status());
|
||
return Err(anyhow::anyhow!("API 请求失败,HTTP 状态码: {}", response.status()));
|
||
}
|
||
|
||
// 获取响应文本以便尝试不同的解析方式
|
||
let response_text = response.text().await?;
|
||
// 关闭更新日志的打印
|
||
// println!("[更新检查] API 响应: {}", response_text);
|
||
|
||
// 尝试解析为自定义更新服务器格式
|
||
let update_info = if let Ok(custom_resp) = serde_json::from_str::<CustomUpdateApiResponse>(&response_text) {
|
||
println!("[更新检查] 检测到自定义更新服务器格式");
|
||
|
||
// 提取版本号(去掉 'v' 前缀)
|
||
let version = custom_resp.version.trim_start_matches('v').to_string();
|
||
println!("[更新检查] 远程版本: {}", version);
|
||
|
||
// 版本比较
|
||
let comparison = compare_version(&version, current_version);
|
||
println!("[更新检查] 版本比较结果: {} (1=有新版本, 0=相同, -1=当前更新)", comparison);
|
||
|
||
if comparison > 0 {
|
||
println!("[更新检查] ✓ 发现新版本: {}", version);
|
||
|
||
// 获取下载链接
|
||
// 优先使用平台特定的链接
|
||
let download_url = if let Some(ref platforms) = custom_resp.platforms {
|
||
// 检测当前平台
|
||
#[cfg(target_os = "windows")]
|
||
let platform_key = "windows-x86_64";
|
||
#[cfg(all(target_os = "macos", target_arch = "x86_64"))]
|
||
let platform_key = "darwin-x86_64";
|
||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||
let platform_key = "darwin-aarch64";
|
||
#[cfg(target_os = "linux")]
|
||
let platform_key = "linux-x86_64";
|
||
#[cfg(not(any(target_os = "windows", target_os = "macos", target_os = "linux")))]
|
||
let platform_key = "";
|
||
|
||
if !platform_key.is_empty() {
|
||
platforms.get(platform_key)
|
||
.map(|p| p.url.clone())
|
||
.unwrap_or_else(|| custom_resp.download_url.clone())
|
||
} else {
|
||
custom_resp.download_url.clone()
|
||
}
|
||
} else {
|
||
custom_resp.download_url.clone()
|
||
};
|
||
|
||
println!("[更新检查] 下载链接: {}", download_url);
|
||
|
||
Some(UpdateInfo {
|
||
version,
|
||
notes: custom_resp.notes,
|
||
download_url,
|
||
})
|
||
} else {
|
||
println!("[更新检查] ✗ 已是最新版本");
|
||
None
|
||
}
|
||
} else {
|
||
// 尝试解析为 gh-info API 格式
|
||
println!("[更新检查] 尝试解析为 gh-info API 格式");
|
||
let api_resp: GhInfoApiResponse = serde_json::from_str(&response_text)
|
||
.context("解析更新 API 响应失败,既不是自定义格式也不是 gh-info 格式")?;
|
||
|
||
// 提取版本号(去掉 'v' 前缀)
|
||
let version = api_resp.latest_version.trim_start_matches('v').to_string();
|
||
println!("[更新检查] 远程版本: {}", version);
|
||
|
||
// 版本比较
|
||
let comparison = compare_version(&version, current_version);
|
||
println!("[更新检查] 版本比较结果: {} (1=有新版本, 0=相同, -1=当前更新)", comparison);
|
||
|
||
if comparison > 0 {
|
||
println!("[更新检查] ✓ 发现新版本: {}", version);
|
||
|
||
// 从 attachments 中获取下载链接
|
||
// 支持两种格式:
|
||
// 1. 字符串数组: ["URL1", "URL2", ...]
|
||
// 2. 嵌套数组: [["文件名", "URL"], ...]
|
||
let download_url = extract_download_url(&api_resp.attachments)
|
||
.ok_or_else(|| anyhow::anyhow!("未找到可下载的安装包"))?;
|
||
|
||
println!("[更新检查] 下载链接: {}", download_url);
|
||
|
||
Some(UpdateInfo {
|
||
version,
|
||
notes: api_resp.changelog,
|
||
download_url,
|
||
})
|
||
} else {
|
||
println!("[更新检查] ✗ 已是最新版本");
|
||
None
|
||
}
|
||
};
|
||
|
||
Ok(update_info)
|
||
}
|
||
|
||
/// 从 attachments 中提取下载 URL
|
||
/// 支持两种格式:
|
||
/// 1. 字符串数组: ["URL1", "URL2", ...] - 优先选择 .exe 或 .msi 文件
|
||
/// 2. 嵌套数组: [["文件名", "URL"], ...] - 优先选择 .exe 或 .msi 文件
|
||
fn extract_download_url(attachments: &serde_json::Value) -> Option<String> {
|
||
// 尝试解析为字符串数组格式: ["URL1", "URL2", ...]
|
||
if let Ok(urls) = serde_json::from_value::<Vec<String>>(attachments.clone()) {
|
||
println!("[更新检查] 检测到字符串数组格式的 attachments");
|
||
// 优先选择 .exe 或 .msi 文件
|
||
if let Some(url) = urls.iter().find(|url| {
|
||
url.ends_with(".exe") || url.ends_with(".msi")
|
||
}) {
|
||
return Some(url.clone());
|
||
}
|
||
// 如果没有找到 .exe 或 .msi,使用第一个 URL
|
||
return urls.first().cloned();
|
||
}
|
||
|
||
// 尝试解析为嵌套数组格式: [["文件名", "URL"], ...]
|
||
if let Ok(nested) = serde_json::from_value::<Vec<Vec<String>>>(attachments.clone()) {
|
||
println!("[更新检查] 检测到嵌套数组格式的 attachments");
|
||
// 优先选择 .exe 或 .msi 文件
|
||
if let Some(url) = nested.iter().find_map(|attachment| {
|
||
if attachment.len() >= 2 {
|
||
let filename = &attachment[0];
|
||
let url = &attachment[1];
|
||
if filename.ends_with(".exe") || filename.ends_with(".msi") {
|
||
Some(url.clone())
|
||
} else {
|
||
None
|
||
}
|
||
} else {
|
||
None
|
||
}
|
||
}) {
|
||
return Some(url);
|
||
}
|
||
// 如果没有找到 .exe 或 .msi,使用第一个附件的 URL
|
||
if let Some(attachment) = nested.first() {
|
||
if attachment.len() >= 2 {
|
||
return Some(attachment[1].clone());
|
||
}
|
||
}
|
||
}
|
||
|
||
None
|
||
}
|
||
|
||
/// 改进的版本比较函数,支持预发布版本(beta.5, beta.6等)
|
||
fn compare_version(new: &str, current: &str) -> i32 {
|
||
println!("[版本比较] 比较版本: '{}' vs '{}'", new, current);
|
||
|
||
// 解析版本号:支持格式如 "0.0.6-beta.5", "beta.6", "0.0.6" 等
|
||
let (new_base, new_pre) = parse_version(new);
|
||
let (current_base, current_pre) = parse_version(current);
|
||
|
||
println!("[版本比较] 新版本 - 基础部分: {:?}, 预发布部分: {:?}", new_base, new_pre);
|
||
println!("[版本比较] 当前版本 - 基础部分: {:?}, 预发布部分: {:?}", current_base, current_pre);
|
||
|
||
// 先比较基础版本号(数字部分)
|
||
let base_comparison = compare_version_parts(&new_base, ¤t_base);
|
||
|
||
if base_comparison != 0 {
|
||
println!("[版本比较] 基础版本不同,返回: {}", base_comparison);
|
||
return base_comparison;
|
||
}
|
||
|
||
// 如果基础版本相同(或都为空),比较预发布标识符
|
||
// 如果基础版本都为空,说明是纯预发布版本(如 beta.5 vs beta.6)
|
||
let pre_comparison = compare_prerelease(&new_pre, ¤t_pre);
|
||
println!("[版本比较] 预发布版本比较结果: {}", pre_comparison);
|
||
|
||
// 如果基础版本都为空且预发布比较结果为0,说明版本完全相同
|
||
if new_base.is_empty() && current_base.is_empty() && pre_comparison == 0 {
|
||
return 0;
|
||
}
|
||
|
||
pre_comparison
|
||
}
|
||
|
||
/// 解析版本号,返回(基础版本号数组,预发布标识符)
|
||
fn parse_version(version: &str) -> (Vec<u32>, Option<String>) {
|
||
// 去掉 'v' 前缀
|
||
let version = version.trim_start_matches('v').trim();
|
||
|
||
// 检查是否有预发布标识符(如 -beta.5, -alpha.1 等)
|
||
let (base_str, pre_str) = if let Some(dash_pos) = version.find('-') {
|
||
let (base, pre) = version.split_at(dash_pos);
|
||
(base, Some(pre[1..].to_string())) // 跳过 '-' 字符
|
||
} else {
|
||
(version, None)
|
||
};
|
||
|
||
// 解析基础版本号(数字部分)
|
||
let base_parts: Vec<u32> = base_str
|
||
.split('.')
|
||
.filter_map(|s| s.parse().ok())
|
||
.collect();
|
||
|
||
// 如果基础版本号为空且没有预发布标识符,可能是纯预发布版本(如 "beta.5")
|
||
// 这种情况下,整个字符串作为预发布标识符
|
||
if base_parts.is_empty() && pre_str.is_none() {
|
||
// 检查是否包含非数字字符(可能是预发布版本)
|
||
if !version.chars().any(|c| c.is_ascii_digit()) {
|
||
return (vec![], Some(version.to_string()));
|
||
}
|
||
}
|
||
|
||
(base_parts, pre_str)
|
||
}
|
||
|
||
/// 比较版本号数组(数字部分)
|
||
fn compare_version_parts(new: &[u32], current: &[u32]) -> i32 {
|
||
let max_len = new.len().max(current.len());
|
||
|
||
for i in 0..max_len {
|
||
let new_val = new.get(i).copied().unwrap_or(0);
|
||
let current_val = current.get(i).copied().unwrap_or(0);
|
||
|
||
if new_val > current_val {
|
||
return 1;
|
||
} else if new_val < current_val {
|
||
return -1;
|
||
}
|
||
}
|
||
|
||
0
|
||
}
|
||
|
||
/// 比较预发布标识符
|
||
/// 规则:
|
||
/// - 有预发布标识符的版本 < 没有预发布标识符的版本
|
||
/// - 如果都有预发布标识符,按字典序比较
|
||
fn compare_prerelease(new: &Option<String>, current: &Option<String>) -> i32 {
|
||
match (new, current) {
|
||
// 都没有预发布标识符,版本相同
|
||
(None, None) => 0,
|
||
// 新版本有预发布,当前版本没有 -> 新版本更旧(预发布版本 < 正式版本)
|
||
(Some(_), None) => -1,
|
||
// 新版本没有预发布,当前版本有 -> 新版本更新
|
||
(None, Some(_)) => 1,
|
||
// 都有预发布标识符,按字典序比较
|
||
(Some(new_pre), Some(current_pre)) => {
|
||
// 尝试提取数字部分进行比较(如 beta.5 -> 5, beta.6 -> 6)
|
||
let new_num = extract_number_from_prerelease(new_pre);
|
||
let current_num = extract_number_from_prerelease(current_pre);
|
||
|
||
if let (Some(new_n), Some(current_n)) = (new_num, current_num) {
|
||
// 如果都能提取数字,比较数字
|
||
if new_n > current_n {
|
||
1
|
||
} else if new_n < current_n {
|
||
-1
|
||
} else {
|
||
// 数字相同,按字符串比较
|
||
new_pre.cmp(current_pre) as i32
|
||
}
|
||
} else {
|
||
// 无法提取数字,按字符串比较
|
||
new_pre.cmp(current_pre) as i32
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 从预发布标识符中提取数字(如 "beta.5" -> 5, "alpha.1" -> 1)
|
||
fn extract_number_from_prerelease(pre: &str) -> Option<u32> {
|
||
// 尝试从最后一部分提取数字
|
||
if let Some(last_part) = pre.split('.').last() {
|
||
last_part.parse().ok()
|
||
} else {
|
||
None
|
||
}
|
||
}
|
||
|
||
/// 下载更新文件(带进度追踪和取消支持)
|
||
pub async fn download_update(
|
||
app: &tauri::AppHandle,
|
||
download_url: &str,
|
||
cancelled: Arc<AtomicBool>,
|
||
) -> Result<PathBuf> {
|
||
println!("[下载更新] 开始下载,下载链接: {}", download_url);
|
||
|
||
let client = reqwest::Client::builder()
|
||
.timeout(std::time::Duration::from_secs(300))
|
||
.build()?;
|
||
|
||
let response = client.get(download_url).send().await?;
|
||
|
||
if !response.status().is_success() {
|
||
return Err(anyhow::anyhow!("下载失败,HTTP 状态码: {}", response.status()));
|
||
}
|
||
|
||
// 获取文件总大小
|
||
let total_size = response.content_length().unwrap_or(0);
|
||
|
||
// 获取缓存目录
|
||
let cache_dir = app
|
||
.path()
|
||
.resolve("updates", BaseDirectory::AppCache)
|
||
.context("无法获取缓存目录")?;
|
||
|
||
fs::create_dir_all(&cache_dir)?;
|
||
|
||
// 从 URL 中提取文件名
|
||
let filename = download_url
|
||
.split('/')
|
||
.last()
|
||
.unwrap_or("update")
|
||
.split('?')
|
||
.next()
|
||
.unwrap_or("update");
|
||
|
||
let file_path = cache_dir.join(filename);
|
||
|
||
println!("[下载更新] 文件名: {}", filename);
|
||
println!("[下载更新] 文件大小: {} bytes ({:.2} MB)", total_size, total_size as f64 / 1024.0 / 1024.0);
|
||
println!("[下载更新] 保存路径: {}", file_path.display());
|
||
|
||
// 下载文件
|
||
let mut file = fs::File::create(&file_path)?;
|
||
let mut stream = response.bytes_stream();
|
||
let mut downloaded: u64 = 0;
|
||
|
||
use futures_util::StreamExt;
|
||
use std::io::Write;
|
||
|
||
while let Some(item) = stream.next().await {
|
||
// 检查是否取消
|
||
if cancelled.load(Ordering::Relaxed) {
|
||
// 删除部分下载的文件
|
||
let _ = fs::remove_file(&file_path);
|
||
return Err(anyhow::anyhow!("下载已取消"));
|
||
}
|
||
|
||
let chunk = item?;
|
||
file.write_all(&chunk)?;
|
||
downloaded += chunk.len() as u64;
|
||
|
||
// 发送进度事件
|
||
if total_size > 0 {
|
||
let progress = (downloaded * 100) / total_size;
|
||
let _ = app.emit("update-download-progress", progress);
|
||
} else {
|
||
// 如果无法获取总大小,发送已下载的字节数
|
||
let _ = app.emit("update-download-progress", downloaded);
|
||
}
|
||
}
|
||
|
||
file.sync_all()?;
|
||
|
||
// 发送完成事件
|
||
let _ = app.emit("update-download-progress", 100u64);
|
||
|
||
println!("[下载更新] ✓ 下载完成,文件已保存到: {}", file_path.display());
|
||
|
||
Ok(file_path)
|
||
}
|
||
|
||
/// 安装更新(Windows)
|
||
#[cfg(target_os = "windows")]
|
||
pub fn install_update(installer_path: &str) -> Result<()> {
|
||
let mut cmd = Command::new(installer_path);
|
||
cmd.args(&["/S"]); // 静默安装
|
||
cmd.creation_flags(CREATE_NO_WINDOW);
|
||
cmd.spawn()?;
|
||
Ok(())
|
||
}
|
||
|
||
/// 安装更新(macOS)
|
||
#[cfg(target_os = "macos")]
|
||
pub fn install_update(installer_path: &str) -> Result<()> {
|
||
Command::new("open").arg(installer_path).spawn()?;
|
||
Ok(())
|
||
}
|
||
|
||
/// 安装更新(Linux)
|
||
#[cfg(target_os = "linux")]
|
||
pub fn install_update(installer_path: &str) -> Result<()> {
|
||
if installer_path.ends_with(".deb") {
|
||
Command::new("sudo")
|
||
.args(&["dpkg", "-i", installer_path])
|
||
.spawn()?;
|
||
} else if installer_path.ends_with(".AppImage") {
|
||
Command::new("chmod")
|
||
.args(&["+x", installer_path])
|
||
.spawn()?;
|
||
}
|
||
Ok(())
|
||
}
|