Files
cc-switch/src-tauri/src/proxy/handlers.rs
T
YoVinchen 766ca23eca feat(proxy): extract session ID from client requests for logging
- Add SessionIdExtractor to parse session ID from Claude/Codex requests
- Support extraction from metadata.user_id, headers, previous_response_id
- Pass session_id through RequestContext to usage logger
- Enable request correlation by session in proxy_request_logs
2025-12-31 21:31:20 +08:00

538 lines
17 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! 请求处理器
//!
//! 处理各种API端点的HTTP请求
//!
//! 重构后的结构:
//! - 通用逻辑提取到 `handler_context` 和 `response_processor` 模块
//! - 各 handler 只保留独特的业务逻辑
//! - Claude 的格式转换逻辑保留在此文件(用于 OpenRouter 旧接口回退)
use super::{
error_mapper::{get_error_message, map_proxy_error_to_status},
handler_config::{
CLAUDE_PARSER_CONFIG, CODEX_PARSER_CONFIG, GEMINI_PARSER_CONFIG, OPENAI_PARSER_CONFIG,
},
handler_context::RequestContext,
providers::{get_adapter, streaming::create_anthropic_sse_stream, transform},
response_processor::{create_logged_passthrough_stream, process_response, SseUsageCollector},
server::ProxyState,
types::*,
usage::parser::TokenUsage,
ProxyError,
};
use crate::app_config::AppType;
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
use rust_decimal::Decimal;
use serde_json::{json, Value};
use std::str::FromStr;
// ============================================================================
// 健康检查和状态查询(简单端点)
// ============================================================================
/// 健康检查
pub async fn health_check() -> (StatusCode, Json<Value>) {
(
StatusCode::OK,
Json(json!({
"status": "healthy",
"timestamp": chrono::Utc::now().to_rfc3339(),
})),
)
}
/// 获取服务状态
pub async fn get_status(State(state): State<ProxyState>) -> Result<Json<ProxyStatus>, ProxyError> {
let status = state.status.read().await.clone();
Ok(Json(status))
}
// ============================================================================
// Claude API 处理器(包含格式转换逻辑)
// ============================================================================
/// 处理 /v1/messages 请求(Claude API
///
/// Claude 处理器包含独特的格式转换逻辑:
/// - 过去用于 OpenRouter 的 OpenAI Chat Completions 兼容接口(Anthropic ↔ OpenAI 转换)
/// - 现在 OpenRouter 已推出 Claude Code 兼容接口,默认不再启用该转换(逻辑保留以备回退)
pub async fn handle_messages(
State(state): State<ProxyState>,
headers: axum::http::HeaderMap,
Json(body): Json<Value>,
) -> Result<axum::response::Response, ProxyError> {
let mut ctx =
RequestContext::new(&state, &body, &headers, AppType::Claude, "Claude", "claude").await?;
let is_stream = body
.get("stream")
.and_then(|s| s.as_bool())
.unwrap_or(false);
// 转发请求
let forwarder = ctx.create_forwarder(&state);
let result = match forwarder
.forward_with_retry(
&AppType::Claude,
"/v1/messages",
body.clone(),
headers,
ctx.get_providers(),
)
.await
{
Ok(result) => result,
Err(mut err) => {
if let Some(provider) = err.provider.take() {
ctx.provider = provider;
}
log_forward_error(&state, &ctx, is_stream, &err.error);
return Err(err.error);
}
};
ctx.provider = result.provider;
let response = result.response;
// 检查是否需要格式转换(OpenRouter 等中转服务)
let adapter = get_adapter(&AppType::Claude);
let needs_transform = adapter.needs_transform(&ctx.provider);
log::info!(
"[Claude] Provider: {}, needs_transform: {}, is_stream: {}",
ctx.provider.name,
needs_transform,
is_stream
);
let status = response.status();
log::info!("[Claude] 上游响应状态: {status}");
// Claude 特有:格式转换处理
if needs_transform {
return handle_claude_transform(response, &ctx, &state, &body, is_stream).await;
}
// 通用响应处理(透传模式)
process_response(response, &ctx, &state, &CLAUDE_PARSER_CONFIG).await
}
/// Claude 格式转换处理(独有逻辑)
///
/// 处理 OpenRouter 旧 OpenAI 兼容接口的回退方案(当前默认不启用)
async fn handle_claude_transform(
response: reqwest::Response,
ctx: &RequestContext,
state: &ProxyState,
_original_body: &Value,
is_stream: bool,
) -> Result<axum::response::Response, ProxyError> {
let status = response.status();
if is_stream {
// 流式响应转换 (OpenAI SSE → Anthropic SSE)
log::info!("[Claude] 开始流式响应转换 (OpenAI SSE → Anthropic SSE)");
let stream = response.bytes_stream();
let sse_stream = create_anthropic_sse_stream(stream);
// 创建使用量收集器
let usage_collector = {
let state = state.clone();
let provider_id = ctx.provider.id.clone();
let model = ctx.request_model.clone();
let status_code = status.as_u16();
let start_time = ctx.start_time;
SseUsageCollector::new(start_time, move |events, first_token_ms| {
if let Some(usage) = TokenUsage::from_claude_stream_events(&events) {
let latency_ms = start_time.elapsed().as_millis() as u64;
let state = state.clone();
let provider_id = provider_id.clone();
let model = model.clone();
tokio::spawn(async move {
log_usage(
&state,
&provider_id,
"claude",
&model,
usage,
latency_ms,
first_token_ms,
true,
status_code,
)
.await;
});
} else {
log::debug!("[Claude] OpenRouter 流式响应缺少 usage 统计,跳过消费记录");
}
})
};
// 获取流式超时配置
let timeout_config = ctx.streaming_timeout_config();
let logged_stream = create_logged_passthrough_stream(
sse_stream,
"Claude/OpenRouter",
Some(usage_collector),
timeout_config,
);
let mut headers = axum::http::HeaderMap::new();
headers.insert(
"Content-Type",
axum::http::HeaderValue::from_static("text/event-stream"),
);
headers.insert(
"Cache-Control",
axum::http::HeaderValue::from_static("no-cache"),
);
headers.insert(
"Connection",
axum::http::HeaderValue::from_static("keep-alive"),
);
let body = axum::body::Body::from_stream(logged_stream);
log::info!("[Claude] ====== 请求结束 (流式转换) ======");
return Ok((headers, body).into_response());
}
// 非流式响应转换 (OpenAI → Anthropic)
log::info!("[Claude] 开始转换响应 (OpenAI → Anthropic)");
let response_headers = response.headers().clone();
let body_bytes = response.bytes().await.map_err(|e| {
log::error!("[Claude] 读取响应体失败: {e}");
ProxyError::ForwardFailed(format!("Failed to read response body: {e}"))
})?;
let body_str = String::from_utf8_lossy(&body_bytes);
log::info!("[Claude] OpenAI 响应长度: {} bytes", body_bytes.len());
log::debug!("[Claude] OpenAI 原始响应: {body_str}");
let openai_response: Value = serde_json::from_slice(&body_bytes).map_err(|e| {
log::error!("[Claude] 解析 OpenAI 响应失败: {e}, body: {body_str}");
ProxyError::TransformError(format!("Failed to parse OpenAI response: {e}"))
})?;
log::info!("[Claude] 解析 OpenAI 响应成功");
log::info!(
"[Claude] <<< OpenAI 响应 JSON:\n{}",
serde_json::to_string_pretty(&openai_response).unwrap_or_default()
);
let anthropic_response = transform::openai_to_anthropic(openai_response).map_err(|e| {
log::error!("[Claude] 转换响应失败: {e}");
e
})?;
log::info!("[Claude] 转换响应成功");
log::info!(
"[Claude] <<< Anthropic 响应 JSON:\n{}",
serde_json::to_string_pretty(&anthropic_response).unwrap_or_default()
);
// 记录使用量
if let Some(usage) = TokenUsage::from_claude_response(&anthropic_response) {
let model = anthropic_response
.get("model")
.and_then(|m| m.as_str())
.unwrap_or("unknown");
let latency_ms = ctx.latency_ms();
tokio::spawn({
let state = state.clone();
let provider_id = ctx.provider.id.clone();
let model = model.to_string();
async move {
log_usage(
&state,
&provider_id,
"claude",
&model,
usage,
latency_ms,
None,
false,
status.as_u16(),
)
.await;
}
});
}
log::info!("[Claude] ====== 请求结束 ======");
// 构建响应
let mut builder = axum::response::Response::builder().status(status);
for (key, value) in response_headers.iter() {
if key.as_str().to_lowercase() != "content-length"
&& key.as_str().to_lowercase() != "transfer-encoding"
{
builder = builder.header(key, value);
}
}
builder = builder.header("content-type", "application/json");
let response_body = serde_json::to_vec(&anthropic_response).map_err(|e| {
log::error!("[Claude] 序列化响应失败: {e}");
ProxyError::TransformError(format!("Failed to serialize response: {e}"))
})?;
log::info!(
"[Claude] 返回转换后的响应, 长度: {} bytes",
response_body.len()
);
let body = axum::body::Body::from(response_body);
Ok(builder.body(body).unwrap())
}
// ============================================================================
// Codex API 处理器
// ============================================================================
/// 处理 /v1/chat/completions 请求(OpenAI Chat Completions API - Codex CLI
pub async fn handle_chat_completions(
State(state): State<ProxyState>,
headers: axum::http::HeaderMap,
Json(body): Json<Value>,
) -> Result<axum::response::Response, ProxyError> {
log::info!("[Codex] ====== /v1/chat/completions 请求开始 ======");
let mut ctx =
RequestContext::new(&state, &body, &headers, AppType::Codex, "Codex", "codex").await?;
let is_stream = body
.get("stream")
.and_then(|v| v.as_bool())
.unwrap_or(false);
log::info!(
"[Codex] 请求模型: {}, 流式: {}",
ctx.request_model,
is_stream
);
let forwarder = ctx.create_forwarder(&state);
let result = match forwarder
.forward_with_retry(
&AppType::Codex,
"/v1/chat/completions",
body,
headers,
ctx.get_providers(),
)
.await
{
Ok(result) => result,
Err(mut err) => {
if let Some(provider) = err.provider.take() {
ctx.provider = provider;
}
log_forward_error(&state, &ctx, is_stream, &err.error);
return Err(err.error);
}
};
ctx.provider = result.provider;
let response = result.response;
log::info!("[Codex] 上游响应状态: {}", response.status());
process_response(response, &ctx, &state, &OPENAI_PARSER_CONFIG).await
}
/// 处理 /v1/responses 请求(OpenAI Responses API - Codex CLI 透传)
pub async fn handle_responses(
State(state): State<ProxyState>,
headers: axum::http::HeaderMap,
Json(body): Json<Value>,
) -> Result<axum::response::Response, ProxyError> {
let mut ctx =
RequestContext::new(&state, &body, &headers, AppType::Codex, "Codex", "codex").await?;
let is_stream = body
.get("stream")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let forwarder = ctx.create_forwarder(&state);
let result = match forwarder
.forward_with_retry(
&AppType::Codex,
"/v1/responses",
body,
headers,
ctx.get_providers(),
)
.await
{
Ok(result) => result,
Err(mut err) => {
if let Some(provider) = err.provider.take() {
ctx.provider = provider;
}
log_forward_error(&state, &ctx, is_stream, &err.error);
return Err(err.error);
}
};
ctx.provider = result.provider;
let response = result.response;
log::info!("[Codex] 上游响应状态: {}", response.status());
process_response(response, &ctx, &state, &CODEX_PARSER_CONFIG).await
}
// ============================================================================
// Gemini API 处理器
// ============================================================================
/// 处理 Gemini API 请求(透传,包括查询参数)
pub async fn handle_gemini(
State(state): State<ProxyState>,
uri: axum::http::Uri,
headers: axum::http::HeaderMap,
Json(body): Json<Value>,
) -> Result<axum::response::Response, ProxyError> {
// Gemini 的模型名称在 URI 中
let mut ctx = RequestContext::new(&state, &body, &headers, AppType::Gemini, "Gemini", "gemini")
.await?
.with_model_from_uri(&uri);
// 提取完整的路径和查询参数
let endpoint = uri
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or(uri.path());
log::info!("[Gemini] 请求端点: {endpoint}");
let is_stream = body
.get("stream")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let forwarder = ctx.create_forwarder(&state);
let result = match forwarder
.forward_with_retry(
&AppType::Gemini,
endpoint,
body,
headers,
ctx.get_providers(),
)
.await
{
Ok(result) => result,
Err(mut err) => {
if let Some(provider) = err.provider.take() {
ctx.provider = provider;
}
log_forward_error(&state, &ctx, is_stream, &err.error);
return Err(err.error);
}
};
ctx.provider = result.provider;
let response = result.response;
log::info!("[Gemini] 上游响应状态: {}", response.status());
process_response(response, &ctx, &state, &GEMINI_PARSER_CONFIG).await
}
// ============================================================================
// 使用量记录(保留用于 Claude 转换逻辑)
// ============================================================================
fn log_forward_error(
state: &ProxyState,
ctx: &RequestContext,
is_streaming: bool,
error: &ProxyError,
) {
use super::usage::logger::UsageLogger;
let logger = UsageLogger::new(&state.db);
let status_code = map_proxy_error_to_status(error);
let error_message = get_error_message(error);
let request_id = uuid::Uuid::new_v4().to_string();
if let Err(e) = logger.log_error_with_context(
request_id,
ctx.provider.id.clone(),
ctx.app_type_str.to_string(),
ctx.request_model.clone(),
status_code,
error_message,
ctx.latency_ms(),
is_streaming,
Some(ctx.session_id.clone()),
None,
) {
log::warn!("记录失败请求日志失败: {e}");
}
}
/// 记录请求使用量
#[allow(clippy::too_many_arguments)]
async fn log_usage(
state: &ProxyState,
provider_id: &str,
app_type: &str,
model: &str,
usage: TokenUsage,
latency_ms: u64,
first_token_ms: Option<u64>,
is_streaming: bool,
status_code: u16,
) {
use super::usage::logger::UsageLogger;
let logger = UsageLogger::new(&state.db);
// 获取 provider 的 cost_multiplier
let multiplier = match state.db.get_provider_by_id(provider_id, app_type) {
Ok(Some(p)) => {
if let Some(meta) = p.meta {
if let Some(cm) = meta.cost_multiplier {
Decimal::from_str(&cm).unwrap_or(Decimal::from(1))
} else {
Decimal::from(1)
}
} else {
Decimal::from(1)
}
}
_ => Decimal::from(1),
};
let request_id = uuid::Uuid::new_v4().to_string();
if let Err(e) = logger.log_with_calculation(
request_id,
provider_id.to_string(),
app_type.to_string(),
model.to_string(),
usage,
multiplier,
latency_ms,
first_token_ms,
status_code,
None,
None, // provider_type
is_streaming,
) {
log::warn!("记录使用量失败: {e}");
}
}