Files
cc-switch/src-tauri/src/proxy/forwarder.rs
2026-03-22 22:30:07 +08:00

1443 lines
60 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! 请求转发器
//!
//! 负责将请求转发到上游Provider支持故障转移
use super::{
body_filter::filter_private_params_with_whitelist,
error::*,
failover_switch::FailoverSwitchManager,
log_codes::fwd as log_fwd,
provider_router::ProviderRouter,
providers::{get_adapter, AuthInfo, AuthStrategy, ProviderAdapter, ProviderType},
thinking_budget_rectifier::{rectify_thinking_budget, should_rectify_thinking_budget},
thinking_rectifier::{
normalize_thinking_type, rectify_anthropic_request, should_rectify_thinking_signature,
},
types::{OptimizerConfig, ProxyStatus, RectifierConfig},
ProxyError,
};
use crate::commands::CopilotAuthState;
use crate::proxy::providers::copilot_auth::CopilotAuthManager;
use crate::{app_config::AppType, provider::Provider};
use reqwest::Response;
use serde_json::Value;
use std::sync::Arc;
use tauri::Manager;
use tokio::sync::RwLock;
/// Headers 黑名单 - 不透传到上游的 Headers
///
/// 精简版黑名单,只过滤必须覆盖或可能导致问题的 header
/// 参考成功透传的请求,保留更多原始 header
///
/// 注意:客户端 IP 类x-forwarded-for, x-real-ip默认透传
const HEADER_BLACKLIST: &[&str] = &[
// 认证类(会被覆盖)
"authorization",
"x-api-key",
"x-goog-api-key",
// 连接类(由 HTTP 客户端管理)
"host",
"content-length",
"transfer-encoding",
// 编码类(会被覆盖为 identity
"accept-encoding",
// 代理转发类(保留 x-forwarded-for 和 x-real-ip
"x-forwarded-host",
"x-forwarded-port",
"x-forwarded-proto",
"forwarded",
// CDN/云服务商特定头
"cf-connecting-ip",
"cf-ipcountry",
"cf-ray",
"cf-visitor",
"true-client-ip",
"fastly-client-ip",
"x-azure-clientip",
"x-azure-fdid",
"x-azure-ref",
"akamai-origin-hop",
"x-akamai-config-log-detail",
// 请求追踪类
"x-request-id",
"x-correlation-id",
"x-trace-id",
"x-amzn-trace-id",
"x-b3-traceid",
"x-b3-spanid",
"x-b3-parentspanid",
"x-b3-sampled",
"traceparent",
"tracestate",
// anthropic 特定头单独处理,避免重复
"anthropic-beta",
"anthropic-version",
// 客户端 IP 单独处理(默认透传)
"x-forwarded-for",
"x-real-ip",
];
pub struct ForwardResult {
pub response: Response,
pub provider: Provider,
}
pub struct ForwardError {
pub error: ProxyError,
pub provider: Option<Provider>,
}
pub struct RequestForwarder {
/// 共享的 ProviderRouter持有熔断器状态
router: Arc<ProviderRouter>,
status: Arc<RwLock<ProxyStatus>>,
current_providers: Arc<RwLock<std::collections::HashMap<String, (String, String)>>>,
/// 故障转移切换管理器
failover_manager: Arc<FailoverSwitchManager>,
/// AppHandle用于发射事件和更新托盘
app_handle: Option<tauri::AppHandle>,
/// 请求开始时的"当前供应商 ID"(用于判断是否需要同步 UI/托盘)
current_provider_id_at_start: String,
/// 整流器配置
rectifier_config: RectifierConfig,
/// 优化器配置
optimizer_config: OptimizerConfig,
/// 非流式请求超时(秒)
non_streaming_timeout: std::time::Duration,
}
impl RequestForwarder {
#[allow(clippy::too_many_arguments)]
pub fn new(
router: Arc<ProviderRouter>,
non_streaming_timeout: u64,
status: Arc<RwLock<ProxyStatus>>,
current_providers: Arc<RwLock<std::collections::HashMap<String, (String, String)>>>,
failover_manager: Arc<FailoverSwitchManager>,
app_handle: Option<tauri::AppHandle>,
current_provider_id_at_start: String,
_streaming_first_byte_timeout: u64,
_streaming_idle_timeout: u64,
rectifier_config: RectifierConfig,
optimizer_config: OptimizerConfig,
) -> Self {
Self {
router,
status,
current_providers,
failover_manager,
app_handle,
current_provider_id_at_start,
rectifier_config,
optimizer_config,
non_streaming_timeout: std::time::Duration::from_secs(non_streaming_timeout),
}
}
/// 转发请求(带故障转移)
///
/// # Arguments
/// * `app_type` - 应用类型
/// * `endpoint` - API 端点
/// * `body` - 请求体
/// * `headers` - 请求头
/// * `providers` - 已选择的 Provider 列表(由 RequestContext 提供,避免重复调用 select_providers
pub async fn forward_with_retry(
&self,
app_type: &AppType,
endpoint: &str,
body: Value,
headers: axum::http::HeaderMap,
providers: Vec<Provider>,
) -> Result<ForwardResult, ForwardError> {
// 获取适配器
let adapter = get_adapter(app_type);
let app_type_str = app_type.as_str();
if providers.is_empty() {
return Err(ForwardError {
error: ProxyError::NoAvailableProvider,
provider: None,
});
}
let mut last_error = None;
let mut last_provider = None;
let mut attempted_providers = 0usize;
// 整流器重试标记:确保整流最多触发一次
let mut rectifier_retried = false;
let mut budget_rectifier_retried = false;
// 单 Provider 场景下跳过熔断器检查(故障转移关闭时)
let bypass_circuit_breaker = providers.len() == 1;
// 依次尝试每个供应商
for provider in providers.iter() {
// 发起请求前先获取熔断器放行许可HalfOpen 会占用探测名额)
// 单 Provider 场景下跳过此检查,避免熔断器阻塞所有请求
let (allowed, used_half_open_permit) = if bypass_circuit_breaker {
(true, false)
} else {
let permit = self
.router
.allow_provider_request(&provider.id, app_type_str)
.await;
(permit.allowed, permit.used_half_open_permit)
};
if !allowed {
continue;
}
// PRE-SEND 优化器:每个 provider 独立决定是否优化
// clone body 以避免 Bedrock 优化字段泄漏到非 Bedrock providerfailover 场景)
let mut provider_body =
if self.optimizer_config.enabled && is_bedrock_provider(provider) {
let mut b = body.clone();
if self.optimizer_config.thinking_optimizer {
super::thinking_optimizer::optimize(&mut b, &self.optimizer_config);
}
if self.optimizer_config.cache_injection {
super::cache_injector::inject(&mut b, &self.optimizer_config);
}
b
} else {
body.clone()
};
attempted_providers += 1;
// 更新状态中的当前Provider信息
{
let mut status = self.status.write().await;
status.current_provider = Some(provider.name.clone());
status.current_provider_id = Some(provider.id.clone());
status.total_requests += 1;
status.last_request_at = Some(chrono::Utc::now().to_rfc3339());
}
// 转发请求(每个 Provider 只尝试一次,重试由客户端控制)
match self
.forward(
provider,
endpoint,
&provider_body,
&headers,
adapter.as_ref(),
)
.await
{
Ok(response) => {
// 成功:记录成功并更新熔断器
let _ = self
.router
.record_result(
&provider.id,
app_type_str,
used_half_open_permit,
true,
None,
)
.await;
// 更新当前应用类型使用的 provider
{
let mut current_providers = self.current_providers.write().await;
current_providers.insert(
app_type_str.to_string(),
(provider.id.clone(), provider.name.clone()),
);
}
// 更新成功统计
{
let mut status = self.status.write().await;
status.success_requests += 1;
status.last_error = None;
let should_switch =
self.current_provider_id_at_start.as_str() != provider.id.as_str();
if should_switch {
status.failover_count += 1;
// 异步触发供应商切换,更新 UI/托盘,并把“当前供应商”同步为实际使用的 provider
let fm = self.failover_manager.clone();
let ah = self.app_handle.clone();
let pid = provider.id.clone();
let pname = provider.name.clone();
let at = app_type_str.to_string();
tokio::spawn(async move {
let _ = fm.try_switch(ah.as_ref(), &at, &pid, &pname).await;
});
}
// 重新计算成功率
if status.total_requests > 0 {
status.success_rate = (status.success_requests as f32
/ status.total_requests as f32)
* 100.0;
}
}
return Ok(ForwardResult {
response,
provider: provider.clone(),
});
}
Err(e) => {
// 检测是否需要触发整流器(仅 Claude/ClaudeAuth 供应商)
let provider_type = ProviderType::from_app_type_and_config(app_type, provider);
let is_anthropic_provider = matches!(
provider_type,
ProviderType::Claude | ProviderType::ClaudeAuth
);
let mut signature_rectifier_non_retryable_client_error = false;
if is_anthropic_provider {
let error_message = extract_error_message(&e);
if should_rectify_thinking_signature(
error_message.as_deref(),
&self.rectifier_config,
) {
// 已经重试过:直接返回错误(不可重试客户端错误)
if rectifier_retried {
log::warn!("[{app_type_str}] [RECT-005] 整流器已触发过,不再重试");
// 释放 HalfOpen permit不记录熔断器这是客户端兼容性问题
self.router
.release_permit_neutral(
&provider.id,
app_type_str,
used_half_open_permit,
)
.await;
let mut status = self.status.write().await;
status.failed_requests += 1;
status.last_error = Some(e.to_string());
if status.total_requests > 0 {
status.success_rate = (status.success_requests as f32
/ status.total_requests as f32)
* 100.0;
}
return Err(ForwardError {
error: e,
provider: Some(provider.clone()),
});
}
// 首次触发:整流请求体
let rectified = rectify_anthropic_request(&mut provider_body);
// 整流未生效:继续尝试 budget 整流路径,避免误判后短路
if !rectified.applied {
log::warn!(
"[{app_type_str}] [RECT-006] thinking 签名整流器触发但无可整流内容,继续检查 budget若 budget 也未命中则按客户端错误返回"
);
signature_rectifier_non_retryable_client_error = true;
} else {
log::info!(
"[{}] [RECT-001] thinking 签名整流器触发, 移除 {} thinking blocks, {} redacted_thinking blocks, {} signature fields",
app_type_str,
rectified.removed_thinking_blocks,
rectified.removed_redacted_thinking_blocks,
rectified.removed_signature_fields
);
// 标记已重试(当前逻辑下重试后必定 return保留标记以备将来扩展
let _ = std::mem::replace(&mut rectifier_retried, true);
// 使用同一供应商重试(不计入熔断器)
match self
.forward(
provider,
endpoint,
&provider_body,
&headers,
adapter.as_ref(),
)
.await
{
Ok(response) => {
log::info!("[{app_type_str}] [RECT-002] 整流重试成功");
// 记录成功
let _ = self
.router
.record_result(
&provider.id,
app_type_str,
used_half_open_permit,
true,
None,
)
.await;
// 更新当前应用类型使用的 provider
{
let mut current_providers =
self.current_providers.write().await;
current_providers.insert(
app_type_str.to_string(),
(provider.id.clone(), provider.name.clone()),
);
}
// 更新成功统计
{
let mut status = self.status.write().await;
status.success_requests += 1;
status.last_error = None;
let should_switch =
self.current_provider_id_at_start.as_str()
!= provider.id.as_str();
if should_switch {
status.failover_count += 1;
// 异步触发供应商切换,更新 UI/托盘
let fm = self.failover_manager.clone();
let ah = self.app_handle.clone();
let pid = provider.id.clone();
let pname = provider.name.clone();
let at = app_type_str.to_string();
tokio::spawn(async move {
let _ = fm
.try_switch(ah.as_ref(), &at, &pid, &pname)
.await;
});
}
if status.total_requests > 0 {
status.success_rate = (status.success_requests
as f32
/ status.total_requests as f32)
* 100.0;
}
}
return Ok(ForwardResult {
response,
provider: provider.clone(),
});
}
Err(retry_err) => {
// 整流重试仍失败:区分错误类型决定是否记录熔断器
log::warn!(
"[{app_type_str}] [RECT-003] 整流重试仍失败: {retry_err}"
);
// 区分错误类型Provider 问题记录失败,客户端问题仅释放 permit
let is_provider_error = match &retry_err {
ProxyError::Timeout(_)
| ProxyError::ForwardFailed(_) => true,
ProxyError::UpstreamError { status, .. } => {
*status >= 500
}
_ => false,
};
if is_provider_error {
// Provider 问题:记录失败到熔断器
let _ = self
.router
.record_result(
&provider.id,
app_type_str,
used_half_open_permit,
false,
Some(retry_err.to_string()),
)
.await;
} else {
// 客户端问题:仅释放 permit不记录熔断器
self.router
.release_permit_neutral(
&provider.id,
app_type_str,
used_half_open_permit,
)
.await;
}
let mut status = self.status.write().await;
status.failed_requests += 1;
status.last_error = Some(retry_err.to_string());
if status.total_requests > 0 {
status.success_rate = (status.success_requests as f32
/ status.total_requests as f32)
* 100.0;
}
return Err(ForwardError {
error: retry_err,
provider: Some(provider.clone()),
});
}
}
}
}
}
// 检测是否需要触发 budget 整流器(仅 Claude/ClaudeAuth 供应商)
if is_anthropic_provider {
let error_message = extract_error_message(&e);
if should_rectify_thinking_budget(
error_message.as_deref(),
&self.rectifier_config,
) {
// 已经重试过:直接返回错误(不可重试客户端错误)
if budget_rectifier_retried {
log::warn!(
"[{app_type_str}] [RECT-013] budget 整流器已触发过,不再重试"
);
self.router
.release_permit_neutral(
&provider.id,
app_type_str,
used_half_open_permit,
)
.await;
let mut status = self.status.write().await;
status.failed_requests += 1;
status.last_error = Some(e.to_string());
if status.total_requests > 0 {
status.success_rate = (status.success_requests as f32
/ status.total_requests as f32)
* 100.0;
}
return Err(ForwardError {
error: e,
provider: Some(provider.clone()),
});
}
let budget_rectified = rectify_thinking_budget(&mut provider_body);
if !budget_rectified.applied {
log::warn!(
"[{app_type_str}] [RECT-014] budget 整流器触发但无可整流内容,不做无意义重试"
);
self.router
.release_permit_neutral(
&provider.id,
app_type_str,
used_half_open_permit,
)
.await;
let mut status = self.status.write().await;
status.failed_requests += 1;
status.last_error = Some(e.to_string());
if status.total_requests > 0 {
status.success_rate = (status.success_requests as f32
/ status.total_requests as f32)
* 100.0;
}
return Err(ForwardError {
error: e,
provider: Some(provider.clone()),
});
}
log::info!(
"[{}] [RECT-010] thinking budget 整流器触发, before={:?}, after={:?}",
app_type_str,
budget_rectified.before,
budget_rectified.after
);
let _ = std::mem::replace(&mut budget_rectifier_retried, true);
// 使用同一供应商重试(不计入熔断器)
match self
.forward(
provider,
endpoint,
&provider_body,
&headers,
adapter.as_ref(),
)
.await
{
Ok(response) => {
log::info!("[{app_type_str}] [RECT-011] budget 整流重试成功");
let _ = self
.router
.record_result(
&provider.id,
app_type_str,
used_half_open_permit,
true,
None,
)
.await;
{
let mut current_providers =
self.current_providers.write().await;
current_providers.insert(
app_type_str.to_string(),
(provider.id.clone(), provider.name.clone()),
);
}
{
let mut status = self.status.write().await;
status.success_requests += 1;
status.last_error = None;
let should_switch =
self.current_provider_id_at_start.as_str()
!= provider.id.as_str();
if should_switch {
status.failover_count += 1;
let fm = self.failover_manager.clone();
let ah = self.app_handle.clone();
let pid = provider.id.clone();
let pname = provider.name.clone();
let at = app_type_str.to_string();
tokio::spawn(async move {
let _ = fm
.try_switch(ah.as_ref(), &at, &pid, &pname)
.await;
});
}
if status.total_requests > 0 {
status.success_rate = (status.success_requests as f32
/ status.total_requests as f32)
* 100.0;
}
}
return Ok(ForwardResult {
response,
provider: provider.clone(),
});
}
Err(retry_err) => {
log::warn!(
"[{app_type_str}] [RECT-012] budget 整流重试仍失败: {retry_err}"
);
let is_provider_error = match &retry_err {
ProxyError::Timeout(_) | ProxyError::ForwardFailed(_) => {
true
}
ProxyError::UpstreamError { status, .. } => *status >= 500,
_ => false,
};
if is_provider_error {
let _ = self
.router
.record_result(
&provider.id,
app_type_str,
used_half_open_permit,
false,
Some(retry_err.to_string()),
)
.await;
} else {
self.router
.release_permit_neutral(
&provider.id,
app_type_str,
used_half_open_permit,
)
.await;
}
let mut status = self.status.write().await;
status.failed_requests += 1;
status.last_error = Some(retry_err.to_string());
if status.total_requests > 0 {
status.success_rate = (status.success_requests as f32
/ status.total_requests as f32)
* 100.0;
}
return Err(ForwardError {
error: retry_err,
provider: Some(provider.clone()),
});
}
}
}
}
if signature_rectifier_non_retryable_client_error {
self.router
.release_permit_neutral(
&provider.id,
app_type_str,
used_half_open_permit,
)
.await;
let mut status = self.status.write().await;
status.failed_requests += 1;
status.last_error = Some(e.to_string());
if status.total_requests > 0 {
status.success_rate = (status.success_requests as f32
/ status.total_requests as f32)
* 100.0;
}
return Err(ForwardError {
error: e,
provider: Some(provider.clone()),
});
}
// 失败:记录失败并更新熔断器
let _ = self
.router
.record_result(
&provider.id,
app_type_str,
used_half_open_permit,
false,
Some(e.to_string()),
)
.await;
// 分类错误
let category = self.categorize_proxy_error(&e);
match category {
ErrorCategory::Retryable => {
// 可重试:更新错误信息,继续尝试下一个供应商
{
let mut status = self.status.write().await;
status.last_error =
Some(format!("Provider {} 失败: {}", provider.name, e));
}
let (log_code, log_message) = build_retryable_failure_log(
&provider.name,
attempted_providers,
providers.len(),
&e,
);
log::warn!("[{app_type_str}] [{log_code}] {log_message}");
last_error = Some(e);
last_provider = Some(provider.clone());
// 继续尝试下一个供应商
continue;
}
ErrorCategory::NonRetryable | ErrorCategory::ClientAbort => {
// 不可重试:直接返回错误
{
let mut status = self.status.write().await;
status.failed_requests += 1;
status.last_error = Some(e.to_string());
if status.total_requests > 0 {
status.success_rate = (status.success_requests as f32
/ status.total_requests as f32)
* 100.0;
}
}
return Err(ForwardError {
error: e,
provider: Some(provider.clone()),
});
}
}
}
}
}
if attempted_providers == 0 {
// providers 列表非空但全部被熔断器拒绝典型HalfOpen 探测名额被占用)
{
let mut status = self.status.write().await;
status.failed_requests += 1;
status.last_error = Some("所有供应商暂时不可用(熔断器限制)".to_string());
if status.total_requests > 0 {
status.success_rate =
(status.success_requests as f32 / status.total_requests as f32) * 100.0;
}
}
return Err(ForwardError {
error: ProxyError::NoAvailableProvider,
provider: None,
});
}
// 所有供应商都失败了
{
let mut status = self.status.write().await;
status.failed_requests += 1;
status.last_error = Some("所有供应商都失败".to_string());
if status.total_requests > 0 {
status.success_rate =
(status.success_requests as f32 / status.total_requests as f32) * 100.0;
}
}
if let Some((log_code, log_message)) =
build_terminal_failure_log(attempted_providers, providers.len(), last_error.as_ref())
{
log::warn!("[{app_type_str}] [{log_code}] {log_message}");
}
Err(ForwardError {
error: last_error.unwrap_or(ProxyError::MaxRetriesExceeded),
provider: last_provider,
})
}
/// 转发单个请求(使用适配器)
async fn forward(
&self,
provider: &Provider,
endpoint: &str,
body: &Value,
headers: &axum::http::HeaderMap,
adapter: &dyn ProviderAdapter,
) -> Result<Response, ProxyError> {
// 使用适配器提取 base_url
let base_url = adapter.extract_base_url(provider)?;
// 检查是否需要格式转换
let needs_transform = adapter.needs_transform(provider);
let is_full_url = provider
.meta
.as_ref()
.and_then(|meta| meta.is_full_url)
.unwrap_or(false);
// 确定有效端点
// GitHub Copilot API 使用 /chat/completions无 /v1 前缀)
let is_copilot = provider
.meta
.as_ref()
.and_then(|m| m.provider_type.as_deref())
== Some("github_copilot")
|| base_url.contains("githubcopilot.com");
let (effective_endpoint, passthrough_query) =
if needs_transform && adapter.name() == "Claude" {
let api_format = super::providers::get_claude_api_format(provider);
rewrite_claude_transform_endpoint(endpoint, api_format, is_copilot)
} else {
(
endpoint.to_string(),
split_endpoint_and_query(endpoint)
.1
.map(ToString::to_string),
)
};
let url = if is_full_url {
append_query_to_full_url(&base_url, passthrough_query.as_deref())
} else {
adapter.build_url(&base_url, &effective_endpoint)
};
// 应用模型映射(独立于格式转换)
let (mapped_body, _original_model, _mapped_model) =
super::model_mapper::apply_model_mapping(body.clone(), provider);
// 与 CCH 对齐:请求前不做 thinking 主动改写(仅保留兼容入口)
let mapped_body = normalize_thinking_type(mapped_body);
// 转换请求体(如果需要)
let request_body = if needs_transform {
adapter.transform_request(mapped_body, provider)?
} else {
mapped_body
};
// 过滤私有参数(以 `_` 开头的字段),防止内部信息泄露到上游
// 默认使用空白名单,过滤所有 _ 前缀字段
let filtered_body = filter_private_params_with_whitelist(request_body, &[]);
// 获取 HTTP 客户端:优先使用供应商单独代理配置,否则使用全局客户端
let proxy_config = provider.meta.as_ref().and_then(|m| m.proxy_config.as_ref());
let client = super::http_client::get_for_provider(proxy_config);
let mut request = client.post(&url);
// 只有当 timeout > 0 时才设置请求超时
// Duration::ZERO 在 reqwest 中表示"立刻超时"而不是"禁用超时"
// 故障转移关闭时会传入 0此时应该使用 client 的默认超时600秒
if !self.non_streaming_timeout.is_zero() {
request = request.timeout(self.non_streaming_timeout);
}
// 过滤黑名单 Headers保护隐私并避免冲突
for (key, value) in headers {
let key_str = key.as_str();
if HEADER_BLACKLIST
.iter()
.any(|h| key_str.eq_ignore_ascii_case(h))
{
continue;
}
// Copilot 请求:过滤会由 add_auth_headers 注入的固定指纹头,
// 防止客户端原始头与注入头重复reqwest header() 是追加语义)
if is_copilot
&& (key_str.eq_ignore_ascii_case("user-agent")
|| key_str.eq_ignore_ascii_case("editor-version")
|| key_str.eq_ignore_ascii_case("editor-plugin-version")
|| key_str.eq_ignore_ascii_case("copilot-integration-id")
|| key_str.eq_ignore_ascii_case("x-github-api-version")
|| key_str.eq_ignore_ascii_case("openai-intent"))
{
continue;
}
request = request.header(key, value);
}
// 处理 anthropic-beta Header仅 Claude
// 关键:确保包含 claude-code-20250219 标记,这是上游服务验证请求来源的依据
// 如果客户端发送的 beta 标记中没有包含 claude-code-20250219需要补充
if adapter.name() == "Claude" {
const CLAUDE_CODE_BETA: &str = "claude-code-20250219";
let beta_value = if let Some(beta) = headers.get("anthropic-beta") {
if let Ok(beta_str) = beta.to_str() {
// 检查是否已包含 claude-code-20250219
if beta_str.contains(CLAUDE_CODE_BETA) {
beta_str.to_string()
} else {
// 补充 claude-code-20250219
format!("{CLAUDE_CODE_BETA},{beta_str}")
}
} else {
CLAUDE_CODE_BETA.to_string()
}
} else {
// 如果客户端没有发送,使用默认值
CLAUDE_CODE_BETA.to_string()
};
request = request.header("anthropic-beta", &beta_value);
}
// 客户端 IP 透传(默认开启)
if let Some(xff) = headers.get("x-forwarded-for") {
if let Ok(xff_str) = xff.to_str() {
request = request.header("x-forwarded-for", xff_str);
}
}
if let Some(real_ip) = headers.get("x-real-ip") {
if let Ok(real_ip_str) = real_ip.to_str() {
request = request.header("x-real-ip", real_ip_str);
}
}
// 流式请求保守禁用压缩,避免上游压缩 SSE 在连接中断时触发解压错误。
// 非流式请求不显式设置 Accept-Encoding让 reqwest 自动协商压缩并透明解压。
if should_force_identity_encoding(&effective_endpoint, &filtered_body, headers) {
request = request.header("accept-encoding", "identity");
}
// 使用适配器添加认证头
if let Some(mut auth) = adapter.extract_auth(provider) {
// GitHub Copilot 特殊处理:从 CopilotAuthManager 获取真实 token
if auth.strategy == AuthStrategy::GitHubCopilot {
if let Some(app_handle) = &self.app_handle {
let copilot_state = app_handle.state::<CopilotAuthState>();
let copilot_auth: tokio::sync::RwLockReadGuard<'_, CopilotAuthManager> =
copilot_state.0.read().await;
// 从 provider.meta 获取关联的 GitHub 账号 ID多账号支持
let account_id = provider
.meta
.as_ref()
.and_then(|m| m.managed_account_id_for("github_copilot"));
// 根据账号 ID 获取对应 token向后兼容无账号 ID 时使用第一个账号)
let token_result = match &account_id {
Some(id) => {
log::debug!("[Copilot] 使用指定账号 {id} 获取 token");
copilot_auth.get_valid_token_for_account(id).await
}
None => {
log::debug!("[Copilot] 使用默认账号获取 token");
copilot_auth.get_valid_token().await
}
};
match token_result {
Ok(token) => {
auth = AuthInfo::new(token, AuthStrategy::GitHubCopilot);
log::debug!(
"[Copilot] 成功获取 Copilot token (account={})",
account_id.as_deref().unwrap_or("default")
);
}
Err(e) => {
log::error!(
"[Copilot] 获取 Copilot token 失败 (account={}): {e}",
account_id.as_deref().unwrap_or("default")
);
return Err(ProxyError::AuthError(format!(
"GitHub Copilot 认证失败: {e}"
)));
}
}
} else {
log::error!("[Copilot] AppHandle 不可用");
return Err(ProxyError::AuthError(
"GitHub Copilot 认证不可用(无 AppHandle".to_string(),
));
}
}
request = adapter.add_auth_headers(request, &auth);
}
// anthropic-version 统一处理(仅 Claude优先使用客户端的版本号否则使用默认值
// 注意:只设置一次,避免重复
if adapter.name() == "Claude" {
let version_str = headers
.get("anthropic-version")
.and_then(|v| v.to_str().ok())
.unwrap_or("2023-06-01");
request = request.header("anthropic-version", version_str);
}
// 输出请求信息日志
let tag = adapter.name();
let request_model = filtered_body
.get("model")
.and_then(|v| v.as_str())
.unwrap_or("<none>");
log::info!("[{tag}] >>> 请求 URL: {url} (model={request_model})");
if let Ok(body_str) = serde_json::to_string(&filtered_body) {
log::debug!(
"[{tag}] >>> 请求体内容 ({}字节): {}",
body_str.len(),
body_str
);
}
// 发送请求
let response = request.json(&filtered_body).send().await.map_err(|e| {
if e.is_timeout() {
ProxyError::Timeout(format!("请求超时: {e}"))
} else if e.is_connect() {
ProxyError::ForwardFailed(format!("连接失败: {e}"))
} else {
ProxyError::ForwardFailed(e.to_string())
}
})?;
// 检查响应状态
let status = response.status();
if status.is_success() {
Ok(response)
} else {
let status_code = status.as_u16();
let body_text = response.text().await.ok();
Err(ProxyError::UpstreamError {
status: status_code,
body: body_text,
})
}
}
fn categorize_proxy_error(&self, error: &ProxyError) -> ErrorCategory {
match error {
// 网络和上游错误:都应该尝试下一个供应商
ProxyError::Timeout(_) => ErrorCategory::Retryable,
ProxyError::ForwardFailed(_) => ErrorCategory::Retryable,
ProxyError::ProviderUnhealthy(_) => ErrorCategory::Retryable,
// 上游 HTTP 错误:无论状态码如何,都尝试下一个供应商
// 原因:不同供应商有不同的限制和认证,一个供应商的 4xx 错误
// 不代表其他供应商也会失败
ProxyError::UpstreamError { .. } => ErrorCategory::Retryable,
// Provider 级配置/转换问题:换一个 Provider 可能就能成功
ProxyError::ConfigError(_) => ErrorCategory::Retryable,
ProxyError::TransformError(_) => ErrorCategory::Retryable,
ProxyError::AuthError(_) => ErrorCategory::Retryable,
ProxyError::StreamIdleTimeout(_) => ErrorCategory::Retryable,
// 无可用供应商:所有供应商都试过了,无法重试
ProxyError::NoAvailableProvider => ErrorCategory::NonRetryable,
// 其他错误(数据库/内部错误等):不是换供应商能解决的问题
_ => ErrorCategory::NonRetryable,
}
}
}
/// 从 ProxyError 中提取错误消息
fn extract_error_message(error: &ProxyError) -> Option<String> {
match error {
ProxyError::UpstreamError { body, .. } => body.clone(),
_ => Some(error.to_string()),
}
}
/// 检测 Provider 是否为 Bedrock通过 CLAUDE_CODE_USE_BEDROCK 环境变量判断)
fn is_bedrock_provider(provider: &Provider) -> bool {
provider
.settings_config
.get("env")
.and_then(|e| e.get("CLAUDE_CODE_USE_BEDROCK"))
.and_then(|v| v.as_str())
.map(|v| v == "1")
.unwrap_or(false)
}
fn build_retryable_failure_log(
provider_name: &str,
attempted_providers: usize,
total_providers: usize,
error: &ProxyError,
) -> (&'static str, String) {
let error_summary = summarize_proxy_error(error);
if total_providers <= 1 {
(
log_fwd::SINGLE_PROVIDER_FAILED,
format!("Provider {provider_name} 请求失败: {error_summary}"),
)
} else {
(
log_fwd::PROVIDER_FAILED_RETRY,
format!(
"Provider {provider_name} 失败,继续尝试下一个 ({attempted_providers}/{total_providers}): {error_summary}"
),
)
}
}
fn build_terminal_failure_log(
attempted_providers: usize,
total_providers: usize,
last_error: Option<&ProxyError>,
) -> Option<(&'static str, String)> {
if total_providers <= 1 {
return None;
}
let error_summary = last_error
.map(summarize_proxy_error)
.unwrap_or_else(|| "未知错误".to_string());
Some((
log_fwd::ALL_PROVIDERS_FAILED,
format!(
"已尝试 {attempted_providers}/{total_providers} 个 Provider均失败。最后错误: {error_summary}"
),
))
}
fn summarize_proxy_error(error: &ProxyError) -> String {
match error {
ProxyError::UpstreamError { status, body } => {
let body_summary = body
.as_deref()
.map(summarize_upstream_body)
.filter(|summary| !summary.is_empty());
match body_summary {
Some(summary) => format!("上游 HTTP {status}: {summary}"),
None => format!("上游 HTTP {status}"),
}
}
ProxyError::Timeout(message) => {
format!("请求超时: {}", summarize_text_for_log(message, 180))
}
ProxyError::ForwardFailed(message) => {
format!("请求转发失败: {}", summarize_text_for_log(message, 180))
}
ProxyError::TransformError(message) => {
format!("响应转换失败: {}", summarize_text_for_log(message, 180))
}
ProxyError::ConfigError(message) => {
format!("配置错误: {}", summarize_text_for_log(message, 180))
}
ProxyError::AuthError(message) => {
format!("认证失败: {}", summarize_text_for_log(message, 180))
}
_ => summarize_text_for_log(&error.to_string(), 180),
}
}
fn summarize_upstream_body(body: &str) -> String {
if let Ok(json_body) = serde_json::from_str::<Value>(body) {
if let Some(message) = extract_json_error_message(&json_body) {
return summarize_text_for_log(&message, 180);
}
if let Ok(compact_json) = serde_json::to_string(&json_body) {
return summarize_text_for_log(&compact_json, 180);
}
}
summarize_text_for_log(body, 180)
}
fn extract_json_error_message(body: &Value) -> Option<String> {
let candidates = [
body.pointer("/error/message"),
body.pointer("/message"),
body.pointer("/detail"),
body.pointer("/error"),
];
candidates
.into_iter()
.flatten()
.find_map(|value| value.as_str().map(ToString::to_string))
}
fn split_endpoint_and_query(endpoint: &str) -> (&str, Option<&str>) {
endpoint
.split_once('?')
.map_or((endpoint, None), |(path, query)| (path, Some(query)))
}
fn strip_beta_query(query: Option<&str>) -> Option<String> {
let filtered = query.map(|query| {
query
.split('&')
.filter(|pair| !pair.is_empty() && !pair.starts_with("beta="))
.collect::<Vec<_>>()
.join("&")
});
match filtered.as_deref() {
Some("") | None => None,
Some(_) => filtered,
}
}
fn is_claude_messages_path(path: &str) -> bool {
matches!(path, "/v1/messages" | "/claude/v1/messages")
}
fn rewrite_claude_transform_endpoint(
endpoint: &str,
api_format: &str,
is_copilot: bool,
) -> (String, Option<String>) {
let (path, query) = split_endpoint_and_query(endpoint);
let passthrough_query = if is_claude_messages_path(path) {
strip_beta_query(query)
} else {
query.map(ToString::to_string)
};
if !is_claude_messages_path(path) {
return (endpoint.to_string(), passthrough_query);
}
let target_path = if is_copilot {
"/chat/completions"
} else if api_format == "openai_responses" {
"/v1/responses"
} else {
"/v1/chat/completions"
};
let rewritten = match passthrough_query.as_deref() {
Some(query) if !query.is_empty() => format!("{target_path}?{query}"),
_ => target_path.to_string(),
};
(rewritten, passthrough_query)
}
fn append_query_to_full_url(base_url: &str, query: Option<&str>) -> String {
match query {
Some(query) if !query.is_empty() => {
if base_url.contains('?') {
format!("{base_url}&{query}")
} else {
format!("{base_url}?{query}")
}
}
_ => base_url.to_string(),
}
}
fn should_force_identity_encoding(
endpoint: &str,
body: &Value,
headers: &axum::http::HeaderMap,
) -> bool {
if body
.get("stream")
.and_then(|value| value.as_bool())
.unwrap_or(false)
{
return true;
}
if endpoint.contains("streamGenerateContent") || endpoint.contains("alt=sse") {
return true;
}
headers
.get(axum::http::header::ACCEPT)
.and_then(|value| value.to_str().ok())
.map(|accept| accept.contains("text/event-stream"))
.unwrap_or(false)
}
fn summarize_text_for_log(text: &str, max_chars: usize) -> String {
let normalized = text.split_whitespace().collect::<Vec<_>>().join(" ");
let trimmed = normalized.trim();
if trimmed.chars().count() <= max_chars {
return trimmed.to_string();
}
let truncated: String = trimmed.chars().take(max_chars).collect();
let truncated = truncated.trim_end();
format!("{truncated}...")
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::{header::ACCEPT, HeaderMap, HeaderValue};
use serde_json::json;
#[test]
fn single_provider_retryable_log_uses_single_provider_code() {
let error = ProxyError::UpstreamError {
status: 429,
body: Some(r#"{"error":{"message":"rate limit exceeded"}}"#.to_string()),
};
let (code, message) = build_retryable_failure_log("PackyCode-response", 1, 1, &error);
assert_eq!(code, log_fwd::SINGLE_PROVIDER_FAILED);
assert!(message.contains("Provider PackyCode-response 请求失败"));
assert!(message.contains("上游 HTTP 429"));
assert!(message.contains("rate limit exceeded"));
assert!(!message.contains("切换下一个"));
}
#[test]
fn multi_provider_retryable_log_keeps_failover_wording() {
let error = ProxyError::Timeout("upstream timed out after 30s".to_string());
let (code, message) = build_retryable_failure_log("primary", 1, 3, &error);
assert_eq!(code, log_fwd::PROVIDER_FAILED_RETRY);
assert!(message.contains("继续尝试下一个 (1/3)"));
assert!(message.contains("请求超时"));
}
#[test]
fn single_provider_has_no_terminal_all_failed_log() {
assert!(build_terminal_failure_log(1, 1, None).is_none());
}
#[test]
fn multi_provider_terminal_log_contains_last_error_summary() {
let error = ProxyError::ForwardFailed("connection reset by peer".to_string());
let (code, message) =
build_terminal_failure_log(2, 2, Some(&error)).expect("expected terminal log");
assert_eq!(code, log_fwd::ALL_PROVIDERS_FAILED);
assert!(message.contains("已尝试 2/2 个 Provider均失败"));
assert!(message.contains("connection reset by peer"));
}
#[test]
fn summarize_upstream_body_prefers_json_message() {
let body = json!({
"error": {
"message": "invalid_request_error: unsupported field"
},
"request_id": "req_123"
});
let summary = summarize_upstream_body(&body.to_string());
assert_eq!(summary, "invalid_request_error: unsupported field");
}
#[test]
fn summarize_text_for_log_collapses_whitespace_and_truncates() {
let summary = summarize_text_for_log("line1\n\n line2 line3", 12);
assert_eq!(summary, "line1 line2...");
}
#[test]
fn rewrite_claude_transform_endpoint_strips_beta_for_chat_completions() {
let (endpoint, passthrough_query) = rewrite_claude_transform_endpoint(
"/v1/messages?beta=true&foo=bar",
"openai_chat",
false,
);
assert_eq!(endpoint, "/v1/chat/completions?foo=bar");
assert_eq!(passthrough_query.as_deref(), Some("foo=bar"));
}
#[test]
fn rewrite_claude_transform_endpoint_strips_beta_for_responses() {
let (endpoint, passthrough_query) = rewrite_claude_transform_endpoint(
"/claude/v1/messages?beta=true&x-id=1",
"openai_responses",
false,
);
assert_eq!(endpoint, "/v1/responses?x-id=1");
assert_eq!(passthrough_query.as_deref(), Some("x-id=1"));
}
#[test]
fn rewrite_claude_transform_endpoint_uses_copilot_path() {
let (endpoint, passthrough_query) =
rewrite_claude_transform_endpoint("/v1/messages?beta=true&x-id=1", "anthropic", true);
assert_eq!(endpoint, "/chat/completions?x-id=1");
assert_eq!(passthrough_query.as_deref(), Some("x-id=1"));
}
#[test]
fn append_query_to_full_url_preserves_existing_query_string() {
let url = append_query_to_full_url("https://relay.example/api?foo=bar", Some("x-id=1"));
assert_eq!(url, "https://relay.example/api?foo=bar&x-id=1");
}
#[test]
fn force_identity_for_stream_flag_requests() {
let headers = HeaderMap::new();
assert!(should_force_identity_encoding(
"/v1/responses",
&json!({ "stream": true }),
&headers
));
}
#[test]
fn force_identity_for_gemini_stream_endpoints() {
let headers = HeaderMap::new();
assert!(should_force_identity_encoding(
"/v1beta/models/gemini-2.5-pro:streamGenerateContent?alt=sse",
&json!({ "model": "gemini-2.5-pro" }),
&headers
));
}
#[test]
fn force_identity_for_sse_accept_header() {
let mut headers = HeaderMap::new();
headers.insert(ACCEPT, HeaderValue::from_static("text/event-stream"));
assert!(should_force_identity_encoding(
"/v1/responses",
&json!({ "model": "gpt-5" }),
&headers
));
}
#[test]
fn non_streaming_requests_allow_automatic_compression() {
let headers = HeaderMap::new();
assert!(!should_force_identity_encoding(
"/v1/responses",
&json!({ "model": "gpt-5" }),
&headers
));
}
}