Compare commits

...

1 Commits

Author SHA1 Message Date
saladday
70504714f0 Keep WebDAV sync from propagating local proxy state
WebDAV sync now separates portable provider strategy from device-local
proxy state in the smallest upstream-shaped way we could keep coherent.
The backup layer scrubs and restores only the local proxy fields in
proxy_config, while backend restore now rejects when takeover artifacts
or running proxy state are still active.

The command layer was kept thin by routing proxy-setting writes back
through ProxyService, so the same-process restore/mutation boundary has
one owner instead of scattered command-side patches.

Constraint: Must stay upstream-friendly for a large open source codebase without introducing repo-specific multi-process machinery
Constraint: WebDAV restore must not clobber device-local proxy bindings or takeover state
Rejected: Exclude proxy_config from sync entirely | would also stop syncing portable proxy strategy fields
Rejected: Port local cross-process lock and managed-child bootstrap bypass | too local-repo-specific for upstream
Confidence: high
Scope-risk: moderate
Directive: Future writes to device-local proxy fields should continue to flow through ProxyService so the restore boundary remains coherent
Tested: cargo fmt --manifest-path src-tauri/Cargo.toml --check
Tested: cargo check --manifest-path src-tauri/Cargo.toml
Tested: cargo test --manifest-path src-tauri/Cargo.toml sync_import_preserves_local_proxy_config_local_fields -- --nocapture
Tested: cargo test --manifest-path src-tauri/Cargo.toml sync_export_scrubs_proxy_config_local_fields_but_keeps_strategy_fields -- --nocapture
Tested: cargo test --manifest-path src-tauri/Cargo.toml has_restore_blocking_proxy_state_is_true_when_live_backup_exists_without_enabled_flag -- --nocapture
Tested: cargo test --manifest-path src-tauri/Cargo.toml has_restore_blocking_proxy_state_is_true_when_live_config_residue_exists_without_enabled_flag -- --nocapture
Tested: cargo test --manifest-path src-tauri/Cargo.toml ensure_restore_allowed_rejects_takeover_artifacts_even_when_enabled_flag_is_false -- --nocapture
Tested: cargo test --manifest-path src-tauri/Cargo.toml proxy_config_update_waits_for_restore_mutation_guard -- --nocapture
Tested: cargo test --manifest-path src-tauri/Cargo.toml start_with_takeover_waits_for_restore_mutation_guard -- --nocapture
Not-tested: Full upstream cargo test suite
Related: SaladDay/cc-switch-cli#111
2026-04-17 15:58:04 +08:00
5 changed files with 764 additions and 43 deletions

View File

