fix(proxy): resolve circuit breaker state persistence and HalfOpen deadlock

This commit addresses several critical issues in the failover system:

**Circuit breaker state persistence (previous fix)**
- Promote ProviderRouter to ProxyState for cross-request state sharing
- Remove redundant router.rs module
- Fix 429 errors to be retryable (rate limiting should try other providers)

**Hot-update circuit breaker config**
- Add update_circuit_breaker_configs() to ProxyServer and ProxyService
- Connect update_circuit_breaker_config command to running circuit breakers
- Add reset_provider_circuit_breaker() for manual breaker reset

**Fix HalfOpen deadlock bug**
- Change half_open_requests from cumulative count to in-flight count
- Release quota in record_success()/record_failure() when in HalfOpen state
- Prevents permanent deadlock when success_threshold > 1

**Fix duplicate select_providers() call**
- Store providers list in RequestContext, pass to forward_with_retry()
- Avoid consuming HalfOpen quota twice per request
- Single call to select_providers() per request lifecycle

**Add per-provider retry with exponential backoff**
- Implement forward_with_provider_retry() with configurable max_retries
- Backoff delays: 100ms, 200ms, 400ms, etc.
This commit is contained in:
Jason
2025-12-13 22:47:49 +08:00
parent 5d424b1383
commit 5a5ca2a989
11 changed files with 347 additions and 141 deletions

View File

@@ -130,15 +130,17 @@ pub async fn reset_circuit_breaker(
provider_id: String,
app_type: String,
) -> Result<(), String> {
// 重置数据库健康状态
// 1. 重置数据库健康状态
let db = &state.db;
db.update_provider_health(&provider_id, &app_type, true, None)
.await
.map_err(|e| e.to_string())?;
// 注意:熔断器状态在内存中,重启代理服务器后会重置
// 如果代理服务器正在运行,需要通知它重置熔断器
// 目前先通过数据库重置健康状态,熔断器会在下次超时后自动尝试半开
// 2. 如果代理正在运行,重置内存中的熔断器状态
state
.proxy_service
.reset_provider_circuit_breaker(&provider_id, &app_type)
.await?;
Ok(())
}
@@ -161,9 +163,19 @@ pub async fn update_circuit_breaker_config(
config: CircuitBreakerConfig,
) -> Result<(), String> {
let db = &state.db;
// 1. 更新数据库配置
db.update_circuit_breaker_config(&config)
.await
.map_err(|e| e.to_string())
.map_err(|e| e.to_string())?;
// 2. 如果代理正在运行,热更新内存中的熔断器配置
state
.proxy_service
.update_circuit_breaker_configs(config)
.await?;
Ok(())
}
/// 获取熔断器统计信息(仅当代理服务器运行时)

View File

