mirror of
https://github.com/farion1231/cc-switch.git
synced 2026-03-28 14:13:40 +08:00
fix(proxy): resolve circuit breaker state persistence and HalfOpen deadlock
This commit addresses several critical issues in the failover system: **Circuit breaker state persistence (previous fix)** - Promote ProviderRouter to ProxyState for cross-request state sharing - Remove redundant router.rs module - Fix 429 errors to be retryable (rate limiting should try other providers) **Hot-update circuit breaker config** - Add update_circuit_breaker_configs() to ProxyServer and ProxyService - Connect update_circuit_breaker_config command to running circuit breakers - Add reset_provider_circuit_breaker() for manual breaker reset **Fix HalfOpen deadlock bug** - Change half_open_requests from cumulative count to in-flight count - Release quota in record_success()/record_failure() when in HalfOpen state - Prevents permanent deadlock when success_threshold > 1 **Fix duplicate select_providers() call** - Store providers list in RequestContext, pass to forward_with_retry() - Avoid consuming HalfOpen quota twice per request - Single call to select_providers() per request lifecycle **Add per-provider retry with exponential backoff** - Implement forward_with_provider_retry() with configurable max_retries - Backoff delays: 100ms, 200ms, 400ms, etc.
This commit is contained in:
@@ -130,15 +130,17 @@ pub async fn reset_circuit_breaker(
|
||||
provider_id: String,
|
||||
app_type: String,
|
||||
) -> Result<(), String> {
|
||||
// 重置数据库健康状态
|
||||
// 1. 重置数据库健康状态
|
||||
let db = &state.db;
|
||||
db.update_provider_health(&provider_id, &app_type, true, None)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
// 注意:熔断器状态在内存中,重启代理服务器后会重置
|
||||
// 如果代理服务器正在运行,需要通知它重置熔断器
|
||||
// 目前先通过数据库重置健康状态,熔断器会在下次超时后自动尝试半开
|
||||
// 2. 如果代理正在运行,重置内存中的熔断器状态
|
||||
state
|
||||
.proxy_service
|
||||
.reset_provider_circuit_breaker(&provider_id, &app_type)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -161,9 +163,19 @@ pub async fn update_circuit_breaker_config(
|
||||
config: CircuitBreakerConfig,
|
||||
) -> Result<(), String> {
|
||||
let db = &state.db;
|
||||
|
||||
// 1. 更新数据库配置
|
||||
db.update_circuit_breaker_config(&config)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
// 2. 如果代理正在运行,热更新内存中的熔断器配置
|
||||
state
|
||||
.proxy_service
|
||||
.update_circuit_breaker_configs(config)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取熔断器统计信息(仅当代理服务器运行时)
|
||||
|
||||
@@ -129,12 +129,31 @@ impl Database {
|
||||
}
|
||||
|
||||
/// 更新Provider健康状态
|
||||
///
|
||||
/// 使用默认阈值(5)判断是否健康,建议使用 `update_provider_health_with_threshold` 传入配置的阈值
|
||||
pub async fn update_provider_health(
|
||||
&self,
|
||||
provider_id: &str,
|
||||
app_type: &str,
|
||||
success: bool,
|
||||
error_msg: Option<String>,
|
||||
) -> Result<(), AppError> {
|
||||
// 默认阈值与 CircuitBreakerConfig::default() 保持一致
|
||||
self.update_provider_health_with_threshold(provider_id, app_type, success, error_msg, 5)
|
||||
.await
|
||||
}
|
||||
|
||||
/// 更新Provider健康状态(带阈值参数)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `failure_threshold` - 连续失败多少次后标记为不健康
|
||||
pub async fn update_provider_health_with_threshold(
|
||||
&self,
|
||||
provider_id: &str,
|
||||
app_type: &str,
|
||||
success: bool,
|
||||
error_msg: Option<String>,
|
||||
failure_threshold: u32,
|
||||
) -> Result<(), AppError> {
|
||||
let conn = lock_conn!(self.conn);
|
||||
|
||||
@@ -142,7 +161,7 @@ impl Database {
|
||||
|
||||
// 先查询当前状态
|
||||
let current = conn.query_row(
|
||||
"SELECT consecutive_failures FROM provider_health
|
||||
"SELECT consecutive_failures FROM provider_health
|
||||
WHERE provider_id = ?1 AND app_type = ?2",
|
||||
rusqlite::params![provider_id, app_type],
|
||||
|row| Ok(row.get::<_, i64>(0)? as u32),
|
||||
@@ -154,7 +173,8 @@ impl Database {
|
||||
} else {
|
||||
// 失败:增加失败计数
|
||||
let failures = current.unwrap_or(0) + 1;
|
||||
let healthy = if failures >= 3 { 0 } else { 1 };
|
||||
// 使用传入的阈值而非硬编码
|
||||
let healthy = if failures >= failure_threshold { 0 } else { 1 };
|
||||
(healthy, failures)
|
||||
};
|
||||
|
||||
@@ -169,10 +189,10 @@ impl Database {
|
||||
"INSERT OR REPLACE INTO provider_health
|
||||
(provider_id, app_type, is_healthy, consecutive_failures,
|
||||
last_success_at, last_failure_at, last_error, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4,
|
||||
COALESCE(?5, (SELECT last_success_at FROM provider_health
|
||||
VALUES (?1, ?2, ?3, ?4,
|
||||
COALESCE(?5, (SELECT last_success_at FROM provider_health
|
||||
WHERE provider_id = ?1 AND app_type = ?2)),
|
||||
COALESCE(?6, (SELECT last_failure_at FROM provider_health
|
||||
COALESCE(?6, (SELECT last_failure_at FROM provider_health
|
||||
WHERE provider_id = ?1 AND app_type = ?2)),
|
||||
?7, ?8)",
|
||||
rusqlite::params![
|
||||
|
||||
@@ -72,8 +72,10 @@ pub struct CircuitBreaker {
|
||||
failed_requests: Arc<AtomicU32>,
|
||||
/// 上次打开时间
|
||||
last_opened_at: Arc<RwLock<Option<Instant>>>,
|
||||
/// 配置
|
||||
config: CircuitBreakerConfig,
|
||||
/// 配置(支持热更新)
|
||||
config: Arc<RwLock<CircuitBreakerConfig>>,
|
||||
/// 半开状态已放行的请求数(用于限流)
|
||||
half_open_requests: Arc<AtomicU32>,
|
||||
}
|
||||
|
||||
impl CircuitBreaker {
|
||||
@@ -86,20 +88,29 @@ impl CircuitBreaker {
|
||||
total_requests: Arc::new(AtomicU32::new(0)),
|
||||
failed_requests: Arc::new(AtomicU32::new(0)),
|
||||
last_opened_at: Arc::new(RwLock::new(None)),
|
||||
config,
|
||||
config: Arc::new(RwLock::new(config)),
|
||||
half_open_requests: Arc::new(AtomicU32::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
/// 更新熔断器配置(热更新,不重置状态)
|
||||
pub async fn update_config(&self, new_config: CircuitBreakerConfig) {
|
||||
*self.config.write().await = new_config;
|
||||
log::debug!("Circuit breaker config updated");
|
||||
}
|
||||
|
||||
/// 检查是否允许请求通过
|
||||
pub async fn allow_request(&self) -> bool {
|
||||
let state = *self.state.read().await;
|
||||
let config = self.config.read().await;
|
||||
|
||||
match state {
|
||||
CircuitState::Closed => true,
|
||||
CircuitState::Open => {
|
||||
// 检查是否应该尝试半开
|
||||
if let Some(opened_at) = *self.last_opened_at.read().await {
|
||||
if opened_at.elapsed().as_secs() >= self.config.timeout_seconds {
|
||||
if opened_at.elapsed().as_secs() >= config.timeout_seconds {
|
||||
drop(config); // 释放读锁再转换状态
|
||||
log::info!(
|
||||
"Circuit breaker transitioning from Open to HalfOpen (timeout reached)"
|
||||
);
|
||||
@@ -109,13 +120,36 @@ impl CircuitBreaker {
|
||||
}
|
||||
false
|
||||
}
|
||||
CircuitState::HalfOpen => true,
|
||||
CircuitState::HalfOpen => {
|
||||
// 半开状态限流:只允许有限请求通过进行探测
|
||||
// 默认最多允许 1 个请求(可在配置中扩展)
|
||||
let max_half_open_requests = 1u32;
|
||||
let current = self.half_open_requests.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
if current < max_half_open_requests {
|
||||
log::debug!(
|
||||
"Circuit breaker HalfOpen: allowing probe request ({}/{})",
|
||||
current + 1,
|
||||
max_half_open_requests
|
||||
);
|
||||
true
|
||||
} else {
|
||||
// 超过限额,回退计数,拒绝请求
|
||||
self.half_open_requests.fetch_sub(1, Ordering::SeqCst);
|
||||
log::debug!(
|
||||
"Circuit breaker HalfOpen: rejecting request (limit reached: {})",
|
||||
max_half_open_requests
|
||||
);
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 记录成功
|
||||
pub async fn record_success(&self) {
|
||||
let state = *self.state.read().await;
|
||||
let config = self.config.read().await;
|
||||
|
||||
// 重置失败计数
|
||||
self.consecutive_failures.store(0, Ordering::SeqCst);
|
||||
@@ -123,14 +157,18 @@ impl CircuitBreaker {
|
||||
|
||||
match state {
|
||||
CircuitState::HalfOpen => {
|
||||
// 释放 in-flight 名额(探测请求结束)
|
||||
self.half_open_requests.fetch_sub(1, Ordering::SeqCst);
|
||||
|
||||
let successes = self.consecutive_successes.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
log::debug!(
|
||||
"Circuit breaker HalfOpen: {} consecutive successes (threshold: {})",
|
||||
successes,
|
||||
self.config.success_threshold
|
||||
config.success_threshold
|
||||
);
|
||||
|
||||
if successes >= self.config.success_threshold {
|
||||
if successes >= config.success_threshold {
|
||||
drop(config); // 释放读锁再转换状态
|
||||
log::info!("Circuit breaker transitioning from HalfOpen to Closed (success threshold reached)");
|
||||
self.transition_to_closed().await;
|
||||
}
|
||||
@@ -145,6 +183,7 @@ impl CircuitBreaker {
|
||||
/// 记录失败
|
||||
pub async fn record_failure(&self) {
|
||||
let state = *self.state.read().await;
|
||||
let config = self.config.read().await;
|
||||
|
||||
// 更新计数器
|
||||
let failures = self.consecutive_failures.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
@@ -158,26 +197,38 @@ impl CircuitBreaker {
|
||||
"Circuit breaker {:?}: {} consecutive failures (threshold: {})",
|
||||
state,
|
||||
failures,
|
||||
self.config.failure_threshold
|
||||
config.failure_threshold
|
||||
);
|
||||
|
||||
// 检查是否应该打开熔断器
|
||||
match state {
|
||||
CircuitState::Closed | CircuitState::HalfOpen => {
|
||||
CircuitState::HalfOpen => {
|
||||
// 释放 in-flight 名额(探测请求结束)
|
||||
self.half_open_requests.fetch_sub(1, Ordering::SeqCst);
|
||||
|
||||
// HalfOpen 状态下失败,立即转为 Open
|
||||
log::warn!(
|
||||
"Circuit breaker HalfOpen probe failed, transitioning to Open"
|
||||
);
|
||||
drop(config);
|
||||
self.transition_to_open().await;
|
||||
}
|
||||
CircuitState::Closed => {
|
||||
// 检查连续失败次数
|
||||
if failures >= self.config.failure_threshold {
|
||||
if failures >= config.failure_threshold {
|
||||
log::warn!(
|
||||
"Circuit breaker opening due to {} consecutive failures (threshold: {})",
|
||||
failures,
|
||||
self.config.failure_threshold
|
||||
config.failure_threshold
|
||||
);
|
||||
drop(config); // 释放读锁再转换状态
|
||||
self.transition_to_open().await;
|
||||
} else {
|
||||
// 检查错误率
|
||||
let total = self.total_requests.load(Ordering::SeqCst);
|
||||
let failed = self.failed_requests.load(Ordering::SeqCst);
|
||||
|
||||
if total >= self.config.min_requests {
|
||||
if total >= config.min_requests {
|
||||
let error_rate = failed as f64 / total as f64;
|
||||
log::debug!(
|
||||
"Circuit breaker error rate: {:.2}% ({}/{} requests)",
|
||||
@@ -186,12 +237,13 @@ impl CircuitBreaker {
|
||||
total
|
||||
);
|
||||
|
||||
if error_rate >= self.config.error_rate_threshold {
|
||||
if error_rate >= config.error_rate_threshold {
|
||||
log::warn!(
|
||||
"Circuit breaker opening due to high error rate: {:.2}% (threshold: {:.2}%)",
|
||||
error_rate * 100.0,
|
||||
self.config.error_rate_threshold * 100.0
|
||||
config.error_rate_threshold * 100.0
|
||||
);
|
||||
drop(config); // 释放读锁再转换状态
|
||||
self.transition_to_open().await;
|
||||
}
|
||||
}
|
||||
@@ -237,6 +289,8 @@ impl CircuitBreaker {
|
||||
async fn transition_to_half_open(&self) {
|
||||
*self.state.write().await = CircuitState::HalfOpen;
|
||||
self.consecutive_successes.store(0, Ordering::SeqCst);
|
||||
// 重置半开状态的请求限流计数
|
||||
self.half_open_requests.store(0, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
/// 转换到关闭状态
|
||||
|
||||
@@ -4,12 +4,12 @@
|
||||
|
||||
use super::{
|
||||
error::*,
|
||||
provider_router::ProviderRouter as NewProviderRouter,
|
||||
provider_router::ProviderRouter,
|
||||
providers::{get_adapter, ProviderAdapter},
|
||||
types::ProxyStatus,
|
||||
ProxyError,
|
||||
};
|
||||
use crate::{app_config::AppType, database::Database, provider::Provider};
|
||||
use crate::{app_config::AppType, provider::Provider};
|
||||
use reqwest::{Client, Response};
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
@@ -18,8 +18,9 @@ use tokio::sync::RwLock;
|
||||
|
||||
pub struct RequestForwarder {
|
||||
client: Client,
|
||||
router: Arc<NewProviderRouter>,
|
||||
#[allow(dead_code)]
|
||||
/// 共享的 ProviderRouter(持有熔断器状态)
|
||||
router: Arc<ProviderRouter>,
|
||||
/// 单个 Provider 内的最大重试次数
|
||||
max_retries: u8,
|
||||
status: Arc<RwLock<ProxyStatus>>,
|
||||
current_providers: Arc<RwLock<std::collections::HashMap<String, (String, String)>>>,
|
||||
@@ -27,7 +28,7 @@ pub struct RequestForwarder {
|
||||
|
||||
impl RequestForwarder {
|
||||
pub fn new(
|
||||
db: Arc<Database>,
|
||||
router: Arc<ProviderRouter>,
|
||||
timeout_secs: u64,
|
||||
max_retries: u8,
|
||||
status: Arc<RwLock<ProxyStatus>>,
|
||||
@@ -44,32 +45,85 @@ impl RequestForwarder {
|
||||
|
||||
Self {
|
||||
client,
|
||||
router: Arc::new(NewProviderRouter::new(db)),
|
||||
router,
|
||||
max_retries,
|
||||
status,
|
||||
current_providers,
|
||||
}
|
||||
}
|
||||
|
||||
/// 对单个 Provider 执行请求(带重试)
|
||||
///
|
||||
/// 在同一个 Provider 上最多重试 max_retries 次,使用指数退避
|
||||
async fn forward_with_provider_retry(
|
||||
&self,
|
||||
provider: &Provider,
|
||||
endpoint: &str,
|
||||
body: &Value,
|
||||
headers: &axum::http::HeaderMap,
|
||||
adapter: &dyn ProviderAdapter,
|
||||
) -> Result<Response, ProxyError> {
|
||||
let mut last_error = None;
|
||||
|
||||
for attempt in 0..=self.max_retries {
|
||||
if attempt > 0 {
|
||||
// 指数退避:100ms, 200ms, 400ms, ...
|
||||
let delay_ms = 100 * 2u64.pow(attempt as u32 - 1);
|
||||
log::info!(
|
||||
"[{}] 重试第 {}/{} 次(等待 {}ms)",
|
||||
adapter.name(),
|
||||
attempt,
|
||||
self.max_retries,
|
||||
delay_ms
|
||||
);
|
||||
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
|
||||
}
|
||||
|
||||
match self.forward(provider, endpoint, body, headers, adapter).await {
|
||||
Ok(response) => return Ok(response),
|
||||
Err(e) => {
|
||||
let category = self.categorize_proxy_error(&e);
|
||||
|
||||
// 只有可重试的错误才继续重试
|
||||
if category == ErrorCategory::NonRetryable {
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
log::debug!(
|
||||
"[{}] Provider {} 第 {} 次请求失败: {}",
|
||||
adapter.name(),
|
||||
provider.name,
|
||||
attempt + 1,
|
||||
e
|
||||
);
|
||||
last_error = Some(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_error.unwrap_or(ProxyError::MaxRetriesExceeded))
|
||||
}
|
||||
|
||||
/// 转发请求(带故障转移)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `app_type` - 应用类型
|
||||
/// * `endpoint` - API 端点
|
||||
/// * `body` - 请求体
|
||||
/// * `headers` - 请求头
|
||||
/// * `providers` - 已选择的 Provider 列表(由 RequestContext 提供,避免重复调用 select_providers)
|
||||
pub async fn forward_with_retry(
|
||||
&self,
|
||||
app_type: &AppType,
|
||||
endpoint: &str,
|
||||
body: Value,
|
||||
headers: axum::http::HeaderMap,
|
||||
providers: Vec<Provider>,
|
||||
) -> Result<Response, ProxyError> {
|
||||
// 获取适配器
|
||||
let adapter = get_adapter(app_type);
|
||||
let app_type_str = app_type.as_str();
|
||||
|
||||
// 使用新的 ProviderRouter 选择所有可用供应商
|
||||
let providers = self
|
||||
.router
|
||||
.select_providers(app_type_str)
|
||||
.await
|
||||
.map_err(|e| ProxyError::DatabaseError(e.to_string()))?;
|
||||
|
||||
if providers.is_empty() {
|
||||
return Err(ProxyError::NoAvailableProvider);
|
||||
}
|
||||
@@ -108,9 +162,9 @@ impl RequestForwarder {
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
// 转发请求
|
||||
// 转发请求(带单 Provider 内重试)
|
||||
match self
|
||||
.forward(provider, endpoint, &body, &headers, adapter.as_ref())
|
||||
.forward_with_provider_retry(provider, endpoint, &body, &headers, adapter.as_ref())
|
||||
.await
|
||||
{
|
||||
Ok(response) => {
|
||||
@@ -373,17 +427,22 @@ impl RequestForwarder {
|
||||
}
|
||||
|
||||
/// 分类ProxyError
|
||||
///
|
||||
/// 决定哪些错误应该触发故障转移到下一个 Provider
|
||||
fn categorize_proxy_error(&self, error: &ProxyError) -> ErrorCategory {
|
||||
match error {
|
||||
ProxyError::Timeout(_) => ErrorCategory::Retryable,
|
||||
ProxyError::ForwardFailed(_) => ErrorCategory::Retryable,
|
||||
ProxyError::UpstreamError { status, .. } => {
|
||||
if *status >= 500 {
|
||||
ErrorCategory::Retryable
|
||||
} else if *status >= 400 && *status < 500 {
|
||||
ErrorCategory::NonRetryable
|
||||
} else {
|
||||
ErrorCategory::Retryable
|
||||
match *status {
|
||||
// 速率限制 - 应该尝试其他 Provider
|
||||
429 => ErrorCategory::Retryable,
|
||||
// 请求超时
|
||||
408 => ErrorCategory::Retryable,
|
||||
// 服务器错误
|
||||
s if s >= 500 => ErrorCategory::Retryable,
|
||||
// 其他 4xx 错误(认证失败、参数错误等)不应重试
|
||||
_ => ErrorCategory::NonRetryable,
|
||||
}
|
||||
}
|
||||
ProxyError::ProviderUnhealthy(_) => ErrorCategory::Retryable,
|
||||
|
||||
@@ -5,8 +5,7 @@
|
||||
use crate::app_config::AppType;
|
||||
use crate::provider::Provider;
|
||||
use crate::proxy::{
|
||||
forwarder::RequestForwarder, router::ProviderRouter, server::ProxyState, types::ProxyConfig,
|
||||
ProxyError,
|
||||
forwarder::RequestForwarder, server::ProxyState, types::ProxyConfig, ProxyError,
|
||||
};
|
||||
use std::time::Instant;
|
||||
|
||||
@@ -15,7 +14,7 @@ use std::time::Instant;
|
||||
/// 贯穿整个请求生命周期,包含:
|
||||
/// - 计时信息
|
||||
/// - 代理配置
|
||||
/// - 选中的 Provider
|
||||
/// - 选中的 Provider 列表(用于故障转移)
|
||||
/// - 请求模型名称
|
||||
/// - 日志标签
|
||||
pub struct RequestContext {
|
||||
@@ -23,8 +22,10 @@ pub struct RequestContext {
|
||||
pub start_time: Instant,
|
||||
/// 代理配置快照
|
||||
pub config: ProxyConfig,
|
||||
/// 选中的 Provider
|
||||
/// 选中的 Provider(故障转移链的第一个)
|
||||
pub provider: Provider,
|
||||
/// 完整的 Provider 列表(用于故障转移)
|
||||
providers: Vec<Provider>,
|
||||
/// 请求中的模型名称
|
||||
pub request_model: String,
|
||||
/// 日志标签(如 "Claude"、"Codex"、"Gemini")
|
||||
@@ -65,21 +66,32 @@ impl RequestContext {
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
|
||||
// Provider 选择
|
||||
let router = ProviderRouter::new(state.db.clone());
|
||||
let provider = router.select_provider(&app_type, &[]).await?;
|
||||
// 使用共享的 ProviderRouter 选择 Provider(熔断器状态跨请求保持)
|
||||
// 注意:只在这里调用一次,结果传递给 forwarder,避免重复消耗 HalfOpen 名额
|
||||
let providers = state
|
||||
.provider_router
|
||||
.select_providers(app_type_str)
|
||||
.await
|
||||
.map_err(|e| ProxyError::DatabaseError(e.to_string()))?;
|
||||
|
||||
let provider = providers
|
||||
.first()
|
||||
.cloned()
|
||||
.ok_or(ProxyError::NoAvailableProvider)?;
|
||||
|
||||
log::info!(
|
||||
"[{}] Provider: {}, model: {}",
|
||||
"[{}] Provider: {}, model: {}, failover chain: {} providers",
|
||||
tag,
|
||||
provider.name,
|
||||
request_model
|
||||
request_model,
|
||||
providers.len()
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
start_time,
|
||||
config,
|
||||
provider,
|
||||
providers,
|
||||
request_model,
|
||||
tag,
|
||||
app_type_str,
|
||||
@@ -110,9 +122,11 @@ impl RequestContext {
|
||||
}
|
||||
|
||||
/// 创建 RequestForwarder
|
||||
///
|
||||
/// 使用共享的 ProviderRouter,确保熔断器状态跨请求保持
|
||||
pub fn create_forwarder(&self, state: &ProxyState) -> RequestForwarder {
|
||||
RequestForwarder::new(
|
||||
state.db.clone(),
|
||||
state.provider_router.clone(),
|
||||
self.config.request_timeout,
|
||||
self.config.max_retries,
|
||||
state.status.clone(),
|
||||
@@ -120,6 +134,13 @@ impl RequestContext {
|
||||
)
|
||||
}
|
||||
|
||||
/// 获取 Provider 列表(用于故障转移)
|
||||
///
|
||||
/// 返回在创建上下文时已选择的 providers,避免重复调用 select_providers()
|
||||
pub fn get_providers(&self) -> Vec<Provider> {
|
||||
self.providers.clone()
|
||||
}
|
||||
|
||||
/// 计算请求延迟(毫秒)
|
||||
#[inline]
|
||||
pub fn latency_ms(&self) -> u64 {
|
||||
|
||||
@@ -84,7 +84,13 @@ pub async fn handle_messages(
|
||||
// 转发请求
|
||||
let forwarder = ctx.create_forwarder(&state);
|
||||
let response = forwarder
|
||||
.forward_with_retry(&AppType::Claude, "/v1/messages", body.clone(), headers)
|
||||
.forward_with_retry(
|
||||
&AppType::Claude,
|
||||
"/v1/messages",
|
||||
body.clone(),
|
||||
headers,
|
||||
ctx.get_providers(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
@@ -299,7 +305,13 @@ pub async fn handle_chat_completions(
|
||||
|
||||
let forwarder = ctx.create_forwarder(&state);
|
||||
let response = forwarder
|
||||
.forward_with_retry(&AppType::Codex, "/v1/chat/completions", body, headers)
|
||||
.forward_with_retry(
|
||||
&AppType::Codex,
|
||||
"/v1/chat/completions",
|
||||
body,
|
||||
headers,
|
||||
ctx.get_providers(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
log::info!("[Codex] 上游响应状态: {}", response.status());
|
||||
@@ -317,7 +329,13 @@ pub async fn handle_responses(
|
||||
|
||||
let forwarder = ctx.create_forwarder(&state);
|
||||
let response = forwarder
|
||||
.forward_with_retry(&AppType::Codex, "/v1/responses", body, headers)
|
||||
.forward_with_retry(
|
||||
&AppType::Codex,
|
||||
"/v1/responses",
|
||||
body,
|
||||
headers,
|
||||
ctx.get_providers(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
log::info!("[Codex] 上游响应状态: {}", response.status());
|
||||
@@ -351,7 +369,13 @@ pub async fn handle_gemini(
|
||||
|
||||
let forwarder = ctx.create_forwarder(&state);
|
||||
let response = forwarder
|
||||
.forward_with_retry(&AppType::Gemini, endpoint, body, headers)
|
||||
.forward_with_retry(
|
||||
&AppType::Gemini,
|
||||
endpoint,
|
||||
body,
|
||||
headers,
|
||||
ctx.get_providers(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
log::info!("[Gemini] 上游响应状态: {}", response.status());
|
||||
|
||||
@@ -13,7 +13,6 @@ pub mod provider_router;
|
||||
pub mod providers;
|
||||
pub mod response_handler;
|
||||
pub mod response_processor;
|
||||
mod router;
|
||||
pub(crate) mod server;
|
||||
pub mod session;
|
||||
pub(crate) mod types;
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
use crate::database::Database;
|
||||
use crate::error::AppError;
|
||||
use crate::provider::Provider;
|
||||
use crate::proxy::circuit_breaker::CircuitBreaker;
|
||||
use crate::proxy::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
@@ -124,7 +124,11 @@ impl ProviderRouter {
|
||||
success: bool,
|
||||
error_msg: Option<String>,
|
||||
) -> Result<(), AppError> {
|
||||
// 1. 更新熔断器状态
|
||||
// 1. 获取熔断器配置(用于更新健康状态和判断是否禁用)
|
||||
let config = self.db.get_circuit_breaker_config().await.ok();
|
||||
let failure_threshold = config.map(|c| c.failure_threshold).unwrap_or(5);
|
||||
|
||||
// 2. 更新熔断器状态
|
||||
let circuit_key = format!("{app_type}:{provider_id}");
|
||||
let breaker = self.get_or_create_circuit_breaker(&circuit_key).await;
|
||||
|
||||
@@ -140,19 +144,21 @@ impl ProviderRouter {
|
||||
);
|
||||
}
|
||||
|
||||
// 2. 更新数据库健康状态
|
||||
// 3. 更新数据库健康状态(使用配置的阈值)
|
||||
self.db
|
||||
.update_provider_health(provider_id, app_type, success, error_msg.clone())
|
||||
.update_provider_health_with_threshold(
|
||||
provider_id,
|
||||
app_type,
|
||||
success,
|
||||
error_msg.clone(),
|
||||
failure_threshold,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// 3. 如果连续失败达到熔断阈值,自动禁用代理目标
|
||||
// 4. 如果连续失败达到熔断阈值,自动禁用代理目标
|
||||
if !success {
|
||||
let health = self.db.get_provider_health(provider_id, app_type).await?;
|
||||
|
||||
// 获取熔断器配置
|
||||
let config = self.db.get_circuit_breaker_config().await.ok();
|
||||
let failure_threshold = config.map(|c| c.failure_threshold).unwrap_or(5);
|
||||
|
||||
// 如果连续失败达到阈值,自动关闭该供应商的代理开关
|
||||
if health.consecutive_failures >= failure_threshold {
|
||||
log::warn!(
|
||||
@@ -171,7 +177,6 @@ impl ProviderRouter {
|
||||
}
|
||||
|
||||
/// 重置熔断器(手动恢复)
|
||||
#[allow(dead_code)]
|
||||
pub async fn reset_circuit_breaker(&self, circuit_key: &str) {
|
||||
let breakers = self.circuit_breakers.read().await;
|
||||
if let Some(breaker) = breakers.get(circuit_key) {
|
||||
@@ -180,6 +185,27 @@ impl ProviderRouter {
|
||||
}
|
||||
}
|
||||
|
||||
/// 重置指定供应商的熔断器
|
||||
pub async fn reset_provider_breaker(&self, provider_id: &str, app_type: &str) {
|
||||
let circuit_key = format!("{app_type}:{provider_id}");
|
||||
self.reset_circuit_breaker(&circuit_key).await;
|
||||
}
|
||||
|
||||
/// 更新所有熔断器的配置(热更新)
|
||||
///
|
||||
/// 当用户在 UI 中修改熔断器配置后调用此方法,
|
||||
/// 所有现有的熔断器会立即使用新配置
|
||||
pub async fn update_all_configs(&self, config: CircuitBreakerConfig) {
|
||||
let breakers = self.circuit_breakers.read().await;
|
||||
let count = breakers.len();
|
||||
|
||||
for breaker in breakers.values() {
|
||||
breaker.update_config(config.clone()).await;
|
||||
}
|
||||
|
||||
log::info!("已更新 {} 个熔断器的配置", count);
|
||||
}
|
||||
|
||||
/// 获取熔断器状态
|
||||
#[allow(dead_code)]
|
||||
pub async fn get_circuit_breaker_stats(
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
//! Provider路由器
|
||||
//!
|
||||
//! 负责选择合适的Provider进行请求转发
|
||||
|
||||
use super::ProxyError;
|
||||
use crate::{app_config::AppType, database::Database, provider::Provider};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct ProviderRouter {
|
||||
db: Arc<Database>,
|
||||
}
|
||||
|
||||
impl ProviderRouter {
|
||||
pub fn new(db: Arc<Database>) -> Self {
|
||||
Self { db }
|
||||
}
|
||||
|
||||
/// 选择Provider(只使用标记为代理目标的 Provider)
|
||||
pub async fn select_provider(
|
||||
&self,
|
||||
app_type: &AppType,
|
||||
_failed_ids: &[String],
|
||||
) -> Result<Provider, ProxyError> {
|
||||
// 1. 获取 Proxy Target Provider ID
|
||||
let proxy_target_id = self
|
||||
.db
|
||||
.get_proxy_target_provider(app_type.as_str())
|
||||
.map_err(|e| ProxyError::DatabaseError(e.to_string()))?;
|
||||
|
||||
let target_id = proxy_target_id.ok_or_else(|| {
|
||||
log::warn!("[{}] 未设置代理目标 Provider", app_type.as_str());
|
||||
ProxyError::NoAvailableProvider
|
||||
})?;
|
||||
|
||||
// 2. 获取所有 Provider
|
||||
let providers = self
|
||||
.db
|
||||
.get_all_providers(app_type.as_str())
|
||||
.map_err(|e| ProxyError::DatabaseError(e.to_string()))?;
|
||||
|
||||
// 3. 找到目标 Provider
|
||||
let target = providers.get(&target_id).ok_or_else(|| {
|
||||
log::warn!(
|
||||
"[{}] 代理目标 Provider 不存在: {}",
|
||||
app_type.as_str(),
|
||||
target_id
|
||||
);
|
||||
ProxyError::NoAvailableProvider
|
||||
})?;
|
||||
|
||||
log::info!(
|
||||
"[{}] 使用代理目标 Provider: {}",
|
||||
app_type.as_str(),
|
||||
target.name
|
||||
);
|
||||
Ok(target.clone())
|
||||
}
|
||||
|
||||
/// 更新Provider健康状态(保留接口但不影响选择)
|
||||
#[allow(dead_code)]
|
||||
pub async fn update_health(
|
||||
&self,
|
||||
_provider: &Provider,
|
||||
_app_type: &AppType,
|
||||
_success: bool,
|
||||
_error_msg: Option<String>,
|
||||
) {
|
||||
// 不再记录健康状态
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
//!
|
||||
//! 基于Axum的HTTP服务器,处理代理请求
|
||||
|
||||
use super::{handlers, types::*, ProxyError};
|
||||
use super::{handlers, provider_router::ProviderRouter, types::*, ProxyError};
|
||||
use crate::database::Database;
|
||||
use axum::{
|
||||
routing::{get, post},
|
||||
@@ -23,6 +23,8 @@ pub struct ProxyState {
|
||||
pub start_time: Arc<RwLock<Option<std::time::Instant>>>,
|
||||
/// 每个应用类型当前使用的 provider (app_type -> (provider_id, provider_name))
|
||||
pub current_providers: Arc<RwLock<std::collections::HashMap<String, (String, String)>>>,
|
||||
/// 共享的 ProviderRouter(持有熔断器状态,跨请求保持)
|
||||
pub provider_router: Arc<ProviderRouter>,
|
||||
}
|
||||
|
||||
/// 代理HTTP服务器
|
||||
@@ -36,12 +38,16 @@ pub struct ProxyServer {
|
||||
|
||||
impl ProxyServer {
|
||||
pub fn new(config: ProxyConfig, db: Arc<Database>) -> Self {
|
||||
// 创建共享的 ProviderRouter(熔断器状态将跨所有请求保持)
|
||||
let provider_router = Arc::new(ProviderRouter::new(db.clone()));
|
||||
|
||||
let state = ProxyState {
|
||||
db,
|
||||
config: Arc::new(RwLock::new(config.clone())),
|
||||
status: Arc::new(RwLock::new(ProxyStatus::default())),
|
||||
start_time: Arc::new(RwLock::new(None)),
|
||||
current_providers: Arc::new(RwLock::new(std::collections::HashMap::new())),
|
||||
provider_router,
|
||||
};
|
||||
|
||||
Self {
|
||||
@@ -192,4 +198,22 @@ impl ProxyServer {
|
||||
pub async fn apply_runtime_config(&self, config: &ProxyConfig) {
|
||||
*self.state.config.write().await = config.clone();
|
||||
}
|
||||
|
||||
/// 热更新熔断器配置
|
||||
///
|
||||
/// 将新配置应用到所有已创建的熔断器实例
|
||||
pub async fn update_circuit_breaker_configs(
|
||||
&self,
|
||||
config: super::circuit_breaker::CircuitBreakerConfig,
|
||||
) {
|
||||
self.state.provider_router.update_all_configs(config).await;
|
||||
}
|
||||
|
||||
/// 重置指定 Provider 的熔断器
|
||||
pub async fn reset_provider_circuit_breaker(&self, provider_id: &str, app_type: &str) {
|
||||
self.state
|
||||
.provider_router
|
||||
.reset_provider_breaker(provider_id, app_type)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -681,4 +681,41 @@ impl ProxyService {
|
||||
pub async fn is_running(&self) -> bool {
|
||||
self.server.read().await.is_some()
|
||||
}
|
||||
|
||||
/// 热更新熔断器配置
|
||||
///
|
||||
/// 如果代理服务器正在运行,将新配置应用到所有已创建的熔断器实例
|
||||
pub async fn update_circuit_breaker_configs(
|
||||
&self,
|
||||
config: crate::proxy::CircuitBreakerConfig,
|
||||
) -> Result<(), String> {
|
||||
if let Some(server) = self.server.read().await.as_ref() {
|
||||
server.update_circuit_breaker_configs(config).await;
|
||||
log::info!("已热更新运行中的熔断器配置");
|
||||
} else {
|
||||
log::debug!("代理服务器未运行,熔断器配置将在下次启动时生效");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 重置指定 Provider 的熔断器
|
||||
///
|
||||
/// 如果代理服务器正在运行,立即重置内存中的熔断器状态
|
||||
pub async fn reset_provider_circuit_breaker(
|
||||
&self,
|
||||
provider_id: &str,
|
||||
app_type: &str,
|
||||
) -> Result<(), String> {
|
||||
if let Some(server) = self.server.read().await.as_ref() {
|
||||
server
|
||||
.reset_provider_circuit_breaker(provider_id, app_type)
|
||||
.await;
|
||||
log::info!(
|
||||
"已重置 Provider {} (app: {}) 的熔断器",
|
||||
provider_id,
|
||||
app_type
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user