@@ -86,10 +86,7 @@ pub async fn update_global_proxy_config(
state: tauri::State<'_, AppState>,
config: GlobalProxyConfig,
) -> Result<(), String> {
let db = &state.db;
db.update_global_proxy_config(config)
.await
.map_err(|e| e.to_string())
state.proxy_service.update_global_proxy_config(config).await
}
/// 获取指定应用的代理配置
@@ -114,10 +111,10 @@ pub async fn update_proxy_config_for_app(
state: tauri::State<'_, AppState>,
config: AppProxyConfig,
) -> Result<(), String> {
let db = &state.db;
db.update_proxy_config_for_app(config)
state
.proxy_service
.update_proxy_config_for_app(config)
.await
.map_err(|e| e.to_string())
}
async fn get_default_cost_multiplier_internal(
@@ -152,8 +149,11 @@ async fn set_default_cost_multiplier_internal(
app_type: &str,
value: &str,
) -> Result<(), AppError> {
let db = &state.db;
db.set_default_cost_multiplier(app_type, value).await
state
.proxy_service
.set_default_cost_multiplier(app_type, value)
.await
.map_err(AppError::Config)
}
#[cfg_attr(not(feature = "test-hooks"), doc(hidden))]
@@ -209,8 +209,11 @@ async fn set_pricing_model_source_internal(
app_type: &str,
value: &str,
) -> Result<(), AppError> {
let db = &state.db;
db.set_pricing_model_source(app_type, value).await
state
.proxy_service
.set_pricing_model_source(app_type, value)
.await
.map_err(AppError::Config)
}
#[cfg_attr(not(feature = "test-hooks"), doc(hidden))]

View File

@@ -115,11 +115,17 @@ pub async fn webdav_sync_upload(state: State<'_, AppState>) -> Result<Value, Str
#[tauri::command]
pub async fn webdav_sync_download(state: State<'_, AppState>) -> Result<Value, String> {
let db = state.db.clone();
let proxy_service = state.proxy_service.clone();
let db_for_sync = db.clone();
let mut settings = require_enabled_webdav_settings()?;
let _auto_sync_suppression = crate::services::webdav_auto_sync::AutoSyncSuppressionGuard::new();
let sync_result = run_with_webdav_lock(webdav_sync_service::download(&db, &mut settings)).await;
let sync_result = run_with_webdav_lock(webdav_sync_service::download(
&db,
&proxy_service,
&mut settings,
))
.await;
let mut result = map_sync_result(sync_result, |error| {
persist_sync_error(&mut settings, error, "manual")
})?;

View File

@@ -7,7 +7,7 @@ use crate::config::get_app_config_dir;
use crate::error::AppError;
use chrono::{Local, Utc};
use rusqlite::backup::Backup;
use rusqlite::types::ValueRef;
use rusqlite::types::Value;
use rusqlite::Connection;
use std::fs;
use std::path::{Path, PathBuf};
@@ -33,6 +33,64 @@ const SYNC_PRESERVE_TABLES: &[&str] = &[
"usage_daily_rollups",
];
const PROXY_CONFIG_LOCAL_COLUMNS: &[&str] =
&["proxy_enabled", "listen_address", "listen_port", "enabled"];
#[derive(Clone, Copy)]
enum SyncNeutralValue {
Integer(i64),
Text(&'static str),
}
impl SyncNeutralValue {
fn into_sql_value(self) -> Value {
match self {
Self::Integer(value) => Value::Integer(value),
Self::Text(value) => Value::Text(value.to_string()),
}
}
}
#[derive(Clone, Copy)]
struct SyncNeutralizedColumn {
column: &'static str,
value: SyncNeutralValue,
}
#[derive(Clone, Copy)]
struct SyncRowTransform {
table: &'static str,
key_column: &'static str,
local_columns: &'static [&'static str],
export_defaults: &'static [SyncNeutralizedColumn],
}
const PROXY_CONFIG_EXPORT_DEFAULTS: &[SyncNeutralizedColumn] = &[
SyncNeutralizedColumn {
column: "proxy_enabled",
value: SyncNeutralValue::Integer(0),
},
SyncNeutralizedColumn {
column: "listen_address",
value: SyncNeutralValue::Text("127.0.0.1"),
},
SyncNeutralizedColumn {
column: "listen_port",
value: SyncNeutralValue::Integer(15721),
},
SyncNeutralizedColumn {
column: "enabled",
value: SyncNeutralValue::Integer(0),
},
];
const SYNC_ROW_TRANSFORMS: &[SyncRowTransform] = &[SyncRowTransform {
table: "proxy_config",
key_column: "app_type",
local_columns: PROXY_CONFIG_LOCAL_COLUMNS,
export_defaults: PROXY_CONFIG_EXPORT_DEFAULTS,
}];
/// A database backup entry for the UI
#[derive(Debug, serde::Serialize)]
#[serde(rename_all = "camelCase")]
@@ -46,13 +104,13 @@ impl Database {
/// 导出为 SQLite 兼容的 SQL 文本(内存字符串,完整导出)
pub fn export_sql_string(&self) -> Result<String, AppError> {
let snapshot = self.snapshot_to_memory()?;
Self::dump_sql(&snapshot, &[])
Self::dump_sql(&snapshot, &[], &[])
}
/// Export SQL for sync (WebDAV), skipping local-only tables' data
pub fn export_sql_string_for_sync(&self) -> Result<String, AppError> {
let snapshot = self.snapshot_to_memory()?;
Self::dump_sql(&snapshot, SYNC_SKIP_TABLES)
Self::dump_sql(&snapshot, SYNC_SKIP_TABLES, SYNC_ROW_TRANSFORMS)
}
/// 导出为 SQLite 兼容的 SQL 文本
@@ -82,19 +140,20 @@ impl Database {
/// 从 SQL 字符串导入,返回生成的备份 ID若无备份则为空字符串
pub fn import_sql_string(&self, sql_raw: &str) -> Result<String, AppError> {
self.import_sql_string_inner(sql_raw, &[])
self.import_sql_string_inner(sql_raw, &[], &[])
}
/// Import SQL generated for sync, then restore local-only tables from the
/// current device snapshot before replacing the main database.
pub(crate) fn import_sql_string_for_sync(&self, sql_raw: &str) -> Result<String, AppError> {
self.import_sql_string_inner(sql_raw, SYNC_PRESERVE_TABLES)
self.import_sql_string_inner(sql_raw, SYNC_PRESERVE_TABLES, SYNC_ROW_TRANSFORMS)
}
fn import_sql_string_inner(
&self,
sql_raw: &str,
preserve_tables: &[&str],
row_transforms: &[SyncRowTransform],
) -> Result<String, AppError> {
let sql_content = sql_raw.trim_start_matches('\u{feff}');
Self::validate_cc_switch_sql_export(sql_content)?;
@@ -102,7 +161,7 @@ impl Database {
// 导入前备份现有数据库
let backup_path = self.backup_database_file()?;
let local_snapshot = if preserve_tables.is_empty() {
let local_snapshot = if preserve_tables.is_empty() && row_transforms.is_empty() {
None
} else {
Some(self.snapshot_to_memory()?)
@@ -127,6 +186,7 @@ impl Database {
Self::validate_basic_state(&temp_conn)?;
if let Some(local_snapshot) = local_snapshot.as_ref() {
Self::restore_tables(local_snapshot, &temp_conn, preserve_tables)?;
Self::restore_row_transforms(local_snapshot, &temp_conn, row_transforms)?;
}
// 使用 Backup 将临时库原子写回主库
@@ -232,6 +292,111 @@ impl Database {
Ok(())
}
fn restore_row_transforms(
source_conn: &Connection,
target_conn: &Connection,
transforms: &[SyncRowTransform],
) -> Result<(), AppError> {
for transform in transforms {
Self::restore_row_transform(source_conn, target_conn, transform)?;
}
Ok(())
}
fn restore_row_transform(
source_conn: &Connection,
target_conn: &Connection,
transform: &SyncRowTransform,
) -> Result<(), AppError> {
if !Self::table_exists(source_conn, transform.table)?
|| !Self::table_exists(target_conn, transform.table)?
{
return Ok(());
}
let source_columns = Self::get_table_columns(source_conn, transform.table)?;
let target_columns = Self::get_table_columns(target_conn, transform.table)?;
if !source_columns
.iter()
.any(|column| column == transform.key_column)
|| !target_columns
.iter()
.any(|column| column == transform.key_column)
{
return Ok(());
}
let local_columns = transform
.local_columns
.iter()
.copied()
.filter(|column| {
source_columns.iter().any(|existing| existing == column)
&& target_columns.iter().any(|existing| existing == column)
})
.collect::<Vec<_>>();
if local_columns.is_empty() {
return Ok(());
}
let select_columns = std::iter::once(transform.key_column)
.chain(local_columns.iter().copied())
.map(Self::quote_ident)
.collect::<Vec<_>>()
.join(", ");
let select_sql = format!(
"SELECT {select_columns} FROM {}",
Self::quote_ident(transform.table)
);
let assignments = local_columns
.iter()
.enumerate()
.map(|(idx, column)| format!("{} = ?{}", Self::quote_ident(column), idx + 1))
.collect::<Vec<_>>()
.join(", ");
let update_sql = format!(
"UPDATE {} SET {assignments} WHERE {} = ?{}",
Self::quote_ident(transform.table),
Self::quote_ident(transform.key_column),
local_columns.len() + 1
);
let mut stmt = source_conn.prepare(&select_sql).map_err(|e| {
AppError::Database(format!(
"读取本地表 {} 的同步字段失败: {e}",
transform.table
))
})?;
let mut rows = stmt.query([]).map_err(|e| {
AppError::Database(format!("查询本地表 {} 数据失败: {e}", transform.table))
})?;
while let Some(row) = rows.next().map_err(|e| AppError::Database(e.to_string()))? {
let mut values = Vec::with_capacity(local_columns.len() + 1);
for idx in 1..=local_columns.len() {
values.push(
row.get::<_, Value>(idx)
.map_err(|e| AppError::Database(e.to_string()))?,
);
}
values.push(
row.get::<_, Value>(0)
.map_err(|e| AppError::Database(e.to_string()))?,
);
target_conn
.execute(&update_sql, rusqlite::params_from_iter(values.iter()))
.map_err(|e| {
AppError::Database(format!(
"恢复本地表 {} 的同步字段失败: {e}",
transform.table
))
})?;
}
Ok(())
}
/// Periodic backup: create a new backup if the latest one is older than the configured interval
pub(crate) fn periodic_backup_if_needed(&self) -> Result<(), AppError> {
let interval_hours = crate::settings::effective_backup_interval_hours();
@@ -384,7 +549,11 @@ impl Database {
}
/// 导出数据库为 SQL 文本
fn dump_sql(conn: &Connection, skip_tables: &[&str]) -> Result<String, AppError> {
fn dump_sql(
conn: &Connection,
skip_tables: &[&str],
row_transforms: &[SyncRowTransform],
) -> Result<String, AppError> {
let mut output = String::new();
let timestamp = Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
let user_version: i64 = conn
@@ -450,10 +619,14 @@ impl Database {
while let Some(row) = rows.next().map_err(|e| AppError::Database(e.to_string()))? {
let mut values = Vec::with_capacity(columns.len());
for idx in 0..columns.len() {
let value = row
.get_ref(idx)
.map_err(|e| AppError::Database(e.to_string()))?;
values.push(Self::format_sql_value(value)?);
values.push(
row.get::<_, Value>(idx)
.map_err(|e| AppError::Database(e.to_string()))?,
);
}
if let Some(transform) = row_transforms.iter().find(|t| t.table == table) {
Self::apply_export_defaults(&columns, &mut values, transform);
}
let cols = columns
@@ -463,7 +636,11 @@ impl Database {
.join(", ");
output.push_str(&format!(
"INSERT INTO \"{table}\" ({cols}) VALUES ({});\n",
values.join(", ")
values
.iter()
.map(Self::format_owned_sql_value)
.collect::<Result<Vec<_>, _>>()?
.join(", ")
));
}
}
@@ -488,19 +665,13 @@ impl Database {
Ok(columns)
}
/// 格式化 SQL 值
fn format_sql_value(value: ValueRef<'_>) -> Result<String, AppError> {
fn format_owned_sql_value(value: &Value) -> Result<String, AppError> {
match value {
ValueRef::Null => Ok("NULL".to_string()),
ValueRef::Integer(i) => Ok(i.to_string()),
ValueRef::Real(f) => Ok(f.to_string()),
ValueRef::Text(t) => {
let text = std::str::from_utf8(t)
.map_err(|e| AppError::Database(format!("文本字段不是有效的 UTF-8: {e}")))?;
let escaped = text.replace('\'', "''");
Ok(format!("'{escaped}'"))
}
ValueRef::Blob(bytes) => {
Value::Null => Ok("NULL".to_string()),
Value::Integer(i) => Ok(i.to_string()),
Value::Real(f) => Ok(f.to_string()),
Value::Text(text) => Ok(format!("'{}'", text.replace('\'', "''"))),
Value::Blob(bytes) => {
let mut s = String::from("X'");
for b in bytes {
use std::fmt::Write;
@@ -512,6 +683,22 @@ impl Database {
}
}
fn apply_export_defaults(
columns: &[String],
values: &mut [Value],
transform: &SyncRowTransform,
) {
for default in transform.export_defaults {
if let Some(idx) = columns.iter().position(|column| column == default.column) {
values[idx] = default.value.into_sql_value();
}
}
}
fn quote_ident(value: &str) -> String {
format!("\"{}\"", value.replace('"', "\"\""))
}
/// List all database backup files, sorted by creation time (newest first)
pub fn list_backups() -> Result<Vec<BackupEntry>, AppError> {
let backup_dir = get_app_config_dir().join("backups");
@@ -692,8 +879,75 @@ mod tests {
use super::Database;
use crate::error::AppError;
use crate::settings::{update_settings, AppSettings};
use rusqlite::Connection;
use serial_test::serial;
fn seed_provider(conn: &Connection, id: &str) -> Result<(), AppError> {
conn.execute(
"INSERT INTO providers (id, app_type, name, settings_config, meta)
VALUES (?1, 'claude', ?2, '{}', '{}')",
rusqlite::params![id, format!("Provider {id}")],
)
.map_err(|e| AppError::Database(e.to_string()))?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn set_proxy_row(
conn: &Connection,
app_type: &str,
proxy_enabled: bool,
listen_address: &str,
listen_port: i64,
enabled: bool,
auto_failover_enabled: bool,
max_retries: i64,
) -> Result<(), AppError> {
conn.execute(
"UPDATE proxy_config
SET proxy_enabled = ?2,
listen_address = ?3,
listen_port = ?4,
enabled = ?5,
auto_failover_enabled = ?6,
max_retries = ?7
WHERE app_type = ?1",
rusqlite::params![
app_type,
if proxy_enabled { 1 } else { 0 },
listen_address,
listen_port,
if enabled { 1 } else { 0 },
if auto_failover_enabled { 1 } else { 0 },
max_retries,
],
)
.map_err(|e| AppError::Database(e.to_string()))?;
Ok(())
}
fn read_proxy_row(
conn: &Connection,
app_type: &str,
) -> Result<(bool, String, i64, bool, bool, i64), AppError> {
conn.query_row(
"SELECT proxy_enabled, listen_address, listen_port, enabled, auto_failover_enabled, max_retries
FROM proxy_config WHERE app_type = ?1",
[app_type],
|row| {
Ok((
row.get::<_, i64>(0)? != 0,
row.get(1)?,
row.get(2)?,
row.get::<_, i64>(3)? != 0,
row.get::<_, i64>(4)? != 0,
row.get(5)?,
))
},
)
.map_err(|e| AppError::Database(e.to_string()))
}
#[test]
fn sync_import_preserves_local_only_tables() -> Result<(), AppError> {
let remote_db = Database::memory()?;
@@ -781,6 +1035,97 @@ mod tests {
Ok(())
}
#[test]
fn sync_import_preserves_local_proxy_config_local_fields() -> Result<(), AppError> {
let remote_db = Database::memory()?;
{
let conn = crate::database::lock_conn!(remote_db.conn);
seed_provider(&conn, "remote-provider")?;
set_proxy_row(
&conn,
"claude",
false,
"192.168.10.10",
31001,
false,
true,
9,
)?;
set_proxy_row(&conn, "codex", true, "192.168.10.11", 31002, true, false, 8)?;
set_proxy_row(
&conn,
"gemini",
false,
"192.168.10.12",
31003,
true,
true,
7,
)?;
}
let remote_sql = remote_db.export_sql_string()?;
let local_db = Database::memory()?;
{
let conn = crate::database::lock_conn!(local_db.conn);
seed_provider(&conn, "local-provider")?;
set_proxy_row(&conn, "claude", true, "10.0.0.1", 21001, true, false, 1)?;
set_proxy_row(&conn, "codex", false, "10.0.0.2", 21002, false, true, 2)?;
set_proxy_row(&conn, "gemini", true, "10.0.0.3", 21003, false, false, 3)?;
}
local_db.import_sql_string_for_sync(&remote_sql)?;
let conn = crate::database::lock_conn!(local_db.conn);
assert_eq!(
read_proxy_row(&conn, "claude")?,
(true, "10.0.0.1".to_string(), 21001, true, true, 9)
);
assert_eq!(
read_proxy_row(&conn, "codex")?,
(false, "10.0.0.2".to_string(), 21002, false, false, 8)
);
assert_eq!(
read_proxy_row(&conn, "gemini")?,
(true, "10.0.0.3".to_string(), 21003, false, true, 7)
);
Ok(())
}
#[test]
fn sync_export_scrubs_proxy_config_local_fields_but_keeps_strategy_fields(
) -> Result<(), AppError> {
let db = Database::memory()?;
{
let conn = crate::database::lock_conn!(db.conn);
seed_provider(&conn, "portable-provider")?;
set_proxy_row(&conn, "claude", true, "10.1.0.1", 41001, true, true, 6)?;
set_proxy_row(&conn, "codex", true, "10.1.0.2", 41002, true, false, 5)?;
set_proxy_row(&conn, "gemini", true, "10.1.0.3", 41003, true, true, 4)?;
}
let sync_sql = db.export_sql_string_for_sync()?;
let old_client_db = Database::memory()?;
old_client_db.import_sql_string(&sync_sql)?;
let conn = crate::database::lock_conn!(old_client_db.conn);
assert_eq!(
read_proxy_row(&conn, "claude")?,
(false, "127.0.0.1".to_string(), 15721, false, true, 6)
);
assert_eq!(
read_proxy_row(&conn, "codex")?,
(false, "127.0.0.1".to_string(), 15721, false, false, 5)
);
assert_eq!(
read_proxy_row(&conn, "gemini")?,
(false, "127.0.0.1".to_string(), 15721, false, true, 4)
);
Ok(())
}
#[test]
#[serial]
fn periodic_maintenance_runs_even_when_auto_backup_disabled() -> Result<(), AppError> {

View File

@@ -14,8 +14,8 @@ use crate::services::provider::{
};
use serde_json::{json, Value};
use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::RwLock;
use std::sync::{Arc, OnceLock};
use tokio::sync::{Mutex, RwLock};
/// 用于接管 Live 配置时的占位符(避免客户端提示缺少 key同时不泄露真实 Token
const PROXY_TOKEN_PLACEHOLDER: &str = "PROXY_MANAGED";
@@ -34,6 +34,11 @@ const CLAUDE_MODEL_OVERRIDE_ENV_KEYS: [&str; 6] = [
"ANTHROPIC_SMALL_FAST_MODEL",
];
pub(crate) fn restore_mutation_guard() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
}
#[derive(Clone)]
pub struct ProxyService {
db: Arc<Database>,
@@ -154,6 +159,11 @@ impl ProxyService {
/// 启动代理服务器
pub async fn start(&self) -> Result<ProxyServerInfo, String> {
let _guard = restore_mutation_guard().lock().await;
self.start_unlocked().await
}
async fn start_unlocked(&self) -> Result<ProxyServerInfo, String> {
// 1. 启动时自动设置 proxy_enabled = true
let mut global_config = self
.db
@@ -204,6 +214,8 @@ impl ProxyService {
/// 启动代理服务器(带 Live 配置接管)
pub async fn start_with_takeover(&self) -> Result<ProxyServerInfo, String> {
let _guard = restore_mutation_guard().lock().await;
// 1. 备份各应用的 Live 配置
self.backup_live_configs().await?;
@@ -242,7 +254,7 @@ impl ProxyService {
}
// 5. 启动代理服务器
match self.start().await {
match self.start_unlocked().await {
Ok(info) => Ok(info),
Err(e) => {
// 启动失败,恢复原始配置
@@ -300,13 +312,22 @@ impl ProxyService {
/// - 开启:自动启动代理服务,仅接管当前 app 的 Live 配置
/// - 关闭:仅恢复当前 app 的 Live 配置;若无其它接管,则自动停止代理服务
pub async fn set_takeover_for_app(&self, app_type: &str, enabled: bool) -> Result<(), String> {
let _guard = restore_mutation_guard().lock().await;
self.set_takeover_for_app_unlocked(app_type, enabled).await
}
async fn set_takeover_for_app_unlocked(
&self,
app_type: &str,
enabled: bool,
) -> Result<(), String> {
let app = AppType::from_str(app_type).map_err(|e| format!("无效的应用类型: {e}"))?;
let app_type_str = app.as_str();
if enabled {
// 1) 代理服务未运行则自动启动
if !self.is_running().await {
self.start().await?;
self.start_unlocked().await?;
}
// 2) 已接管则直接返回(幂等);但如果缺少备份或占位符残留,需要重建接管
@@ -429,7 +450,7 @@ impl ProxyService {
if self.is_running().await {
// 此时没有任何 app 处于接管状态,停止服务即可
let _ = self.stop().await;
let _ = self.stop_unlocked().await;
}
}
@@ -699,6 +720,11 @@ impl ProxyService {
/// 停止代理服务器
pub async fn stop(&self) -> Result<(), String> {
let _guard = restore_mutation_guard().lock().await;
self.stop_unlocked().await
}
async fn stop_unlocked(&self) -> Result<(), String> {
if let Some(server) = self.server.write().await.take() {
server
.stop()
@@ -730,8 +756,9 @@ impl ProxyService {
///
/// 会清除 settings 表中的代理状态,下次启动不会自动恢复。
pub async fn stop_with_restore(&self) -> Result<(), String> {
let _guard = restore_mutation_guard().lock().await;
// 1. 停止代理服务器(即使未运行也继续执行恢复逻辑)
if let Err(e) = self.stop().await {
if let Err(e) = self.stop_unlocked().await {
log::warn!("停止代理服务器失败(将继续恢复 Live 配置): {e}");
}
@@ -777,8 +804,9 @@ impl ProxyService {
///
/// 用于程序正常退出时,保留代理状态以便下次启动时自动恢复
pub async fn stop_with_restore_keep_state(&self) -> Result<(), String> {
let _guard = restore_mutation_guard().lock().await;
// 1. 停止代理服务器(即使未运行也继续执行恢复逻辑)
if let Err(e) = self.stop().await {
if let Err(e) = self.stop_unlocked().await {
log::warn!("停止代理服务器失败(将继续恢复 Live 配置): {e}");
}
@@ -1815,8 +1843,54 @@ impl ProxyService {
.map_err(|e| format!("获取代理配置失败: {e}"))
}
/// 更新全局代理配置(统一字段)
pub async fn update_global_proxy_config(
&self,
config: GlobalProxyConfig,
) -> Result<(), String> {
let _guard = restore_mutation_guard().lock().await;
self.db
.update_global_proxy_config(config)
.await
.map_err(|e| format!("保存全局代理配置失败: {e}"))
}
/// 更新指定应用的代理配置(应用级字段)
pub async fn update_proxy_config_for_app(&self, config: AppProxyConfig) -> Result<(), String> {
let _guard = restore_mutation_guard().lock().await;
self.db
.update_proxy_config_for_app(config)
.await
.map_err(|e| format!("保存应用代理配置失败: {e}"))
}
pub async fn set_default_cost_multiplier(
&self,
app_type: &str,
value: &str,
) -> Result<(), String> {
let _guard = restore_mutation_guard().lock().await;
self.db
.set_default_cost_multiplier(app_type, value)
.await
.map_err(|e| format!("保存默认成本倍率失败: {e}"))
}
pub async fn set_pricing_model_source(
&self,
app_type: &str,
value: &str,
) -> Result<(), String> {
let _guard = restore_mutation_guard().lock().await;
self.db
.set_pricing_model_source(app_type, value)
.await
.map_err(|e| format!("保存计费模式来源失败: {e}"))
}
/// 更新代理配置
pub async fn update_config(&self, config: &ProxyConfig) -> Result<(), String> {
let _guard = restore_mutation_guard().lock().await;
// 记录旧配置用于判定是否需要重启
let previous = self
.db
@@ -1901,6 +1975,34 @@ impl ProxyService {
self.server.read().await.is_some()
}
/// 检查当前是否存在会让 WebDAV restore 不安全的本地代理状态
pub async fn has_restore_blocking_proxy_state(&self) -> Result<bool, String> {
if self.is_running().await {
return Ok(true);
}
if self
.db
.has_any_live_backup()
.await
.map_err(|e| format!("读取 live 备份状态失败: {e}"))?
{
return Ok(true);
}
if self.detect_takeover_in_live_configs() {
return Ok(true);
}
if self
.db
.is_live_takeover_active()
.await
.map_err(|e| format!("读取代理接管状态失败: {e}"))?
{
return Ok(true);
}
Ok(false)
}
/// 热更新熔断器配置
///
/// 如果代理服务器正在运行,将新配置应用到所有已创建的熔断器实例
@@ -1941,6 +2043,9 @@ mod tests {
use crate::provider::ProviderMeta;
use serial_test::serial;
use std::env;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tempfile::TempDir;
struct TempHome {
@@ -2108,6 +2213,191 @@ model = "gpt-5.1-codex"
);
}
#[tokio::test]
#[serial]
async fn has_restore_blocking_proxy_state_is_true_when_live_backup_exists_without_enabled_flag()
{
let _home = TempHome::new();
crate::settings::reload_settings().expect("reload settings");
let db = Arc::new(Database::memory().expect("init db"));
let service = ProxyService::new(db.clone());
db.save_live_backup("claude", "{\"env\":{}}")
.await
.expect("seed live backup");
let config = db
.get_proxy_config_for_app("claude")
.await
.expect("get proxy config");
assert!(
!config.enabled,
"enabled flag should remain false for the stronger-artefact test"
);
assert!(
service
.has_restore_blocking_proxy_state()
.await
.expect("check restore blocking proxy state"),
"live backup should block restore even when enabled is false"
);
}
#[tokio::test]
#[serial]
async fn has_restore_blocking_proxy_state_is_true_when_live_config_residue_exists_without_enabled_flag(
) {
let _home = TempHome::new();
crate::settings::reload_settings().expect("reload settings");
let db = Arc::new(Database::memory().expect("init db"));
let service = ProxyService::new(db.clone());
service
.write_claude_live(&json!({
"env": {
"ANTHROPIC_API_KEY": PROXY_TOKEN_PLACEHOLDER
}
}))
.expect("seed taken-over claude live config");
let config = db
.get_proxy_config_for_app("claude")
.await
.expect("get claude proxy config");
assert!(
!config.enabled,
"enabled flag should remain false for the live-residue test"
);
assert!(
service
.has_restore_blocking_proxy_state()
.await
.expect("check restore blocking proxy state"),
"live config residue should block restore even when enabled is false"
);
}
#[tokio::test]
#[serial]
async fn proxy_config_update_waits_for_restore_mutation_guard() {
let db = Arc::new(Database::memory().expect("init db"));
let service = ProxyService::new(db.clone());
let initial = service.get_config().await.expect("read initial config");
let mut updated = initial.clone();
updated.listen_port = if initial.listen_port == 15721 {
15722
} else {
initial.listen_port + 1
};
let expected_port = updated.listen_port;
let guard = restore_mutation_guard().lock().await;
let completed = Arc::new(AtomicBool::new(false));
let completed_bg = Arc::clone(&completed);
let service_bg = service.clone();
let handle = tokio::spawn(async move {
service_bg
.update_config(&updated)
.await
.expect("update config after guard release");
completed_bg.store(true, Ordering::SeqCst);
});
tokio::time::sleep(Duration::from_millis(150)).await;
assert!(
!completed.load(Ordering::SeqCst),
"config update should wait behind the restore/mutation guard"
);
assert!(
!handle.is_finished(),
"config update task should still be blocked by the guard"
);
drop(guard);
tokio::time::timeout(Duration::from_secs(5), handle)
.await
.expect("config update task should finish after guard release")
.expect("config update task should succeed");
assert_eq!(
service
.get_config()
.await
.expect("read config after guard release")
.listen_port,
expected_port
);
}
#[tokio::test]
#[serial]
async fn start_with_takeover_waits_for_restore_mutation_guard() {
let _home = TempHome::new();
crate::settings::reload_settings().expect("reload settings");
let db = Arc::new(Database::memory().expect("init db"));
let service = ProxyService::new(db.clone());
let provider = Provider::with_id(
"claude-provider".to_string(),
"Claude Provider".to_string(),
json!({
"env": {
"ANTHROPIC_API_KEY": "db-key"
}
}),
Some("claude".to_string()),
);
db.save_provider("claude", &provider)
.expect("save claude provider");
db.set_current_provider("claude", &provider.id)
.expect("set current provider");
service
.write_claude_live(&json!({
"env": {
"ANTHROPIC_API_KEY": "live-key"
}
}))
.expect("seed claude live config");
let guard = restore_mutation_guard().lock().await;
let completed = Arc::new(AtomicBool::new(false));
let completed_bg = Arc::clone(&completed);
let service_bg = service.clone();
let handle = tokio::spawn(async move {
service_bg
.start_with_takeover()
.await
.expect("start with takeover after guard release");
completed_bg.store(true, Ordering::SeqCst);
});
tokio::time::sleep(Duration::from_millis(150)).await;
assert!(
!completed.load(Ordering::SeqCst),
"start_with_takeover should wait behind the restore/mutation guard"
);
assert!(
!handle.is_finished(),
"start_with_takeover task should still be blocked by the guard"
);
drop(guard);
tokio::time::timeout(Duration::from_secs(5), handle)
.await
.expect("start_with_takeover task should complete after guard release")
.expect("start_with_takeover task should succeed");
service
.stop_with_restore()
.await
.expect("cleanup started proxy");
}
#[tokio::test]
#[serial]
async fn sync_claude_token_respects_existing_api_key_field() {

View File

@@ -182,6 +182,7 @@ pub async fn upload(
/// Download remote snapshot and apply to local database + skills.
pub async fn download(
db: &crate::database::Database,
proxy_service: &crate::services::ProxyService,
settings: &mut WebDavSyncSettings,
) -> Result<Value, AppError> {
settings.validate()?;
@@ -216,6 +217,11 @@ pub async fn download(
)
.await?;
let _guard = crate::services::proxy::restore_mutation_guard()
.lock()
.await;
ensure_restore_allowed(proxy_service).await?;
// Apply snapshot
apply_snapshot(db, &db_sql, &skills_zip)?;
@@ -233,6 +239,24 @@ pub async fn download(
}))
}
async fn ensure_restore_allowed(
proxy_service: &crate::services::ProxyService,
) -> Result<(), AppError> {
if proxy_service
.has_restore_blocking_proxy_state()
.await
.map_err(AppError::Config)?
{
return Err(localized(
"webdav.sync.restore_blocked_proxy_active",
"当前本地代理或接管状态仍然活跃,请先恢复本地代理状态后再执行 WebDAV 恢复",
"Local proxy or takeover state is still active. Restore local proxy state before running WebDAV restore.",
));
}
Ok(())
}
/// Fetch remote manifest info without downloading artifacts.
pub async fn fetch_remote_info(settings: &WebDavSyncSettings) -> Result<Option<Value>, AppError> {
settings.validate()?;
@@ -669,6 +693,11 @@ fn validate_artifact_size_limit(artifact_name: &str, size: u64) -> Result<(), Ap
#[cfg(test)]
mod tests {
use super::*;
use crate::database::Database;
use crate::provider::Provider;
use crate::services::ProxyService;
use std::sync::Arc;
use tempfile::TempDir;
fn artifact(sha256: &str, size: u64) -> ArtifactMeta {
ArtifactMeta {
@@ -881,4 +910,52 @@ mod tests {
fn validate_artifact_size_limit_accepts_limit_boundary() {
assert!(validate_artifact_size_limit("skills.zip", MAX_SYNC_ARTIFACT_BYTES).is_ok());
}
#[tokio::test]
async fn ensure_restore_allowed_rejects_takeover_artifacts_even_when_enabled_flag_is_false() {
let temp_home = TempDir::new().expect("create temp home");
std::env::set_var("HOME", temp_home.path());
std::env::set_var("USERPROFILE", temp_home.path());
std::env::set_var("CC_SWITCH_TEST_HOME", temp_home.path());
let db = Arc::new(Database::memory().expect("init db"));
let proxy_service = ProxyService::new(db.clone());
let provider = Provider::with_id(
"claude-provider".to_string(),
"Claude Provider".to_string(),
serde_json::json!({
"env": {
"ANTHROPIC_API_KEY": "db-key"
}
}),
Some("claude".to_string()),
);
db.save_provider("claude", &provider)
.expect("save claude provider");
db.set_current_provider("claude", &provider.id)
.expect("set current claude provider");
db.save_live_backup("claude", "{\"env\":{}}")
.await
.expect("seed live backup");
let mut proxy_config = db
.get_proxy_config_for_app("claude")
.await
.expect("get claude proxy config");
proxy_config.enabled = false;
db.update_proxy_config_for_app(proxy_config)
.await
.expect("persist cleared enabled flag");
let err = ensure_restore_allowed(&proxy_service)
.await
.expect_err("live backup should still block restore");
assert!(
err.to_string().contains("restore")
|| err.to_string().contains("恢复")
|| err.to_string().contains("proxy"),
"unexpected error: {err}"
);
}
}