diff --git a/src-tauri/src/commands/proxy.rs b/src-tauri/src/commands/proxy.rs index 9c9a3421..b74e1afe 100644 --- a/src-tauri/src/commands/proxy.rs +++ b/src-tauri/src/commands/proxy.rs @@ -130,15 +130,17 @@ pub async fn reset_circuit_breaker( provider_id: String, app_type: String, ) -> Result<(), String> { - // 重置数据库健康状态 + // 1. 重置数据库健康状态 let db = &state.db; db.update_provider_health(&provider_id, &app_type, true, None) .await .map_err(|e| e.to_string())?; - // 注意:熔断器状态在内存中,重启代理服务器后会重置 - // 如果代理服务器正在运行,需要通知它重置熔断器 - // 目前先通过数据库重置健康状态,熔断器会在下次超时后自动尝试半开 + // 2. 如果代理正在运行,重置内存中的熔断器状态 + state + .proxy_service + .reset_provider_circuit_breaker(&provider_id, &app_type) + .await?; Ok(()) } @@ -161,9 +163,19 @@ pub async fn update_circuit_breaker_config( config: CircuitBreakerConfig, ) -> Result<(), String> { let db = &state.db; + + // 1. 更新数据库配置 db.update_circuit_breaker_config(&config) .await - .map_err(|e| e.to_string()) + .map_err(|e| e.to_string())?; + + // 2. 如果代理正在运行,热更新内存中的熔断器配置 + state + .proxy_service + .update_circuit_breaker_configs(config) + .await?; + + Ok(()) } /// 获取熔断器统计信息(仅当代理服务器运行时) diff --git a/src-tauri/src/database/dao/proxy.rs b/src-tauri/src/database/dao/proxy.rs index 961760c4..0303481c 100644 --- a/src-tauri/src/database/dao/proxy.rs +++ b/src-tauri/src/database/dao/proxy.rs @@ -129,12 +129,31 @@ impl Database { } /// 更新Provider健康状态 + /// + /// 使用默认阈值(5)判断是否健康,建议使用 `update_provider_health_with_threshold` 传入配置的阈值 pub async fn update_provider_health( &self, provider_id: &str, app_type: &str, success: bool, error_msg: Option, + ) -> Result<(), AppError> { + // 默认阈值与 CircuitBreakerConfig::default() 保持一致 + self.update_provider_health_with_threshold(provider_id, app_type, success, error_msg, 5) + .await + } + + /// 更新Provider健康状态(带阈值参数) + /// + /// # Arguments + /// * `failure_threshold` - 连续失败多少次后标记为不健康 + pub async fn update_provider_health_with_threshold( + &self, + provider_id: &str, + app_type: &str, + success: bool, + error_msg: Option, + failure_threshold: u32, ) -> Result<(), AppError> { let conn = lock_conn!(self.conn); @@ -142,7 +161,7 @@ impl Database { // 先查询当前状态 let current = conn.query_row( - "SELECT consecutive_failures FROM provider_health + "SELECT consecutive_failures FROM provider_health WHERE provider_id = ?1 AND app_type = ?2", rusqlite::params![provider_id, app_type], |row| Ok(row.get::<_, i64>(0)? as u32), @@ -154,7 +173,8 @@ impl Database { } else { // 失败:增加失败计数 let failures = current.unwrap_or(0) + 1; - let healthy = if failures >= 3 { 0 } else { 1 }; + // 使用传入的阈值而非硬编码 + let healthy = if failures >= failure_threshold { 0 } else { 1 }; (healthy, failures) }; @@ -169,10 +189,10 @@ impl Database { "INSERT OR REPLACE INTO provider_health (provider_id, app_type, is_healthy, consecutive_failures, last_success_at, last_failure_at, last_error, updated_at) - VALUES (?1, ?2, ?3, ?4, - COALESCE(?5, (SELECT last_success_at FROM provider_health + VALUES (?1, ?2, ?3, ?4, + COALESCE(?5, (SELECT last_success_at FROM provider_health WHERE provider_id = ?1 AND app_type = ?2)), - COALESCE(?6, (SELECT last_failure_at FROM provider_health + COALESCE(?6, (SELECT last_failure_at FROM provider_health WHERE provider_id = ?1 AND app_type = ?2)), ?7, ?8)", rusqlite::params![ diff --git a/src-tauri/src/proxy/circuit_breaker.rs b/src-tauri/src/proxy/circuit_breaker.rs index acedef87..3bdd9df8 100644 --- a/src-tauri/src/proxy/circuit_breaker.rs +++ b/src-tauri/src/proxy/circuit_breaker.rs @@ -72,8 +72,10 @@ pub struct CircuitBreaker { failed_requests: Arc, /// 上次打开时间 last_opened_at: Arc>>, - /// 配置 - config: CircuitBreakerConfig, + /// 配置(支持热更新) + config: Arc>, + /// 半开状态已放行的请求数(用于限流) + half_open_requests: Arc, } impl CircuitBreaker { @@ -86,20 +88,29 @@ impl CircuitBreaker { total_requests: Arc::new(AtomicU32::new(0)), failed_requests: Arc::new(AtomicU32::new(0)), last_opened_at: Arc::new(RwLock::new(None)), - config, + config: Arc::new(RwLock::new(config)), + half_open_requests: Arc::new(AtomicU32::new(0)), } } + /// 更新熔断器配置(热更新,不重置状态) + pub async fn update_config(&self, new_config: CircuitBreakerConfig) { + *self.config.write().await = new_config; + log::debug!("Circuit breaker config updated"); + } + /// 检查是否允许请求通过 pub async fn allow_request(&self) -> bool { let state = *self.state.read().await; + let config = self.config.read().await; match state { CircuitState::Closed => true, CircuitState::Open => { // 检查是否应该尝试半开 if let Some(opened_at) = *self.last_opened_at.read().await { - if opened_at.elapsed().as_secs() >= self.config.timeout_seconds { + if opened_at.elapsed().as_secs() >= config.timeout_seconds { + drop(config); // 释放读锁再转换状态 log::info!( "Circuit breaker transitioning from Open to HalfOpen (timeout reached)" ); @@ -109,13 +120,36 @@ impl CircuitBreaker { } false } - CircuitState::HalfOpen => true, + CircuitState::HalfOpen => { + // 半开状态限流:只允许有限请求通过进行探测 + // 默认最多允许 1 个请求(可在配置中扩展) + let max_half_open_requests = 1u32; + let current = self.half_open_requests.fetch_add(1, Ordering::SeqCst); + + if current < max_half_open_requests { + log::debug!( + "Circuit breaker HalfOpen: allowing probe request ({}/{})", + current + 1, + max_half_open_requests + ); + true + } else { + // 超过限额,回退计数,拒绝请求 + self.half_open_requests.fetch_sub(1, Ordering::SeqCst); + log::debug!( + "Circuit breaker HalfOpen: rejecting request (limit reached: {})", + max_half_open_requests + ); + false + } + } } } /// 记录成功 pub async fn record_success(&self) { let state = *self.state.read().await; + let config = self.config.read().await; // 重置失败计数 self.consecutive_failures.store(0, Ordering::SeqCst); @@ -123,14 +157,18 @@ impl CircuitBreaker { match state { CircuitState::HalfOpen => { + // 释放 in-flight 名额(探测请求结束) + self.half_open_requests.fetch_sub(1, Ordering::SeqCst); + let successes = self.consecutive_successes.fetch_add(1, Ordering::SeqCst) + 1; log::debug!( "Circuit breaker HalfOpen: {} consecutive successes (threshold: {})", successes, - self.config.success_threshold + config.success_threshold ); - if successes >= self.config.success_threshold { + if successes >= config.success_threshold { + drop(config); // 释放读锁再转换状态 log::info!("Circuit breaker transitioning from HalfOpen to Closed (success threshold reached)"); self.transition_to_closed().await; } @@ -145,6 +183,7 @@ impl CircuitBreaker { /// 记录失败 pub async fn record_failure(&self) { let state = *self.state.read().await; + let config = self.config.read().await; // 更新计数器 let failures = self.consecutive_failures.fetch_add(1, Ordering::SeqCst) + 1; @@ -158,26 +197,38 @@ impl CircuitBreaker { "Circuit breaker {:?}: {} consecutive failures (threshold: {})", state, failures, - self.config.failure_threshold + config.failure_threshold ); // 检查是否应该打开熔断器 match state { - CircuitState::Closed | CircuitState::HalfOpen => { + CircuitState::HalfOpen => { + // 释放 in-flight 名额(探测请求结束) + self.half_open_requests.fetch_sub(1, Ordering::SeqCst); + + // HalfOpen 状态下失败,立即转为 Open + log::warn!( + "Circuit breaker HalfOpen probe failed, transitioning to Open" + ); + drop(config); + self.transition_to_open().await; + } + CircuitState::Closed => { // 检查连续失败次数 - if failures >= self.config.failure_threshold { + if failures >= config.failure_threshold { log::warn!( "Circuit breaker opening due to {} consecutive failures (threshold: {})", failures, - self.config.failure_threshold + config.failure_threshold ); + drop(config); // 释放读锁再转换状态 self.transition_to_open().await; } else { // 检查错误率 let total = self.total_requests.load(Ordering::SeqCst); let failed = self.failed_requests.load(Ordering::SeqCst); - if total >= self.config.min_requests { + if total >= config.min_requests { let error_rate = failed as f64 / total as f64; log::debug!( "Circuit breaker error rate: {:.2}% ({}/{} requests)", @@ -186,12 +237,13 @@ impl CircuitBreaker { total ); - if error_rate >= self.config.error_rate_threshold { + if error_rate >= config.error_rate_threshold { log::warn!( "Circuit breaker opening due to high error rate: {:.2}% (threshold: {:.2}%)", error_rate * 100.0, - self.config.error_rate_threshold * 100.0 + config.error_rate_threshold * 100.0 ); + drop(config); // 释放读锁再转换状态 self.transition_to_open().await; } } @@ -237,6 +289,8 @@ impl CircuitBreaker { async fn transition_to_half_open(&self) { *self.state.write().await = CircuitState::HalfOpen; self.consecutive_successes.store(0, Ordering::SeqCst); + // 重置半开状态的请求限流计数 + self.half_open_requests.store(0, Ordering::SeqCst); } /// 转换到关闭状态 diff --git a/src-tauri/src/proxy/forwarder.rs b/src-tauri/src/proxy/forwarder.rs index b4226ba0..38555713 100644 --- a/src-tauri/src/proxy/forwarder.rs +++ b/src-tauri/src/proxy/forwarder.rs @@ -4,12 +4,12 @@ use super::{ error::*, - provider_router::ProviderRouter as NewProviderRouter, + provider_router::ProviderRouter, providers::{get_adapter, ProviderAdapter}, types::ProxyStatus, ProxyError, }; -use crate::{app_config::AppType, database::Database, provider::Provider}; +use crate::{app_config::AppType, provider::Provider}; use reqwest::{Client, Response}; use serde_json::Value; use std::sync::Arc; @@ -18,8 +18,9 @@ use tokio::sync::RwLock; pub struct RequestForwarder { client: Client, - router: Arc, - #[allow(dead_code)] + /// 共享的 ProviderRouter(持有熔断器状态) + router: Arc, + /// 单个 Provider 内的最大重试次数 max_retries: u8, status: Arc>, current_providers: Arc>>, @@ -27,7 +28,7 @@ pub struct RequestForwarder { impl RequestForwarder { pub fn new( - db: Arc, + router: Arc, timeout_secs: u64, max_retries: u8, status: Arc>, @@ -44,32 +45,85 @@ impl RequestForwarder { Self { client, - router: Arc::new(NewProviderRouter::new(db)), + router, max_retries, status, current_providers, } } + /// 对单个 Provider 执行请求(带重试) + /// + /// 在同一个 Provider 上最多重试 max_retries 次,使用指数退避 + async fn forward_with_provider_retry( + &self, + provider: &Provider, + endpoint: &str, + body: &Value, + headers: &axum::http::HeaderMap, + adapter: &dyn ProviderAdapter, + ) -> Result { + let mut last_error = None; + + for attempt in 0..=self.max_retries { + if attempt > 0 { + // 指数退避:100ms, 200ms, 400ms, ... + let delay_ms = 100 * 2u64.pow(attempt as u32 - 1); + log::info!( + "[{}] 重试第 {}/{} 次(等待 {}ms)", + adapter.name(), + attempt, + self.max_retries, + delay_ms + ); + tokio::time::sleep(Duration::from_millis(delay_ms)).await; + } + + match self.forward(provider, endpoint, body, headers, adapter).await { + Ok(response) => return Ok(response), + Err(e) => { + let category = self.categorize_proxy_error(&e); + + // 只有可重试的错误才继续重试 + if category == ErrorCategory::NonRetryable { + return Err(e); + } + + log::debug!( + "[{}] Provider {} 第 {} 次请求失败: {}", + adapter.name(), + provider.name, + attempt + 1, + e + ); + last_error = Some(e); + } + } + } + + Err(last_error.unwrap_or(ProxyError::MaxRetriesExceeded)) + } + /// 转发请求(带故障转移) + /// + /// # Arguments + /// * `app_type` - 应用类型 + /// * `endpoint` - API 端点 + /// * `body` - 请求体 + /// * `headers` - 请求头 + /// * `providers` - 已选择的 Provider 列表(由 RequestContext 提供,避免重复调用 select_providers) pub async fn forward_with_retry( &self, app_type: &AppType, endpoint: &str, body: Value, headers: axum::http::HeaderMap, + providers: Vec, ) -> Result { // 获取适配器 let adapter = get_adapter(app_type); let app_type_str = app_type.as_str(); - // 使用新的 ProviderRouter 选择所有可用供应商 - let providers = self - .router - .select_providers(app_type_str) - .await - .map_err(|e| ProxyError::DatabaseError(e.to_string()))?; - if providers.is_empty() { return Err(ProxyError::NoAvailableProvider); } @@ -108,9 +162,9 @@ impl RequestForwarder { let start = Instant::now(); - // 转发请求 + // 转发请求(带单 Provider 内重试) match self - .forward(provider, endpoint, &body, &headers, adapter.as_ref()) + .forward_with_provider_retry(provider, endpoint, &body, &headers, adapter.as_ref()) .await { Ok(response) => { @@ -373,17 +427,22 @@ impl RequestForwarder { } /// 分类ProxyError + /// + /// 决定哪些错误应该触发故障转移到下一个 Provider fn categorize_proxy_error(&self, error: &ProxyError) -> ErrorCategory { match error { ProxyError::Timeout(_) => ErrorCategory::Retryable, ProxyError::ForwardFailed(_) => ErrorCategory::Retryable, ProxyError::UpstreamError { status, .. } => { - if *status >= 500 { - ErrorCategory::Retryable - } else if *status >= 400 && *status < 500 { - ErrorCategory::NonRetryable - } else { - ErrorCategory::Retryable + match *status { + // 速率限制 - 应该尝试其他 Provider + 429 => ErrorCategory::Retryable, + // 请求超时 + 408 => ErrorCategory::Retryable, + // 服务器错误 + s if s >= 500 => ErrorCategory::Retryable, + // 其他 4xx 错误(认证失败、参数错误等)不应重试 + _ => ErrorCategory::NonRetryable, } } ProxyError::ProviderUnhealthy(_) => ErrorCategory::Retryable, diff --git a/src-tauri/src/proxy/handler_context.rs b/src-tauri/src/proxy/handler_context.rs index 08fbbf26..8748d097 100644 --- a/src-tauri/src/proxy/handler_context.rs +++ b/src-tauri/src/proxy/handler_context.rs @@ -5,8 +5,7 @@ use crate::app_config::AppType; use crate::provider::Provider; use crate::proxy::{ - forwarder::RequestForwarder, router::ProviderRouter, server::ProxyState, types::ProxyConfig, - ProxyError, + forwarder::RequestForwarder, server::ProxyState, types::ProxyConfig, ProxyError, }; use std::time::Instant; @@ -15,7 +14,7 @@ use std::time::Instant; /// 贯穿整个请求生命周期,包含: /// - 计时信息 /// - 代理配置 -/// - 选中的 Provider +/// - 选中的 Provider 列表(用于故障转移) /// - 请求模型名称 /// - 日志标签 pub struct RequestContext { @@ -23,8 +22,10 @@ pub struct RequestContext { pub start_time: Instant, /// 代理配置快照 pub config: ProxyConfig, - /// 选中的 Provider + /// 选中的 Provider(故障转移链的第一个) pub provider: Provider, + /// 完整的 Provider 列表(用于故障转移) + providers: Vec, /// 请求中的模型名称 pub request_model: String, /// 日志标签(如 "Claude"、"Codex"、"Gemini") @@ -65,21 +66,32 @@ impl RequestContext { .unwrap_or("unknown") .to_string(); - // Provider 选择 - let router = ProviderRouter::new(state.db.clone()); - let provider = router.select_provider(&app_type, &[]).await?; + // 使用共享的 ProviderRouter 选择 Provider(熔断器状态跨请求保持) + // 注意:只在这里调用一次,结果传递给 forwarder,避免重复消耗 HalfOpen 名额 + let providers = state + .provider_router + .select_providers(app_type_str) + .await + .map_err(|e| ProxyError::DatabaseError(e.to_string()))?; + + let provider = providers + .first() + .cloned() + .ok_or(ProxyError::NoAvailableProvider)?; log::info!( - "[{}] Provider: {}, model: {}", + "[{}] Provider: {}, model: {}, failover chain: {} providers", tag, provider.name, - request_model + request_model, + providers.len() ); Ok(Self { start_time, config, provider, + providers, request_model, tag, app_type_str, @@ -110,9 +122,11 @@ impl RequestContext { } /// 创建 RequestForwarder + /// + /// 使用共享的 ProviderRouter,确保熔断器状态跨请求保持 pub fn create_forwarder(&self, state: &ProxyState) -> RequestForwarder { RequestForwarder::new( - state.db.clone(), + state.provider_router.clone(), self.config.request_timeout, self.config.max_retries, state.status.clone(), @@ -120,6 +134,13 @@ impl RequestContext { ) } + /// 获取 Provider 列表(用于故障转移) + /// + /// 返回在创建上下文时已选择的 providers,避免重复调用 select_providers() + pub fn get_providers(&self) -> Vec { + self.providers.clone() + } + /// 计算请求延迟(毫秒) #[inline] pub fn latency_ms(&self) -> u64 { diff --git a/src-tauri/src/proxy/handlers.rs b/src-tauri/src/proxy/handlers.rs index 6e4704ca..ccbd1c5d 100644 --- a/src-tauri/src/proxy/handlers.rs +++ b/src-tauri/src/proxy/handlers.rs @@ -84,7 +84,13 @@ pub async fn handle_messages( // 转发请求 let forwarder = ctx.create_forwarder(&state); let response = forwarder - .forward_with_retry(&AppType::Claude, "/v1/messages", body.clone(), headers) + .forward_with_retry( + &AppType::Claude, + "/v1/messages", + body.clone(), + headers, + ctx.get_providers(), + ) .await?; let status = response.status(); @@ -299,7 +305,13 @@ pub async fn handle_chat_completions( let forwarder = ctx.create_forwarder(&state); let response = forwarder - .forward_with_retry(&AppType::Codex, "/v1/chat/completions", body, headers) + .forward_with_retry( + &AppType::Codex, + "/v1/chat/completions", + body, + headers, + ctx.get_providers(), + ) .await?; log::info!("[Codex] 上游响应状态: {}", response.status()); @@ -317,7 +329,13 @@ pub async fn handle_responses( let forwarder = ctx.create_forwarder(&state); let response = forwarder - .forward_with_retry(&AppType::Codex, "/v1/responses", body, headers) + .forward_with_retry( + &AppType::Codex, + "/v1/responses", + body, + headers, + ctx.get_providers(), + ) .await?; log::info!("[Codex] 上游响应状态: {}", response.status()); @@ -351,7 +369,13 @@ pub async fn handle_gemini( let forwarder = ctx.create_forwarder(&state); let response = forwarder - .forward_with_retry(&AppType::Gemini, endpoint, body, headers) + .forward_with_retry( + &AppType::Gemini, + endpoint, + body, + headers, + ctx.get_providers(), + ) .await?; log::info!("[Gemini] 上游响应状态: {}", response.status()); diff --git a/src-tauri/src/proxy/mod.rs b/src-tauri/src/proxy/mod.rs index 5b052f0f..d068d320 100644 --- a/src-tauri/src/proxy/mod.rs +++ b/src-tauri/src/proxy/mod.rs @@ -13,7 +13,6 @@ pub mod provider_router; pub mod providers; pub mod response_handler; pub mod response_processor; -mod router; pub(crate) mod server; pub mod session; pub(crate) mod types; diff --git a/src-tauri/src/proxy/provider_router.rs b/src-tauri/src/proxy/provider_router.rs index 6866fc1c..c3bb237a 100644 --- a/src-tauri/src/proxy/provider_router.rs +++ b/src-tauri/src/proxy/provider_router.rs @@ -5,7 +5,7 @@ use crate::database::Database; use crate::error::AppError; use crate::provider::Provider; -use crate::proxy::circuit_breaker::CircuitBreaker; +use crate::proxy::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; @@ -124,7 +124,11 @@ impl ProviderRouter { success: bool, error_msg: Option, ) -> Result<(), AppError> { - // 1. 更新熔断器状态 + // 1. 获取熔断器配置(用于更新健康状态和判断是否禁用) + let config = self.db.get_circuit_breaker_config().await.ok(); + let failure_threshold = config.map(|c| c.failure_threshold).unwrap_or(5); + + // 2. 更新熔断器状态 let circuit_key = format!("{app_type}:{provider_id}"); let breaker = self.get_or_create_circuit_breaker(&circuit_key).await; @@ -140,19 +144,21 @@ impl ProviderRouter { ); } - // 2. 更新数据库健康状态 + // 3. 更新数据库健康状态(使用配置的阈值) self.db - .update_provider_health(provider_id, app_type, success, error_msg.clone()) + .update_provider_health_with_threshold( + provider_id, + app_type, + success, + error_msg.clone(), + failure_threshold, + ) .await?; - // 3. 如果连续失败达到熔断阈值,自动禁用代理目标 + // 4. 如果连续失败达到熔断阈值,自动禁用代理目标 if !success { let health = self.db.get_provider_health(provider_id, app_type).await?; - // 获取熔断器配置 - let config = self.db.get_circuit_breaker_config().await.ok(); - let failure_threshold = config.map(|c| c.failure_threshold).unwrap_or(5); - // 如果连续失败达到阈值,自动关闭该供应商的代理开关 if health.consecutive_failures >= failure_threshold { log::warn!( @@ -171,7 +177,6 @@ impl ProviderRouter { } /// 重置熔断器(手动恢复) - #[allow(dead_code)] pub async fn reset_circuit_breaker(&self, circuit_key: &str) { let breakers = self.circuit_breakers.read().await; if let Some(breaker) = breakers.get(circuit_key) { @@ -180,6 +185,27 @@ impl ProviderRouter { } } + /// 重置指定供应商的熔断器 + pub async fn reset_provider_breaker(&self, provider_id: &str, app_type: &str) { + let circuit_key = format!("{app_type}:{provider_id}"); + self.reset_circuit_breaker(&circuit_key).await; + } + + /// 更新所有熔断器的配置(热更新) + /// + /// 当用户在 UI 中修改熔断器配置后调用此方法, + /// 所有现有的熔断器会立即使用新配置 + pub async fn update_all_configs(&self, config: CircuitBreakerConfig) { + let breakers = self.circuit_breakers.read().await; + let count = breakers.len(); + + for breaker in breakers.values() { + breaker.update_config(config.clone()).await; + } + + log::info!("已更新 {} 个熔断器的配置", count); + } + /// 获取熔断器状态 #[allow(dead_code)] pub async fn get_circuit_breaker_stats( diff --git a/src-tauri/src/proxy/router.rs b/src-tauri/src/proxy/router.rs deleted file mode 100644 index c59c99bf..00000000 --- a/src-tauri/src/proxy/router.rs +++ /dev/null @@ -1,70 +0,0 @@ -//! Provider路由器 -//! -//! 负责选择合适的Provider进行请求转发 - -use super::ProxyError; -use crate::{app_config::AppType, database::Database, provider::Provider}; -use std::sync::Arc; - -pub struct ProviderRouter { - db: Arc, -} - -impl ProviderRouter { - pub fn new(db: Arc) -> Self { - Self { db } - } - - /// 选择Provider(只使用标记为代理目标的 Provider) - pub async fn select_provider( - &self, - app_type: &AppType, - _failed_ids: &[String], - ) -> Result { - // 1. 获取 Proxy Target Provider ID - let proxy_target_id = self - .db - .get_proxy_target_provider(app_type.as_str()) - .map_err(|e| ProxyError::DatabaseError(e.to_string()))?; - - let target_id = proxy_target_id.ok_or_else(|| { - log::warn!("[{}] 未设置代理目标 Provider", app_type.as_str()); - ProxyError::NoAvailableProvider - })?; - - // 2. 获取所有 Provider - let providers = self - .db - .get_all_providers(app_type.as_str()) - .map_err(|e| ProxyError::DatabaseError(e.to_string()))?; - - // 3. 找到目标 Provider - let target = providers.get(&target_id).ok_or_else(|| { - log::warn!( - "[{}] 代理目标 Provider 不存在: {}", - app_type.as_str(), - target_id - ); - ProxyError::NoAvailableProvider - })?; - - log::info!( - "[{}] 使用代理目标 Provider: {}", - app_type.as_str(), - target.name - ); - Ok(target.clone()) - } - - /// 更新Provider健康状态(保留接口但不影响选择) - #[allow(dead_code)] - pub async fn update_health( - &self, - _provider: &Provider, - _app_type: &AppType, - _success: bool, - _error_msg: Option, - ) { - // 不再记录健康状态 - } -} diff --git a/src-tauri/src/proxy/server.rs b/src-tauri/src/proxy/server.rs index 7c60ecf3..93a89099 100644 --- a/src-tauri/src/proxy/server.rs +++ b/src-tauri/src/proxy/server.rs @@ -2,7 +2,7 @@ //! //! 基于Axum的HTTP服务器,处理代理请求 -use super::{handlers, types::*, ProxyError}; +use super::{handlers, provider_router::ProviderRouter, types::*, ProxyError}; use crate::database::Database; use axum::{ routing::{get, post}, @@ -23,6 +23,8 @@ pub struct ProxyState { pub start_time: Arc>>, /// 每个应用类型当前使用的 provider (app_type -> (provider_id, provider_name)) pub current_providers: Arc>>, + /// 共享的 ProviderRouter(持有熔断器状态,跨请求保持) + pub provider_router: Arc, } /// 代理HTTP服务器 @@ -36,12 +38,16 @@ pub struct ProxyServer { impl ProxyServer { pub fn new(config: ProxyConfig, db: Arc) -> Self { + // 创建共享的 ProviderRouter(熔断器状态将跨所有请求保持) + let provider_router = Arc::new(ProviderRouter::new(db.clone())); + let state = ProxyState { db, config: Arc::new(RwLock::new(config.clone())), status: Arc::new(RwLock::new(ProxyStatus::default())), start_time: Arc::new(RwLock::new(None)), current_providers: Arc::new(RwLock::new(std::collections::HashMap::new())), + provider_router, }; Self { @@ -192,4 +198,22 @@ impl ProxyServer { pub async fn apply_runtime_config(&self, config: &ProxyConfig) { *self.state.config.write().await = config.clone(); } + + /// 热更新熔断器配置 + /// + /// 将新配置应用到所有已创建的熔断器实例 + pub async fn update_circuit_breaker_configs( + &self, + config: super::circuit_breaker::CircuitBreakerConfig, + ) { + self.state.provider_router.update_all_configs(config).await; + } + + /// 重置指定 Provider 的熔断器 + pub async fn reset_provider_circuit_breaker(&self, provider_id: &str, app_type: &str) { + self.state + .provider_router + .reset_provider_breaker(provider_id, app_type) + .await; + } } diff --git a/src-tauri/src/services/proxy.rs b/src-tauri/src/services/proxy.rs index 45d55f31..f31bade7 100644 --- a/src-tauri/src/services/proxy.rs +++ b/src-tauri/src/services/proxy.rs @@ -681,4 +681,41 @@ impl ProxyService { pub async fn is_running(&self) -> bool { self.server.read().await.is_some() } + + /// 热更新熔断器配置 + /// + /// 如果代理服务器正在运行,将新配置应用到所有已创建的熔断器实例 + pub async fn update_circuit_breaker_configs( + &self, + config: crate::proxy::CircuitBreakerConfig, + ) -> Result<(), String> { + if let Some(server) = self.server.read().await.as_ref() { + server.update_circuit_breaker_configs(config).await; + log::info!("已热更新运行中的熔断器配置"); + } else { + log::debug!("代理服务器未运行,熔断器配置将在下次启动时生效"); + } + Ok(()) + } + + /// 重置指定 Provider 的熔断器 + /// + /// 如果代理服务器正在运行,立即重置内存中的熔断器状态 + pub async fn reset_provider_circuit_breaker( + &self, + provider_id: &str, + app_type: &str, + ) -> Result<(), String> { + if let Some(server) = self.server.read().await.as_ref() { + server + .reset_provider_circuit_breaker(provider_id, app_type) + .await; + log::info!( + "已重置 Provider {} (app: {}) 的熔断器", + provider_id, + app_type + ); + } + Ok(()) + } }