refactor(proxy): use hyper client for header-case preserving forwarding

Previously the proxy used reqwest for all upstream requests. reqwest
normalizes header names to lowercase and reorders them internally,
making proxied requests distinguishable from direct CLI requests.
Some upstream providers are sensitive to these differences.

This commit replaces reqwest with a hyper-based HTTP client on the
default (non-proxy) path, achieving wire-level header fidelity:

Server layer (server.rs):
- Replace axum::serve with a manual hyper HTTP/1.1 accept loop
- Enable preserve_header_case(true) so incoming header casing is
  captured in a HeaderCaseMap extension on each request
- Bridge hyper requests to axum Router via tower::Service

New hyper client module (hyper_client.rs):
- Lazy-initialized hyper-util Client with preserve_header_case
- ProxyResponse enum wrapping both hyper::Response and reqwest::Response
  behind a unified interface (status, headers, bytes, bytes_stream)
- send_request() builds requests with ordered HeaderMap + case map

Request handlers (handlers.rs):
- Switch from (HeaderMap, Json<Value>) extractors to raw
  axum::extract::Request to preserve Extensions (containing the
  HeaderCaseMap from the accept loop)
- Pass extensions through the forwarding chain

Forwarder (forwarder.rs):
- Remove HEADER_BLACKLIST array; replace with ordered header iteration
  that preserves original header sequence and casing
- Build ordered_headers by iterating client headers, skipping only
  auth/host/content-length, and inserting auth headers at the original
  authorization position to maintain order
- Handle anthropic-beta (ensure claude-code-20250219 tag) and
  anthropic-version (passthrough or default) inline during iteration
- Remove should_force_identity_encoding() — accept-encoding is now
  transparently forwarded to upstream
- Use hyper client by default; fall back to reqwest only when an
  HTTP/SOCKS5 proxy tunnel is configured

Provider adapters (adapter.rs, claude.rs, codex.rs, gemini.rs):
- Replace add_auth_headers(RequestBuilder) -> RequestBuilder with
  get_auth_headers(AuthInfo) -> Vec<(HeaderName, HeaderValue)>
- Adapters now return header pairs instead of mutating a reqwest builder
- Claude adapter: merge Anthropic/ClaudeAuth/Bearer into single branch;
  move Copilot fingerprint headers into get_auth_headers

Response processing (response_processor.rs):
- Add manual decompression (gzip/deflate/brotli via flate2 + brotli)
  for non-streaming responses, since reqwest auto-decompression is now
  disabled to allow accept-encoding passthrough
- Add compressed-SSE warning log for streaming responses
- Accept ProxyResponse instead of reqwest::Response

HTTP client (http_client.rs):
- Disable reqwest auto-decompression (.no_gzip/.no_brotli/.no_deflate)
  on both global and per-provider clients

Streaming adapters (streaming.rs, streaming_responses.rs):
- Generalize stream error type from reqwest::Error to generic E: Error

