Files
cc-switch/src-tauri/src/proxy/forwarder.rs
Jason Young a268127f1f Fix/Resolve panic issues in proxy-related code (#560)
* fix(proxy): change default port from 5000 to 15721

Port 5000 conflicts with AirPlay Receiver on macOS 12+.
Also adds error handling for proxy toggle and i18n placeholder updates.

* fix(proxy): replace unwrap/expect with graceful error handling

- Handle HTTP client initialization failure with no_proxy fallback
- Fix potential panic on Unicode slicing in API key preview
- Add proper error handling for response body builder
- Handle edge case where SystemTime is before UNIX_EPOCH

* fix(proxy): handle UTF-8 char boundary when truncating request body log

Rust strings are UTF-8 encoded, slicing at a fixed byte index may cut
in the middle of a multi-byte character (e.g., Chinese, emoji), causing
a panic. Use is_char_boundary() to find the nearest safe cut point.

* fix(proxy): improve robustness and prevent panics

- Add reqwest socks feature to support SOCKS proxy environments
- Fix UTF-8 safety in masked_key/masked_access_token (use chars() instead of byte slicing)
- Fix UTF-8 boundary check in usage_script HTTP response truncation
- Add defensive checks for JSON operations in proxy service
- Remove verbose debug logs that could trigger panic-prone code paths
2026-01-09 13:09:19 +08:00

534 lines
20 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! 请求转发器
//!
//! 负责将请求转发到上游Provider支持故障转移
use super::{
body_filter::filter_private_params_with_whitelist,
error::*,
failover_switch::FailoverSwitchManager,
provider_router::ProviderRouter,
providers::{get_adapter, ProviderAdapter},
types::ProxyStatus,
ProxyError,
};
use crate::{app_config::AppType, provider::Provider};
use reqwest::{Client, Response};
use serde_json::Value;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
/// Headers 黑名单 - 不透传到上游的 Headers
///
/// 精简版黑名单,只过滤必须覆盖或可能导致问题的 header
/// 参考成功透传的请求,保留更多原始 header
///
/// 注意:客户端 IP 类x-forwarded-for, x-real-ip默认透传
const HEADER_BLACKLIST: &[&str] = &[
// 认证类(会被覆盖)
"authorization",
"x-api-key",
// 连接类(由 HTTP 客户端管理)
"host",
"content-length",
"transfer-encoding",
// 编码类(会被覆盖为 identity
"accept-encoding",
// 代理转发类(保留 x-forwarded-for 和 x-real-ip
"x-forwarded-host",
"x-forwarded-port",
"x-forwarded-proto",
"forwarded",
// CDN/云服务商特定头
"cf-connecting-ip",
"cf-ipcountry",
"cf-ray",
"cf-visitor",
"true-client-ip",
"fastly-client-ip",
"x-azure-clientip",
"x-azure-fdid",
"x-azure-ref",
"akamai-origin-hop",
"x-akamai-config-log-detail",
// 请求追踪类
"x-request-id",
"x-correlation-id",
"x-trace-id",
"x-amzn-trace-id",
"x-b3-traceid",
"x-b3-spanid",
"x-b3-parentspanid",
"x-b3-sampled",
"traceparent",
"tracestate",
// anthropic 特定头单独处理,避免重复
"anthropic-beta",
"anthropic-version",
// 客户端 IP 单独处理(默认透传)
"x-forwarded-for",
"x-real-ip",
];
pub struct ForwardResult {
pub response: Response,
pub provider: Provider,
}
pub struct ForwardError {
pub error: ProxyError,
pub provider: Option<Provider>,
}
pub struct RequestForwarder {
client: Option<Client>,
client_init_error: Option<String>,
/// 共享的 ProviderRouter持有熔断器状态
router: Arc<ProviderRouter>,
status: Arc<RwLock<ProxyStatus>>,
current_providers: Arc<RwLock<std::collections::HashMap<String, (String, String)>>>,
/// 故障转移切换管理器
failover_manager: Arc<FailoverSwitchManager>,
/// AppHandle用于发射事件和更新托盘
app_handle: Option<tauri::AppHandle>,
/// 请求开始时的"当前供应商 ID"(用于判断是否需要同步 UI/托盘)
current_provider_id_at_start: String,
}
impl RequestForwarder {
#[allow(clippy::too_many_arguments)]
pub fn new(
router: Arc<ProviderRouter>,
non_streaming_timeout: u64,
status: Arc<RwLock<ProxyStatus>>,
current_providers: Arc<RwLock<std::collections::HashMap<String, (String, String)>>>,
failover_manager: Arc<FailoverSwitchManager>,
app_handle: Option<tauri::AppHandle>,
current_provider_id_at_start: String,
_streaming_first_byte_timeout: u64,
_streaming_idle_timeout: u64,
) -> Self {
// 全局超时设置为 1800 秒30 分钟),确保业务层超时配置能正常工作
// 参考 Claude Code Hub 的 undici 全局超时设计
const GLOBAL_TIMEOUT_SECS: u64 = 1800;
let timeout_secs = if non_streaming_timeout > 0 {
non_streaming_timeout
} else {
GLOBAL_TIMEOUT_SECS
};
// 注意:这里不能用 expect/unwrap。
// release 配置为 panic=abort一旦 build 失败会导致整个应用闪退。
// 常见原因:用户环境变量里存在不合法/不支持的代理HTTP(S)_PROXY/ALL_PROXY 等)。
let (client, client_init_error) = match Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.build()
{
Ok(client) => (Some(client), None),
Err(e) => {
// 降级:忽略系统/环境代理,避免因代理配置问题导致整个应用崩溃
match Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.no_proxy()
.build()
{
Ok(client) => (Some(client), Some(e.to_string())),
Err(fallback_err) => (
None,
Some(format!(
"Failed to create HTTP client: {e}; no_proxy fallback failed: {fallback_err}"
)),
),
}
}
};
Self {
client,
client_init_error,
router,
status,
current_providers,
failover_manager,
app_handle,
current_provider_id_at_start,
}
}
/// 转发请求(带故障转移)
///
/// # Arguments
/// * `app_type` - 应用类型
/// * `endpoint` - API 端点
/// * `body` - 请求体
/// * `headers` - 请求头
/// * `providers` - 已选择的 Provider 列表(由 RequestContext 提供,避免重复调用 select_providers
pub async fn forward_with_retry(
&self,
app_type: &AppType,
endpoint: &str,
body: Value,
headers: axum::http::HeaderMap,
providers: Vec<Provider>,
) -> Result<ForwardResult, ForwardError> {
// 获取适配器
let adapter = get_adapter(app_type);
let app_type_str = app_type.as_str();
if providers.is_empty() {
return Err(ForwardError {
error: ProxyError::NoAvailableProvider,
provider: None,
});
}
let mut last_error = None;
let mut last_provider = None;
let mut attempted_providers = 0usize;
// 单 Provider 场景下跳过熔断器检查(故障转移关闭时)
let bypass_circuit_breaker = providers.len() == 1;
// 依次尝试每个供应商
for provider in providers.iter() {
// 发起请求前先获取熔断器放行许可HalfOpen 会占用探测名额)
// 单 Provider 场景下跳过此检查,避免熔断器阻塞所有请求
let (allowed, used_half_open_permit) = if bypass_circuit_breaker {
(true, false)
} else {
let permit = self
.router
.allow_provider_request(&provider.id, app_type_str)
.await;
(permit.allowed, permit.used_half_open_permit)
};
if !allowed {
continue;
}
attempted_providers += 1;
// 更新状态中的当前Provider信息
{
let mut status = self.status.write().await;
status.current_provider = Some(provider.name.clone());
status.current_provider_id = Some(provider.id.clone());
status.total_requests += 1;
status.last_request_at = Some(chrono::Utc::now().to_rfc3339());
}
// 转发请求(每个 Provider 只尝试一次,重试由客户端控制)
match self
.forward(provider, endpoint, &body, &headers, adapter.as_ref())
.await
{
Ok(response) => {
// 成功:记录成功并更新熔断器
let _ = self
.router
.record_result(
&provider.id,
app_type_str,
used_half_open_permit,
true,
None,
)
.await;
// 更新当前应用类型使用的 provider
{
let mut current_providers = self.current_providers.write().await;
current_providers.insert(
app_type_str.to_string(),
(provider.id.clone(), provider.name.clone()),
);
}
// 更新成功统计
{
let mut status = self.status.write().await;
status.success_requests += 1;
status.last_error = None;
let should_switch =
self.current_provider_id_at_start.as_str() != provider.id.as_str();
if should_switch {
status.failover_count += 1;
// 异步触发供应商切换,更新 UI/托盘,并把“当前供应商”同步为实际使用的 provider
let fm = self.failover_manager.clone();
let ah = self.app_handle.clone();
let pid = provider.id.clone();
let pname = provider.name.clone();
let at = app_type_str.to_string();
tokio::spawn(async move {
let _ = fm.try_switch(ah.as_ref(), &at, &pid, &pname).await;
});
}
// 重新计算成功率
if status.total_requests > 0 {
status.success_rate = (status.success_requests as f32
/ status.total_requests as f32)
* 100.0;
}
}
return Ok(ForwardResult {
response,
provider: provider.clone(),
});
}
Err(e) => {
// 失败:记录失败并更新熔断器
let _ = self
.router
.record_result(
&provider.id,
app_type_str,
used_half_open_permit,
false,
Some(e.to_string()),
)
.await;
// 分类错误
let category = self.categorize_proxy_error(&e);
match category {
ErrorCategory::Retryable => {
// 可重试:更新错误信息,继续尝试下一个供应商
{
let mut status = self.status.write().await;
status.last_error =
Some(format!("Provider {} 失败: {}", provider.name, e));
}
last_error = Some(e);
last_provider = Some(provider.clone());
// 继续尝试下一个供应商
continue;
}
ErrorCategory::NonRetryable | ErrorCategory::ClientAbort => {
// 不可重试:直接返回错误
{
let mut status = self.status.write().await;
status.failed_requests += 1;
status.last_error = Some(e.to_string());
if status.total_requests > 0 {
status.success_rate = (status.success_requests as f32
/ status.total_requests as f32)
* 100.0;
}
}
return Err(ForwardError {
error: e,
provider: Some(provider.clone()),
});
}
}
}
}
}
if attempted_providers == 0 {
// providers 列表非空但全部被熔断器拒绝典型HalfOpen 探测名额被占用)
{
let mut status = self.status.write().await;
status.failed_requests += 1;
status.last_error = Some("所有供应商暂时不可用(熔断器限制)".to_string());
if status.total_requests > 0 {
status.success_rate =
(status.success_requests as f32 / status.total_requests as f32) * 100.0;
}
}
return Err(ForwardError {
error: ProxyError::NoAvailableProvider,
provider: None,
});
}
// 所有供应商都失败了
{
let mut status = self.status.write().await;
status.failed_requests += 1;
status.last_error = Some("所有供应商都失败".to_string());
if status.total_requests > 0 {
status.success_rate =
(status.success_requests as f32 / status.total_requests as f32) * 100.0;
}
}
Err(ForwardError {
error: last_error.unwrap_or(ProxyError::MaxRetriesExceeded),
provider: last_provider,
})
}
/// 转发单个请求(使用适配器)
async fn forward(
&self,
provider: &Provider,
endpoint: &str,
body: &Value,
headers: &axum::http::HeaderMap,
adapter: &dyn ProviderAdapter,
) -> Result<Response, ProxyError> {
// 使用适配器提取 base_url
let base_url = adapter.extract_base_url(provider)?;
// 检查是否需要格式转换
let needs_transform = adapter.needs_transform(provider);
let effective_endpoint =
if needs_transform && adapter.name() == "Claude" && endpoint == "/v1/messages" {
"/v1/chat/completions"
} else {
endpoint
};
// 使用适配器构建 URL
let url = adapter.build_url(&base_url, effective_endpoint);
// 应用模型映射(独立于格式转换)
let (mapped_body, _original_model, _mapped_model) =
super::model_mapper::apply_model_mapping(body.clone(), provider);
// 转换请求体(如果需要)
let request_body = if needs_transform {
adapter.transform_request(mapped_body, provider)?
} else {
mapped_body
};
// 过滤私有参数(以 `_` 开头的字段),防止内部信息泄露到上游
// 默认使用空白名单,过滤所有 _ 前缀字段
let filtered_body = filter_private_params_with_whitelist(request_body, &[]);
// 构建请求
let client = self.client.as_ref().ok_or_else(|| {
ProxyError::ForwardFailed(
self.client_init_error
.clone()
.unwrap_or_else(|| "HTTP client is not initialized".to_string()),
)
})?;
let mut request = client.post(&url);
// 过滤黑名单 Headers保护隐私并避免冲突
for (key, value) in headers {
if HEADER_BLACKLIST
.iter()
.any(|h| key.as_str().eq_ignore_ascii_case(h))
{
continue;
}
request = request.header(key, value);
}
// 处理 anthropic-beta Header仅 Claude
// 关键:确保包含 claude-code-20250219 标记,这是上游服务验证请求来源的依据
// 如果客户端发送的 beta 标记中没有包含 claude-code-20250219需要补充
if adapter.name() == "Claude" {
const CLAUDE_CODE_BETA: &str = "claude-code-20250219";
let beta_value = if let Some(beta) = headers.get("anthropic-beta") {
if let Ok(beta_str) = beta.to_str() {
// 检查是否已包含 claude-code-20250219
if beta_str.contains(CLAUDE_CODE_BETA) {
beta_str.to_string()
} else {
// 补充 claude-code-20250219
format!("{CLAUDE_CODE_BETA},{beta_str}")
}
} else {
CLAUDE_CODE_BETA.to_string()
}
} else {
// 如果客户端没有发送,使用默认值
CLAUDE_CODE_BETA.to_string()
};
request = request.header("anthropic-beta", &beta_value);
}
// 客户端 IP 透传(默认开启)
if let Some(xff) = headers.get("x-forwarded-for") {
if let Ok(xff_str) = xff.to_str() {
request = request.header("x-forwarded-for", xff_str);
}
}
if let Some(real_ip) = headers.get("x-real-ip") {
if let Ok(real_ip_str) = real_ip.to_str() {
request = request.header("x-real-ip", real_ip_str);
}
}
// 禁用压缩,避免 gzip 流式响应解析错误
// 参考 CCH: undici 在连接提前关闭时会对不完整的 gzip 流抛出错误
request = request.header("accept-encoding", "identity");
// 使用适配器添加认证头
if let Some(auth) = adapter.extract_auth(provider) {
request = adapter.add_auth_headers(request, &auth);
}
// anthropic-version 统一处理(仅 Claude优先使用客户端的版本号否则使用默认值
// 注意:只设置一次,避免重复
if adapter.name() == "Claude" {
let version_str = headers
.get("anthropic-version")
.and_then(|v| v.to_str().ok())
.unwrap_or("2023-06-01");
request = request.header("anthropic-version", version_str);
}
// 发送请求
let response = request.json(&filtered_body).send().await.map_err(|e| {
if e.is_timeout() {
ProxyError::Timeout(format!("请求超时: {e}"))
} else if e.is_connect() {
ProxyError::ForwardFailed(format!("连接失败: {e}"))
} else {
ProxyError::ForwardFailed(e.to_string())
}
})?;
// 检查响应状态
let status = response.status();
if status.is_success() {
Ok(response)
} else {
let status_code = status.as_u16();
let body_text = response.text().await.ok();
Err(ProxyError::UpstreamError {
status: status_code,
body: body_text,
})
}
}
fn categorize_proxy_error(&self, error: &ProxyError) -> ErrorCategory {
match error {
// 网络和上游错误:都应该尝试下一个供应商
ProxyError::Timeout(_) => ErrorCategory::Retryable,
ProxyError::ForwardFailed(_) => ErrorCategory::Retryable,
ProxyError::ProviderUnhealthy(_) => ErrorCategory::Retryable,
// 上游 HTTP 错误:无论状态码如何,都尝试下一个供应商
// 原因:不同供应商有不同的限制和认证,一个供应商的 4xx 错误
// 不代表其他供应商也会失败
ProxyError::UpstreamError { .. } => ErrorCategory::Retryable,
// Provider 级配置/转换问题:换一个 Provider 可能就能成功
ProxyError::ConfigError(_) => ErrorCategory::Retryable,
ProxyError::TransformError(_) => ErrorCategory::Retryable,
ProxyError::AuthError(_) => ErrorCategory::Retryable,
ProxyError::StreamIdleTimeout(_) => ErrorCategory::Retryable,
// 无可用供应商:所有供应商都试过了,无法重试
ProxyError::NoAvailableProvider => ErrorCategory::NonRetryable,
// 其他错误(数据库/内部错误等):不是换供应商能解决的问题
_ => ErrorCategory::NonRetryable,
}
}
}