@@ -129,12 +129,31 @@ impl Database {
}
/// 更新Provider健康状态
///
/// 使用默认阈值5判断是否健康建议使用 `update_provider_health_with_threshold` 传入配置的阈值
pub async fn update_provider_health(
&self,
provider_id: &str,
app_type: &str,
success: bool,
error_msg: Option<String>,
) -> Result<(), AppError> {
// 默认阈值与 CircuitBreakerConfig::default() 保持一致
self.update_provider_health_with_threshold(provider_id, app_type, success, error_msg, 5)
.await
}
/// 更新Provider健康状态带阈值参数
///
/// # Arguments
/// * `failure_threshold` - 连续失败多少次后标记为不健康
pub async fn update_provider_health_with_threshold(
&self,
provider_id: &str,
app_type: &str,
success: bool,
error_msg: Option<String>,
failure_threshold: u32,
) -> Result<(), AppError> {
let conn = lock_conn!(self.conn);
@@ -142,7 +161,7 @@ impl Database {
// 先查询当前状态
let current = conn.query_row(
"SELECT consecutive_failures FROM provider_health
"SELECT consecutive_failures FROM provider_health
WHERE provider_id = ?1 AND app_type = ?2",
rusqlite::params![provider_id, app_type],
|row| Ok(row.get::<_, i64>(0)? as u32),
@@ -154,7 +173,8 @@ impl Database {
} else {
// 失败:增加失败计数
let failures = current.unwrap_or(0) + 1;
let healthy = if failures >= 3 { 0 } else { 1 };
// 使用传入的阈值而非硬编码
let healthy = if failures >= failure_threshold { 0 } else { 1 };
(healthy, failures)
};
@@ -169,10 +189,10 @@ impl Database {
"INSERT OR REPLACE INTO provider_health
(provider_id, app_type, is_healthy, consecutive_failures,
last_success_at, last_failure_at, last_error, updated_at)
VALUES (?1, ?2, ?3, ?4,
COALESCE(?5, (SELECT last_success_at FROM provider_health
VALUES (?1, ?2, ?3, ?4,
COALESCE(?5, (SELECT last_success_at FROM provider_health
WHERE provider_id = ?1 AND app_type = ?2)),
COALESCE(?6, (SELECT last_failure_at FROM provider_health
COALESCE(?6, (SELECT last_failure_at FROM provider_health
WHERE provider_id = ?1 AND app_type = ?2)),
?7, ?8)",
rusqlite::params![

View File

@@ -72,8 +72,10 @@ pub struct CircuitBreaker {
failed_requests: Arc<AtomicU32>,
/// 上次打开时间
last_opened_at: Arc<RwLock<Option<Instant>>>,
/// 配置
config: CircuitBreakerConfig,
/// 配置(支持热更新)
config: Arc<RwLock<CircuitBreakerConfig>>,
/// 半开状态已放行的请求数(用于限流)
half_open_requests: Arc<AtomicU32>,
}
impl CircuitBreaker {
@@ -86,20 +88,29 @@ impl CircuitBreaker {
total_requests: Arc::new(AtomicU32::new(0)),
failed_requests: Arc::new(AtomicU32::new(0)),
last_opened_at: Arc::new(RwLock::new(None)),
config,
config: Arc::new(RwLock::new(config)),
half_open_requests: Arc::new(AtomicU32::new(0)),
}
}
/// 更新熔断器配置(热更新,不重置状态)
pub async fn update_config(&self, new_config: CircuitBreakerConfig) {
*self.config.write().await = new_config;
log::debug!("Circuit breaker config updated");
}
/// 检查是否允许请求通过
pub async fn allow_request(&self) -> bool {
let state = *self.state.read().await;
let config = self.config.read().await;
match state {
CircuitState::Closed => true,
CircuitState::Open => {
// 检查是否应该尝试半开
if let Some(opened_at) = *self.last_opened_at.read().await {
if opened_at.elapsed().as_secs() >= self.config.timeout_seconds {
if opened_at.elapsed().as_secs() >= config.timeout_seconds {
drop(config); // 释放读锁再转换状态
log::info!(
"Circuit breaker transitioning from Open to HalfOpen (timeout reached)"
);
@@ -109,13 +120,36 @@ impl CircuitBreaker {
}
false
}
CircuitState::HalfOpen => true,
CircuitState::HalfOpen => {
// 半开状态限流:只允许有限请求通过进行探测
// 默认最多允许 1 个请求(可在配置中扩展)
let max_half_open_requests = 1u32;
let current = self.half_open_requests.fetch_add(1, Ordering::SeqCst);
if current < max_half_open_requests {
log::debug!(
"Circuit breaker HalfOpen: allowing probe request ({}/{})",
current + 1,
max_half_open_requests
);
true
} else {
// 超过限额,回退计数,拒绝请求
self.half_open_requests.fetch_sub(1, Ordering::SeqCst);
log::debug!(
"Circuit breaker HalfOpen: rejecting request (limit reached: {})",
max_half_open_requests
);
false
}
}
}
}
/// 记录成功
pub async fn record_success(&self) {
let state = *self.state.read().await;
let config = self.config.read().await;
// 重置失败计数
self.consecutive_failures.store(0, Ordering::SeqCst);
@@ -123,14 +157,18 @@ impl CircuitBreaker {
match state {
CircuitState::HalfOpen => {
// 释放 in-flight 名额(探测请求结束)
self.half_open_requests.fetch_sub(1, Ordering::SeqCst);
let successes = self.consecutive_successes.fetch_add(1, Ordering::SeqCst) + 1;
log::debug!(
"Circuit breaker HalfOpen: {} consecutive successes (threshold: {})",
successes,
self.config.success_threshold
config.success_threshold
);
if successes >= self.config.success_threshold {
if successes >= config.success_threshold {
drop(config); // 释放读锁再转换状态
log::info!("Circuit breaker transitioning from HalfOpen to Closed (success threshold reached)");
self.transition_to_closed().await;
}
@@ -145,6 +183,7 @@ impl CircuitBreaker {
/// 记录失败
pub async fn record_failure(&self) {
let state = *self.state.read().await;
let config = self.config.read().await;
// 更新计数器
let failures = self.consecutive_failures.fetch_add(1, Ordering::SeqCst) + 1;
@@ -158,26 +197,38 @@ impl CircuitBreaker {
"Circuit breaker {:?}: {} consecutive failures (threshold: {})",
state,
failures,
self.config.failure_threshold
config.failure_threshold
);
// 检查是否应该打开熔断器
match state {
CircuitState::Closed | CircuitState::HalfOpen => {
CircuitState::HalfOpen => {
// 释放 in-flight 名额(探测请求结束)
self.half_open_requests.fetch_sub(1, Ordering::SeqCst);
// HalfOpen 状态下失败,立即转为 Open
log::warn!(
"Circuit breaker HalfOpen probe failed, transitioning to Open"
);
drop(config);
self.transition_to_open().await;
}
CircuitState::Closed => {
// 检查连续失败次数
if failures >= self.config.failure_threshold {
if failures >= config.failure_threshold {
log::warn!(
"Circuit breaker opening due to {} consecutive failures (threshold: {})",
failures,
self.config.failure_threshold
config.failure_threshold
);
drop(config); // 释放读锁再转换状态
self.transition_to_open().await;
} else {
// 检查错误率
let total = self.total_requests.load(Ordering::SeqCst);
let failed = self.failed_requests.load(Ordering::SeqCst);
if total >= self.config.min_requests {
if total >= config.min_requests {
let error_rate = failed as f64 / total as f64;
log::debug!(
"Circuit breaker error rate: {:.2}% ({}/{} requests)",
@@ -186,12 +237,13 @@ impl CircuitBreaker {
total
);
if error_rate >= self.config.error_rate_threshold {
if error_rate >= config.error_rate_threshold {
log::warn!(
"Circuit breaker opening due to high error rate: {:.2}% (threshold: {:.2}%)",
error_rate * 100.0,
self.config.error_rate_threshold * 100.0
config.error_rate_threshold * 100.0
);
drop(config); // 释放读锁再转换状态
self.transition_to_open().await;
}
}
@@ -237,6 +289,8 @@ impl CircuitBreaker {
async fn transition_to_half_open(&self) {
*self.state.write().await = CircuitState::HalfOpen;
self.consecutive_successes.store(0, Ordering::SeqCst);
// 重置半开状态的请求限流计数
self.half_open_requests.store(0, Ordering::SeqCst);
}
/// 转换到关闭状态

View File

@@ -4,12 +4,12 @@
use super::{
error::*,
provider_router::ProviderRouter as NewProviderRouter,
provider_router::ProviderRouter,
providers::{get_adapter, ProviderAdapter},
types::ProxyStatus,
ProxyError,
};
use crate::{app_config::AppType, database::Database, provider::Provider};
use crate::{app_config::AppType, provider::Provider};
use reqwest::{Client, Response};
use serde_json::Value;
use std::sync::Arc;
@@ -18,8 +18,9 @@ use tokio::sync::RwLock;
pub struct RequestForwarder {
client: Client,
router: Arc<NewProviderRouter>,
#[allow(dead_code)]
/// 共享的 ProviderRouter(持有熔断器状态)
router: Arc<ProviderRouter>,
/// 单个 Provider 内的最大重试次数
max_retries: u8,
status: Arc<RwLock<ProxyStatus>>,
current_providers: Arc<RwLock<std::collections::HashMap<String, (String, String)>>>,
@@ -27,7 +28,7 @@ pub struct RequestForwarder {
impl RequestForwarder {
pub fn new(
db: Arc<Database>,
router: Arc<ProviderRouter>,
timeout_secs: u64,
max_retries: u8,
status: Arc<RwLock<ProxyStatus>>,
@@ -44,32 +45,85 @@ impl RequestForwarder {
Self {
client,
router: Arc::new(NewProviderRouter::new(db)),
router,
max_retries,
status,
current_providers,
}
}
/// 对单个 Provider 执行请求(带重试)
///
/// 在同一个 Provider 上最多重试 max_retries 次,使用指数退避
async fn forward_with_provider_retry(
&self,
provider: &Provider,
endpoint: &str,
body: &Value,
headers: &axum::http::HeaderMap,
adapter: &dyn ProviderAdapter,
) -> Result<Response, ProxyError> {
let mut last_error = None;
for attempt in 0..=self.max_retries {
if attempt > 0 {
// 指数退避100ms, 200ms, 400ms, ...
let delay_ms = 100 * 2u64.pow(attempt as u32 - 1);
log::info!(
"[{}] 重试第 {}/{} 次(等待 {}ms",
adapter.name(),
attempt,
self.max_retries,
delay_ms
);
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
match self.forward(provider, endpoint, body, headers, adapter).await {
Ok(response) => return Ok(response),
Err(e) => {
let category = self.categorize_proxy_error(&e);
// 只有可重试的错误才继续重试
if category == ErrorCategory::NonRetryable {
return Err(e);
}
log::debug!(
"[{}] Provider {} 第 {} 次请求失败: {}",
adapter.name(),
provider.name,
attempt + 1,
e
);
last_error = Some(e);
}
}
}
Err(last_error.unwrap_or(ProxyError::MaxRetriesExceeded))
}
/// 转发请求(带故障转移)
///
/// # Arguments
/// * `app_type` - 应用类型
/// * `endpoint` - API 端点
/// * `body` - 请求体
/// * `headers` - 请求头
/// * `providers` - 已选择的 Provider 列表(由 RequestContext 提供,避免重复调用 select_providers
pub async fn forward_with_retry(
&self,
app_type: &AppType,
endpoint: &str,
body: Value,
headers: axum::http::HeaderMap,
providers: Vec<Provider>,
) -> Result<Response, ProxyError> {
// 获取适配器
let adapter = get_adapter(app_type);
let app_type_str = app_type.as_str();
// 使用新的 ProviderRouter 选择所有可用供应商
let providers = self
.router
.select_providers(app_type_str)
.await
.map_err(|e| ProxyError::DatabaseError(e.to_string()))?;
if providers.is_empty() {
return Err(ProxyError::NoAvailableProvider);
}
@@ -108,9 +162,9 @@ impl RequestForwarder {
let start = Instant::now();
// 转发请求
// 转发请求(带单 Provider 内重试)
match self
.forward(provider, endpoint, &body, &headers, adapter.as_ref())
.forward_with_provider_retry(provider, endpoint, &body, &headers, adapter.as_ref())
.await
{
Ok(response) => {
@@ -373,17 +427,22 @@ impl RequestForwarder {
}
/// 分类ProxyError
///
/// 决定哪些错误应该触发故障转移到下一个 Provider
fn categorize_proxy_error(&self, error: &ProxyError) -> ErrorCategory {
match error {
ProxyError::Timeout(_) => ErrorCategory::Retryable,
ProxyError::ForwardFailed(_) => ErrorCategory::Retryable,
ProxyError::UpstreamError { status, .. } => {
if *status >= 500 {
ErrorCategory::Retryable
} else if *status >= 400 && *status < 500 {
ErrorCategory::NonRetryable
} else {
ErrorCategory::Retryable
match *status {
// 速率限制 - 应该尝试其他 Provider
429 => ErrorCategory::Retryable,
// 请求超时
408 => ErrorCategory::Retryable,
// 服务器错误
s if s >= 500 => ErrorCategory::Retryable,
// 其他 4xx 错误(认证失败、参数错误等)不应重试
_ => ErrorCategory::NonRetryable,
}
}
ProxyError::ProviderUnhealthy(_) => ErrorCategory::Retryable,

View File

@@ -5,8 +5,7 @@
use crate::app_config::AppType;
use crate::provider::Provider;
use crate::proxy::{
forwarder::RequestForwarder, router::ProviderRouter, server::ProxyState, types::ProxyConfig,
ProxyError,
forwarder::RequestForwarder, server::ProxyState, types::ProxyConfig, ProxyError,
};
use std::time::Instant;
@@ -15,7 +14,7 @@ use std::time::Instant;
/// 贯穿整个请求生命周期,包含:
/// - 计时信息
/// - 代理配置
/// - 选中的 Provider
/// - 选中的 Provider 列表(用于故障转移)
/// - 请求模型名称
/// - 日志标签
pub struct RequestContext {
@@ -23,8 +22,10 @@ pub struct RequestContext {
pub start_time: Instant,
/// 代理配置快照
pub config: ProxyConfig,
/// 选中的 Provider
/// 选中的 Provider(故障转移链的第一个)
pub provider: Provider,
/// 完整的 Provider 列表(用于故障转移)
providers: Vec<Provider>,
/// 请求中的模型名称
pub request_model: String,
/// 日志标签(如 "Claude"、"Codex"、"Gemini"
@@ -65,21 +66,32 @@ impl RequestContext {
.unwrap_or("unknown")
.to_string();
// Provider 选择
let router = ProviderRouter::new(state.db.clone());
let provider = router.select_provider(&app_type, &[]).await?;
// 使用共享的 ProviderRouter 选择 Provider熔断器状态跨请求保持
// 注意:只在这里调用一次,结果传递给 forwarder避免重复消耗 HalfOpen 名额
let providers = state
.provider_router
.select_providers(app_type_str)
.await
.map_err(|e| ProxyError::DatabaseError(e.to_string()))?;
let provider = providers
.first()
.cloned()
.ok_or(ProxyError::NoAvailableProvider)?;
log::info!(
"[{}] Provider: {}, model: {}",
"[{}] Provider: {}, model: {}, failover chain: {} providers",
tag,
provider.name,
request_model
request_model,
providers.len()
);
Ok(Self {
start_time,
config,
provider,
providers,
request_model,
tag,
app_type_str,
@@ -110,9 +122,11 @@ impl RequestContext {
}
/// 创建 RequestForwarder
///
/// 使用共享的 ProviderRouter确保熔断器状态跨请求保持
pub fn create_forwarder(&self, state: &ProxyState) -> RequestForwarder {
RequestForwarder::new(
state.db.clone(),
state.provider_router.clone(),
self.config.request_timeout,
self.config.max_retries,
state.status.clone(),
@@ -120,6 +134,13 @@ impl RequestContext {
)
}
/// 获取 Provider 列表(用于故障转移)
///
/// 返回在创建上下文时已选择的 providers避免重复调用 select_providers()
pub fn get_providers(&self) -> Vec<Provider> {
self.providers.clone()
}
/// 计算请求延迟(毫秒)
#[inline]
pub fn latency_ms(&self) -> u64 {

View File

@@ -84,7 +84,13 @@ pub async fn handle_messages(
// 转发请求
let forwarder = ctx.create_forwarder(&state);
let response = forwarder
.forward_with_retry(&AppType::Claude, "/v1/messages", body.clone(), headers)
.forward_with_retry(
&AppType::Claude,
"/v1/messages",
body.clone(),
headers,
ctx.get_providers(),
)
.await?;
let status = response.status();
@@ -299,7 +305,13 @@ pub async fn handle_chat_completions(
let forwarder = ctx.create_forwarder(&state);
let response = forwarder
.forward_with_retry(&AppType::Codex, "/v1/chat/completions", body, headers)
.forward_with_retry(
&AppType::Codex,
"/v1/chat/completions",
body,
headers,
ctx.get_providers(),
)
.await?;
log::info!("[Codex] 上游响应状态: {}", response.status());
@@ -317,7 +329,13 @@ pub async fn handle_responses(
let forwarder = ctx.create_forwarder(&state);
let response = forwarder
.forward_with_retry(&AppType::Codex, "/v1/responses", body, headers)
.forward_with_retry(
&AppType::Codex,
"/v1/responses",
body,
headers,
ctx.get_providers(),
)
.await?;
log::info!("[Codex] 上游响应状态: {}", response.status());
@@ -351,7 +369,13 @@ pub async fn handle_gemini(
let forwarder = ctx.create_forwarder(&state);
let response = forwarder
.forward_with_retry(&AppType::Gemini, endpoint, body, headers)
.forward_with_retry(
&AppType::Gemini,
endpoint,
body,
headers,
ctx.get_providers(),
)
.await?;
log::info!("[Gemini] 上游响应状态: {}", response.status());

View File

@@ -13,7 +13,6 @@ pub mod provider_router;
pub mod providers;
pub mod response_handler;
pub mod response_processor;
mod router;
pub(crate) mod server;
pub mod session;
pub(crate) mod types;

View File

@@ -5,7 +5,7 @@
use crate::database::Database;
use crate::error::AppError;
use crate::provider::Provider;
use crate::proxy::circuit_breaker::CircuitBreaker;
use crate::proxy::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
@@ -124,7 +124,11 @@ impl ProviderRouter {
success: bool,
error_msg: Option<String>,
) -> Result<(), AppError> {
// 1. 更新熔断器状态
// 1. 获取熔断器配置(用于更新健康状态和判断是否禁用)
let config = self.db.get_circuit_breaker_config().await.ok();
let failure_threshold = config.map(|c| c.failure_threshold).unwrap_or(5);
// 2. 更新熔断器状态
let circuit_key = format!("{app_type}:{provider_id}");
let breaker = self.get_or_create_circuit_breaker(&circuit_key).await;
@@ -140,19 +144,21 @@ impl ProviderRouter {
);
}
// 2. 更新数据库健康状态
// 3. 更新数据库健康状态(使用配置的阈值)
self.db
.update_provider_health(provider_id, app_type, success, error_msg.clone())
.update_provider_health_with_threshold(
provider_id,
app_type,
success,
error_msg.clone(),
failure_threshold,
)
.await?;
// 3. 如果连续失败达到熔断阈值,自动禁用代理目标
// 4. 如果连续失败达到熔断阈值,自动禁用代理目标
if !success {
let health = self.db.get_provider_health(provider_id, app_type).await?;
// 获取熔断器配置
let config = self.db.get_circuit_breaker_config().await.ok();
let failure_threshold = config.map(|c| c.failure_threshold).unwrap_or(5);
// 如果连续失败达到阈值,自动关闭该供应商的代理开关
if health.consecutive_failures >= failure_threshold {
log::warn!(
@@ -171,7 +177,6 @@ impl ProviderRouter {
}
/// 重置熔断器(手动恢复)
#[allow(dead_code)]
pub async fn reset_circuit_breaker(&self, circuit_key: &str) {
let breakers = self.circuit_breakers.read().await;
if let Some(breaker) = breakers.get(circuit_key) {
@@ -180,6 +185,27 @@ impl ProviderRouter {
}
}
/// 重置指定供应商的熔断器
pub async fn reset_provider_breaker(&self, provider_id: &str, app_type: &str) {
let circuit_key = format!("{app_type}:{provider_id}");
self.reset_circuit_breaker(&circuit_key).await;
}
/// 更新所有熔断器的配置(热更新)
///
/// 当用户在 UI 中修改熔断器配置后调用此方法,
/// 所有现有的熔断器会立即使用新配置
pub async fn update_all_configs(&self, config: CircuitBreakerConfig) {
let breakers = self.circuit_breakers.read().await;
let count = breakers.len();
for breaker in breakers.values() {
breaker.update_config(config.clone()).await;
}
log::info!("已更新 {} 个熔断器的配置", count);
}
/// 获取熔断器状态
#[allow(dead_code)]
pub async fn get_circuit_breaker_stats(

View File

@@ -1,70 +0,0 @@
//! Provider路由器
//!
//! 负责选择合适的Provider进行请求转发
use super::ProxyError;
use crate::{app_config::AppType, database::Database, provider::Provider};
use std::sync::Arc;
pub struct ProviderRouter {
db: Arc<Database>,
}
impl ProviderRouter {
pub fn new(db: Arc<Database>) -> Self {
Self { db }
}
/// 选择Provider只使用标记为代理目标的 Provider
pub async fn select_provider(
&self,
app_type: &AppType,
_failed_ids: &[String],
) -> Result<Provider, ProxyError> {
// 1. 获取 Proxy Target Provider ID
let proxy_target_id = self
.db
.get_proxy_target_provider(app_type.as_str())
.map_err(|e| ProxyError::DatabaseError(e.to_string()))?;
let target_id = proxy_target_id.ok_or_else(|| {
log::warn!("[{}] 未设置代理目标 Provider", app_type.as_str());
ProxyError::NoAvailableProvider
})?;
// 2. 获取所有 Provider
let providers = self
.db
.get_all_providers(app_type.as_str())
.map_err(|e| ProxyError::DatabaseError(e.to_string()))?;
// 3. 找到目标 Provider
let target = providers.get(&target_id).ok_or_else(|| {
log::warn!(
"[{}] 代理目标 Provider 不存在: {}",
app_type.as_str(),
target_id
);
ProxyError::NoAvailableProvider
})?;
log::info!(
"[{}] 使用代理目标 Provider: {}",
app_type.as_str(),
target.name
);
Ok(target.clone())
}
/// 更新Provider健康状态保留接口但不影响选择
#[allow(dead_code)]
pub async fn update_health(
&self,
_provider: &Provider,
_app_type: &AppType,
_success: bool,
_error_msg: Option<String>,
) {
// 不再记录健康状态
}
}

View File

@@ -2,7 +2,7 @@
//!
//! 基于Axum的HTTP服务器处理代理请求
use super::{handlers, types::*, ProxyError};
use super::{handlers, provider_router::ProviderRouter, types::*, ProxyError};
use crate::database::Database;
use axum::{
routing::{get, post},
@@ -23,6 +23,8 @@ pub struct ProxyState {
pub start_time: Arc<RwLock<Option<std::time::Instant>>>,
/// 每个应用类型当前使用的 provider (app_type -> (provider_id, provider_name))
pub current_providers: Arc<RwLock<std::collections::HashMap<String, (String, String)>>>,
/// 共享的 ProviderRouter持有熔断器状态跨请求保持
pub provider_router: Arc<ProviderRouter>,
}
/// 代理HTTP服务器
@@ -36,12 +38,16 @@ pub struct ProxyServer {
impl ProxyServer {
pub fn new(config: ProxyConfig, db: Arc<Database>) -> Self {
// 创建共享的 ProviderRouter熔断器状态将跨所有请求保持
let provider_router = Arc::new(ProviderRouter::new(db.clone()));
let state = ProxyState {
db,
config: Arc::new(RwLock::new(config.clone())),
status: Arc::new(RwLock::new(ProxyStatus::default())),
start_time: Arc::new(RwLock::new(None)),
current_providers: Arc::new(RwLock::new(std::collections::HashMap::new())),
provider_router,
};
Self {
@@ -192,4 +198,22 @@ impl ProxyServer {
pub async fn apply_runtime_config(&self, config: &ProxyConfig) {
*self.state.config.write().await = config.clone();
}
/// 热更新熔断器配置
///
/// 将新配置应用到所有已创建的熔断器实例
pub async fn update_circuit_breaker_configs(
&self,
config: super::circuit_breaker::CircuitBreakerConfig,
) {
self.state.provider_router.update_all_configs(config).await;
}
/// 重置指定 Provider 的熔断器
pub async fn reset_provider_circuit_breaker(&self, provider_id: &str, app_type: &str) {
self.state
.provider_router
.reset_provider_breaker(provider_id, app_type)
.await;
}
}

View File

@@ -681,4 +681,41 @@ impl ProxyService {
pub async fn is_running(&self) -> bool {
self.server.read().await.is_some()
}
/// 热更新熔断器配置
///
/// 如果代理服务器正在运行,将新配置应用到所有已创建的熔断器实例
pub async fn update_circuit_breaker_configs(
&self,
config: crate::proxy::CircuitBreakerConfig,
) -> Result<(), String> {
if let Some(server) = self.server.read().await.as_ref() {
server.update_circuit_breaker_configs(config).await;
log::info!("已热更新运行中的熔断器配置");
} else {
log::debug!("代理服务器未运行,熔断器配置将在下次启动时生效");
}
Ok(())
}
/// 重置指定 Provider 的熔断器
///
/// 如果代理服务器正在运行,立即重置内存中的熔断器状态
pub async fn reset_provider_circuit_breaker(
&self,
provider_id: &str,
app_type: &str,
) -> Result<(), String> {
if let Some(server) = self.server.read().await.as_ref() {
server
.reset_provider_circuit_breaker(provider_id, app_type)
.await;
log::info!(
"已重置 Provider {} (app: {}) 的熔断器",
provider_id,
app_type
);
}
Ok(())
}
}