Misc:
- log_codes.rs: add SRV-005 (ACCEPT_ERR) and SRV-006 (CONN_ERR)
- stream_check.rs: reformat copilot header lines
- transform.rs: fix trailing whitespace alignment
This commit is contained in:
YoVinchen
2026-03-27 15:34:25 +08:00
parent 2c2c72271a
commit 4084b53834
16 changed files with 722 additions and 398 deletions
+243 -224
View File
@@ -2,6 +2,7 @@
//!
//! 负责将请求转发到上游Provider,支持故障转移
use super::hyper_client::ProxyResponse;
use super::{
body_filter::filter_private_params_with_whitelist,
error::*,
@@ -19,67 +20,14 @@ use super::{
use crate::commands::CopilotAuthState;
use crate::proxy::providers::copilot_auth::CopilotAuthManager;
use crate::{app_config::AppType, provider::Provider};
use reqwest::Response;
use http::Extensions;
use serde_json::Value;
use std::sync::Arc;
use tauri::Manager;
use tokio::sync::RwLock;
/// Headers 黑名单 - 不透传到上游的 Headers
///
/// 精简版黑名单,只过滤必须覆盖或可能导致问题的 header
/// 参考成功透传的请求,保留更多原始 header
///
/// 注意:客户端 IP 类(x-forwarded-for, x-real-ip)默认透传
const HEADER_BLACKLIST: &[&str] = &[
// 认证类(会被覆盖)
"authorization",
"x-api-key",
"x-goog-api-key",
// 连接类(由 HTTP 客户端管理)
"host",
"content-length",
"transfer-encoding",
// 编码类(会被覆盖为 identity)
"accept-encoding",
// 代理转发类(保留 x-forwarded-for 和 x-real-ip
"x-forwarded-host",
"x-forwarded-port",
"x-forwarded-proto",
"forwarded",
// CDN/云服务商特定头
"cf-connecting-ip",
"cf-ipcountry",
"cf-ray",
"cf-visitor",
"true-client-ip",
"fastly-client-ip",
"x-azure-clientip",
"x-azure-fdid",
"x-azure-ref",
"akamai-origin-hop",
"x-akamai-config-log-detail",
// 请求追踪类
"x-request-id",
"x-correlation-id",
"x-trace-id",
"x-amzn-trace-id",
"x-b3-traceid",
"x-b3-spanid",
"x-b3-parentspanid",
"x-b3-sampled",
"traceparent",
"tracestate",
// anthropic 特定头单独处理,避免重复
"anthropic-beta",
"anthropic-version",
// 客户端 IP 单独处理(默认透传)
"x-forwarded-for",
"x-real-ip",
];
pub struct ForwardResult {
pub response: Response,
pub response: ProxyResponse,
pub provider: Provider,
}
@@ -149,6 +97,7 @@ impl RequestForwarder {
endpoint: &str,
body: Value,
headers: axum::http::HeaderMap,
extensions: Extensions,
providers: Vec<Provider>,
) -> Result<ForwardResult, ForwardError> {
// 获取适配器
@@ -225,6 +174,7 @@ impl RequestForwarder {
endpoint,
&provider_body,
&headers,
&extensions,
adapter.as_ref(),
)
.await
@@ -353,6 +303,7 @@ impl RequestForwarder {
endpoint,
&provider_body,
&headers,
&extensions,
adapter.as_ref(),
)
.await
@@ -550,6 +501,7 @@ impl RequestForwarder {
endpoint,
&provider_body,
&headers,
&extensions,
adapter.as_ref(),
)
.await
@@ -787,8 +739,9 @@ impl RequestForwarder {
endpoint: &str,
body: &Value,
headers: &axum::http::HeaderMap,
extensions: &Extensions,
adapter: &dyn ProviderAdapter,
) -> Result<Response, ProxyError> {
) -> Result<ProxyResponse, ProxyError> {
// 使用适配器提取 base_url
let base_url = adapter.extract_base_url(provider)?;
@@ -842,86 +795,8 @@ impl RequestForwarder {
// 默认使用空白名单,过滤所有 _ 前缀字段
let filtered_body = filter_private_params_with_whitelist(request_body, &[]);
// 获取 HTTP 客户端:优先使用供应商单独代理配置,否则使用全局客户端
let proxy_config = provider.meta.as_ref().and_then(|m| m.proxy_config.as_ref());
let client = super::http_client::get_for_provider(proxy_config);
let mut request = client.post(&url);
// 只有当 timeout > 0 时才设置请求超时
// Duration::ZERO 在 reqwest 中表示"立刻超时"而不是"禁用超时"
// 故障转移关闭时会传入 0,此时应该使用 client 的默认超时(600秒)
if !self.non_streaming_timeout.is_zero() {
request = request.timeout(self.non_streaming_timeout);
}
// 过滤黑名单 Headers,保护隐私并避免冲突
for (key, value) in headers {
let key_str = key.as_str();
if HEADER_BLACKLIST
.iter()
.any(|h| key_str.eq_ignore_ascii_case(h))
{
continue;
}
// Copilot 请求:过滤会由 add_auth_headers 注入的固定指纹头,
// 防止客户端原始头与注入头重复(reqwest header() 是追加语义)
if is_copilot
&& (key_str.eq_ignore_ascii_case("user-agent")
|| key_str.eq_ignore_ascii_case("editor-version")
|| key_str.eq_ignore_ascii_case("editor-plugin-version")
|| key_str.eq_ignore_ascii_case("copilot-integration-id")
|| key_str.eq_ignore_ascii_case("x-github-api-version")
|| key_str.eq_ignore_ascii_case("openai-intent"))
{
continue;
}
request = request.header(key, value);
}
// 处理 anthropic-beta Header(仅 Claude
// 关键:确保包含 claude-code-20250219 标记,这是上游服务验证请求来源的依据
// 如果客户端发送的 beta 标记中没有包含 claude-code-20250219,需要补充
if adapter.name() == "Claude" {
const CLAUDE_CODE_BETA: &str = "claude-code-20250219";
let beta_value = if let Some(beta) = headers.get("anthropic-beta") {
if let Ok(beta_str) = beta.to_str() {
// 检查是否已包含 claude-code-20250219
if beta_str.contains(CLAUDE_CODE_BETA) {
beta_str.to_string()
} else {
// 补充 claude-code-20250219
format!("{CLAUDE_CODE_BETA},{beta_str}")
}
} else {
CLAUDE_CODE_BETA.to_string()
}
} else {
// 如果客户端没有发送,使用默认值
CLAUDE_CODE_BETA.to_string()
};
request = request.header("anthropic-beta", &beta_value);
}
// 客户端 IP 透传(默认开启)
if let Some(xff) = headers.get("x-forwarded-for") {
if let Ok(xff_str) = xff.to_str() {
request = request.header("x-forwarded-for", xff_str);
}
}
if let Some(real_ip) = headers.get("x-real-ip") {
if let Ok(real_ip_str) = real_ip.to_str() {
request = request.header("x-real-ip", real_ip_str);
}
}
// 流式请求保守禁用压缩,避免上游压缩 SSE 在连接中断时触发解压错误。
// 非流式请求不显式设置 Accept-Encoding,让 reqwest 自动协商压缩并透明解压。
if should_force_identity_encoding(effective_endpoint, &filtered_body, headers) {
request = request.header("accept-encoding", "identity");
}
// 使用适配器添加认证头
if let Some(mut auth) = adapter.extract_auth(provider) {
// 获取认证头(提前准备,用于内联替换)
let auth_headers = if let Some(mut auth) = adapter.extract_auth(provider) {
// GitHub Copilot 特殊处理:从 CopilotAuthManager 获取真实 token
if auth.strategy == AuthStrategy::GitHubCopilot {
if let Some(app_handle) = &self.app_handle {
@@ -972,17 +847,192 @@ impl RequestForwarder {
));
}
}
request = adapter.add_auth_headers(request, &auth);
adapter.get_auth_headers(&auth)
} else {
Vec::new()
};
// Copilot 指纹头名(由 get_auth_headers 注入,需在原始头中去重)
let copilot_fingerprint_headers: &[&str] = if is_copilot {
&[
"user-agent",
"editor-version",
"editor-plugin-version",
"copilot-integration-id",
"x-github-api-version",
"openai-intent",
]
} else {
&[]
};
// 预计算 anthropic-beta 值(仅 Claude
let anthropic_beta_value = if adapter.name() == "Claude" {
const CLAUDE_CODE_BETA: &str = "claude-code-20250219";
Some(if let Some(beta) = headers.get("anthropic-beta") {
if let Ok(beta_str) = beta.to_str() {
if beta_str.contains(CLAUDE_CODE_BETA) {
beta_str.to_string()
} else {
format!("{CLAUDE_CODE_BETA},{beta_str}")
}
} else {
CLAUDE_CODE_BETA.to_string()
}
} else {
CLAUDE_CODE_BETA.to_string()
})
} else {
None
};
// ============================================================
// 构建有序 HeaderMap — 内联替换,保持客户端原始顺序
// ============================================================
let mut ordered_headers = http::HeaderMap::new();
let mut saw_auth = false;
let mut saw_accept_encoding = false;
let mut saw_anthropic_beta = false;
let mut saw_anthropic_version = false;
for (key, value) in headers {
let key_str = key.as_str();
// --- 连接 / 追踪 / CDN 类 — 无条件跳过 ---
if matches!(
key_str,
"host"
| "content-length"
| "transfer-encoding"
| "x-forwarded-host"
| "x-forwarded-port"
| "x-forwarded-proto"
| "forwarded"
| "cf-connecting-ip"
| "cf-ipcountry"
| "cf-ray"
| "cf-visitor"
| "true-client-ip"
| "fastly-client-ip"
| "x-azure-clientip"
| "x-azure-fdid"
| "x-azure-ref"
| "akamai-origin-hop"
| "x-akamai-config-log-detail"
| "x-request-id"
| "x-correlation-id"
| "x-trace-id"
| "x-amzn-trace-id"
| "x-b3-traceid"
| "x-b3-spanid"
| "x-b3-parentspanid"
| "x-b3-sampled"
| "traceparent"
| "tracestate"
) {
continue;
}
// --- 认证类 — 用 adapter 提供的认证头替换(在原始位置) ---
if key_str.eq_ignore_ascii_case("authorization")
|| key_str.eq_ignore_ascii_case("x-api-key")
|| key_str.eq_ignore_ascii_case("x-goog-api-key")
{
if !saw_auth {
saw_auth = true;
for (ah_name, ah_value) in &auth_headers {
ordered_headers.append(ah_name.clone(), ah_value.clone());
}
}
continue;
}
// --- accept-encoding — 替换为与直连 HTTPS 一致的值 ---
if key_str.eq_ignore_ascii_case("accept-encoding") {
if !saw_accept_encoding {
saw_accept_encoding = true;
ordered_headers.append(
http::header::ACCEPT_ENCODING,
http::HeaderValue::from_static("br, gzip, deflate"),
);
}
continue;
}
// --- anthropic-beta — 用重建值替换(确保含 claude-code 标记) ---
if key_str.eq_ignore_ascii_case("anthropic-beta") {
if !saw_anthropic_beta {
saw_anthropic_beta = true;
if let Some(ref beta_val) = anthropic_beta_value {
if let Ok(hv) = http::HeaderValue::from_str(beta_val) {
ordered_headers.append("anthropic-beta", hv);
}
}
}
continue;
}
// --- anthropic-version — 透传客户端值 ---
if key_str.eq_ignore_ascii_case("anthropic-version") {
saw_anthropic_version = true;
ordered_headers.append(key.clone(), value.clone());
continue;
}
// --- Copilot 指纹头 — 跳过(由 auth_headers 提供) ---
if copilot_fingerprint_headers
.iter()
.any(|h| key_str.eq_ignore_ascii_case(h))
{
continue;
}
// --- 默认:透传 ---
ordered_headers.append(key.clone(), value.clone());
}
// anthropic-version 统一处理(仅 Claude):优先使用客户端的版本号,否则使用默认值
// 注意:只设置一次,避免重复
if adapter.name() == "Claude" {
let version_str = headers
.get("anthropic-version")
.and_then(|v| v.to_str().ok())
.unwrap_or("2023-06-01");
request = request.header("anthropic-version", version_str);
// 如果原始请求中没有认证头,在末尾追加
if !saw_auth && !auth_headers.is_empty() {
for (ah_name, ah_value) in &auth_headers {
ordered_headers.append(ah_name.clone(), ah_value.clone());
}
}
// 如果原始请求中没有 accept-encoding,追加
if !saw_accept_encoding {
ordered_headers.append(
http::header::ACCEPT_ENCODING,
http::HeaderValue::from_static("br, gzip, deflate"),
);
}
// 如果原始请求中没有 anthropic-beta 且有值需要添加,追加
if !saw_anthropic_beta {
if let Some(ref beta_val) = anthropic_beta_value {
if let Ok(hv) = http::HeaderValue::from_str(beta_val) {
ordered_headers.append("anthropic-beta", hv);
}
}
}
// anthropic-version:仅在缺失时补充默认值
if adapter.name() == "Claude" && !saw_anthropic_version {
ordered_headers.append(
"anthropic-version",
http::HeaderValue::from_static("2023-06-01"),
);
}
// 序列化请求体
let body_bytes = serde_json::to_vec(&filtered_body)
.map_err(|e| ProxyError::Internal(format!("Failed to serialize request body: {e}")))?;
// 确保 content-type 存在
if !ordered_headers.contains_key(http::header::CONTENT_TYPE) {
ordered_headers.insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/json"),
);
}
// 输出请求信息日志
@@ -1000,16 +1050,55 @@ impl RequestForwarder {
);
}
// 确定超时
let timeout = if self.non_streaming_timeout.is_zero() {
std::time::Duration::from_secs(600) // 默认 600 秒
} else {
self.non_streaming_timeout
};
// 检查是否需要通过代理发送(供应商单独代理或全局代理)
let proxy_config = provider.meta.as_ref().and_then(|m| m.proxy_config.as_ref());
let use_reqwest_proxy = proxy_config.map(|c| c.enabled).unwrap_or(false)
|| super::http_client::get_current_proxy_url().is_some();
let uri: http::Uri = url
.parse()
.map_err(|e| ProxyError::ForwardFailed(format!("Invalid URL '{url}': {e}")))?;
// 发送请求
let response = request.json(&filtered_body).send().await.map_err(|e| {
if e.is_timeout() {
ProxyError::Timeout(format!("请求超时: {e}"))
} else if e.is_connect() {
ProxyError::ForwardFailed(format!("连接失败: {e}"))
} else {
ProxyError::ForwardFailed(e.to_string())
let response = if use_reqwest_proxy {
// 回退到 reqwest(支持 HTTP/SOCKS5 代理隧道)
let client = super::http_client::get_for_provider(proxy_config);
let mut request = client.post(&url);
if !self.non_streaming_timeout.is_zero() {
request = request.timeout(self.non_streaming_timeout);
}
})?;
for (key, value) in &ordered_headers {
request = request.header(key, value);
}
let reqwest_resp = request.body(body_bytes).send().await.map_err(|e| {
if e.is_timeout() {
ProxyError::Timeout(format!("请求超时: {e}"))
} else if e.is_connect() {
ProxyError::ForwardFailed(format!("连接失败: {e}"))
} else {
ProxyError::ForwardFailed(e.to_string())
}
})?;
ProxyResponse::Reqwest(reqwest_resp)
} else {
// 主路径:使用 hyper client(保持 header case + order
super::hyper_client::send_request(
uri,
http::Method::POST,
ordered_headers,
extensions.clone(),
body_bytes,
timeout,
)
.await?
};
// 检查响应状态
let status = response.status();
@@ -1018,7 +1107,7 @@ impl RequestForwarder {
Ok(response)
} else {
let status_code = status.as_u16();
let body_text = response.text().await.ok();
let body_text = String::from_utf8(response.bytes().await?.to_vec()).ok();
Err(ProxyError::UpstreamError {
status: status_code,
@@ -1173,30 +1262,6 @@ fn extract_json_error_message(body: &Value) -> Option<String> {
.find_map(|value| value.as_str().map(ToString::to_string))
}
fn should_force_identity_encoding(
endpoint: &str,
body: &Value,
headers: &axum::http::HeaderMap,
) -> bool {
if body
.get("stream")
.and_then(|value| value.as_bool())
.unwrap_or(false)
{
return true;
}
if endpoint.contains("streamGenerateContent") || endpoint.contains("alt=sse") {
return true;
}
headers
.get(axum::http::header::ACCEPT)
.and_then(|value| value.to_str().ok())
.map(|accept| accept.contains("text/event-stream"))
.unwrap_or(false)
}
fn summarize_text_for_log(text: &str, max_chars: usize) -> String {
let normalized = text.split_whitespace().collect::<Vec<_>>().join(" ");
let trimmed = normalized.trim();
@@ -1213,7 +1278,6 @@ fn summarize_text_for_log(text: &str, max_chars: usize) -> String {
#[cfg(test)]
mod tests {
use super::*;
use axum::http::{header::ACCEPT, HeaderMap, HeaderValue};
use serde_json::json;
#[test]
@@ -1280,49 +1344,4 @@ mod tests {
assert_eq!(summary, "line1 line2...");
}
#[test]
fn force_identity_for_stream_flag_requests() {
let headers = HeaderMap::new();
assert!(should_force_identity_encoding(
"/v1/responses",
&json!({ "stream": true }),
&headers
));
}
#[test]
fn force_identity_for_gemini_stream_endpoints() {
let headers = HeaderMap::new();
assert!(should_force_identity_encoding(
"/v1beta/models/gemini-2.5-pro:streamGenerateContent?alt=sse",
&json!({ "model": "gemini-2.5-pro" }),
&headers
));
}
#[test]
fn force_identity_for_sse_accept_header() {
let mut headers = HeaderMap::new();
headers.insert(ACCEPT, HeaderValue::from_static("text/event-stream"));
assert!(should_force_identity_encoding(
"/v1/responses",
&json!({ "model": "gpt-5" }),
&headers
));
}
#[test]
fn non_streaming_requests_allow_automatic_compression() {
let headers = HeaderMap::new();
assert!(!should_force_identity_encoding(
"/v1/responses",
&json!({ "model": "gpt-5" }),
&headers
));
}
}
+68 -15
View File
@@ -27,6 +27,7 @@ use super::{
use crate::app_config::AppType;
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
use bytes::Bytes;
use http_body_util::BodyExt;
use serde_json::{json, Value};
// ============================================================================
@@ -61,9 +62,19 @@ pub async fn get_status(State(state): State<ProxyState>) -> Result<Json<ProxySta
/// - 现在 OpenRouter 已推出 Claude Code 兼容接口,默认不再启用该转换(逻辑保留以备回退)
pub async fn handle_messages(
State(state): State<ProxyState>,
headers: axum::http::HeaderMap,
Json(body): Json<Value>,
request: axum::extract::Request,
) -> Result<axum::response::Response, ProxyError> {
let (parts, body) = request.into_parts();
let headers = parts.headers;
let extensions = parts.extensions;
let body_bytes = body
.collect()
.await
.map_err(|e| ProxyError::Internal(format!("Failed to read request body: {e}")))?
.to_bytes();
let body: Value = serde_json::from_slice(&body_bytes)
.map_err(|e| ProxyError::Internal(format!("Failed to parse request body: {e}")))?;
let mut ctx =
RequestContext::new(&state, &body, &headers, AppType::Claude, "Claude", "claude").await?;
@@ -80,6 +91,7 @@ pub async fn handle_messages(
"/v1/messages",
body.clone(),
headers,
extensions,
ctx.get_providers(),
)
.await
@@ -114,7 +126,7 @@ pub async fn handle_messages(
///
/// 支持 OpenAI Chat Completions 和 Responses API 两种格式的转换
async fn handle_claude_transform(
response: reqwest::Response,
response: super::hyper_client::ProxyResponse,
ctx: &RequestContext,
state: &ProxyState,
_original_body: &Value,
@@ -201,10 +213,7 @@ async fn handle_claude_transform(
// 非流式响应转换 (OpenAI/Responses → Anthropic)
let response_headers = response.headers().clone();
let body_bytes = response.bytes().await.map_err(|e| {
log::error!("[Claude] 读取响应体失败: {e}");
ProxyError::ForwardFailed(format!("Failed to read response body: {e}"))
})?;
let body_bytes = response.bytes().await?;
let body_str = String::from_utf8_lossy(&body_bytes);
@@ -287,9 +296,19 @@ async fn handle_claude_transform(
/// 处理 /v1/chat/completions 请求(OpenAI Chat Completions API - Codex CLI
pub async fn handle_chat_completions(
State(state): State<ProxyState>,
headers: axum::http::HeaderMap,
Json(body): Json<Value>,
request: axum::extract::Request,
) -> Result<axum::response::Response, ProxyError> {
let (parts, req_body) = request.into_parts();
let headers = parts.headers;
let extensions = parts.extensions;
let body_bytes = req_body
.collect()
.await
.map_err(|e| ProxyError::Internal(format!("Failed to read request body: {e}")))?
.to_bytes();
let body: Value = serde_json::from_slice(&body_bytes)
.map_err(|e| ProxyError::Internal(format!("Failed to parse request body: {e}")))?;
let mut ctx =
RequestContext::new(&state, &body, &headers, AppType::Codex, "Codex", "codex").await?;
@@ -305,6 +324,7 @@ pub async fn handle_chat_completions(
"/chat/completions",
body,
headers,
extensions,
ctx.get_providers(),
)
.await
@@ -328,9 +348,19 @@ pub async fn handle_chat_completions(
/// 处理 /v1/responses 请求(OpenAI Responses API - Codex CLI 透传)
pub async fn handle_responses(
State(state): State<ProxyState>,
headers: axum::http::HeaderMap,
Json(body): Json<Value>,
request: axum::extract::Request,
) -> Result<axum::response::Response, ProxyError> {
let (parts, req_body) = request.into_parts();
let headers = parts.headers;
let extensions = parts.extensions;
let body_bytes = req_body
.collect()
.await
.map_err(|e| ProxyError::Internal(format!("Failed to read request body: {e}")))?
.to_bytes();
let body: Value = serde_json::from_slice(&body_bytes)
.map_err(|e| ProxyError::Internal(format!("Failed to parse request body: {e}")))?;
let mut ctx =
RequestContext::new(&state, &body, &headers, AppType::Codex, "Codex", "codex").await?;
@@ -346,6 +376,7 @@ pub async fn handle_responses(
"/responses",
body,
headers,
extensions,
ctx.get_providers(),
)
.await
@@ -369,9 +400,19 @@ pub async fn handle_responses(
/// 处理 /v1/responses/compact 请求(OpenAI Responses Compact API - Codex CLI 透传)
pub async fn handle_responses_compact(
State(state): State<ProxyState>,
headers: axum::http::HeaderMap,
Json(body): Json<Value>,
request: axum::extract::Request,
) -> Result<axum::response::Response, ProxyError> {
let (parts, req_body) = request.into_parts();
let headers = parts.headers;
let extensions = parts.extensions;
let body_bytes = req_body
.collect()
.await
.map_err(|e| ProxyError::Internal(format!("Failed to read request body: {e}")))?
.to_bytes();
let body: Value = serde_json::from_slice(&body_bytes)
.map_err(|e| ProxyError::Internal(format!("Failed to parse request body: {e}")))?;
let mut ctx =
RequestContext::new(&state, &body, &headers, AppType::Codex, "Codex", "codex").await?;
@@ -387,6 +428,7 @@ pub async fn handle_responses_compact(
"/responses/compact",
body,
headers,
extensions,
ctx.get_providers(),
)
.await
@@ -415,9 +457,19 @@ pub async fn handle_responses_compact(
pub async fn handle_gemini(
State(state): State<ProxyState>,
uri: axum::http::Uri,
headers: axum::http::HeaderMap,
Json(body): Json<Value>,
request: axum::extract::Request,
) -> Result<axum::response::Response, ProxyError> {
let (parts, req_body) = request.into_parts();
let headers = parts.headers;
let extensions = parts.extensions;
let body_bytes = req_body
.collect()
.await
.map_err(|e| ProxyError::Internal(format!("Failed to read request body: {e}")))?
.to_bytes();
let body: Value = serde_json::from_slice(&body_bytes)
.map_err(|e| ProxyError::Internal(format!("Failed to parse request body: {e}")))?;
// Gemini 的模型名称在 URI 中
let mut ctx = RequestContext::new(&state, &body, &headers, AppType::Gemini, "Gemini", "gemini")
.await?
@@ -441,6 +493,7 @@ pub async fn handle_gemini(
endpoint,
body,
headers,
extensions,
ctx.get_providers(),
)
.await
+9 -1
View File
@@ -219,7 +219,12 @@ fn build_client(proxy_url: Option<&str>) -> Result<Client, String> {
.timeout(Duration::from_secs(600))
.connect_timeout(Duration::from_secs(30))
.pool_max_idle_per_host(10)
.tcp_keepalive(Duration::from_secs(60));
.tcp_keepalive(Duration::from_secs(60))
// 禁用 reqwest 自动解压:防止 reqwest 覆盖客户端原始 accept-encoding header。
// 响应解压由 response_processor 根据 content-encoding 手动处理。
.no_gzip()
.no_brotli()
.no_deflate();
// 有代理地址则使用代理,否则跟随系统代理
if let Some(url) = proxy_url {
@@ -387,6 +392,9 @@ pub fn build_client_for_provider(proxy_config: Option<&ProviderProxyConfig>) ->
.connect_timeout(Duration::from_secs(30))
.pool_max_idle_per_host(10)
.tcp_keepalive(Duration::from_secs(60))
.no_gzip()
.no_brotli()
.no_deflate()
.proxy(proxy)
.build()
{
+178
View File
@@ -0,0 +1,178 @@
//! Hyper-based HTTP client for proxy forwarding
//!
//! Uses hyper directly (instead of reqwest) to support:
//! - `preserve_header_case(true)` — keeps original header name casing
//! - Header order preservation via `HeaderCaseMap` extension transfer
//!
//! Falls back to reqwest when an upstream proxy (HTTP/SOCKS5) is configured,
//! since hyper-util's legacy client doesn't natively support proxy tunneling.
use super::ProxyError;
use bytes::Bytes;
use futures::stream::Stream;
use http_body_util::BodyExt;
use hyper_rustls::HttpsConnectorBuilder;
use hyper_util::{client::legacy::Client, rt::TokioExecutor};
use std::sync::OnceLock;
type HyperClient = Client<
hyper_rustls::HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>,
http_body_util::Full<Bytes>,
>;
/// Lazily-initialized hyper client with header-case preservation enabled.
fn global_hyper_client() -> &'static HyperClient {
static CLIENT: OnceLock<HyperClient> = OnceLock::new();
CLIENT.get_or_init(|| {
let connector = HttpsConnectorBuilder::new()
.with_webpki_roots()
.https_or_http()
.enable_http1()
.build();
Client::builder(TokioExecutor::new())
.http1_preserve_header_case(true)
.build(connector)
})
}
/// Unified response wrapper that can hold either a hyper or reqwest response.
///
/// The hyper variant is used for the main (direct) path with header-case preservation.
/// The reqwest variant is the fallback when an upstream HTTP/SOCKS5 proxy is configured.
pub enum ProxyResponse {
Hyper(hyper::Response<hyper::body::Incoming>),
Reqwest(reqwest::Response),
}
impl ProxyResponse {
pub fn status(&self) -> http::StatusCode {
match self {
Self::Hyper(r) => r.status(),
Self::Reqwest(r) => r.status(),
}
}
pub fn headers(&self) -> &http::HeaderMap {
match self {
Self::Hyper(r) => r.headers(),
Self::Reqwest(r) => r.headers(),
}
}
/// Shortcut: extract `content-type` header value as `&str`.
pub fn content_type(&self) -> Option<&str> {
self.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
}
/// Check if the response is an SSE stream.
pub fn is_sse(&self) -> bool {
self.content_type()
.map(|ct| ct.contains("text/event-stream"))
.unwrap_or(false)
}
/// Consume the response and collect the full body into `Bytes`.
pub async fn bytes(self) -> Result<Bytes, ProxyError> {
match self {
Self::Hyper(r) => {
let collected = r.into_body().collect().await.map_err(|e| {
ProxyError::ForwardFailed(format!("Failed to read response body: {e}"))
})?;
Ok(collected.to_bytes())
}
Self::Reqwest(r) => r.bytes().await.map_err(|e| {
ProxyError::ForwardFailed(format!("Failed to read response body: {e}"))
}),
}
}
/// Consume the response and return a byte-chunk stream (for SSE pass-through).
pub fn bytes_stream(self) -> impl Stream<Item = Result<Bytes, std::io::Error>> + Send {
use futures::StreamExt;
match self {
Self::Hyper(r) => {
let body = r.into_body();
let stream = futures::stream::unfold(body, |mut body| async {
match body.frame().await {
Some(Ok(frame)) => {
if let Ok(data) = frame.into_data() {
if data.is_empty() {
Some((Ok(Bytes::new()), body))
} else {
Some((Ok(data), body))
}
} else {
Some((Ok(Bytes::new()), body))
}
}
Some(Err(e)) => Some((Err(std::io::Error::other(e.to_string())), body)),
None => None,
}
})
.filter(|result| {
futures::future::ready(!matches!(result, Ok(ref b) if b.is_empty()))
});
Box::pin(stream)
as std::pin::Pin<Box<dyn Stream<Item = Result<Bytes, std::io::Error>> + Send>>
}
Self::Reqwest(r) => {
let stream = r
.bytes_stream()
.map(|r| r.map_err(|e| std::io::Error::other(e.to_string())));
Box::pin(stream)
}
}
}
}
/// Send an HTTP request via the global hyper client (with header-case preservation).
///
/// `original_extensions` should carry the `HeaderCaseMap` populated by the
/// server-side hyper parser (via `preserve_header_case(true)`).
/// The hyper client will read it back and serialise headers with the original casing.
pub async fn send_request(
uri: http::Uri,
method: http::Method,
headers: http::HeaderMap,
original_extensions: http::Extensions,
body: Vec<u8>,
timeout: std::time::Duration,
) -> Result<ProxyResponse, ProxyError> {
let mut req = http::Request::builder()
.method(method)
.uri(&uri)
.body(http_body_util::Full::new(Bytes::from(body)))
.map_err(|e| ProxyError::ForwardFailed(format!("Failed to build request: {e}")))?;
// Set headers (order is preserved by http::HeaderMap insertion order)
*req.headers_mut() = headers;
// Transfer extensions from the incoming request — this carries the internal
// `HeaderCaseMap` that tells the hyper client how to case each header name.
// Debug: check extension count before transfer
log::debug!(
"[HyperClient] Transferring extensions to outgoing request (uri={})",
uri
);
*req.extensions_mut() = original_extensions;
let client = global_hyper_client();
let resp = tokio::time::timeout(timeout, client.request(req))
.await
.map_err(|_| ProxyError::Timeout(format!("请求超时: {}s", timeout.as_secs())))?
.map_err(|e| {
let msg = e.to_string();
if msg.contains("connect") {
ProxyError::ForwardFailed(format!("连接失败: {e}"))
} else {
ProxyError::ForwardFailed(e.to_string())
}
})?;
Ok(ProxyResponse::Hyper(resp))
}
+2
View File
@@ -26,6 +26,8 @@ pub mod srv {
pub const STOPPED: &str = "SRV-002";
pub const STOP_TIMEOUT: &str = "SRV-003";
pub const TASK_ERROR: &str = "SRV-004";
pub const ACCEPT_ERR: &str = "SRV-005";
pub const CONN_ERR: &str = "SRV-006";
}
/// 转发器日志码
+1
View File
@@ -14,6 +14,7 @@ pub mod handler_context;
mod handlers;
mod health;
pub mod http_client;
pub mod hyper_client;
pub mod log_codes;
pub mod model_mapper;
pub mod provider_router;
+4 -85
View File
@@ -5,7 +5,6 @@
use super::auth::AuthInfo;
use crate::provider::Provider;
use crate::proxy::error::ProxyError;
use reqwest::RequestBuilder;
use serde_json::Value;
/// 供应商适配器 Trait
@@ -14,116 +13,36 @@ use serde_json::Value;
/// - URL 构建
/// - 认证信息提取和头部注入
/// - 请求/响应格式转换(可选)
///
/// # 示例
///
/// ```ignore
/// pub struct ClaudeAdapter;
///
/// impl ProviderAdapter for ClaudeAdapter {
/// fn name(&self) -> &'static str { "Claude" }
///
/// fn extract_base_url(&self, provider: &Provider) -> Result<String, ProxyError> {
/// // 从 provider 配置中提取 base_url
/// }
///
/// fn extract_auth(&self, provider: &Provider) -> Option<AuthInfo> {
/// // 从 provider 配置中提取认证信息
/// }
///
/// fn build_url(&self, base_url: &str, endpoint: &str) -> String {
/// format!("{}{}", base_url.trim_end_matches('/'), endpoint)
/// }
///
/// fn add_auth_headers(&self, request: RequestBuilder, auth: &AuthInfo) -> RequestBuilder {
/// // 添加认证头
/// }
/// }
/// ```
pub trait ProviderAdapter: Send + Sync {
/// 适配器名称(用于日志和调试)
fn name(&self) -> &'static str;
/// 从 Provider 配置中提取 base_url
///
/// # Arguments
/// * `provider` - Provider 配置
///
/// # Returns
/// * `Ok(String)` - 提取到的 base_url(已去除尾部斜杠)
/// * `Err(ProxyError)` - 提取失败
fn extract_base_url(&self, provider: &Provider) -> Result<String, ProxyError>;
/// 从 Provider 配置中提取认证信息
///
/// # Arguments
/// * `provider` - Provider 配置
///
/// # Returns
/// * `Some(AuthInfo)` - 提取到的认证信息
/// * `None` - 未找到认证信息
fn extract_auth(&self, provider: &Provider) -> Option<AuthInfo>;
/// 构建请求 URL
///
/// # Arguments
/// * `base_url` - 基础 URL
/// * `endpoint` - 请求端点(如 `/v1/messages`
///
/// # Returns
/// 完整的请求 URL
fn build_url(&self, base_url: &str, endpoint: &str) -> String;
/// 添加认证头到请求
/// Return auth headers as `(name, value)` pairs.
///
/// # Arguments
/// * `request` - reqwest RequestBuilder
/// * `auth` - 认证信息
///
/// # Returns
/// 添加了认证头的 RequestBuilder
fn add_auth_headers(&self, request: RequestBuilder, auth: &AuthInfo) -> RequestBuilder;
/// The forwarder inserts these at the position of the original auth header
/// so that header order is preserved.
fn get_auth_headers(&self, auth: &AuthInfo) -> Vec<(http::HeaderName, http::HeaderValue)>;
/// 是否需要格式转换
///
/// 默认返回 `false`(透传模式)。
/// 仅当供应商需要格式转换时(如 Claude + OpenRouter 旧 OpenAI 兼容接口)才返回 `true`。
///
/// # Arguments
/// * `provider` - Provider 配置
fn needs_transform(&self, _provider: &Provider) -> bool {
false
}
/// 转换请求体
///
/// 将请求体从一种格式转换为另一种格式(如 Anthropic → OpenAI)。
/// 默认实现直接返回原始请求体(透传)。
///
/// # Arguments
/// * `body` - 原始请求体
/// * `provider` - Provider 配置(用于获取模型映射等)
///
/// # Returns
/// * `Ok(Value)` - 转换后的请求体
/// * `Err(ProxyError)` - 转换失败
fn transform_request(&self, body: Value, _provider: &Provider) -> Result<Value, ProxyError> {
Ok(body)
}
/// 转换响应体
///
/// 将响应体从一种格式转换为另一种格式(如 OpenAI → Anthropic)。
/// 默认实现直接返回原始响应体(透传)。
///
/// # Arguments
/// * `body` - 原始响应体
///
/// # Returns
/// * `Ok(Value)` - 转换后的响应体
/// * `Err(ProxyError)` - 转换失败
///
/// Note: 响应转换将在 handler 层集成,目前预留接口
#[allow(dead_code)]
fn transform_response(&self, body: Value) -> Result<Value, ProxyError> {
Ok(body)
+40 -23
View File
@@ -16,7 +16,6 @@
use super::{AuthInfo, AuthStrategy, ProviderAdapter, ProviderType};
use crate::provider::Provider;
use crate::proxy::error::ProxyError;
use reqwest::RequestBuilder;
/// 获取 Claude 供应商的 API 格式
///
@@ -330,32 +329,50 @@ impl ProviderAdapter for ClaudeAdapter {
}
}
fn add_auth_headers(&self, request: RequestBuilder, auth: &AuthInfo) -> RequestBuilder {
fn get_auth_headers(&self, auth: &AuthInfo) -> Vec<(http::HeaderName, http::HeaderValue)> {
use http::{HeaderName, HeaderValue};
// 注意:anthropic-version 由 forwarder.rs 统一处理(透传客户端值或设置默认值)
// 这里不再设置 anthropic-version,避免 header 重复
let bearer = format!("Bearer {}", auth.api_key);
match auth.strategy {
// Anthropic 官方: Authorization Bearer + x-api-key
AuthStrategy::Anthropic => request
.header("Authorization", format!("Bearer {}", auth.api_key))
.header("x-api-key", &auth.api_key),
// ClaudeAuth 中转服务: 仅 Bearer,无 x-api-key
AuthStrategy::ClaudeAuth => {
request.header("Authorization", format!("Bearer {}", auth.api_key))
AuthStrategy::Anthropic | AuthStrategy::ClaudeAuth | AuthStrategy::Bearer => {
vec![(
HeaderName::from_static("authorization"),
HeaderValue::from_str(&bearer).unwrap(),
)]
}
// OpenRouter: Bearer
AuthStrategy::Bearer => {
request.header("Authorization", format!("Bearer {}", auth.api_key))
AuthStrategy::GitHubCopilot => {
vec![
(
HeaderName::from_static("authorization"),
HeaderValue::from_str(&bearer).unwrap(),
),
(
HeaderName::from_static("editor-version"),
HeaderValue::from_static(super::copilot_auth::COPILOT_EDITOR_VERSION),
),
(
HeaderName::from_static("editor-plugin-version"),
HeaderValue::from_static(super::copilot_auth::COPILOT_PLUGIN_VERSION),
),
(
HeaderName::from_static("copilot-integration-id"),
HeaderValue::from_static(super::copilot_auth::COPILOT_INTEGRATION_ID),
),
(
HeaderName::from_static("user-agent"),
HeaderValue::from_static(super::copilot_auth::COPILOT_USER_AGENT),
),
(
HeaderName::from_static("x-github-api-version"),
HeaderValue::from_static(super::copilot_auth::COPILOT_API_VERSION),
),
(
HeaderName::from_static("openai-intent"),
HeaderValue::from_static("conversation-panel"),
),
]
}
// GitHub Copilot: Bearer + 统一指纹头
AuthStrategy::GitHubCopilot => request
.header("Authorization", format!("Bearer {}", auth.api_key))
.header("editor-version", super::copilot_auth::COPILOT_EDITOR_VERSION)
.header("editor-plugin-version", super::copilot_auth::COPILOT_PLUGIN_VERSION)
.header("copilot-integration-id", super::copilot_auth::COPILOT_INTEGRATION_ID)
.header("user-agent", super::copilot_auth::COPILOT_USER_AGENT)
.header("x-github-api-version", super::copilot_auth::COPILOT_API_VERSION)
.header("openai-intent", "conversation-panel"),
_ => request,
_ => vec![],
}
}
+6 -3
View File
@@ -9,7 +9,6 @@ use super::{AuthInfo, AuthStrategy, ProviderAdapter};
use crate::provider::Provider;
use crate::proxy::error::ProxyError;
use regex::Regex;
use reqwest::RequestBuilder;
use std::sync::LazyLock;
/// 官方 Codex 客户端 User-Agent 正则
@@ -174,8 +173,12 @@ impl ProviderAdapter for CodexAdapter {
url
}
fn add_auth_headers(&self, request: RequestBuilder, auth: &AuthInfo) -> RequestBuilder {
request.header("Authorization", format!("Bearer {}", auth.api_key))
fn get_auth_headers(&self, auth: &AuthInfo) -> Vec<(http::HeaderName, http::HeaderValue)> {
let bearer = format!("Bearer {}", auth.api_key);
vec![(
http::HeaderName::from_static("authorization"),
http::HeaderValue::from_str(&bearer).unwrap(),
)]
}
}
+16 -8
View File
@@ -9,7 +9,6 @@
use super::{AuthInfo, AuthStrategy, ProviderAdapter, ProviderType};
use crate::provider::Provider;
use crate::proxy::error::ProxyError;
use reqwest::RequestBuilder;
/// Gemini 适配器
pub struct GeminiAdapter;
@@ -217,17 +216,26 @@ impl ProviderAdapter for GeminiAdapter {
url
}
fn add_auth_headers(&self, request: RequestBuilder, auth: &AuthInfo) -> RequestBuilder {
fn get_auth_headers(&self, auth: &AuthInfo) -> Vec<(http::HeaderName, http::HeaderValue)> {
use http::{HeaderName, HeaderValue};
match auth.strategy {
// OAuth Bearer 认证
AuthStrategy::GoogleOAuth => {
let token = auth.access_token.as_ref().unwrap_or(&auth.api_key);
request
.header("Authorization", format!("Bearer {token}"))
.header("x-goog-api-client", "GeminiCLI/1.0")
vec![
(
HeaderName::from_static("authorization"),
HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
),
(
HeaderName::from_static("x-goog-api-client"),
HeaderValue::from_static("GeminiCLI/1.0"),
),
]
}
// API Key 认证
_ => request.header("x-goog-api-key", &auth.api_key),
_ => vec![(
HeaderName::from_static("x-goog-api-key"),
HeaderValue::from_str(&auth.api_key).unwrap(),
)],
}
}
}
+8 -4
View File
@@ -88,8 +88,8 @@ struct ToolBlockState {
}
/// 创建 Anthropic SSE 流
pub fn create_anthropic_sse_stream(
stream: impl Stream<Item = Result<Bytes, reqwest::Error>> + Send + 'static,
pub fn create_anthropic_sse_stream<E: std::error::Error + Send + 'static>(
stream: impl Stream<Item = Result<Bytes, E>> + Send + 'static,
) -> impl Stream<Item = Result<Bytes, std::io::Error>> + Send {
async_stream::stream! {
let mut buffer = String::new();
@@ -598,7 +598,9 @@ mod tests {
"data: [DONE]\n\n"
);
let upstream = stream::iter(vec![Ok(Bytes::from(input.as_bytes().to_vec()))]);
let upstream = stream::iter(vec![Ok::<_, std::io::Error>(Bytes::from(
input.as_bytes().to_vec(),
))]);
let converted = create_anthropic_sse_stream(upstream);
let chunks: Vec<_> = converted.collect().await;
@@ -686,7 +688,9 @@ mod tests {
"data: [DONE]\n\n"
);
let upstream = stream::iter(vec![Ok(Bytes::from(input.as_bytes().to_vec()))]);
let upstream = stream::iter(vec![Ok::<_, std::io::Error>(Bytes::from(
input.as_bytes().to_vec(),
))]);
let converted = create_anthropic_sse_stream(upstream);
let chunks: Vec<_> = converted.collect().await;
let merged = chunks
@@ -96,8 +96,8 @@ fn resolve_content_index(
///
/// 状态机跟踪: message_id, current_model, has_sent_message_start, item/content index map
/// SSE 解析支持 named events (event: + data: 行)
pub fn create_anthropic_sse_stream_from_responses(
stream: impl Stream<Item = Result<Bytes, reqwest::Error>> + Send + 'static,
pub fn create_anthropic_sse_stream_from_responses<E: std::error::Error + Send + 'static>(
stream: impl Stream<Item = Result<Bytes, E>> + Send + 'static,
) -> impl Stream<Item = Result<Bytes, std::io::Error>> + Send {
async_stream::stream! {
let mut buffer = String::new();
@@ -758,7 +758,9 @@ mod tests {
"data: {\"type\":\"response.completed\",\"response\":{\"status\":\"completed\",\"usage\":{\"input_tokens\":12,\"output_tokens\":3}}}\n\n"
);
let upstream = stream::iter(vec![Ok(Bytes::from(input.as_bytes().to_vec()))]);
let upstream = stream::iter(vec![Ok::<_, std::io::Error>(Bytes::from(
input.as_bytes().to_vec(),
))]);
let converted = create_anthropic_sse_stream_from_responses(upstream);
let chunks: Vec<_> = converted.collect().await;
@@ -800,7 +802,9 @@ mod tests {
"data: {\"type\":\"response.completed\",\"response\":{\"status\":\"completed\",\"usage\":{\"input_tokens\":8,\"output_tokens\":4}}}\n\n"
);
let upstream = stream::iter(vec![Ok(Bytes::from(input.as_bytes().to_vec()))]);
let upstream = stream::iter(vec![Ok::<_, std::io::Error>(Bytes::from(
input.as_bytes().to_vec(),
))]);
let converted = create_anthropic_sse_stream_from_responses(upstream);
let chunks: Vec<_> = converted.collect().await;
let merged = chunks
@@ -871,7 +875,9 @@ mod tests {
"data: {\"type\":\"response.completed\",\"response\":{\"status\":\"completed\",\"usage\":{\"input_tokens\":5,\"output_tokens\":10}}}\n\n"
);
let upstream = stream::iter(vec![Ok(Bytes::from(input.as_bytes().to_vec()))]);
let upstream = stream::iter(vec![Ok::<_, std::io::Error>(Bytes::from(
input.as_bytes().to_vec(),
))]);
let converted = create_anthropic_sse_stream_from_responses(upstream);
let chunks: Vec<_> = converted.collect().await;
let merged = chunks
+1 -1
View File
@@ -50,7 +50,7 @@ pub fn resolve_reasoning_effort(body: &Value) -> Option<&'static str> {
"medium" => Some("medium"),
"high" => Some("high"),
"max" => Some("xhigh"), // OpenAI xhigh = maximum reasoning effort
_ => None, // unknown value — do not inject
_ => None, // unknown value — do not inject
};
}
+76 -19
View File
@@ -5,17 +5,19 @@
use super::{
handler_config::UsageParserConfig,
handler_context::{RequestContext, StreamingTimeoutConfig},
hyper_client::ProxyResponse,
server::ProxyState,
sse::strip_sse_field,
usage::parser::TokenUsage,
ProxyError,
};
use axum::http::header::HeaderMap;
use axum::response::{IntoResponse, Response};
use bytes::Bytes;
use futures::stream::{Stream, StreamExt};
use reqwest::header::HeaderMap;
use serde_json::Value;
use std::{
io::Read,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
@@ -24,24 +26,61 @@ use std::{
};
use tokio::sync::Mutex;
// ============================================================================
// 响应解压
// ============================================================================
/// 根据 content-encoding 解压响应体字节
///
/// reqwest 自动解压已禁用(为了透传 accept-encoding),需要手动解压。
fn decompress_body(content_encoding: &str, body: &[u8]) -> Result<Vec<u8>, std::io::Error> {
match content_encoding {
"gzip" | "x-gzip" => {
let mut decoder = flate2::read::GzDecoder::new(body);
let mut decompressed = Vec::new();
decoder.read_to_end(&mut decompressed)?;
Ok(decompressed)
}
"deflate" => {
let mut decoder = flate2::read::DeflateDecoder::new(body);
let mut decompressed = Vec::new();
decoder.read_to_end(&mut decompressed)?;
Ok(decompressed)
}
"br" => {
let mut decompressed = Vec::new();
brotli::BrotliDecompress(&mut std::io::Cursor::new(body), &mut decompressed)?;
Ok(decompressed)
}
_ => {
log::warn!("未知的 content-encoding: {content_encoding},跳过解压");
Ok(body.to_vec())
}
}
}
/// 从响应头提取 content-encoding(忽略 identity 和 chunked
fn get_content_encoding(headers: &HeaderMap) -> Option<String> {
headers
.get("content-encoding")
.and_then(|v| v.to_str().ok())
.map(|s| s.trim().to_lowercase())
.filter(|s| !s.is_empty() && s != "identity")
}
// ============================================================================
// 公共接口
// ============================================================================
/// 检测响应是否为 SSE 流式响应
#[inline]
pub fn is_sse_response(response: &reqwest::Response) -> bool {
response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.map(|ct| ct.contains("text/event-stream"))
.unwrap_or(false)
pub fn is_sse_response(response: &ProxyResponse) -> bool {
response.is_sse()
}
/// 处理流式响应
pub async fn handle_streaming(
response: reqwest::Response,
response: ProxyResponse,
ctx: &RequestContext,
state: &ProxyState,
parser_config: &UsageParserConfig,
@@ -53,6 +92,15 @@ pub async fn handle_streaming(
status.as_u16(),
format_headers(response.headers())
);
// 检查流式响应是否被压缩(SSE 通常不压缩,如果压缩则 SSE 解析会失败)
if let Some(encoding) = get_content_encoding(response.headers()) {
log::warn!(
"[{}] 流式响应含 content-encoding={encoding}SSE 解析可能失败。\
上游在 accept-encoding 透传后压缩了 SSE 流。",
ctx.tag
);
}
let mut builder = axum::response::Response::builder().status(status);
// 复制响应头
@@ -61,9 +109,7 @@ pub async fn handle_streaming(
}
// 创建字节流
let stream = response
.bytes_stream()
.map(|chunk| chunk.map_err(|e| std::io::Error::other(e.to_string())));
let stream = response.bytes_stream();
// 创建使用量收集器
let usage_collector = create_usage_collector(ctx, state, status.as_u16(), parser_config);
@@ -87,7 +133,7 @@ pub async fn handle_streaming(
/// 处理非流式响应
pub async fn handle_non_streaming(
response: reqwest::Response,
response: ProxyResponse,
ctx: &RequestContext,
state: &ProxyState,
parser_config: &UsageParserConfig,
@@ -96,18 +142,29 @@ pub async fn handle_non_streaming(
let status = response.status();
// 读取响应体
let body_bytes = response.bytes().await.map_err(|e| {
log::error!("[{}] 读取响应失败: {e}", ctx.tag);
ProxyError::ForwardFailed(format!("Failed to read response body: {e}"))
})?;
let raw_bytes = response.bytes().await?;
log::debug!(
"[{}] 已接收上游响应体: status={}, bytes={}, headers={}",
ctx.tag,
status.as_u16(),
body_bytes.len(),
raw_bytes.len(),
format_headers(&response_headers)
);
// 手动解压(reqwest 自动解压已禁用以透传 accept-encoding
let body_bytes: Bytes = if let Some(encoding) = get_content_encoding(&response_headers) {
log::debug!("[{}] 解压非流式响应: content-encoding={encoding}", ctx.tag);
match decompress_body(&encoding, &raw_bytes) {
Ok(decompressed) => Bytes::from(decompressed),
Err(e) => {
log::warn!("[{}] 解压失败 ({encoding}): {e},使用原始数据", ctx.tag);
raw_bytes
}
}
} else {
raw_bytes
};
log::debug!(
"[{}] 上游响应体内容: {}",
ctx.tag,
@@ -190,7 +247,7 @@ pub async fn handle_non_streaming(
///
/// 根据响应类型自动选择流式或非流式处理
pub async fn process_response(
response: reqwest::Response,
response: ProxyResponse,
ctx: &RequestContext,
state: &ProxyState,
parser_config: &UsageParserConfig,
+50 -7
View File
@@ -1,6 +1,12 @@
//! HTTP代理服务器
//!
//! 基于Axum的HTTP服务器,处理代理请求
//!
//! Uses a manual hyper HTTP/1.1 accept loop with `preserve_header_case(true)` so
//! that the original header-name casing from the CLI client is captured in a
//! `HeaderCaseMap` extension. This map is later forwarded to the upstream via
//! the hyper-based HTTP client, producing wire-level header casing identical to
//! a direct (non-proxied) CLI request.
use super::{
failover_switch::FailoverSwitchManager, handlers, log_codes::srv as log_srv,
@@ -12,6 +18,7 @@ use axum::{
routing::{get, post},
Router,
};
use hyper_util::rt::TokioIo;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::{oneshot, RwLock};
@@ -114,15 +121,51 @@ impl ProxyServer {
// 记录启动时间
*self.state.start_time.write().await = Some(std::time::Instant::now());
// 启动服务器
// 启动服务器 — 使用手动 hyper HTTP/1.1 accept loop
// 开启 preserve_header_case 以捕获客户端请求头的原始大小写
let state = self.state.clone();
let handle = tokio::spawn(async move {
axum::serve(listener, app)
.with_graceful_shutdown(async {
shutdown_rx.await.ok();
})
.await
.ok();
let mut shutdown_rx = shutdown_rx;
loop {
tokio::select! {
result = listener.accept() => {
let (stream, _remote_addr) = match result {
Ok(v) => v,
Err(e) => {
log::error!("[{SRV}] accept 失败: {e}", SRV = log_srv::ACCEPT_ERR);
continue;
}
};
let app = app.clone();
tokio::spawn(async move {
// service_fn 将 axum Routertower::Service)桥接到 hyper
let service = hyper::service::service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
let mut router = app.clone();
async move {
// 将 hyper::body::Incoming 转为 axum::body::Body,保留 extensions
let (parts, body) = req.into_parts();
let body = axum::body::Body::new(body);
let axum_req = http::Request::from_parts(parts, body);
<Router as tower::Service<http::Request<axum::body::Body>>>::call(&mut router, axum_req).await
}
});
if let Err(e) = hyper::server::conn::http1::Builder::new()
.preserve_header_case(true)
.serve_connection(TokioIo::new(stream), service)
.await
{
// Connection reset / broken pipe 等在代理场景下很常见,debug 级别
log::debug!("[{SRV}] connection error: {e}", SRV = log_srv::CONN_ERR);
}
});
}
_ = &mut shutdown_rx => {
break;
}
}
}
// 服务器停止后更新状态
state.status.write().await.running = false;
+9 -3
View File
@@ -12,8 +12,8 @@ use std::time::Instant;
use crate::app_config::AppType;
use crate::error::AppError;
use crate::provider::Provider;
use crate::proxy::providers::transform::anthropic_to_openai;
use crate::proxy::providers::copilot_auth;
use crate::proxy::providers::transform::anthropic_to_openai;
use crate::proxy::providers::{get_adapter, AuthInfo, AuthStrategy};
/// 健康状态枚举
@@ -365,8 +365,14 @@ impl StreamCheckService {
.header("accept-encoding", "identity")
.header("user-agent", copilot_auth::COPILOT_USER_AGENT)
.header("editor-version", copilot_auth::COPILOT_EDITOR_VERSION)
.header("editor-plugin-version", copilot_auth::COPILOT_PLUGIN_VERSION)
.header("copilot-integration-id", copilot_auth::COPILOT_INTEGRATION_ID)
.header(
"editor-plugin-version",
copilot_auth::COPILOT_PLUGIN_VERSION,
)
.header(
"copilot-integration-id",
copilot_auth::COPILOT_INTEGRATION_ID,
)
.header("x-github-api-version", copilot_auth::COPILOT_API_VERSION)
.header("openai-intent", "conversation-panel");
} else if is_openai_chat {