fix(proxy): serialize per-app provider switches to prevent state corruption

Concurrent failover switches for the same app could cause is_current,
local settings, and Live backup to point at different providers.

- Add SwitchLockManager with per-app mutexes (different apps still parallel)
- Unify scattered switch logic into ProxyService::hot_switch_provider
- Fix TOCTOU in set_current_provider via mutate_settings
- Add logical_target_changed to skip redundant UI refreshes
- Add tests for serialization and restore-waits-for-switch scenarios
This commit is contained in:
Jason
2026-03-25 16:02:35 +08:00
parent 6a083cdd1c
commit af8f907467
6 changed files with 344 additions and 95 deletions
+11 -23
View File
@@ -2,15 +2,12 @@
//!
//! 处理故障转移成功后的供应商切换逻辑,包括:
//! - 去重控制(避免多个请求同时触发)
//! - 数据库更新
//! - 托盘菜单更新
//! - 前端事件发射
//! - Live 备份更新
use crate::database::Database;
use crate::error::AppError;
use std::collections::HashSet;
use std::str::FromStr;
use std::sync::Arc;
use tauri::{Emitter, Manager};
use tokio::sync::RwLock;
@@ -98,30 +95,21 @@ impl FailoverSwitchManager {
log::info!("[FO-001] 切换: {app_type} → {provider_name}");
// 1. 更新数据库 is_current
self.db.set_current_provider(app_type, provider_id)?;
let mut switched = false;
// 2. 更新本地 settings(设备级)
let app_type_enum = crate::app_config::AppType::from_str(app_type)
.map_err(|_| AppError::Message(format!("无效的应用类型: {app_type}")))?;
crate::settings::set_current_provider(&app_type_enum, Some(provider_id))?;
// 3. 更新托盘菜单和发射事件
if let Some(app) = app_handle {
// 更新托盘菜单
if let Some(app_state) = app.try_state::<crate::store::AppState>() {
// 更新 Live 备份(确保代理停止时恢复正确配置)
if let Ok(Some(provider)) = self.db.get_provider_by_id(provider_id, app_type) {
if let Err(e) = app_state
.proxy_service
.update_live_backup_from_provider(app_type, &provider)
.await
{
log::warn!("[FO-003] Live 备份更新失败: {e}");
}
switched = app_state
.proxy_service
.hot_switch_provider(app_type, provider_id)
.await
.map_err(AppError::Message)?
.logical_target_changed;
if !switched {
return Ok(false);
}
// 重建托盘菜单
if let Ok(new_menu) = crate::tray::create_tray_menu(app, app_state.inner()) {
if let Some(tray) = app.tray_by_id("main") {
if let Err(e) = tray.set_menu(Some(new_menu)) {
@@ -142,6 +130,6 @@ impl FailoverSwitchManager {
}
}
Ok(true)
Ok(switched)
}
}
+1
View File
@@ -24,6 +24,7 @@ pub mod response_processor;
pub(crate) mod server;
pub mod session;
pub(crate) mod sse;
pub(crate) mod switch_lock;
pub mod thinking_budget_rectifier;
pub mod thinking_optimizer;
pub mod thinking_rectifier;
+42
View File
@@ -0,0 +1,42 @@
//! Per-app switch lock
//!
//! 确保同一应用同时只有一个供应商切换操作在执行,
//! 防止并发切换导致 is_current 与 Live 备份不一致。
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{Mutex, OwnedMutexGuard, RwLock};
/// 每个应用类型一把互斥锁,保证同一应用的切换操作串行执行。
///
/// 不同应用之间(如 Claude 和 Codex)可以并行切换。
#[derive(Clone, Default)]
pub struct SwitchLockManager {
locks: Arc<RwLock<HashMap<String, Arc<Mutex<()>>>>>,
}
impl SwitchLockManager {
pub fn new() -> Self {
Self::default()
}
/// 获取指定应用的切换锁。
///
/// 返回 `OwnedMutexGuard`,持有期间同一 `app_type` 的其他切换会排队等待。
pub async fn lock_for_app(&self, app_type: &str) -> OwnedMutexGuard<()> {
let lock = {
let locks = self.locks.read().await;
if let Some(lock) = locks.get(app_type) {
lock.clone()
} else {
drop(locks);
let mut locks = self.locks.write().await;
locks
.entry(app_type.to_string())
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone()
}
};
lock.lock_owned().await
}
}
+2 -22
View File
@@ -480,32 +480,12 @@ impl ProviderService {
id
);
// 获取新供应商的完整配置(用于更新备份)
let provider = providers
.get(id)
.ok_or_else(|| AppError::Message(format!("供应商 {id} 不存在")))?;
// Update database is_current
state.db.set_current_provider(app_type.as_str(), id)?;
// Update local settings for consistency
crate::settings::set_current_provider(&app_type, Some(id))?;
// 更新 Live 备份(确保代理关闭时恢复正确的供应商配置)
futures::executor::block_on(
state
.proxy_service
.update_live_backup_from_provider(app_type.as_str(), provider),
.hot_switch_provider(app_type.as_str(), id),
)
.map_err(|e| AppError::Message(format!("更新 Live 备份失败: {e}")))?;
// 关键修复:接管模式下切换供应商不会写回 Live 配置,
// 需要主动清理 Claude Live 中的“模型覆盖”字段,避免仍以旧模型名发起请求。
if matches!(app_type, AppType::Claude) {
if let Err(e) = state.proxy_service.cleanup_claude_model_overrides_in_live() {
log::warn!("清理 Claude Live 模型字段失败(不影响切换结果): {e}");
}
}
.map_err(|e| AppError::Message(format!("热切换失败: {e}")))?;
// Note: No Live config write, no MCP sync
// The proxy server will route requests to the new provider via is_current
+280 -39
View File
@@ -7,6 +7,7 @@ use crate::config::{get_claude_settings_path, read_json_file, write_json_file};
use crate::database::Database;
use crate::provider::Provider;
use crate::proxy::server::ProxyServer;
use crate::proxy::switch_lock::SwitchLockManager;
use crate::proxy::types::*;
use crate::services::provider::{
build_effective_settings_with_common_config, write_live_with_common_config,
@@ -39,6 +40,12 @@ pub struct ProxyService {
server: Arc<RwLock<Option<ProxyServer>>>,
/// AppHandle,用于传递给 ProxyServer 以支持故障转移时的 UI 更新
app_handle: Arc<RwLock<Option<tauri::AppHandle>>>,
switch_locks: SwitchLockManager,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct HotSwitchOutcome {
pub logical_target_changed: bool,
}
impl ProxyService {
@@ -47,6 +54,7 @@ impl ProxyService {
db,
server: Arc::new(RwLock::new(None)),
app_handle: Arc::new(RwLock::new(None)),
switch_locks: SwitchLockManager::new(),
}
}
@@ -1100,6 +1108,11 @@ impl ProxyService {
/// 恢复指定应用的 Live 配置(若无备份则不做任何操作)
async fn restore_live_config_for_app(&self, app_type: &AppType) -> Result<(), String> {
let _guard = self.switch_locks.lock_for_app(app_type.as_str()).await;
self.restore_live_config_for_app_inner(app_type).await
}
async fn restore_live_config_for_app_inner(&self, app_type: &AppType) -> Result<(), String> {
match app_type {
AppType::Claude => {
if let Ok(Some(backup)) = self.db.get_live_backup("claude").await {
@@ -1159,6 +1172,15 @@ impl ProxyService {
async fn restore_live_config_for_app_with_fallback(
&self,
app_type: &AppType,
) -> Result<(), String> {
let _guard = self.switch_locks.lock_for_app(app_type.as_str()).await;
self.restore_live_config_for_app_with_fallback_inner(app_type)
.await
}
async fn restore_live_config_for_app_with_fallback_inner(
&self,
app_type: &AppType,
) -> Result<(), String> {
let app_type_str = app_type.as_str();
@@ -1487,6 +1509,17 @@ impl ProxyService {
&self,
app_type: &str,
provider: &Provider,
) -> Result<(), String> {
let _guard = self.switch_locks.lock_for_app(app_type).await;
self.update_live_backup_from_provider_inner(app_type, provider)
.await
}
/// 仅供已持有 per-app 切换锁的调用方使用。
async fn update_live_backup_from_provider_inner(
&self,
app_type: &str,
provider: &Provider,
) -> Result<(), String> {
let app_type_enum =
AppType::from_str(app_type).map_err(|_| format!("未知的应用类型: {app_type}"))?;
@@ -1540,6 +1573,69 @@ impl ProxyService {
Ok(())
}
pub async fn hot_switch_provider(
&self,
app_type: &str,
provider_id: &str,
) -> Result<HotSwitchOutcome, String> {
let _guard = self.switch_locks.lock_for_app(app_type).await;
let app_type_enum =
AppType::from_str(app_type).map_err(|_| format!("无效的应用类型: {app_type}"))?;
let provider = self
.db
.get_provider_by_id(provider_id, app_type)
.map_err(|e| format!("读取供应商失败: {e}"))?
.ok_or_else(|| format!("供应商不存在: {provider_id}"))?;
let logical_target_changed =
crate::settings::get_effective_current_provider(&self.db, &app_type_enum)
.map_err(|e| format!("读取当前供应商失败: {e}"))?
.as_deref()
!= Some(provider_id);
let has_backup = self
.db
.get_live_backup(app_type_enum.as_str())
.await
.map_err(|e| format!("读取 {app_type} 备份失败: {e}"))?
.is_some();
let live_taken_over = self.detect_takeover_in_live_config_for_app(&app_type_enum);
let should_sync_backup = has_backup || live_taken_over;
self.db
.set_current_provider(app_type_enum.as_str(), provider_id)
.map_err(|e| format!("更新当前供应商失败: {e}"))?;
crate::settings::set_current_provider(&app_type_enum, Some(provider_id))
.map_err(|e| format!("更新本地当前供应商失败: {e}"))?;
if should_sync_backup {
self.update_live_backup_from_provider_inner(app_type, &provider)
.await?;
if matches!(app_type_enum, AppType::Claude) {
if let Err(e) = self.cleanup_claude_model_overrides_in_live() {
log::warn!("清理 Claude Live 模型字段失败(不影响热切换结果): {e}");
}
}
}
if let Some(server) = self.server.read().await.as_ref() {
server
.set_active_target(app_type_enum.as_str(), &provider.id, &provider.name)
.await;
}
Ok(HotSwitchOutcome {
logical_target_changed,
})
}
#[cfg(test)]
async fn lock_switch_for_test(&self, app_type: &str) -> tokio::sync::OwnedMutexGuard<()> {
self.switch_locks.lock_for_app(app_type).await
}
fn preserve_codex_mcp_servers_in_backup(
target_settings: &mut Value,
existing_backup: &Value,
@@ -1607,47 +1703,13 @@ impl ProxyService {
app_type: &str,
provider_id: &str,
) -> Result<(), String> {
// 代理模式切换供应商(热切换):
// - 更新 SSOT(数据库 is_current
// - 同步本地 settings(设备级 current_provider_*
// - 若该应用正处于接管模式,则同步更新 Live 备份(用于停止代理时恢复)
let app_type_enum =
AppType::from_str(app_type).map_err(|_| format!("无效的应用类型: {app_type}"))?;
let outcome = self.hot_switch_provider(app_type, provider_id).await?;
self.db
.set_current_provider(app_type_enum.as_str(), provider_id)
.map_err(|e| format!("更新当前供应商失败: {e}"))?;
// 同步本地 settings(设备级优先)
crate::settings::set_current_provider(&app_type_enum, Some(provider_id))
.map_err(|e| format!("更新本地当前供应商失败: {e}"))?;
// 仅在确实处于接管状态时才更新 Live 备份,避免无接管时误写覆盖 Live
let has_backup = self
.db
.get_live_backup(app_type_enum.as_str())
.await
.ok()
.flatten()
.is_some();
let live_taken_over = self.detect_takeover_in_live_config_for_app(&app_type_enum);
if let Ok(Some(provider)) = self.db.get_provider_by_id(provider_id, app_type) {
// 同步更新 Live 备份(用于 stop_with_restore 恢复)
if has_backup || live_taken_over {
self.update_live_backup_from_provider(app_type, &provider)
.await?;
}
// 同步更新 ProxyStatus.active_targets(用于 UI 立即反映切换目标)
if let Some(server) = self.server.read().await.as_ref() {
server
.set_active_target(app_type_enum.as_str(), &provider.id, &provider.name)
.await;
}
if outcome.logical_target_changed {
log::info!("代理模式:已切换 {app_type} 的目标供应商为 {provider_id}");
} else {
log::debug!("代理模式:{app_type} 已对齐到目标供应商 {provider_id}");
}
log::info!("代理模式:已切换 {app_type} 的目标供应商为 {provider_id}");
Ok(())
}
@@ -2193,6 +2255,185 @@ model = "gpt-5.1-codex"
assert_eq!(backup.original_config, expected);
}
#[tokio::test]
#[serial]
async fn hot_switch_provider_serializes_same_app_switches() {
use tokio::time::{sleep, Duration};
let _home = TempHome::new();
crate::settings::reload_settings().expect("reload settings");
let db = Arc::new(Database::memory().expect("init db"));
let service = ProxyService::new(db.clone());
let provider_a = Provider::with_id(
"a".to_string(),
"A".to_string(),
json!({ "env": { "ANTHROPIC_API_KEY": "a-key" } }),
None,
);
let provider_b = Provider::with_id(
"b".to_string(),
"B".to_string(),
json!({ "env": { "ANTHROPIC_API_KEY": "b-key" } }),
None,
);
let provider_c = Provider::with_id(
"c".to_string(),
"C".to_string(),
json!({ "env": { "ANTHROPIC_API_KEY": "c-key" } }),
None,
);
db.save_provider("claude", &provider_a)
.expect("save provider a");
db.save_provider("claude", &provider_b)
.expect("save provider b");
db.save_provider("claude", &provider_c)
.expect("save provider c");
db.set_current_provider("claude", "a")
.expect("set current provider");
crate::settings::set_current_provider(&AppType::Claude, Some("a"))
.expect("set local current provider");
db.save_live_backup("claude", "{\"env\":{}}")
.await
.expect("seed live backup");
let guard = service.lock_switch_for_test("claude").await;
let service_for_b = service.clone();
let service_for_c = service.clone();
let switch_b = tokio::spawn(async move {
service_for_b
.hot_switch_provider("claude", "b")
.await
.expect("switch to b")
});
sleep(Duration::from_millis(20)).await;
let switch_c = tokio::spawn(async move {
service_for_c
.hot_switch_provider("claude", "c")
.await
.expect("switch to c")
});
sleep(Duration::from_millis(20)).await;
drop(guard);
let outcome_b = switch_b.await.expect("join switch b");
let outcome_c = switch_c.await.expect("join switch c");
assert!(outcome_b.logical_target_changed);
assert!(outcome_c.logical_target_changed);
assert_eq!(
crate::settings::get_effective_current_provider(&db, &AppType::Claude)
.expect("effective current"),
Some("c".to_string())
);
assert_eq!(
crate::settings::get_current_provider(&AppType::Claude).as_deref(),
Some("c")
);
assert_eq!(
db.get_current_provider("claude").expect("db current"),
Some("c".to_string())
);
let backup = db
.get_live_backup("claude")
.await
.expect("get live backup")
.expect("backup exists");
let expected = serde_json::to_string(&provider_c.settings_config).expect("serialize");
assert_eq!(backup.original_config, expected);
}
#[tokio::test]
#[serial]
async fn restore_waits_for_hot_switch_and_restores_latest_backup() {
use tokio::time::{sleep, Duration};
let _home = TempHome::new();
crate::settings::reload_settings().expect("reload settings");
let db = Arc::new(Database::memory().expect("init db"));
let service = ProxyService::new(db.clone());
let provider_a = Provider::with_id(
"a".to_string(),
"A".to_string(),
json!({ "env": { "ANTHROPIC_API_KEY": "a-key" } }),
None,
);
let provider_b = Provider::with_id(
"b".to_string(),
"B".to_string(),
json!({ "env": { "ANTHROPIC_API_KEY": "b-key" } }),
None,
);
db.save_provider("claude", &provider_a)
.expect("save provider a");
db.save_provider("claude", &provider_b)
.expect("save provider b");
db.set_current_provider("claude", "a")
.expect("set current provider");
crate::settings::set_current_provider(&AppType::Claude, Some("a"))
.expect("set local current provider");
db.save_live_backup(
"claude",
&serde_json::to_string(&provider_a.settings_config).expect("serialize provider a"),
)
.await
.expect("seed live backup");
service
.write_claude_live(&json!({ "env": { "ANTHROPIC_API_KEY": "stale" } }))
.expect("seed live file");
let guard = service.lock_switch_for_test("claude").await;
let service_for_switch = service.clone();
let service_for_restore = service.clone();
let switch_to_b = tokio::spawn(async move {
service_for_switch
.hot_switch_provider("claude", "b")
.await
.expect("switch to b")
});
sleep(Duration::from_millis(20)).await;
let restore = tokio::spawn(async move {
service_for_restore
.restore_live_config_for_app_with_fallback(&AppType::Claude)
.await
.expect("restore claude live")
});
sleep(Duration::from_millis(20)).await;
drop(guard);
let outcome = switch_to_b.await.expect("join switch");
restore.await.expect("join restore");
assert!(outcome.logical_target_changed);
assert_eq!(
crate::settings::get_effective_current_provider(&db, &AppType::Claude)
.expect("effective current"),
Some("b".to_string())
);
let backup = db
.get_live_backup("claude")
.await
.expect("get live backup")
.expect("backup exists");
let expected = serde_json::to_string(&provider_b.settings_config).expect("serialize");
assert_eq!(backup.original_config, expected);
assert_eq!(
service.read_claude_live().expect("read live"),
provider_b.settings_config
);
}
#[tokio::test]
#[serial]
async fn update_live_backup_from_provider_applies_claude_common_config() {
+8 -11
View File
@@ -584,17 +584,14 @@ pub fn get_current_provider(app_type: &AppType) -> Option<String> {
/// 这是设备级别的设置,不随数据库同步。
/// 传入 `None` 会清除当前供应商设置。
pub fn set_current_provider(app_type: &AppType, id: Option<&str>) -> Result<(), AppError> {
let mut settings = get_settings();
match app_type {
AppType::Claude => settings.current_provider_claude = id.map(|s| s.to_string()),
AppType::Codex => settings.current_provider_codex = id.map(|s| s.to_string()),
AppType::Gemini => settings.current_provider_gemini = id.map(|s| s.to_string()),
AppType::OpenCode => settings.current_provider_opencode = id.map(|s| s.to_string()),
AppType::OpenClaw => settings.current_provider_openclaw = id.map(|s| s.to_string()),
}
update_settings(settings)
let id_owned = id.map(|s| s.to_string());
mutate_settings(|settings| match app_type {
AppType::Claude => settings.current_provider_claude = id_owned.clone(),
AppType::Codex => settings.current_provider_codex = id_owned.clone(),
AppType::Gemini => settings.current_provider_gemini = id_owned.clone(),
AppType::OpenCode => settings.current_provider_opencode = id_owned.clone(),
AppType::OpenClaw => settings.current_provider_openclaw = id_owned.clone(),
})
}
/// 获取有效的当前供应商 ID(验证存在性)