mirror of
https://github.com/farion1231/cc-switch.git
synced 2026-04-28 13:42:51 +08:00
Compare commits
1 Commits
codex/issu
...
codex/upst
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
70504714f0 |
@@ -86,10 +86,7 @@ pub async fn update_global_proxy_config(
|
||||
state: tauri::State<'_, AppState>,
|
||||
config: GlobalProxyConfig,
|
||||
) -> Result<(), String> {
|
||||
let db = &state.db;
|
||||
db.update_global_proxy_config(config)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
state.proxy_service.update_global_proxy_config(config).await
|
||||
}
|
||||
|
||||
/// 获取指定应用的代理配置
|
||||
@@ -114,10 +111,10 @@ pub async fn update_proxy_config_for_app(
|
||||
state: tauri::State<'_, AppState>,
|
||||
config: AppProxyConfig,
|
||||
) -> Result<(), String> {
|
||||
let db = &state.db;
|
||||
db.update_proxy_config_for_app(config)
|
||||
state
|
||||
.proxy_service
|
||||
.update_proxy_config_for_app(config)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
async fn get_default_cost_multiplier_internal(
|
||||
@@ -152,8 +149,11 @@ async fn set_default_cost_multiplier_internal(
|
||||
app_type: &str,
|
||||
value: &str,
|
||||
) -> Result<(), AppError> {
|
||||
let db = &state.db;
|
||||
db.set_default_cost_multiplier(app_type, value).await
|
||||
state
|
||||
.proxy_service
|
||||
.set_default_cost_multiplier(app_type, value)
|
||||
.await
|
||||
.map_err(AppError::Config)
|
||||
}
|
||||
|
||||
#[cfg_attr(not(feature = "test-hooks"), doc(hidden))]
|
||||
@@ -209,8 +209,11 @@ async fn set_pricing_model_source_internal(
|
||||
app_type: &str,
|
||||
value: &str,
|
||||
) -> Result<(), AppError> {
|
||||
let db = &state.db;
|
||||
db.set_pricing_model_source(app_type, value).await
|
||||
state
|
||||
.proxy_service
|
||||
.set_pricing_model_source(app_type, value)
|
||||
.await
|
||||
.map_err(AppError::Config)
|
||||
}
|
||||
|
||||
#[cfg_attr(not(feature = "test-hooks"), doc(hidden))]
|
||||
|
||||
@@ -115,11 +115,17 @@ pub async fn webdav_sync_upload(state: State<'_, AppState>) -> Result<Value, Str
|
||||
#[tauri::command]
|
||||
pub async fn webdav_sync_download(state: State<'_, AppState>) -> Result<Value, String> {
|
||||
let db = state.db.clone();
|
||||
let proxy_service = state.proxy_service.clone();
|
||||
let db_for_sync = db.clone();
|
||||
let mut settings = require_enabled_webdav_settings()?;
|
||||
let _auto_sync_suppression = crate::services::webdav_auto_sync::AutoSyncSuppressionGuard::new();
|
||||
|
||||
let sync_result = run_with_webdav_lock(webdav_sync_service::download(&db, &mut settings)).await;
|
||||
let sync_result = run_with_webdav_lock(webdav_sync_service::download(
|
||||
&db,
|
||||
&proxy_service,
|
||||
&mut settings,
|
||||
))
|
||||
.await;
|
||||
let mut result = map_sync_result(sync_result, |error| {
|
||||
persist_sync_error(&mut settings, error, "manual")
|
||||
})?;
|
||||
|
||||
@@ -7,7 +7,7 @@ use crate::config::get_app_config_dir;
|
||||
use crate::error::AppError;
|
||||
use chrono::{Local, Utc};
|
||||
use rusqlite::backup::Backup;
|
||||
use rusqlite::types::ValueRef;
|
||||
use rusqlite::types::Value;
|
||||
use rusqlite::Connection;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
@@ -33,6 +33,64 @@ const SYNC_PRESERVE_TABLES: &[&str] = &[
|
||||
"usage_daily_rollups",
|
||||
];
|
||||
|
||||
const PROXY_CONFIG_LOCAL_COLUMNS: &[&str] =
|
||||
&["proxy_enabled", "listen_address", "listen_port", "enabled"];
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum SyncNeutralValue {
|
||||
Integer(i64),
|
||||
Text(&'static str),
|
||||
}
|
||||
|
||||
impl SyncNeutralValue {
|
||||
fn into_sql_value(self) -> Value {
|
||||
match self {
|
||||
Self::Integer(value) => Value::Integer(value),
|
||||
Self::Text(value) => Value::Text(value.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct SyncNeutralizedColumn {
|
||||
column: &'static str,
|
||||
value: SyncNeutralValue,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct SyncRowTransform {
|
||||
table: &'static str,
|
||||
key_column: &'static str,
|
||||
local_columns: &'static [&'static str],
|
||||
export_defaults: &'static [SyncNeutralizedColumn],
|
||||
}
|
||||
|
||||
const PROXY_CONFIG_EXPORT_DEFAULTS: &[SyncNeutralizedColumn] = &[
|
||||
SyncNeutralizedColumn {
|
||||
column: "proxy_enabled",
|
||||
value: SyncNeutralValue::Integer(0),
|
||||
},
|
||||
SyncNeutralizedColumn {
|
||||
column: "listen_address",
|
||||
value: SyncNeutralValue::Text("127.0.0.1"),
|
||||
},
|
||||
SyncNeutralizedColumn {
|
||||
column: "listen_port",
|
||||
value: SyncNeutralValue::Integer(15721),
|
||||
},
|
||||
SyncNeutralizedColumn {
|
||||
column: "enabled",
|
||||
value: SyncNeutralValue::Integer(0),
|
||||
},
|
||||
];
|
||||
|
||||
const SYNC_ROW_TRANSFORMS: &[SyncRowTransform] = &[SyncRowTransform {
|
||||
table: "proxy_config",
|
||||
key_column: "app_type",
|
||||
local_columns: PROXY_CONFIG_LOCAL_COLUMNS,
|
||||
export_defaults: PROXY_CONFIG_EXPORT_DEFAULTS,
|
||||
}];
|
||||
|
||||
/// A database backup entry for the UI
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
@@ -46,13 +104,13 @@ impl Database {
|
||||
/// 导出为 SQLite 兼容的 SQL 文本(内存字符串,完整导出)
|
||||
pub fn export_sql_string(&self) -> Result<String, AppError> {
|
||||
let snapshot = self.snapshot_to_memory()?;
|
||||
Self::dump_sql(&snapshot, &[])
|
||||
Self::dump_sql(&snapshot, &[], &[])
|
||||
}
|
||||
|
||||
/// Export SQL for sync (WebDAV), skipping local-only tables' data
|
||||
pub fn export_sql_string_for_sync(&self) -> Result<String, AppError> {
|
||||
let snapshot = self.snapshot_to_memory()?;
|
||||
Self::dump_sql(&snapshot, SYNC_SKIP_TABLES)
|
||||
Self::dump_sql(&snapshot, SYNC_SKIP_TABLES, SYNC_ROW_TRANSFORMS)
|
||||
}
|
||||
|
||||
/// 导出为 SQLite 兼容的 SQL 文本
|
||||
@@ -82,19 +140,20 @@ impl Database {
|
||||
|
||||
/// 从 SQL 字符串导入,返回生成的备份 ID(若无备份则为空字符串)
|
||||
pub fn import_sql_string(&self, sql_raw: &str) -> Result<String, AppError> {
|
||||
self.import_sql_string_inner(sql_raw, &[])
|
||||
self.import_sql_string_inner(sql_raw, &[], &[])
|
||||
}
|
||||
|
||||
/// Import SQL generated for sync, then restore local-only tables from the
|
||||
/// current device snapshot before replacing the main database.
|
||||
pub(crate) fn import_sql_string_for_sync(&self, sql_raw: &str) -> Result<String, AppError> {
|
||||
self.import_sql_string_inner(sql_raw, SYNC_PRESERVE_TABLES)
|
||||
self.import_sql_string_inner(sql_raw, SYNC_PRESERVE_TABLES, SYNC_ROW_TRANSFORMS)
|
||||
}
|
||||
|
||||
fn import_sql_string_inner(
|
||||
&self,
|
||||
sql_raw: &str,
|
||||
preserve_tables: &[&str],
|
||||
row_transforms: &[SyncRowTransform],
|
||||
) -> Result<String, AppError> {
|
||||
let sql_content = sql_raw.trim_start_matches('\u{feff}');
|
||||
Self::validate_cc_switch_sql_export(sql_content)?;
|
||||
@@ -102,7 +161,7 @@ impl Database {
|
||||
// 导入前备份现有数据库
|
||||
let backup_path = self.backup_database_file()?;
|
||||
|
||||
let local_snapshot = if preserve_tables.is_empty() {
|
||||
let local_snapshot = if preserve_tables.is_empty() && row_transforms.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(self.snapshot_to_memory()?)
|
||||
@@ -127,6 +186,7 @@ impl Database {
|
||||
Self::validate_basic_state(&temp_conn)?;
|
||||
if let Some(local_snapshot) = local_snapshot.as_ref() {
|
||||
Self::restore_tables(local_snapshot, &temp_conn, preserve_tables)?;
|
||||
Self::restore_row_transforms(local_snapshot, &temp_conn, row_transforms)?;
|
||||
}
|
||||
|
||||
// 使用 Backup 将临时库原子写回主库
|
||||
@@ -232,6 +292,111 @@ impl Database {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn restore_row_transforms(
|
||||
source_conn: &Connection,
|
||||
target_conn: &Connection,
|
||||
transforms: &[SyncRowTransform],
|
||||
) -> Result<(), AppError> {
|
||||
for transform in transforms {
|
||||
Self::restore_row_transform(source_conn, target_conn, transform)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn restore_row_transform(
|
||||
source_conn: &Connection,
|
||||
target_conn: &Connection,
|
||||
transform: &SyncRowTransform,
|
||||
) -> Result<(), AppError> {
|
||||
if !Self::table_exists(source_conn, transform.table)?
|
||||
|| !Self::table_exists(target_conn, transform.table)?
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let source_columns = Self::get_table_columns(source_conn, transform.table)?;
|
||||
let target_columns = Self::get_table_columns(target_conn, transform.table)?;
|
||||
if !source_columns
|
||||
.iter()
|
||||
.any(|column| column == transform.key_column)
|
||||
|| !target_columns
|
||||
.iter()
|
||||
.any(|column| column == transform.key_column)
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let local_columns = transform
|
||||
.local_columns
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|column| {
|
||||
source_columns.iter().any(|existing| existing == column)
|
||||
&& target_columns.iter().any(|existing| existing == column)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
if local_columns.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let select_columns = std::iter::once(transform.key_column)
|
||||
.chain(local_columns.iter().copied())
|
||||
.map(Self::quote_ident)
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
let select_sql = format!(
|
||||
"SELECT {select_columns} FROM {}",
|
||||
Self::quote_ident(transform.table)
|
||||
);
|
||||
let assignments = local_columns
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, column)| format!("{} = ?{}", Self::quote_ident(column), idx + 1))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
let update_sql = format!(
|
||||
"UPDATE {} SET {assignments} WHERE {} = ?{}",
|
||||
Self::quote_ident(transform.table),
|
||||
Self::quote_ident(transform.key_column),
|
||||
local_columns.len() + 1
|
||||
);
|
||||
|
||||
let mut stmt = source_conn.prepare(&select_sql).map_err(|e| {
|
||||
AppError::Database(format!(
|
||||
"读取本地表 {} 的同步字段失败: {e}",
|
||||
transform.table
|
||||
))
|
||||
})?;
|
||||
let mut rows = stmt.query([]).map_err(|e| {
|
||||
AppError::Database(format!("查询本地表 {} 数据失败: {e}", transform.table))
|
||||
})?;
|
||||
|
||||
while let Some(row) = rows.next().map_err(|e| AppError::Database(e.to_string()))? {
|
||||
let mut values = Vec::with_capacity(local_columns.len() + 1);
|
||||
for idx in 1..=local_columns.len() {
|
||||
values.push(
|
||||
row.get::<_, Value>(idx)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?,
|
||||
);
|
||||
}
|
||||
values.push(
|
||||
row.get::<_, Value>(0)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?,
|
||||
);
|
||||
|
||||
target_conn
|
||||
.execute(&update_sql, rusqlite::params_from_iter(values.iter()))
|
||||
.map_err(|e| {
|
||||
AppError::Database(format!(
|
||||
"恢复本地表 {} 的同步字段失败: {e}",
|
||||
transform.table
|
||||
))
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Periodic backup: create a new backup if the latest one is older than the configured interval
|
||||
pub(crate) fn periodic_backup_if_needed(&self) -> Result<(), AppError> {
|
||||
let interval_hours = crate::settings::effective_backup_interval_hours();
|
||||
@@ -384,7 +549,11 @@ impl Database {
|
||||
}
|
||||
|
||||
/// 导出数据库为 SQL 文本
|
||||
fn dump_sql(conn: &Connection, skip_tables: &[&str]) -> Result<String, AppError> {
|
||||
fn dump_sql(
|
||||
conn: &Connection,
|
||||
skip_tables: &[&str],
|
||||
row_transforms: &[SyncRowTransform],
|
||||
) -> Result<String, AppError> {
|
||||
let mut output = String::new();
|
||||
let timestamp = Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
|
||||
let user_version: i64 = conn
|
||||
@@ -450,10 +619,14 @@ impl Database {
|
||||
while let Some(row) = rows.next().map_err(|e| AppError::Database(e.to_string()))? {
|
||||
let mut values = Vec::with_capacity(columns.len());
|
||||
for idx in 0..columns.len() {
|
||||
let value = row
|
||||
.get_ref(idx)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
values.push(Self::format_sql_value(value)?);
|
||||
values.push(
|
||||
row.get::<_, Value>(idx)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?,
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(transform) = row_transforms.iter().find(|t| t.table == table) {
|
||||
Self::apply_export_defaults(&columns, &mut values, transform);
|
||||
}
|
||||
|
||||
let cols = columns
|
||||
@@ -463,7 +636,11 @@ impl Database {
|
||||
.join(", ");
|
||||
output.push_str(&format!(
|
||||
"INSERT INTO \"{table}\" ({cols}) VALUES ({});\n",
|
||||
values.join(", ")
|
||||
values
|
||||
.iter()
|
||||
.map(Self::format_owned_sql_value)
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.join(", ")
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -488,19 +665,13 @@ impl Database {
|
||||
Ok(columns)
|
||||
}
|
||||
|
||||
/// 格式化 SQL 值
|
||||
fn format_sql_value(value: ValueRef<'_>) -> Result<String, AppError> {
|
||||
fn format_owned_sql_value(value: &Value) -> Result<String, AppError> {
|
||||
match value {
|
||||
ValueRef::Null => Ok("NULL".to_string()),
|
||||
ValueRef::Integer(i) => Ok(i.to_string()),
|
||||
ValueRef::Real(f) => Ok(f.to_string()),
|
||||
ValueRef::Text(t) => {
|
||||
let text = std::str::from_utf8(t)
|
||||
.map_err(|e| AppError::Database(format!("文本字段不是有效的 UTF-8: {e}")))?;
|
||||
let escaped = text.replace('\'', "''");
|
||||
Ok(format!("'{escaped}'"))
|
||||
}
|
||||
ValueRef::Blob(bytes) => {
|
||||
Value::Null => Ok("NULL".to_string()),
|
||||
Value::Integer(i) => Ok(i.to_string()),
|
||||
Value::Real(f) => Ok(f.to_string()),
|
||||
Value::Text(text) => Ok(format!("'{}'", text.replace('\'', "''"))),
|
||||
Value::Blob(bytes) => {
|
||||
let mut s = String::from("X'");
|
||||
for b in bytes {
|
||||
use std::fmt::Write;
|
||||
@@ -512,6 +683,22 @@ impl Database {
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_export_defaults(
|
||||
columns: &[String],
|
||||
values: &mut [Value],
|
||||
transform: &SyncRowTransform,
|
||||
) {
|
||||
for default in transform.export_defaults {
|
||||
if let Some(idx) = columns.iter().position(|column| column == default.column) {
|
||||
values[idx] = default.value.into_sql_value();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn quote_ident(value: &str) -> String {
|
||||
format!("\"{}\"", value.replace('"', "\"\""))
|
||||
}
|
||||
|
||||
/// List all database backup files, sorted by creation time (newest first)
|
||||
pub fn list_backups() -> Result<Vec<BackupEntry>, AppError> {
|
||||
let backup_dir = get_app_config_dir().join("backups");
|
||||
@@ -692,8 +879,75 @@ mod tests {
|
||||
use super::Database;
|
||||
use crate::error::AppError;
|
||||
use crate::settings::{update_settings, AppSettings};
|
||||
use rusqlite::Connection;
|
||||
use serial_test::serial;
|
||||
|
||||
fn seed_provider(conn: &Connection, id: &str) -> Result<(), AppError> {
|
||||
conn.execute(
|
||||
"INSERT INTO providers (id, app_type, name, settings_config, meta)
|
||||
VALUES (?1, 'claude', ?2, '{}', '{}')",
|
||||
rusqlite::params![id, format!("Provider {id}")],
|
||||
)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn set_proxy_row(
|
||||
conn: &Connection,
|
||||
app_type: &str,
|
||||
proxy_enabled: bool,
|
||||
listen_address: &str,
|
||||
listen_port: i64,
|
||||
enabled: bool,
|
||||
auto_failover_enabled: bool,
|
||||
max_retries: i64,
|
||||
) -> Result<(), AppError> {
|
||||
conn.execute(
|
||||
"UPDATE proxy_config
|
||||
SET proxy_enabled = ?2,
|
||||
listen_address = ?3,
|
||||
listen_port = ?4,
|
||||
enabled = ?5,
|
||||
auto_failover_enabled = ?6,
|
||||
max_retries = ?7
|
||||
WHERE app_type = ?1",
|
||||
rusqlite::params![
|
||||
app_type,
|
||||
if proxy_enabled { 1 } else { 0 },
|
||||
listen_address,
|
||||
listen_port,
|
||||
if enabled { 1 } else { 0 },
|
||||
if auto_failover_enabled { 1 } else { 0 },
|
||||
max_retries,
|
||||
],
|
||||
)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn read_proxy_row(
|
||||
conn: &Connection,
|
||||
app_type: &str,
|
||||
) -> Result<(bool, String, i64, bool, bool, i64), AppError> {
|
||||
conn.query_row(
|
||||
"SELECT proxy_enabled, listen_address, listen_port, enabled, auto_failover_enabled, max_retries
|
||||
FROM proxy_config WHERE app_type = ?1",
|
||||
[app_type],
|
||||
|row| {
|
||||
Ok((
|
||||
row.get::<_, i64>(0)? != 0,
|
||||
row.get(1)?,
|
||||
row.get(2)?,
|
||||
row.get::<_, i64>(3)? != 0,
|
||||
row.get::<_, i64>(4)? != 0,
|
||||
row.get(5)?,
|
||||
))
|
||||
},
|
||||
)
|
||||
.map_err(|e| AppError::Database(e.to_string()))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_import_preserves_local_only_tables() -> Result<(), AppError> {
|
||||
let remote_db = Database::memory()?;
|
||||
@@ -781,6 +1035,97 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_import_preserves_local_proxy_config_local_fields() -> Result<(), AppError> {
|
||||
let remote_db = Database::memory()?;
|
||||
{
|
||||
let conn = crate::database::lock_conn!(remote_db.conn);
|
||||
seed_provider(&conn, "remote-provider")?;
|
||||
set_proxy_row(
|
||||
&conn,
|
||||
"claude",
|
||||
false,
|
||||
"192.168.10.10",
|
||||
31001,
|
||||
false,
|
||||
true,
|
||||
9,
|
||||
)?;
|
||||
set_proxy_row(&conn, "codex", true, "192.168.10.11", 31002, true, false, 8)?;
|
||||
set_proxy_row(
|
||||
&conn,
|
||||
"gemini",
|
||||
false,
|
||||
"192.168.10.12",
|
||||
31003,
|
||||
true,
|
||||
true,
|
||||
7,
|
||||
)?;
|
||||
}
|
||||
let remote_sql = remote_db.export_sql_string()?;
|
||||
|
||||
let local_db = Database::memory()?;
|
||||
{
|
||||
let conn = crate::database::lock_conn!(local_db.conn);
|
||||
seed_provider(&conn, "local-provider")?;
|
||||
set_proxy_row(&conn, "claude", true, "10.0.0.1", 21001, true, false, 1)?;
|
||||
set_proxy_row(&conn, "codex", false, "10.0.0.2", 21002, false, true, 2)?;
|
||||
set_proxy_row(&conn, "gemini", true, "10.0.0.3", 21003, false, false, 3)?;
|
||||
}
|
||||
|
||||
local_db.import_sql_string_for_sync(&remote_sql)?;
|
||||
|
||||
let conn = crate::database::lock_conn!(local_db.conn);
|
||||
assert_eq!(
|
||||
read_proxy_row(&conn, "claude")?,
|
||||
(true, "10.0.0.1".to_string(), 21001, true, true, 9)
|
||||
);
|
||||
assert_eq!(
|
||||
read_proxy_row(&conn, "codex")?,
|
||||
(false, "10.0.0.2".to_string(), 21002, false, false, 8)
|
||||
);
|
||||
assert_eq!(
|
||||
read_proxy_row(&conn, "gemini")?,
|
||||
(true, "10.0.0.3".to_string(), 21003, false, true, 7)
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_export_scrubs_proxy_config_local_fields_but_keeps_strategy_fields(
|
||||
) -> Result<(), AppError> {
|
||||
let db = Database::memory()?;
|
||||
{
|
||||
let conn = crate::database::lock_conn!(db.conn);
|
||||
seed_provider(&conn, "portable-provider")?;
|
||||
set_proxy_row(&conn, "claude", true, "10.1.0.1", 41001, true, true, 6)?;
|
||||
set_proxy_row(&conn, "codex", true, "10.1.0.2", 41002, true, false, 5)?;
|
||||
set_proxy_row(&conn, "gemini", true, "10.1.0.3", 41003, true, true, 4)?;
|
||||
}
|
||||
|
||||
let sync_sql = db.export_sql_string_for_sync()?;
|
||||
let old_client_db = Database::memory()?;
|
||||
old_client_db.import_sql_string(&sync_sql)?;
|
||||
|
||||
let conn = crate::database::lock_conn!(old_client_db.conn);
|
||||
assert_eq!(
|
||||
read_proxy_row(&conn, "claude")?,
|
||||
(false, "127.0.0.1".to_string(), 15721, false, true, 6)
|
||||
);
|
||||
assert_eq!(
|
||||
read_proxy_row(&conn, "codex")?,
|
||||
(false, "127.0.0.1".to_string(), 15721, false, false, 5)
|
||||
);
|
||||
assert_eq!(
|
||||
read_proxy_row(&conn, "gemini")?,
|
||||
(false, "127.0.0.1".to_string(), 15721, false, true, 4)
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn periodic_maintenance_runs_even_when_auto_backup_disabled() -> Result<(), AppError> {
|
||||
|
||||
@@ -14,8 +14,8 @@ use crate::services::provider::{
|
||||
};
|
||||
use serde_json::{json, Value};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
|
||||
/// 用于接管 Live 配置时的占位符(避免客户端提示缺少 key,同时不泄露真实 Token)
|
||||
const PROXY_TOKEN_PLACEHOLDER: &str = "PROXY_MANAGED";
|
||||
@@ -34,6 +34,11 @@ const CLAUDE_MODEL_OVERRIDE_ENV_KEYS: [&str; 6] = [
|
||||
"ANTHROPIC_SMALL_FAST_MODEL",
|
||||
];
|
||||
|
||||
pub(crate) fn restore_mutation_guard() -> &'static Mutex<()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ProxyService {
|
||||
db: Arc<Database>,
|
||||
@@ -154,6 +159,11 @@ impl ProxyService {
|
||||
|
||||
/// 启动代理服务器
|
||||
pub async fn start(&self) -> Result<ProxyServerInfo, String> {
|
||||
let _guard = restore_mutation_guard().lock().await;
|
||||
self.start_unlocked().await
|
||||
}
|
||||
|
||||
async fn start_unlocked(&self) -> Result<ProxyServerInfo, String> {
|
||||
// 1. 启动时自动设置 proxy_enabled = true
|
||||
let mut global_config = self
|
||||
.db
|
||||
@@ -204,6 +214,8 @@ impl ProxyService {
|
||||
|
||||
/// 启动代理服务器(带 Live 配置接管)
|
||||
pub async fn start_with_takeover(&self) -> Result<ProxyServerInfo, String> {
|
||||
let _guard = restore_mutation_guard().lock().await;
|
||||
|
||||
// 1. 备份各应用的 Live 配置
|
||||
self.backup_live_configs().await?;
|
||||
|
||||
@@ -242,7 +254,7 @@ impl ProxyService {
|
||||
}
|
||||
|
||||
// 5. 启动代理服务器
|
||||
match self.start().await {
|
||||
match self.start_unlocked().await {
|
||||
Ok(info) => Ok(info),
|
||||
Err(e) => {
|
||||
// 启动失败,恢复原始配置
|
||||
@@ -300,13 +312,22 @@ impl ProxyService {
|
||||
/// - 开启:自动启动代理服务,仅接管当前 app 的 Live 配置
|
||||
/// - 关闭:仅恢复当前 app 的 Live 配置;若无其它接管,则自动停止代理服务
|
||||
pub async fn set_takeover_for_app(&self, app_type: &str, enabled: bool) -> Result<(), String> {
|
||||
let _guard = restore_mutation_guard().lock().await;
|
||||
self.set_takeover_for_app_unlocked(app_type, enabled).await
|
||||
}
|
||||
|
||||
async fn set_takeover_for_app_unlocked(
|
||||
&self,
|
||||
app_type: &str,
|
||||
enabled: bool,
|
||||
) -> Result<(), String> {
|
||||
let app = AppType::from_str(app_type).map_err(|e| format!("无效的应用类型: {e}"))?;
|
||||
let app_type_str = app.as_str();
|
||||
|
||||
if enabled {
|
||||
// 1) 代理服务未运行则自动启动
|
||||
if !self.is_running().await {
|
||||
self.start().await?;
|
||||
self.start_unlocked().await?;
|
||||
}
|
||||
|
||||
// 2) 已接管则直接返回(幂等);但如果缺少备份或占位符残留,需要重建接管
|
||||
@@ -429,7 +450,7 @@ impl ProxyService {
|
||||
|
||||
if self.is_running().await {
|
||||
// 此时没有任何 app 处于接管状态,停止服务即可
|
||||
let _ = self.stop().await;
|
||||
let _ = self.stop_unlocked().await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -699,6 +720,11 @@ impl ProxyService {
|
||||
|
||||
/// 停止代理服务器
|
||||
pub async fn stop(&self) -> Result<(), String> {
|
||||
let _guard = restore_mutation_guard().lock().await;
|
||||
self.stop_unlocked().await
|
||||
}
|
||||
|
||||
async fn stop_unlocked(&self) -> Result<(), String> {
|
||||
if let Some(server) = self.server.write().await.take() {
|
||||
server
|
||||
.stop()
|
||||
@@ -730,8 +756,9 @@ impl ProxyService {
|
||||
///
|
||||
/// 会清除 settings 表中的代理状态,下次启动不会自动恢复。
|
||||
pub async fn stop_with_restore(&self) -> Result<(), String> {
|
||||
let _guard = restore_mutation_guard().lock().await;
|
||||
// 1. 停止代理服务器(即使未运行也继续执行恢复逻辑)
|
||||
if let Err(e) = self.stop().await {
|
||||
if let Err(e) = self.stop_unlocked().await {
|
||||
log::warn!("停止代理服务器失败(将继续恢复 Live 配置): {e}");
|
||||
}
|
||||
|
||||
@@ -777,8 +804,9 @@ impl ProxyService {
|
||||
///
|
||||
/// 用于程序正常退出时,保留代理状态以便下次启动时自动恢复
|
||||
pub async fn stop_with_restore_keep_state(&self) -> Result<(), String> {
|
||||
let _guard = restore_mutation_guard().lock().await;
|
||||
// 1. 停止代理服务器(即使未运行也继续执行恢复逻辑)
|
||||
if let Err(e) = self.stop().await {
|
||||
if let Err(e) = self.stop_unlocked().await {
|
||||
log::warn!("停止代理服务器失败(将继续恢复 Live 配置): {e}");
|
||||
}
|
||||
|
||||
@@ -1815,8 +1843,54 @@ impl ProxyService {
|
||||
.map_err(|e| format!("获取代理配置失败: {e}"))
|
||||
}
|
||||
|
||||
/// 更新全局代理配置(统一字段)
|
||||
pub async fn update_global_proxy_config(
|
||||
&self,
|
||||
config: GlobalProxyConfig,
|
||||
) -> Result<(), String> {
|
||||
let _guard = restore_mutation_guard().lock().await;
|
||||
self.db
|
||||
.update_global_proxy_config(config)
|
||||
.await
|
||||
.map_err(|e| format!("保存全局代理配置失败: {e}"))
|
||||
}
|
||||
|
||||
/// 更新指定应用的代理配置(应用级字段)
|
||||
pub async fn update_proxy_config_for_app(&self, config: AppProxyConfig) -> Result<(), String> {
|
||||
let _guard = restore_mutation_guard().lock().await;
|
||||
self.db
|
||||
.update_proxy_config_for_app(config)
|
||||
.await
|
||||
.map_err(|e| format!("保存应用代理配置失败: {e}"))
|
||||
}
|
||||
|
||||
pub async fn set_default_cost_multiplier(
|
||||
&self,
|
||||
app_type: &str,
|
||||
value: &str,
|
||||
) -> Result<(), String> {
|
||||
let _guard = restore_mutation_guard().lock().await;
|
||||
self.db
|
||||
.set_default_cost_multiplier(app_type, value)
|
||||
.await
|
||||
.map_err(|e| format!("保存默认成本倍率失败: {e}"))
|
||||
}
|
||||
|
||||
pub async fn set_pricing_model_source(
|
||||
&self,
|
||||
app_type: &str,
|
||||
value: &str,
|
||||
) -> Result<(), String> {
|
||||
let _guard = restore_mutation_guard().lock().await;
|
||||
self.db
|
||||
.set_pricing_model_source(app_type, value)
|
||||
.await
|
||||
.map_err(|e| format!("保存计费模式来源失败: {e}"))
|
||||
}
|
||||
|
||||
/// 更新代理配置
|
||||
pub async fn update_config(&self, config: &ProxyConfig) -> Result<(), String> {
|
||||
let _guard = restore_mutation_guard().lock().await;
|
||||
// 记录旧配置用于判定是否需要重启
|
||||
let previous = self
|
||||
.db
|
||||
@@ -1901,6 +1975,34 @@ impl ProxyService {
|
||||
self.server.read().await.is_some()
|
||||
}
|
||||
|
||||
/// 检查当前是否存在会让 WebDAV restore 不安全的本地代理状态
|
||||
pub async fn has_restore_blocking_proxy_state(&self) -> Result<bool, String> {
|
||||
if self.is_running().await {
|
||||
return Ok(true);
|
||||
}
|
||||
if self
|
||||
.db
|
||||
.has_any_live_backup()
|
||||
.await
|
||||
.map_err(|e| format!("读取 live 备份状态失败: {e}"))?
|
||||
{
|
||||
return Ok(true);
|
||||
}
|
||||
if self.detect_takeover_in_live_configs() {
|
||||
return Ok(true);
|
||||
}
|
||||
if self
|
||||
.db
|
||||
.is_live_takeover_active()
|
||||
.await
|
||||
.map_err(|e| format!("读取代理接管状态失败: {e}"))?
|
||||
{
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
/// 热更新熔断器配置
|
||||
///
|
||||
/// 如果代理服务器正在运行,将新配置应用到所有已创建的熔断器实例
|
||||
@@ -1941,6 +2043,9 @@ mod tests {
|
||||
use crate::provider::ProviderMeta;
|
||||
use serial_test::serial;
|
||||
use std::env;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tempfile::TempDir;
|
||||
|
||||
struct TempHome {
|
||||
@@ -2108,6 +2213,191 @@ model = "gpt-5.1-codex"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn has_restore_blocking_proxy_state_is_true_when_live_backup_exists_without_enabled_flag()
|
||||
{
|
||||
let _home = TempHome::new();
|
||||
crate::settings::reload_settings().expect("reload settings");
|
||||
|
||||
let db = Arc::new(Database::memory().expect("init db"));
|
||||
let service = ProxyService::new(db.clone());
|
||||
|
||||
db.save_live_backup("claude", "{\"env\":{}}")
|
||||
.await
|
||||
.expect("seed live backup");
|
||||
|
||||
let config = db
|
||||
.get_proxy_config_for_app("claude")
|
||||
.await
|
||||
.expect("get proxy config");
|
||||
assert!(
|
||||
!config.enabled,
|
||||
"enabled flag should remain false for the stronger-artefact test"
|
||||
);
|
||||
assert!(
|
||||
service
|
||||
.has_restore_blocking_proxy_state()
|
||||
.await
|
||||
.expect("check restore blocking proxy state"),
|
||||
"live backup should block restore even when enabled is false"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn has_restore_blocking_proxy_state_is_true_when_live_config_residue_exists_without_enabled_flag(
|
||||
) {
|
||||
let _home = TempHome::new();
|
||||
crate::settings::reload_settings().expect("reload settings");
|
||||
|
||||
let db = Arc::new(Database::memory().expect("init db"));
|
||||
let service = ProxyService::new(db.clone());
|
||||
|
||||
service
|
||||
.write_claude_live(&json!({
|
||||
"env": {
|
||||
"ANTHROPIC_API_KEY": PROXY_TOKEN_PLACEHOLDER
|
||||
}
|
||||
}))
|
||||
.expect("seed taken-over claude live config");
|
||||
|
||||
let config = db
|
||||
.get_proxy_config_for_app("claude")
|
||||
.await
|
||||
.expect("get claude proxy config");
|
||||
assert!(
|
||||
!config.enabled,
|
||||
"enabled flag should remain false for the live-residue test"
|
||||
);
|
||||
assert!(
|
||||
service
|
||||
.has_restore_blocking_proxy_state()
|
||||
.await
|
||||
.expect("check restore blocking proxy state"),
|
||||
"live config residue should block restore even when enabled is false"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn proxy_config_update_waits_for_restore_mutation_guard() {
|
||||
let db = Arc::new(Database::memory().expect("init db"));
|
||||
let service = ProxyService::new(db.clone());
|
||||
|
||||
let initial = service.get_config().await.expect("read initial config");
|
||||
let mut updated = initial.clone();
|
||||
updated.listen_port = if initial.listen_port == 15721 {
|
||||
15722
|
||||
} else {
|
||||
initial.listen_port + 1
|
||||
};
|
||||
let expected_port = updated.listen_port;
|
||||
|
||||
let guard = restore_mutation_guard().lock().await;
|
||||
let completed = Arc::new(AtomicBool::new(false));
|
||||
let completed_bg = Arc::clone(&completed);
|
||||
let service_bg = service.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
service_bg
|
||||
.update_config(&updated)
|
||||
.await
|
||||
.expect("update config after guard release");
|
||||
completed_bg.store(true, Ordering::SeqCst);
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(150)).await;
|
||||
assert!(
|
||||
!completed.load(Ordering::SeqCst),
|
||||
"config update should wait behind the restore/mutation guard"
|
||||
);
|
||||
assert!(
|
||||
!handle.is_finished(),
|
||||
"config update task should still be blocked by the guard"
|
||||
);
|
||||
|
||||
drop(guard);
|
||||
|
||||
tokio::time::timeout(Duration::from_secs(5), handle)
|
||||
.await
|
||||
.expect("config update task should finish after guard release")
|
||||
.expect("config update task should succeed");
|
||||
|
||||
assert_eq!(
|
||||
service
|
||||
.get_config()
|
||||
.await
|
||||
.expect("read config after guard release")
|
||||
.listen_port,
|
||||
expected_port
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn start_with_takeover_waits_for_restore_mutation_guard() {
|
||||
let _home = TempHome::new();
|
||||
crate::settings::reload_settings().expect("reload settings");
|
||||
|
||||
let db = Arc::new(Database::memory().expect("init db"));
|
||||
let service = ProxyService::new(db.clone());
|
||||
let provider = Provider::with_id(
|
||||
"claude-provider".to_string(),
|
||||
"Claude Provider".to_string(),
|
||||
json!({
|
||||
"env": {
|
||||
"ANTHROPIC_API_KEY": "db-key"
|
||||
}
|
||||
}),
|
||||
Some("claude".to_string()),
|
||||
);
|
||||
db.save_provider("claude", &provider)
|
||||
.expect("save claude provider");
|
||||
db.set_current_provider("claude", &provider.id)
|
||||
.expect("set current provider");
|
||||
service
|
||||
.write_claude_live(&json!({
|
||||
"env": {
|
||||
"ANTHROPIC_API_KEY": "live-key"
|
||||
}
|
||||
}))
|
||||
.expect("seed claude live config");
|
||||
|
||||
let guard = restore_mutation_guard().lock().await;
|
||||
let completed = Arc::new(AtomicBool::new(false));
|
||||
let completed_bg = Arc::clone(&completed);
|
||||
let service_bg = service.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
service_bg
|
||||
.start_with_takeover()
|
||||
.await
|
||||
.expect("start with takeover after guard release");
|
||||
completed_bg.store(true, Ordering::SeqCst);
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(150)).await;
|
||||
assert!(
|
||||
!completed.load(Ordering::SeqCst),
|
||||
"start_with_takeover should wait behind the restore/mutation guard"
|
||||
);
|
||||
assert!(
|
||||
!handle.is_finished(),
|
||||
"start_with_takeover task should still be blocked by the guard"
|
||||
);
|
||||
|
||||
drop(guard);
|
||||
|
||||
tokio::time::timeout(Duration::from_secs(5), handle)
|
||||
.await
|
||||
.expect("start_with_takeover task should complete after guard release")
|
||||
.expect("start_with_takeover task should succeed");
|
||||
|
||||
service
|
||||
.stop_with_restore()
|
||||
.await
|
||||
.expect("cleanup started proxy");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn sync_claude_token_respects_existing_api_key_field() {
|
||||
|
||||
@@ -182,6 +182,7 @@ pub async fn upload(
|
||||
/// Download remote snapshot and apply to local database + skills.
|
||||
pub async fn download(
|
||||
db: &crate::database::Database,
|
||||
proxy_service: &crate::services::ProxyService,
|
||||
settings: &mut WebDavSyncSettings,
|
||||
) -> Result<Value, AppError> {
|
||||
settings.validate()?;
|
||||
@@ -216,6 +217,11 @@ pub async fn download(
|
||||
)
|
||||
.await?;
|
||||
|
||||
let _guard = crate::services::proxy::restore_mutation_guard()
|
||||
.lock()
|
||||
.await;
|
||||
ensure_restore_allowed(proxy_service).await?;
|
||||
|
||||
// Apply snapshot
|
||||
apply_snapshot(db, &db_sql, &skills_zip)?;
|
||||
|
||||
@@ -233,6 +239,24 @@ pub async fn download(
|
||||
}))
|
||||
}
|
||||
|
||||
async fn ensure_restore_allowed(
|
||||
proxy_service: &crate::services::ProxyService,
|
||||
) -> Result<(), AppError> {
|
||||
if proxy_service
|
||||
.has_restore_blocking_proxy_state()
|
||||
.await
|
||||
.map_err(AppError::Config)?
|
||||
{
|
||||
return Err(localized(
|
||||
"webdav.sync.restore_blocked_proxy_active",
|
||||
"当前本地代理或接管状态仍然活跃,请先恢复本地代理状态后再执行 WebDAV 恢复",
|
||||
"Local proxy or takeover state is still active. Restore local proxy state before running WebDAV restore.",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Fetch remote manifest info without downloading artifacts.
|
||||
pub async fn fetch_remote_info(settings: &WebDavSyncSettings) -> Result<Option<Value>, AppError> {
|
||||
settings.validate()?;
|
||||
@@ -669,6 +693,11 @@ fn validate_artifact_size_limit(artifact_name: &str, size: u64) -> Result<(), Ap
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::database::Database;
|
||||
use crate::provider::Provider;
|
||||
use crate::services::ProxyService;
|
||||
use std::sync::Arc;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn artifact(sha256: &str, size: u64) -> ArtifactMeta {
|
||||
ArtifactMeta {
|
||||
@@ -881,4 +910,52 @@ mod tests {
|
||||
fn validate_artifact_size_limit_accepts_limit_boundary() {
|
||||
assert!(validate_artifact_size_limit("skills.zip", MAX_SYNC_ARTIFACT_BYTES).is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ensure_restore_allowed_rejects_takeover_artifacts_even_when_enabled_flag_is_false() {
|
||||
let temp_home = TempDir::new().expect("create temp home");
|
||||
std::env::set_var("HOME", temp_home.path());
|
||||
std::env::set_var("USERPROFILE", temp_home.path());
|
||||
std::env::set_var("CC_SWITCH_TEST_HOME", temp_home.path());
|
||||
|
||||
let db = Arc::new(Database::memory().expect("init db"));
|
||||
let proxy_service = ProxyService::new(db.clone());
|
||||
|
||||
let provider = Provider::with_id(
|
||||
"claude-provider".to_string(),
|
||||
"Claude Provider".to_string(),
|
||||
serde_json::json!({
|
||||
"env": {
|
||||
"ANTHROPIC_API_KEY": "db-key"
|
||||
}
|
||||
}),
|
||||
Some("claude".to_string()),
|
||||
);
|
||||
db.save_provider("claude", &provider)
|
||||
.expect("save claude provider");
|
||||
db.set_current_provider("claude", &provider.id)
|
||||
.expect("set current claude provider");
|
||||
db.save_live_backup("claude", "{\"env\":{}}")
|
||||
.await
|
||||
.expect("seed live backup");
|
||||
|
||||
let mut proxy_config = db
|
||||
.get_proxy_config_for_app("claude")
|
||||
.await
|
||||
.expect("get claude proxy config");
|
||||
proxy_config.enabled = false;
|
||||
db.update_proxy_config_for_app(proxy_config)
|
||||
.await
|
||||
.expect("persist cleared enabled flag");
|
||||
|
||||
let err = ensure_restore_allowed(&proxy_service)
|
||||
.await
|
||||
.expect_err("live backup should still block restore");
|
||||
assert!(
|
||||
err.to_string().contains("restore")
|
||||
|| err.to_string().contains("恢复")
|
||||
|| err.to_string().contains("proxy"),
|
||||
"unexpected error: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user