From 1b73b26c0ed5362bbbd898412b7fcca830c31e0c Mon Sep 17 00:00:00 2001 From: Jason Date: Wed, 17 Dec 2025 08:49:47 +0800 Subject: [PATCH] fix(proxy): resolve circuit breaker race condition and error classification This commit addresses two critical issues in the proxy failover logic: 1. Circuit Breaker HalfOpen Concurrency Bug: - Introduced `AllowResult` struct to track half-open permit usage - Added state guard in `transition_to_half_open()` to prevent duplicate resets - Replaced `fetch_sub` with CAS loop in `release_half_open_permit()` to prevent underflow - Separated `is_available()` (routing) from `allow_request()` (permit acquisition) 2. Error Classification Conflation: - Split retry logic into `should_retry_same_provider()` and `categorize_proxy_error()` - Same-provider retry: only for transient errors (timeout, 429, 5xx) - Cross-provider failover: now includes ConfigError, TransformError, AuthError - 4xx errors (401/403) no longer waste retries on the same provider --- src-tauri/src/proxy/circuit_breaker.rs | 180 ++++++++++++++++++------- src-tauri/src/proxy/forwarder.rs | 51 +++++-- src-tauri/src/proxy/provider_router.rs | 15 ++- 3 files changed, 183 insertions(+), 63 deletions(-) diff --git a/src-tauri/src/proxy/circuit_breaker.rs b/src-tauri/src/proxy/circuit_breaker.rs index b448e38c..52d9eaa1 100644 --- a/src-tauri/src/proxy/circuit_breaker.rs +++ b/src-tauri/src/proxy/circuit_breaker.rs @@ -78,6 +78,16 @@ pub struct CircuitBreaker { half_open_requests: Arc, } +/// 熔断器放行结果 +/// +/// `used_half_open_permit` 表示本次放行是否占用了 HalfOpen 探测名额。 +/// 调用方应在请求结束后把该值传回 `record_success` / `record_failure` 用于正确释放名额。 +#[derive(Debug, Clone, Copy)] +pub struct AllowResult { + pub allowed: bool, + pub used_half_open_permit: bool, +} + impl CircuitBreaker { /// 创建新的熔断器 pub fn new(config: CircuitBreakerConfig) -> Self { @@ -130,13 +140,16 @@ impl CircuitBreaker { } /// 检查是否允许请求通过 - pub async fn allow_request(&self) -> bool { + pub async fn allow_request(&self) -> AllowResult { let state = *self.state.read().await; - let config = self.config.read().await; match state { - CircuitState::Closed => true, + CircuitState::Closed => AllowResult { + allowed: true, + used_half_open_permit: false, + }, CircuitState::Open => { + let config = self.config.read().await; // 检查是否应该尝试半开 if let Some(opened_at) = *self.last_opened_at.read().await { if opened_at.elapsed().as_secs() >= config.timeout_seconds { @@ -145,52 +158,47 @@ impl CircuitBreaker { "Circuit breaker transitioning from Open to HalfOpen (timeout reached)" ); self.transition_to_half_open().await; - // 增加计数,确保 record_success/record_failure 减计数时不会下溢 - self.half_open_requests.fetch_add(1, Ordering::SeqCst); - return true; + + // 转换后按当前状态决定是否需要获取 HalfOpen 探测名额 + let current_state = *self.state.read().await; + return match current_state { + CircuitState::Closed => AllowResult { + allowed: true, + used_half_open_permit: false, + }, + CircuitState::HalfOpen => self.allow_half_open_probe(), + CircuitState::Open => AllowResult { + allowed: false, + used_half_open_permit: false, + }, + }; } } - false - } - CircuitState::HalfOpen => { - // 半开状态限流:只允许有限请求通过进行探测 - // 默认最多允许 1 个请求(可在配置中扩展) - let max_half_open_requests = 1u32; - let current = self.half_open_requests.fetch_add(1, Ordering::SeqCst); - if current < max_half_open_requests { - log::debug!( - "Circuit breaker HalfOpen: allowing probe request ({}/{})", - current + 1, - max_half_open_requests - ); - true - } else { - // 超过限额,回退计数,拒绝请求 - self.half_open_requests.fetch_sub(1, Ordering::SeqCst); - log::debug!( - "Circuit breaker HalfOpen: rejecting request (limit reached: {max_half_open_requests})" - ); - false + AllowResult { + allowed: false, + used_half_open_permit: false, } } + CircuitState::HalfOpen => self.allow_half_open_probe(), } } /// 记录成功 - pub async fn record_success(&self) { + pub async fn record_success(&self, used_half_open_permit: bool) { let state = *self.state.read().await; let config = self.config.read().await; + if used_half_open_permit { + self.release_half_open_permit(); + } + // 重置失败计数 self.consecutive_failures.store(0, Ordering::SeqCst); self.total_requests.fetch_add(1, Ordering::SeqCst); match state { CircuitState::HalfOpen => { - // 释放 in-flight 名额(探测请求结束) - self.half_open_requests.fetch_sub(1, Ordering::SeqCst); - let successes = self.consecutive_successes.fetch_add(1, Ordering::SeqCst) + 1; log::debug!( "Circuit breaker HalfOpen: {} consecutive successes (threshold: {})", @@ -212,10 +220,14 @@ impl CircuitBreaker { } /// 记录失败 - pub async fn record_failure(&self) { + pub async fn record_failure(&self, used_half_open_permit: bool) { let state = *self.state.read().await; let config = self.config.read().await; + if used_half_open_permit { + self.release_half_open_permit(); + } + // 更新计数器 let failures = self.consecutive_failures.fetch_add(1, Ordering::SeqCst) + 1; self.total_requests.fetch_add(1, Ordering::SeqCst); @@ -234,9 +246,6 @@ impl CircuitBreaker { // 检查是否应该打开熔断器 match state { CircuitState::HalfOpen => { - // 释放 in-flight 名额(探测请求结束) - self.half_open_requests.fetch_sub(1, Ordering::SeqCst); - // HalfOpen 状态下失败,立即转为 Open log::warn!("Circuit breaker HalfOpen probe failed, transitioning to Open"); drop(config); @@ -307,6 +316,56 @@ impl CircuitBreaker { self.transition_to_closed().await; } + fn allow_half_open_probe(&self) -> AllowResult { + // 半开状态限流:只允许有限请求通过进行探测 + // 默认最多允许 1 个请求(可在配置中扩展) + let max_half_open_requests = 1u32; + let current = self.half_open_requests.fetch_add(1, Ordering::SeqCst); + + if current < max_half_open_requests { + log::debug!( + "Circuit breaker HalfOpen: allowing probe request ({}/{})", + current + 1, + max_half_open_requests + ); + AllowResult { + allowed: true, + used_half_open_permit: true, + } + } else { + // 超过限额,回退计数,拒绝请求 + self.half_open_requests.fetch_sub(1, Ordering::SeqCst); + log::debug!( + "Circuit breaker HalfOpen: rejecting request (limit reached: {max_half_open_requests})" + ); + AllowResult { + allowed: false, + used_half_open_permit: false, + } + } + } + + fn release_half_open_permit(&self) { + let mut current = self.half_open_requests.load(Ordering::SeqCst); + loop { + if current == 0 { + // 理论上不应该发生:说明调用方传入的 used_half_open_permit 与实际占用不一致 + log::debug!("Circuit breaker HalfOpen permit already released (counter=0)"); + return; + } + + match self.half_open_requests.compare_exchange( + current, + current - 1, + Ordering::SeqCst, + Ordering::SeqCst, + ) { + Ok(_) => return, + Err(actual) => current = actual, + } + } + } + /// 转换到打开状态 async fn transition_to_open(&self) { *self.state.write().await = CircuitState::Open; @@ -317,7 +376,12 @@ impl CircuitBreaker { /// 转换到半开状态 async fn transition_to_half_open(&self) { - *self.state.write().await = CircuitState::HalfOpen; + let mut state = self.state.write().await; + if *state != CircuitState::Open { + return; + } + + *state = CircuitState::HalfOpen; self.consecutive_successes.store(0, Ordering::SeqCst); // 重置半开状态的请求限流计数 self.half_open_requests.store(0, Ordering::SeqCst); @@ -359,16 +423,16 @@ mod tests { // 初始状态应该是关闭 assert_eq!(breaker.get_state().await, CircuitState::Closed); - assert!(breaker.allow_request().await); + assert!(breaker.allow_request().await.allowed); // 记录 3 次失败 for _ in 0..3 { - breaker.record_failure().await; + breaker.record_failure(false).await; } // 应该转换到打开状态 assert_eq!(breaker.get_state().await, CircuitState::Open); - assert!(!breaker.allow_request().await); + assert!(!breaker.allow_request().await.allowed); } #[tokio::test] @@ -381,8 +445,8 @@ mod tests { let breaker = CircuitBreaker::new(config); // 打开熔断器 - breaker.record_failure().await; - breaker.record_failure().await; + breaker.record_failure(false).await; + breaker.record_failure(false).await; assert_eq!(breaker.get_state().await, CircuitState::Open); // 手动转换到半开状态 @@ -390,13 +454,37 @@ mod tests { assert_eq!(breaker.get_state().await, CircuitState::HalfOpen); // 记录 2 次成功 - breaker.record_success().await; - breaker.record_success().await; + breaker.record_success(false).await; + breaker.record_success(false).await; // 应该转换到关闭状态 assert_eq!(breaker.get_state().await, CircuitState::Closed); } + #[tokio::test] + async fn test_half_open_transition_does_not_reset_inflight_permit() { + let config = CircuitBreakerConfig { + timeout_seconds: 0, + ..Default::default() + }; + let breaker = CircuitBreaker::new(config); + + // 进入 Open,然后由于 timeout_seconds=0,allow_request 会立即切换到 HalfOpen 并占用探测名额 + breaker.transition_to_open().await; + let first = breaker.allow_request().await; + assert!(first.allowed); + assert!(first.used_half_open_permit); + assert_eq!(breaker.get_state().await, CircuitState::HalfOpen); + + // 模拟并发下的“重复 HalfOpen 转换调用”,不应重置 in-flight 计数 + breaker.transition_to_half_open().await; + + // 由于名额仍被占用,第二次请求应被拒绝 + let second = breaker.allow_request().await; + assert!(!second.allowed); + assert!(!second.used_half_open_permit); + } + #[tokio::test] async fn test_circuit_breaker_reset() { let config = CircuitBreakerConfig { @@ -406,13 +494,13 @@ mod tests { let breaker = CircuitBreaker::new(config); // 打开熔断器 - breaker.record_failure().await; - breaker.record_failure().await; + breaker.record_failure(false).await; + breaker.record_failure(false).await; assert_eq!(breaker.get_state().await, CircuitState::Open); // 重置 breaker.reset().await; assert_eq!(breaker.get_state().await, CircuitState::Closed); - assert!(breaker.allow_request().await); + assert!(breaker.allow_request().await.allowed); } } diff --git a/src-tauri/src/proxy/forwarder.rs b/src-tauri/src/proxy/forwarder.rs index 18a88e3c..fa47ada0 100644 --- a/src-tauri/src/proxy/forwarder.rs +++ b/src-tauri/src/proxy/forwarder.rs @@ -94,10 +94,8 @@ impl RequestForwarder { { Ok(response) => return Ok(response), Err(e) => { - let category = self.categorize_proxy_error(&e); - - // 只有可重试的错误才继续重试 - if category == ErrorCategory::NonRetryable { + // 只有“同一 Provider 内可重试”的错误才继续重试 + if !self.should_retry_same_provider(&e) { return Err(e); } @@ -153,11 +151,11 @@ impl RequestForwarder { // 依次尝试每个供应商 for provider in providers.iter() { // 发起请求前先获取熔断器放行许可(HalfOpen 会占用探测名额) - if !self + let permit = self .router .allow_provider_request(&provider.id, app_type_str) - .await - { + .await; + if !permit.allowed { log::debug!( "[{}] Provider {} 熔断器拒绝本次请求,跳过", app_type_str, @@ -166,6 +164,8 @@ impl RequestForwarder { continue; } + let used_half_open_permit = permit.used_half_open_permit; + attempted_providers += 1; if attempted_providers > 1 { failover_happened = true; @@ -202,7 +202,13 @@ impl RequestForwarder { // 成功:记录成功并更新熔断器 if let Err(e) = self .router - .record_result(&provider.id, app_type_str, true, None) + .record_result( + &provider.id, + app_type_str, + used_half_open_permit, + true, + None, + ) .await { log::warn!("Failed to record success: {e}"); @@ -268,7 +274,13 @@ impl RequestForwarder { // 失败:记录失败并更新熔断器 if let Err(record_err) = self .router - .record_result(&provider.id, app_type_str, false, Some(e.to_string())) + .record_result( + &provider.id, + app_type_str, + used_half_open_permit, + false, + Some(e.to_string()), + ) .await { log::warn!("Failed to record failure: {record_err}"); @@ -489,6 +501,19 @@ impl RequestForwarder { /// /// 设计原则:既然用户配置了多个供应商,就应该让所有供应商都尝试一遍。 /// 只有明确是客户端中断的情况才不重试。 + fn should_retry_same_provider(&self, error: &ProxyError) -> bool { + match error { + // 网络类错误:短暂抖动时同一 Provider 内重试有意义 + ProxyError::Timeout(_) => true, + ProxyError::ForwardFailed(_) => true, + // 上游 HTTP 错误:只对“可能瞬态”的状态码做同 Provider 重试(其余交给 failover) + ProxyError::UpstreamError { status, .. } => { + *status == 408 || *status == 429 || *status >= 500 + } + _ => false, + } + } + fn categorize_proxy_error(&self, error: &ProxyError) -> ErrorCategory { match error { // 网络和上游错误:都应该尝试下一个供应商 @@ -499,9 +524,15 @@ impl RequestForwarder { // 原因:不同供应商有不同的限制和认证,一个供应商的 4xx 错误 // 不代表其他供应商也会失败 ProxyError::UpstreamError { .. } => ErrorCategory::Retryable, + // Provider 级配置/转换问题:换一个 Provider 可能就能成功 + ProxyError::ConfigError(_) => ErrorCategory::Retryable, + ProxyError::TransformError(_) => ErrorCategory::Retryable, + ProxyError::AuthError(_) => ErrorCategory::Retryable, + ProxyError::StreamIdleTimeout(_) => ErrorCategory::Retryable, + ProxyError::MaxRetriesExceeded => ErrorCategory::Retryable, // 无可用供应商:所有供应商都试过了,无法重试 ProxyError::NoAvailableProvider => ErrorCategory::NonRetryable, - // 其他错误(配置错误、数据库错误等):不是供应商问题,无需重试 + // 其他错误(数据库/内部错误等):不是换供应商能解决的问题 _ => ErrorCategory::NonRetryable, } } diff --git a/src-tauri/src/proxy/provider_router.rs b/src-tauri/src/proxy/provider_router.rs index 01d80bec..d6f30226 100644 --- a/src-tauri/src/proxy/provider_router.rs +++ b/src-tauri/src/proxy/provider_router.rs @@ -5,7 +5,7 @@ use crate::database::Database; use crate::error::AppError; use crate::provider::Provider; -use crate::proxy::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig}; +use crate::proxy::circuit_breaker::{AllowResult, CircuitBreaker, CircuitBreakerConfig}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; @@ -123,7 +123,7 @@ impl ProviderRouter { /// /// 注意:调用方必须在请求结束后通过 `record_result()` 释放 HalfOpen 名额, /// 否则会导致该 Provider 长时间无法进入探测状态。 - pub async fn allow_provider_request(&self, provider_id: &str, app_type: &str) -> bool { + pub async fn allow_provider_request(&self, provider_id: &str, app_type: &str) -> AllowResult { let circuit_key = format!("{app_type}:{provider_id}"); let breaker = self.get_or_create_circuit_breaker(&circuit_key).await; breaker.allow_request().await @@ -134,6 +134,7 @@ impl ProviderRouter { &self, provider_id: &str, app_type: &str, + used_half_open_permit: bool, success: bool, error_msg: Option, ) -> Result<(), AppError> { @@ -146,10 +147,10 @@ impl ProviderRouter { let breaker = self.get_or_create_circuit_breaker(&circuit_key).await; if success { - breaker.record_success().await; + breaker.record_success(used_half_open_permit).await; log::debug!("Provider {provider_id} request succeeded"); } else { - breaker.record_failure().await; + breaker.record_failure(used_half_open_permit).await; log::warn!( "Provider {} request failed: {}", provider_id, @@ -265,7 +266,7 @@ mod tests { // 测试创建熔断器 let breaker = router.get_or_create_circuit_breaker("claude:test").await; - assert!(breaker.allow_request().await); + assert!(breaker.allow_request().await.allowed); } #[tokio::test] @@ -296,7 +297,7 @@ mod tests { // 让 B 进入 Open 状态(failure_threshold=1) router - .record_result("b", "claude", false, Some("fail".to_string())) + .record_result("b", "claude", false, false, Some("fail".to_string())) .await .unwrap(); @@ -305,6 +306,6 @@ mod tests { assert_eq!(providers.len(), 2); // 如果 select_providers 错误地消耗了 HalfOpen 名额,这里会返回 false(被限流拒绝) - assert!(router.allow_provider_request("b", "claude").await); + assert!(router.allow_provider_request("b", "claude").await.allowed); } }