From ebe2a665ae2fec35777d4ae759f6710faeeb14ef Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 11 Dec 2025 23:22:05 +0800 Subject: [PATCH] refactor(proxy): modularize handlers.rs to reduce code duplication Extract common request handling logic into dedicated modules: - handler_config.rs: Usage parser configurations for each API type - handler_context.rs: Request lifecycle context management - response_processor.rs: Unified streaming/non-streaming response handling Reduces handlers.rs from ~1130 lines to ~418 lines (-63%), eliminating repeated initialization and response processing patterns across the four API handlers (Claude, Codex Chat, Codex Responses, Gemini). --- src-tauri/src/proxy/handler_config.rs | 164 +++ src-tauri/src/proxy/handler_context.rs | 128 ++ src-tauri/src/proxy/handlers.rs | 1414 +++++---------------- src-tauri/src/proxy/mod.rs | 3 + src-tauri/src/proxy/response_processor.rs | 411 ++++++ 5 files changed, 1057 insertions(+), 1063 deletions(-) create mode 100644 src-tauri/src/proxy/handler_config.rs create mode 100644 src-tauri/src/proxy/handler_context.rs create mode 100644 src-tauri/src/proxy/response_processor.rs diff --git a/src-tauri/src/proxy/handler_config.rs b/src-tauri/src/proxy/handler_config.rs new file mode 100644 index 00000000..5710f0b4 --- /dev/null +++ b/src-tauri/src/proxy/handler_config.rs @@ -0,0 +1,164 @@ +//! Handler 配置模块 +//! +//! 定义各 API 处理器的配置结构和使用量解析器 + +use crate::app_config::AppType; +use crate::proxy::usage::parser::TokenUsage; +use serde_json::Value; + +/// 使用量解析器类型别名 +pub type StreamUsageParser = fn(&[Value]) -> Option; +pub type ResponseUsageParser = fn(&Value) -> Option; + +/// 模型提取器类型别名 +/// 参数: (流式事件列表, 请求中的模型名称) -> 最终使用的模型名称 +pub type StreamModelExtractor = fn(&[Value], &str) -> String; + +/// 各 API 的使用量解析配置 +#[derive(Clone, Copy)] +pub struct UsageParserConfig { + /// 流式响应解析器 + pub stream_parser: StreamUsageParser, + /// 非流式响应解析器 + pub response_parser: ResponseUsageParser, + /// 流式响应中的模型提取器 + pub model_extractor: StreamModelExtractor, + /// 应用类型字符串(用于日志记录) + pub app_type_str: &'static str, +} + +// ============================================================================ +// 模型提取器实现 +// ============================================================================ + +/// Claude 流式响应模型提取(直接使用请求模型) +fn claude_model_extractor(_events: &[Value], request_model: &str) -> String { + request_model.to_string() +} + +/// OpenAI Chat Completions 流式响应模型提取 +fn openai_model_extractor(events: &[Value], request_model: &str) -> String { + events + .iter() + .find_map(|e| e.get("model")?.as_str()) + .unwrap_or(request_model) + .to_string() +} + +/// Codex Responses API 流式响应模型提取 +fn codex_model_extractor(events: &[Value], request_model: &str) -> String { + events + .iter() + .find_map(|e| { + if e.get("type")?.as_str()? == "response.completed" { + e.get("response")?.get("model")?.as_str() + } else { + None + } + }) + .unwrap_or(request_model) + .to_string() +} + +/// Gemini 流式响应模型提取(优先使用 usage.model) +fn gemini_model_extractor(events: &[Value], request_model: &str) -> String { + // 首先尝试从解析的 usage 中获取模型 + if let Some(usage) = TokenUsage::from_gemini_stream_chunks(events) { + if let Some(model) = usage.model { + return model; + } + } + request_model.to_string() +} + +// ============================================================================ +// 预定义配置 +// ============================================================================ + +/// Claude API 解析配置 +pub const CLAUDE_PARSER_CONFIG: UsageParserConfig = UsageParserConfig { + stream_parser: TokenUsage::from_claude_stream_events, + response_parser: TokenUsage::from_claude_response, + model_extractor: claude_model_extractor, + app_type_str: "claude", +}; + +/// OpenAI Chat Completions API 解析配置(用于 Codex /v1/chat/completions) +pub const OPENAI_PARSER_CONFIG: UsageParserConfig = UsageParserConfig { + stream_parser: TokenUsage::from_openai_stream_events, + response_parser: TokenUsage::from_openai_response, + model_extractor: openai_model_extractor, + app_type_str: "codex", +}; + +/// Codex Responses API 解析配置(用于 /v1/responses) +pub const CODEX_PARSER_CONFIG: UsageParserConfig = UsageParserConfig { + stream_parser: TokenUsage::from_codex_stream_events, + response_parser: TokenUsage::from_codex_response, + model_extractor: codex_model_extractor, + app_type_str: "codex", +}; + +/// Gemini API 解析配置 +pub const GEMINI_PARSER_CONFIG: UsageParserConfig = UsageParserConfig { + stream_parser: TokenUsage::from_gemini_stream_chunks, + response_parser: TokenUsage::from_gemini_response, + model_extractor: gemini_model_extractor, + app_type_str: "gemini", +}; + +// ============================================================================ +// Handler 配置(预留,用于进一步简化) +// ============================================================================ + +/// Handler 基础配置 +/// +/// 预留结构,可用于进一步统一各 handler 的配置 +#[allow(dead_code)] +#[derive(Clone)] +pub struct HandlerConfig { + /// 应用类型 + pub app_type: AppType, + /// 日志标签 + pub tag: &'static str, + /// 应用类型字符串 + pub app_type_str: &'static str, + /// 使用量解析配置 + pub parser_config: &'static UsageParserConfig, +} + +/// Claude Handler 配置 +#[allow(dead_code)] +pub const CLAUDE_HANDLER_CONFIG: HandlerConfig = HandlerConfig { + app_type: AppType::Claude, + tag: "Claude", + app_type_str: "claude", + parser_config: &CLAUDE_PARSER_CONFIG, +}; + +/// Codex Chat Completions Handler 配置 +#[allow(dead_code)] +pub const CODEX_CHAT_HANDLER_CONFIG: HandlerConfig = HandlerConfig { + app_type: AppType::Codex, + tag: "Codex", + app_type_str: "codex", + parser_config: &OPENAI_PARSER_CONFIG, +}; + +/// Codex Responses Handler 配置 +#[allow(dead_code)] +pub const CODEX_RESPONSES_HANDLER_CONFIG: HandlerConfig = HandlerConfig { + app_type: AppType::Codex, + tag: "Codex", + app_type_str: "codex", + parser_config: &CODEX_PARSER_CONFIG, +}; + +/// Gemini Handler 配置 +#[allow(dead_code)] +pub const GEMINI_HANDLER_CONFIG: HandlerConfig = HandlerConfig { + app_type: AppType::Gemini, + tag: "Gemini", + app_type_str: "gemini", + parser_config: &GEMINI_PARSER_CONFIG, +}; diff --git a/src-tauri/src/proxy/handler_context.rs b/src-tauri/src/proxy/handler_context.rs new file mode 100644 index 00000000..08fbbf26 --- /dev/null +++ b/src-tauri/src/proxy/handler_context.rs @@ -0,0 +1,128 @@ +//! 请求上下文模块 +//! +//! 提供请求生命周期的上下文管理,封装通用初始化逻辑 + +use crate::app_config::AppType; +use crate::provider::Provider; +use crate::proxy::{ + forwarder::RequestForwarder, router::ProviderRouter, server::ProxyState, types::ProxyConfig, + ProxyError, +}; +use std::time::Instant; + +/// 请求上下文 +/// +/// 贯穿整个请求生命周期,包含: +/// - 计时信息 +/// - 代理配置 +/// - 选中的 Provider +/// - 请求模型名称 +/// - 日志标签 +pub struct RequestContext { + /// 请求开始时间 + pub start_time: Instant, + /// 代理配置快照 + pub config: ProxyConfig, + /// 选中的 Provider + pub provider: Provider, + /// 请求中的模型名称 + pub request_model: String, + /// 日志标签(如 "Claude"、"Codex"、"Gemini") + pub tag: &'static str, + /// 应用类型字符串(如 "claude"、"codex"、"gemini") + pub app_type_str: &'static str, + /// 应用类型(预留,目前通过 app_type_str 使用) + #[allow(dead_code)] + pub app_type: AppType, +} + +impl RequestContext { + /// 创建请求上下文 + /// + /// # Arguments + /// * `state` - 代理服务器状态 + /// * `body` - 请求体 JSON + /// * `app_type` - 应用类型 + /// * `tag` - 日志标签 + /// * `app_type_str` - 应用类型字符串 + /// + /// # Errors + /// 返回 `ProxyError` 如果 Provider 选择失败 + pub async fn new( + state: &ProxyState, + body: &serde_json::Value, + app_type: AppType, + tag: &'static str, + app_type_str: &'static str, + ) -> Result { + let start_time = Instant::now(); + let config = state.config.read().await.clone(); + + // 从请求体提取模型名称 + let request_model = body + .get("model") + .and_then(|m| m.as_str()) + .unwrap_or("unknown") + .to_string(); + + // Provider 选择 + let router = ProviderRouter::new(state.db.clone()); + let provider = router.select_provider(&app_type, &[]).await?; + + log::info!( + "[{}] Provider: {}, model: {}", + tag, + provider.name, + request_model + ); + + Ok(Self { + start_time, + config, + provider, + request_model, + tag, + app_type_str, + app_type, + }) + } + + /// 从 URI 提取模型名称(Gemini 专用) + /// + /// Gemini API 的模型名称在 URI 中,格式如: + /// `/v1beta/models/gemini-pro:generateContent` + pub fn with_model_from_uri(mut self, uri: &axum::http::Uri) -> Self { + let endpoint = uri + .path_and_query() + .map(|pq| pq.as_str()) + .unwrap_or(uri.path()); + + self.request_model = endpoint + .split('/') + .find(|s| s.starts_with("models/")) + .and_then(|s| s.strip_prefix("models/")) + .map(|s| s.split(':').next().unwrap_or(s)) + .unwrap_or("unknown") + .to_string(); + + log::info!("[{}] 从 URI 提取模型: {}", self.tag, self.request_model); + self + } + + /// 创建 RequestForwarder + pub fn create_forwarder(&self, state: &ProxyState) -> RequestForwarder { + RequestForwarder::new( + state.db.clone(), + self.config.request_timeout, + self.config.max_retries, + state.status.clone(), + state.current_providers.clone(), + ) + } + + /// 计算请求延迟(毫秒) + #[inline] + pub fn latency_ms(&self) -> u64 { + self.start_time.elapsed().as_millis() as u64 + } +} diff --git a/src-tauri/src/proxy/handlers.rs b/src-tauri/src/proxy/handlers.rs index 679fd2a7..6e4704ca 100644 --- a/src-tauri/src/proxy/handlers.rs +++ b/src-tauri/src/proxy/handlers.rs @@ -1,87 +1,369 @@ //! 请求处理器 //! //! 处理各种API端点的HTTP请求 +//! +//! 重构后的结构: +//! - 通用逻辑提取到 `handler_context` 和 `response_processor` 模块 +//! - 各 handler 只保留独特的业务逻辑 +//! - Claude 的格式转换逻辑保留在此文件(独有功能) use super::{ - forwarder::RequestForwarder, - providers::{get_adapter, transform, ProviderType}, + handler_config::{ + CLAUDE_PARSER_CONFIG, CODEX_PARSER_CONFIG, GEMINI_PARSER_CONFIG, OPENAI_PARSER_CONFIG, + }, + handler_context::RequestContext, + providers::{get_adapter, streaming::create_anthropic_sse_stream, transform}, + response_processor::{ + create_logged_passthrough_stream, process_response, SseUsageCollector, + }, server::ProxyState, - session::ProxySession, types::*, - usage::{logger::UsageLogger, parser::TokenUsage}, + usage::parser::TokenUsage, ProxyError, }; use crate::app_config::AppType; use axum::{extract::State, http::StatusCode, response::IntoResponse, Json}; -use bytes::Bytes; -use futures::stream::{Stream, StreamExt}; use rust_decimal::Decimal; use serde_json::{json, Value}; -use std::{ - str::FromStr, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, -}; -use tokio::sync::Mutex; +use std::str::FromStr; -/// 记录请求使用量(带 ProxySession 支持) -#[allow(dead_code, clippy::too_many_arguments)] -async fn log_usage_with_session( - state: &ProxyState, - session: &ProxySession, - provider_id: &str, - app_type: &str, - usage: TokenUsage, - latency_ms: u64, - first_token_ms: Option, - status_code: u16, - provider_type: Option<&ProviderType>, -) { - let logger = UsageLogger::new(&state.db); +// ============================================================================ +// 健康检查和状态查询(简单端点) +// ============================================================================ - // 获取 provider 的 cost_multiplier - let multiplier = match state.db.get_provider_by_id(provider_id, app_type) { - Ok(Some(p)) => { - if let Some(meta) = p.meta { - if let Some(cm) = meta.cost_multiplier { - Decimal::from_str(&cm).unwrap_or(Decimal::from(1)) - } else { - Decimal::from(1) - } - } else { - Decimal::from(1) - } - } - _ => Decimal::from(1), - }; - - let model = session - .model - .clone() - .unwrap_or_else(|| "unknown".to_string()); - let provider_type_str = provider_type.map(|pt| pt.as_str().to_string()); - - if let Err(e) = logger.log_with_calculation( - session.session_id.clone(), - provider_id.to_string(), - app_type.to_string(), - model, - usage, - multiplier, - latency_ms, - first_token_ms, - status_code, - Some(session.session_id.clone()), - provider_type_str, - session.is_streaming, - ) { - log::warn!("记录使用量失败: {e}"); - } +/// 健康检查 +pub async fn health_check() -> (StatusCode, Json) { + ( + StatusCode::OK, + Json(json!({ + "status": "healthy", + "timestamp": chrono::Utc::now().to_rfc3339(), + })), + ) } -/// 记录请求使用量(兼容旧接口) +/// 获取服务状态 +pub async fn get_status(State(state): State) -> Result, ProxyError> { + let status = state.status.read().await.clone(); + Ok(Json(status)) +} + +// ============================================================================ +// Claude API 处理器(包含格式转换逻辑) +// ============================================================================ + +/// 处理 /v1/messages 请求(Claude API) +/// +/// Claude 处理器包含独特的格式转换逻辑: +/// - 当使用 OpenRouter 等中转服务时,需要将 Anthropic 格式转换为 OpenAI 格式 +/// - 响应需要从 OpenAI 格式转回 Anthropic 格式 +pub async fn handle_messages( + State(state): State, + headers: axum::http::HeaderMap, + Json(body): Json, +) -> Result { + let ctx = + RequestContext::new(&state, &body, AppType::Claude, "Claude", "claude").await?; + + // 检查是否需要格式转换(OpenRouter 等中转服务) + let adapter = get_adapter(&AppType::Claude); + let needs_transform = adapter.needs_transform(&ctx.provider); + + let is_stream = body + .get("stream") + .and_then(|s| s.as_bool()) + .unwrap_or(false); + + log::info!( + "[Claude] Provider: {}, needs_transform: {}, is_stream: {}", + ctx.provider.name, + needs_transform, + is_stream + ); + + // 转发请求 + let forwarder = ctx.create_forwarder(&state); + let response = forwarder + .forward_with_retry(&AppType::Claude, "/v1/messages", body.clone(), headers) + .await?; + + let status = response.status(); + log::info!("[Claude] 上游响应状态: {status}"); + + // Claude 特有:格式转换处理 + if needs_transform { + return handle_claude_transform(response, &ctx, &state, &body, is_stream).await; + } + + // 通用响应处理(透传模式) + process_response(response, &ctx, &state, &CLAUDE_PARSER_CONFIG).await +} + +/// Claude 格式转换处理(独有逻辑) +/// +/// 处理 OpenRouter 等需要格式转换的中转服务 +async fn handle_claude_transform( + response: reqwest::Response, + ctx: &RequestContext, + state: &ProxyState, + _original_body: &Value, + is_stream: bool, +) -> Result { + let status = response.status(); + + if is_stream { + // 流式响应转换 (OpenAI SSE → Anthropic SSE) + log::info!("[Claude] 开始流式响应转换 (OpenAI SSE → Anthropic SSE)"); + + let stream = response.bytes_stream(); + let sse_stream = create_anthropic_sse_stream(stream); + + // 创建使用量收集器 + let usage_collector = { + let state = state.clone(); + let provider_id = ctx.provider.id.clone(); + let model = ctx.request_model.clone(); + let status_code = status.as_u16(); + let start_time = ctx.start_time; + + SseUsageCollector::new(start_time, move |events, first_token_ms| { + if let Some(usage) = TokenUsage::from_claude_stream_events(&events) { + let latency_ms = start_time.elapsed().as_millis() as u64; + let state = state.clone(); + let provider_id = provider_id.clone(); + let model = model.clone(); + + tokio::spawn(async move { + log_usage( + &state, + &provider_id, + "claude", + &model, + usage, + latency_ms, + first_token_ms, + true, + status_code, + ) + .await; + }); + } else { + log::debug!("[Claude] OpenRouter 流式响应缺少 usage 统计,跳过消费记录"); + } + }) + }; + + let logged_stream = create_logged_passthrough_stream( + sse_stream, + "Claude/OpenRouter", + Some(usage_collector), + ); + + let mut headers = axum::http::HeaderMap::new(); + headers.insert( + "Content-Type", + axum::http::HeaderValue::from_static("text/event-stream"), + ); + headers.insert( + "Cache-Control", + axum::http::HeaderValue::from_static("no-cache"), + ); + headers.insert( + "Connection", + axum::http::HeaderValue::from_static("keep-alive"), + ); + + let body = axum::body::Body::from_stream(logged_stream); + log::info!("[Claude] ====== 请求结束 (流式转换) ======"); + return Ok((headers, body).into_response()); + } + + // 非流式响应转换 (OpenAI → Anthropic) + log::info!("[Claude] 开始转换响应 (OpenAI → Anthropic)"); + + let response_headers = response.headers().clone(); + + let body_bytes = response.bytes().await.map_err(|e| { + log::error!("[Claude] 读取响应体失败: {e}"); + ProxyError::ForwardFailed(format!("Failed to read response body: {e}")) + })?; + + let body_str = String::from_utf8_lossy(&body_bytes); + log::info!("[Claude] OpenAI 响应长度: {} bytes", body_bytes.len()); + log::debug!("[Claude] OpenAI 原始响应: {body_str}"); + + let openai_response: Value = serde_json::from_slice(&body_bytes).map_err(|e| { + log::error!("[Claude] 解析 OpenAI 响应失败: {e}, body: {body_str}"); + ProxyError::TransformError(format!("Failed to parse OpenAI response: {e}")) + })?; + + log::info!("[Claude] 解析 OpenAI 响应成功"); + log::info!( + "[Claude] <<< OpenAI 响应 JSON:\n{}", + serde_json::to_string_pretty(&openai_response).unwrap_or_default() + ); + + let anthropic_response = transform::openai_to_anthropic(openai_response).map_err(|e| { + log::error!("[Claude] 转换响应失败: {e}"); + e + })?; + + log::info!("[Claude] 转换响应成功"); + log::info!( + "[Claude] <<< Anthropic 响应 JSON:\n{}", + serde_json::to_string_pretty(&anthropic_response).unwrap_or_default() + ); + + // 记录使用量 + if let Some(usage) = TokenUsage::from_claude_response(&anthropic_response) { + let model = anthropic_response + .get("model") + .and_then(|m| m.as_str()) + .unwrap_or("unknown"); + let latency_ms = ctx.latency_ms(); + + tokio::spawn({ + let state = state.clone(); + let provider_id = ctx.provider.id.clone(); + let model = model.to_string(); + async move { + log_usage( + &state, + &provider_id, + "claude", + &model, + usage, + latency_ms, + None, + false, + status.as_u16(), + ) + .await; + } + }); + } + + log::info!("[Claude] ====== 请求结束 ======"); + + // 构建响应 + let mut builder = axum::response::Response::builder().status(status); + + for (key, value) in response_headers.iter() { + if key.as_str().to_lowercase() != "content-length" + && key.as_str().to_lowercase() != "transfer-encoding" + { + builder = builder.header(key, value); + } + } + + builder = builder.header("content-type", "application/json"); + + let response_body = serde_json::to_vec(&anthropic_response).map_err(|e| { + log::error!("[Claude] 序列化响应失败: {e}"); + ProxyError::TransformError(format!("Failed to serialize response: {e}")) + })?; + + log::info!( + "[Claude] 返回转换后的响应, 长度: {} bytes", + response_body.len() + ); + + let body = axum::body::Body::from(response_body); + Ok(builder.body(body).unwrap()) +} + +// ============================================================================ +// Codex API 处理器 +// ============================================================================ + +/// 处理 /v1/chat/completions 请求(OpenAI Chat Completions API - Codex CLI) +pub async fn handle_chat_completions( + State(state): State, + headers: axum::http::HeaderMap, + Json(body): Json, +) -> Result { + log::info!("[Codex] ====== /v1/chat/completions 请求开始 ======"); + + let ctx = RequestContext::new(&state, &body, AppType::Codex, "Codex", "codex").await?; + + let is_stream = body + .get("stream") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + log::info!( + "[Codex] 请求模型: {}, 流式: {}", + ctx.request_model, + is_stream + ); + + let forwarder = ctx.create_forwarder(&state); + let response = forwarder + .forward_with_retry(&AppType::Codex, "/v1/chat/completions", body, headers) + .await?; + + log::info!("[Codex] 上游响应状态: {}", response.status()); + + process_response(response, &ctx, &state, &OPENAI_PARSER_CONFIG).await +} + +/// 处理 /v1/responses 请求(OpenAI Responses API - Codex CLI 透传) +pub async fn handle_responses( + State(state): State, + headers: axum::http::HeaderMap, + Json(body): Json, +) -> Result { + let ctx = RequestContext::new(&state, &body, AppType::Codex, "Codex", "codex").await?; + + let forwarder = ctx.create_forwarder(&state); + let response = forwarder + .forward_with_retry(&AppType::Codex, "/v1/responses", body, headers) + .await?; + + log::info!("[Codex] 上游响应状态: {}", response.status()); + + process_response(response, &ctx, &state, &CODEX_PARSER_CONFIG).await +} + +// ============================================================================ +// Gemini API 处理器 +// ============================================================================ + +/// 处理 Gemini API 请求(透传,包括查询参数) +pub async fn handle_gemini( + State(state): State, + uri: axum::http::Uri, + headers: axum::http::HeaderMap, + Json(body): Json, +) -> Result { + // Gemini 的模型名称在 URI 中 + let ctx = RequestContext::new(&state, &body, AppType::Gemini, "Gemini", "gemini") + .await? + .with_model_from_uri(&uri); + + // 提取完整的路径和查询参数 + let endpoint = uri + .path_and_query() + .map(|pq| pq.as_str()) + .unwrap_or(uri.path()); + + log::info!("[Gemini] 请求端点: {}", endpoint); + + let forwarder = ctx.create_forwarder(&state); + let response = forwarder + .forward_with_retry(&AppType::Gemini, endpoint, body, headers) + .await?; + + log::info!("[Gemini] 上游响应状态: {}", response.status()); + + process_response(response, &ctx, &state, &GEMINI_PARSER_CONFIG).await +} + +// ============================================================================ +// 使用量记录(保留用于 Claude 转换逻辑) +// ============================================================================ + +/// 记录请求使用量 #[allow(clippy::too_many_arguments)] async fn log_usage( state: &ProxyState, @@ -94,6 +376,8 @@ async fn log_usage( is_streaming: bool, status_code: u16, ) { + use super::usage::logger::UsageLogger; + let logger = UsageLogger::new(&state.db); // 获取 provider 的 cost_multiplier @@ -131,999 +415,3 @@ async fn log_usage( log::warn!("记录使用量失败: {e}"); } } - -type UsageCallbackWithTiming = Arc, Option) + Send + Sync + 'static>; - -#[derive(Clone)] -struct SseUsageCollector { - inner: Arc, -} - -struct SseUsageCollectorInner { - events: Mutex>, - first_event_time: Mutex>, - start_time: std::time::Instant, - on_complete: UsageCallbackWithTiming, - finished: AtomicBool, -} - -impl SseUsageCollector { - fn new( - start_time: std::time::Instant, - callback: impl Fn(Vec, Option) + Send + Sync + 'static, - ) -> Self { - let on_complete: UsageCallbackWithTiming = Arc::new(callback); - Self { - inner: Arc::new(SseUsageCollectorInner { - events: Mutex::new(Vec::new()), - first_event_time: Mutex::new(None), - start_time, - on_complete, - finished: AtomicBool::new(false), - }), - } - } - - async fn push(&self, event: Value) { - // 记录首个事件时间 - { - let mut first_time = self.inner.first_event_time.lock().await; - if first_time.is_none() { - *first_time = Some(std::time::Instant::now()); - } - } - let mut events = self.inner.events.lock().await; - events.push(event); - } - - async fn finish(&self) { - if self.inner.finished.swap(true, Ordering::SeqCst) { - return; - } - - let events = { - let mut guard = self.inner.events.lock().await; - std::mem::take(&mut *guard) - }; - - let first_token_ms = { - let first_time = self.inner.first_event_time.lock().await; - first_time.map(|t| (t - self.inner.start_time).as_millis() as u64) - }; - - (self.inner.on_complete)(events, first_token_ms); - } -} - -/// 创建带日志记录的透传流 -fn create_logged_passthrough_stream( - stream: impl Stream> + Send + 'static, - tag: &'static str, - usage_collector: Option, -) -> impl Stream> + Send { - async_stream::stream! { - let mut buffer = String::new(); - let mut collector = usage_collector; - - tokio::pin!(stream); - - while let Some(chunk) = stream.next().await { - match chunk { - Ok(bytes) => { - let text = String::from_utf8_lossy(&bytes); - buffer.push_str(&text); - - // 尝试解析并记录完整的 SSE 事件 - while let Some(pos) = buffer.find("\n\n") { - let event_text = buffer[..pos].to_string(); - buffer = buffer[pos + 2..].to_string(); - - if !event_text.trim().is_empty() { - // 提取 data 部分并尝试解析为 JSON - for line in event_text.lines() { - if let Some(data) = line.strip_prefix("data: ") { - if data.trim() != "[DONE]" { - if let Ok(json_value) = serde_json::from_str::(data) { - if let Some(c) = &collector { - c.push(json_value.clone()).await; - } - log::info!( - "[{}] <<< SSE 事件:\n{}", - tag, - serde_json::to_string_pretty(&json_value).unwrap_or_else(|_| data.to_string()) - ); - } else { - log::info!("[{tag}] <<< SSE 数据: {data}"); - } - } else { - log::info!("[{tag}] <<< SSE: [DONE]"); - } - } - } - } - } - - yield Ok(bytes); - } - Err(e) => { - log::error!("[{tag}] 流错误: {e}"); - yield Err(std::io::Error::other(e.to_string())); - break; - } - } - } - - log::info!("[{}] ====== 流结束 ======", tag); - - if let Some(c) = collector.take() { - c.finish().await; - } - } -} - -/// 健康检查 -pub async fn health_check() -> (StatusCode, Json) { - ( - StatusCode::OK, - Json(json!({ - "status": "healthy", - "timestamp": chrono::Utc::now().to_rfc3339(), - })), - ) -} - -/// 获取服务状态 -pub async fn get_status(State(state): State) -> Result, ProxyError> { - let status = state.status.read().await.clone(); - Ok(Json(status)) -} - -/// 处理 /v1/messages 请求(Claude API) -pub async fn handle_messages( - State(state): State, - headers: axum::http::HeaderMap, - Json(body): Json, -) -> Result { - let start_time = std::time::Instant::now(); - - let config = state.config.read().await.clone(); - let request_model = body - .get("model") - .and_then(|m| m.as_str()) - .unwrap_or("unknown") - .to_string(); - - // 选择目标 Provider - let router = super::router::ProviderRouter::new(state.db.clone()); - let failed_ids = Vec::new(); - let provider = router - .select_provider(&AppType::Claude, &failed_ids) - .await?; - - // 检查是否需要转换(OpenRouter) - let adapter = get_adapter(&AppType::Claude); - let needs_transform = adapter.needs_transform(&provider); - - // 检查是否是流式请求 - let is_stream = body - .get("stream") - .and_then(|s| s.as_bool()) - .unwrap_or(false); - - log::info!( - "[Claude] Provider: {}, needs_transform: {}, is_stream: {}", - provider.name, - needs_transform, - is_stream - ); - - let forwarder = RequestForwarder::new( - state.db.clone(), - config.request_timeout, - config.max_retries, - state.status.clone(), - state.current_providers.clone(), - ); - - let response = forwarder - .forward_with_retry(&AppType::Claude, "/v1/messages", body, headers) - .await?; - - let status = response.status(); - log::info!("[Claude] 上游响应状态: {status}"); - - // 如果需要转换 - if needs_transform { - if is_stream { - // 流式响应转换 - log::info!("[Claude] 开始流式响应转换 (OpenAI SSE → Anthropic SSE)"); - - let stream = response.bytes_stream(); - let sse_stream = super::providers::streaming::create_anthropic_sse_stream(stream); - - let usage_collector = { - let state = state.clone(); - let provider_id = provider.id.clone(); - let model = request_model.clone(); - let status_code = status.as_u16(); - let start_time_clone = start_time; - SseUsageCollector::new(start_time, move |events, first_token_ms| { - if let Some(usage) = TokenUsage::from_claude_stream_events(&events) { - let latency_ms = start_time_clone.elapsed().as_millis() as u64; - let state = state.clone(); - let provider_id = provider_id.clone(); - let model = model.clone(); - tokio::spawn(async move { - log_usage( - &state, - &provider_id, - "claude", - &model, - usage, - latency_ms, - first_token_ms, - true, // is_streaming - status_code, - ) - .await; - }); - } else { - log::debug!("[Claude] OpenRouter 流式响应缺少 usage 统计,跳过消费记录"); - } - }) - }; - - let logged_stream = create_logged_passthrough_stream( - sse_stream, - "Claude/OpenRouter", - Some(usage_collector), - ); - - let mut headers = axum::http::HeaderMap::new(); - headers.insert( - "Content-Type", - axum::http::HeaderValue::from_static("text/event-stream"), - ); - headers.insert( - "Cache-Control", - axum::http::HeaderValue::from_static("no-cache"), - ); - headers.insert( - "Connection", - axum::http::HeaderValue::from_static("keep-alive"), - ); - - let body = axum::body::Body::from_stream(logged_stream); - log::info!("[Claude] ====== 请求结束 (流式转换) ======"); - return Ok((headers, body).into_response()); - } else { - // 非流式响应转换 - log::info!("[Claude] 开始转换响应 (OpenAI → Anthropic)"); - - let response_headers = response.headers().clone(); - - // 读取响应体 - let body_bytes = response.bytes().await.map_err(|e| { - log::error!("[Claude] 读取响应体失败: {e}"); - ProxyError::ForwardFailed(format!("Failed to read response body: {e}")) - })?; - - let body_str = String::from_utf8_lossy(&body_bytes); - log::info!("[Claude] OpenAI 响应长度: {} bytes", body_bytes.len()); - log::debug!("[Claude] OpenAI 原始响应: {body_str}"); - - // 解析并转换 - let openai_response: Value = serde_json::from_slice(&body_bytes).map_err(|e| { - log::error!("[Claude] 解析 OpenAI 响应失败: {e}, body: {body_str}"); - ProxyError::TransformError(format!("Failed to parse OpenAI response: {e}")) - })?; - - log::info!("[Claude] 解析 OpenAI 响应成功"); - log::info!( - "[Claude] <<< OpenAI 响应 JSON:\n{}", - serde_json::to_string_pretty(&openai_response).unwrap_or_default() - ); - - let anthropic_response = - transform::openai_to_anthropic(openai_response).map_err(|e| { - log::error!("[Claude] 转换响应失败: {e}"); - e - })?; - - log::info!("[Claude] 转换响应成功"); - log::info!( - "[Claude] <<< Anthropic 响应 JSON:\n{}", - serde_json::to_string_pretty(&anthropic_response).unwrap_or_default() - ); - - // 记录使用量 - if let Some(usage) = TokenUsage::from_claude_response(&anthropic_response) { - let model = anthropic_response - .get("model") - .and_then(|m| m.as_str()) - .unwrap_or("unknown"); - let latency_ms = start_time.elapsed().as_millis() as u64; - - tokio::spawn({ - let state = state.clone(); - let provider_id = provider.id.clone(); - let model = model.to_string(); - async move { - log_usage( - &state, - &provider_id, - "claude", - &model, - usage, - latency_ms, - None, - false, - status.as_u16(), - ) - .await; - } - }); - } - - log::info!("[Claude] ====== 请求结束 ======"); - - // 构建响应 - let mut builder = axum::response::Response::builder().status(status); - - // 复制响应头(排除 content-length,因为内容已改变) - for (key, value) in response_headers.iter() { - if key.as_str().to_lowercase() != "content-length" - && key.as_str().to_lowercase() != "transfer-encoding" - { - builder = builder.header(key, value); - } - } - - builder = builder.header("content-type", "application/json"); - - let response_body = serde_json::to_vec(&anthropic_response).map_err(|e| { - log::error!("[Claude] 序列化响应失败: {e}"); - ProxyError::TransformError(format!("Failed to serialize response: {e}")) - })?; - - log::info!( - "[Claude] 返回转换后的响应, 长度: {} bytes", - response_body.len() - ); - - let body = axum::body::Body::from(response_body); - return Ok(builder.body(body).unwrap()); - } - } - - // 透传响应(直连 Anthropic) - log::info!("[Claude] 透传响应模式"); - - // 检查是否流式响应 - let content_type = response - .headers() - .get("content-type") - .and_then(|v| v.to_str().ok()) - .unwrap_or(""); - let is_sse = content_type.contains("text/event-stream"); - - if is_sse { - // 流式透传:使用包装流记录 SSE 事件 - log::info!("[Claude] 流式透传响应 (SSE)"); - let mut builder = axum::response::Response::builder().status(status); - - for (key, value) in response.headers() { - builder = builder.header(key, value); - } - - let stream = response - .bytes_stream() - .map(|chunk| chunk.map_err(|e| std::io::Error::other(e.to_string()))); - let usage_collector = { - let state = state.clone(); - let provider_id = provider.id.clone(); - let model = request_model.clone(); - let status_code = status.as_u16(); - let start_time_clone = start_time; - SseUsageCollector::new(start_time, move |events, first_token_ms| { - if let Some(usage) = TokenUsage::from_claude_stream_events(&events) { - let latency_ms = start_time_clone.elapsed().as_millis() as u64; - let state = state.clone(); - let provider_id = provider_id.clone(); - let model = model.clone(); - tokio::spawn(async move { - log_usage( - &state, - &provider_id, - "claude", - &model, - usage, - latency_ms, - first_token_ms, - true, - status_code, - ) - .await; - }); - } else { - log::debug!("[Claude] 流式响应缺少 usage 统计,跳过消费记录"); - } - }) - }; - let logged_stream = - create_logged_passthrough_stream(stream, "Claude", Some(usage_collector)); - - let body = axum::body::Body::from_stream(logged_stream); - log::info!("[Claude] ====== 请求结束 (流式) ======"); - Ok(builder.body(body).unwrap()) - } else { - // 非流式透传:读取完整响应并记录 - let response_headers = response.headers().clone(); - let status = response.status(); - - let body_bytes = response.bytes().await.map_err(|e| { - log::error!("[Claude] 读取透传响应失败: {e}"); - ProxyError::ForwardFailed(format!("Failed to read response body: {e}")) - })?; - - // 记录响应 JSON - if let Ok(json_value) = serde_json::from_slice::(&body_bytes) { - log::info!( - "[Claude] <<< Anthropic 透传响应 JSON:\n{}", - serde_json::to_string_pretty(&json_value).unwrap_or_default() - ); - - // 记录使用量 - if let Some(usage) = TokenUsage::from_claude_response(&json_value) { - let model = json_value - .get("model") - .and_then(|m| m.as_str()) - .unwrap_or("unknown"); - let latency_ms = start_time.elapsed().as_millis() as u64; - - tokio::spawn({ - let state = state.clone(); - let provider_id = provider.id.clone(); - let model = model.to_string(); - async move { - log_usage( - &state, - &provider_id, - "claude", - &model, - usage, - latency_ms, - None, - false, - status.as_u16(), - ) - .await; - } - }); - } - } else { - log::info!( - "[Claude] <<< 透传响应 (非 JSON): {} bytes", - body_bytes.len() - ); - } - log::info!("[Claude] ====== 请求结束 ======"); - - let mut builder = axum::response::Response::builder().status(status); - for (key, value) in response_headers.iter() { - builder = builder.header(key, value); - } - - let body = axum::body::Body::from(body_bytes); - Ok(builder.body(body).unwrap()) - } -} - -/// 处理 Gemini API 请求(透传,包括查询参数) -pub async fn handle_gemini( - State(state): State, - uri: axum::http::Uri, - headers: axum::http::HeaderMap, - Json(body): Json, -) -> Result { - let start_time = std::time::Instant::now(); - - let config = state.config.read().await.clone(); - - // 选择目标 Provider - let router = super::router::ProviderRouter::new(state.db.clone()); - let failed_ids = Vec::new(); - let provider = router - .select_provider(&AppType::Gemini, &failed_ids) - .await?; - - let forwarder = RequestForwarder::new( - state.db.clone(), - config.request_timeout, - config.max_retries, - state.status.clone(), - state.current_providers.clone(), - ); - - // 提取完整的路径和查询参数 - let endpoint = uri - .path_and_query() - .map(|pq| pq.as_str()) - .unwrap_or(uri.path()); - let gemini_model = endpoint - .split('/') - .find(|s| s.starts_with("models/")) - .and_then(|s| s.strip_prefix("models/")) - .map(|s| s.split(':').next().unwrap_or(s)) - .unwrap_or("unknown") - .to_string(); - - log::info!("[Gemini] 请求端点: {endpoint}"); - - let response = forwarder - .forward_with_retry(&AppType::Gemini, endpoint, body, headers) - .await?; - - let status = response.status(); - log::info!("[Gemini] 上游响应状态: {status}"); - - // 检查是否流式响应 - let content_type = response - .headers() - .get("content-type") - .and_then(|v| v.to_str().ok()) - .unwrap_or(""); - let is_sse = content_type.contains("text/event-stream"); - - if is_sse { - // 流式透传 - log::info!("[Gemini] 流式透传响应 (SSE)"); - let mut builder = axum::response::Response::builder().status(status); - - for (key, value) in response.headers() { - builder = builder.header(key, value); - } - - let stream = response - .bytes_stream() - .map(|chunk| chunk.map_err(|e| std::io::Error::other(e.to_string()))); - let usage_collector = { - let state = state.clone(); - let provider_id = provider.id.clone(); - let fallback_model = gemini_model.clone(); - let status_code = status.as_u16(); - let start_time_clone = start_time; - SseUsageCollector::new(start_time, move |events, first_token_ms| { - if let Some(usage) = TokenUsage::from_gemini_stream_chunks(&events) { - // 优先使用响应中的实际模型名称,否则使用从 URI 提取的模型名称 - let model = usage - .model - .clone() - .unwrap_or_else(|| fallback_model.clone()); - let latency_ms = start_time_clone.elapsed().as_millis() as u64; - let state = state.clone(); - let provider_id = provider_id.clone(); - tokio::spawn(async move { - log_usage( - &state, - &provider_id, - "gemini", - &model, - usage, - latency_ms, - first_token_ms, - true, - status_code, - ) - .await; - }); - } else { - log::debug!("[Gemini] 流式响应缺少 usage 统计,跳过消费记录"); - } - }) - }; - let logged_stream = - create_logged_passthrough_stream(stream, "Gemini", Some(usage_collector)); - - let body = axum::body::Body::from_stream(logged_stream); - Ok(builder.body(body).unwrap()) - } else { - // 非流式透传 - let response_headers = response.headers().clone(); - let status = response.status(); - - let body_bytes = response.bytes().await.map_err(|e| { - log::error!("[Gemini] 读取响应失败: {e}"); - ProxyError::ForwardFailed(format!("Failed to read response body: {e}")) - })?; - - // 记录响应 JSON - if let Ok(json_value) = serde_json::from_slice::(&body_bytes) { - log::info!( - "[Gemini] <<< 响应 JSON:\n{}", - serde_json::to_string_pretty(&json_value).unwrap_or_default() - ); - - // 记录使用量 - if let Some(usage) = TokenUsage::from_gemini_response(&json_value) { - // 优先使用响应中的实际模型名称,否则使用从 URI 提取的模型名称 - let model = usage.model.clone().unwrap_or_else(|| gemini_model.clone()); - let latency_ms = start_time.elapsed().as_millis() as u64; - tokio::spawn({ - let state = state.clone(); - let provider_id = provider.id.clone(); - async move { - log_usage( - &state, - &provider_id, - "gemini", - &model, - usage, - latency_ms, - None, - false, - status.as_u16(), - ) - .await; - } - }); - } - } else { - log::info!("[Gemini] <<< 响应 (非 JSON): {} bytes", body_bytes.len()); - } - log::info!("[Gemini] ====== 请求结束 ======"); - - let mut builder = axum::response::Response::builder().status(status); - for (key, value) in response_headers.iter() { - builder = builder.header(key, value); - } - - let body = axum::body::Body::from(body_bytes); - Ok(builder.body(body).unwrap()) - } -} - -/// 处理 /v1/responses 请求(OpenAI Responses API - Codex CLI 透传) -pub async fn handle_responses( - State(state): State, - headers: axum::http::HeaderMap, - Json(body): Json, -) -> Result { - let start_time = std::time::Instant::now(); - - let config = state.config.read().await.clone(); - let request_model = body - .get("model") - .and_then(|m| m.as_str()) - .unwrap_or("unknown") - .to_string(); - - // 选择目标 Provider - let router = super::router::ProviderRouter::new(state.db.clone()); - let failed_ids = Vec::new(); - let provider = router.select_provider(&AppType::Codex, &failed_ids).await?; - - let forwarder = RequestForwarder::new( - state.db.clone(), - config.request_timeout, - config.max_retries, - state.status.clone(), - state.current_providers.clone(), - ); - - let response = forwarder - .forward_with_retry(&AppType::Codex, "/v1/responses", body, headers) - .await?; - - let status = response.status(); - log::info!("[Codex] 上游响应状态: {status}"); - - // 检查是否流式响应 - let content_type = response - .headers() - .get("content-type") - .and_then(|v| v.to_str().ok()) - .unwrap_or(""); - let is_sse = content_type.contains("text/event-stream"); - - if is_sse { - // 流式透传 - log::info!("[Codex] 流式透传响应 (SSE)"); - let mut builder = axum::response::Response::builder().status(status); - - for (key, value) in response.headers() { - builder = builder.header(key, value); - } - - let stream = response - .bytes_stream() - .map(|chunk| chunk.map_err(|e| std::io::Error::other(e.to_string()))); - let usage_collector = { - let state = state.clone(); - let provider_id = provider.id.clone(); - let request_model = request_model.clone(); - let status_code = status.as_u16(); - let start_time_clone = start_time; - SseUsageCollector::new(start_time, move |events, first_token_ms| { - if let Some(usage) = TokenUsage::from_codex_stream_events(&events) { - // 尝试从事件中提取模型,回退到请求模型 - let model = events - .iter() - .find_map(|e| { - if e.get("type")?.as_str()? == "response.completed" { - e.get("response")?.get("model")?.as_str() - } else { - None - } - }) - .unwrap_or(&request_model) - .to_string(); - let latency_ms = start_time_clone.elapsed().as_millis() as u64; - - let state = state.clone(); - let provider_id = provider_id.clone(); - tokio::spawn(async move { - log_usage( - &state, - &provider_id, - "codex", - &model, - usage, - latency_ms, - first_token_ms, - true, - status_code, - ) - .await; - }); - } else { - log::debug!("[Codex] 流式响应缺少 usage 统计,跳过消费记录"); - } - }) - }; - let logged_stream = - create_logged_passthrough_stream(stream, "Codex", Some(usage_collector)); - - let body = axum::body::Body::from_stream(logged_stream); - Ok(builder.body(body).unwrap()) - } else { - // 非流式透传 - let response_headers = response.headers().clone(); - let status = response.status(); - - let body_bytes = response.bytes().await.map_err(|e| { - log::error!("[Codex] 读取响应失败: {e}"); - ProxyError::ForwardFailed(format!("Failed to read response body: {e}")) - })?; - - // 记录响应 JSON - if let Ok(json_value) = serde_json::from_slice::(&body_bytes) { - log::info!( - "[Codex] <<< 响应 JSON:\n{}", - serde_json::to_string_pretty(&json_value).unwrap_or_default() - ); - - // 记录使用量 - if let Some(usage) = TokenUsage::from_codex_response(&json_value) { - let model = json_value - .get("model") - .and_then(|m| m.as_str()) - .unwrap_or("unknown"); - let latency_ms = start_time.elapsed().as_millis() as u64; - - log::info!( - "[Codex] 解析到 usage: input={}, output={}", - usage.input_tokens, - usage.output_tokens - ); - - tokio::spawn({ - let state = state.clone(); - let provider_id = provider.id.clone(); - let model = model.to_string(); - async move { - log_usage( - &state, - &provider_id, - "codex", - &model, - usage, - latency_ms, - None, - false, - status.as_u16(), - ) - .await; - } - }); - } else { - log::warn!("[Codex] 未能解析 usage 信息,跳过记录"); - } - } else { - log::info!("[Codex] <<< 响应 (非 JSON): {} bytes", body_bytes.len()); - } - log::info!("[Codex] ====== 请求结束 ======"); - - let mut builder = axum::response::Response::builder().status(status); - for (key, value) in response_headers.iter() { - builder = builder.header(key, value); - } - - let body = axum::body::Body::from(body_bytes); - Ok(builder.body(body).unwrap()) - } -} - -/// 处理 /v1/chat/completions 请求(OpenAI Chat Completions API - Codex CLI) -pub async fn handle_chat_completions( - State(state): State, - headers: axum::http::HeaderMap, - Json(body): Json, -) -> Result { - let start_time = std::time::Instant::now(); - log::info!("[Codex] ====== /v1/chat/completions 请求开始 ======"); - - let config = state.config.read().await.clone(); - let request_model = body - .get("model") - .and_then(|m| m.as_str()) - .unwrap_or("unknown") - .to_string(); - let is_stream = body - .get("stream") - .and_then(|v| v.as_bool()) - .unwrap_or(false); - - log::info!("[Codex] 请求模型: {request_model}, 流式: {is_stream}"); - - // 选择目标 Provider - let router = super::router::ProviderRouter::new(state.db.clone()); - let failed_ids = Vec::new(); - let provider = router.select_provider(&AppType::Codex, &failed_ids).await?; - - log::info!("[Codex] 选择 Provider: {}", provider.id); - - let forwarder = RequestForwarder::new( - state.db.clone(), - config.request_timeout, - config.max_retries, - state.status.clone(), - state.current_providers.clone(), - ); - - let response = forwarder - .forward_with_retry(&AppType::Codex, "/v1/chat/completions", body, headers) - .await?; - - let status = response.status(); - log::info!("[Codex] 上游响应状态: {status}"); - - // 检查是否流式响应 - let content_type = response - .headers() - .get("content-type") - .and_then(|v| v.to_str().ok()) - .unwrap_or(""); - let is_sse = content_type.contains("text/event-stream"); - - if is_sse { - // 流式透传 - log::info!("[Codex] 流式透传响应 (SSE)"); - let mut builder = axum::response::Response::builder().status(status); - - for (key, value) in response.headers() { - builder = builder.header(key, value); - } - - let stream = response - .bytes_stream() - .map(|chunk| chunk.map_err(|e| std::io::Error::other(e.to_string()))); - - let usage_collector = { - let state = state.clone(); - let provider_id = provider.id.clone(); - let request_model = request_model.clone(); - let status_code = status.as_u16(); - let start_time_clone = start_time; - SseUsageCollector::new(start_time, move |events, first_token_ms| { - if let Some(usage) = TokenUsage::from_openai_stream_events(&events) { - let model = events - .iter() - .find_map(|e| e.get("model")?.as_str()) - .unwrap_or(&request_model) - .to_string(); - let latency_ms = start_time_clone.elapsed().as_millis() as u64; - - let state = state.clone(); - let provider_id = provider_id.clone(); - tokio::spawn(async move { - log_usage( - &state, - &provider_id, - "codex", - &model, - usage, - latency_ms, - first_token_ms, - true, - status_code, - ) - .await; - }); - } else { - log::debug!("[Codex] 流式响应缺少 usage 统计,跳过消费记录"); - } - }) - }; - let logged_stream = - create_logged_passthrough_stream(stream, "Codex", Some(usage_collector)); - - let body = axum::body::Body::from_stream(logged_stream); - Ok(builder.body(body).unwrap()) - } else { - // 非流式透传 - let response_headers = response.headers().clone(); - let status = response.status(); - - let body_bytes = response.bytes().await.map_err(|e| { - log::error!("[Codex] 读取响应失败: {e}"); - ProxyError::ForwardFailed(format!("Failed to read response body: {e}")) - })?; - - // 记录响应 JSON - if let Ok(json_value) = serde_json::from_slice::(&body_bytes) { - log::info!( - "[Codex] <<< 响应 JSON:\n{}", - serde_json::to_string_pretty(&json_value).unwrap_or_default() - ); - - // 记录使用量 (OpenAI 格式: prompt_tokens, completion_tokens) - if let Some(usage) = TokenUsage::from_openai_response(&json_value) { - let model = json_value - .get("model") - .and_then(|m| m.as_str()) - .unwrap_or("unknown"); - let latency_ms = start_time.elapsed().as_millis() as u64; - - log::info!( - "[Codex] 解析到 usage: input={}, output={}", - usage.input_tokens, - usage.output_tokens - ); - - tokio::spawn({ - let state = state.clone(); - let provider_id = provider.id.clone(); - let model = model.to_string(); - async move { - log_usage( - &state, - &provider_id, - "codex", - &model, - usage, - latency_ms, - None, - false, - status.as_u16(), - ) - .await; - } - }); - } else { - log::warn!("[Codex] 未能解析 usage 信息,跳过记录"); - } - } else { - log::info!("[Codex] <<< 响应 (非 JSON): {} bytes", body_bytes.len()); - } - log::info!("[Codex] ====== 请求结束 ======"); - - let mut builder = axum::response::Response::builder().status(status); - for (key, value) in response_headers.iter() { - builder = builder.header(key, value); - } - - let body = axum::body::Body::from(body_bytes); - Ok(builder.body(body).unwrap()) - } -} diff --git a/src-tauri/src/proxy/mod.rs b/src-tauri/src/proxy/mod.rs index c23d277c..5b052f0f 100644 --- a/src-tauri/src/proxy/mod.rs +++ b/src-tauri/src/proxy/mod.rs @@ -5,11 +5,14 @@ pub mod circuit_breaker; pub mod error; mod forwarder; +pub mod handler_config; +pub mod handler_context; mod handlers; mod health; pub mod provider_router; pub mod providers; pub mod response_handler; +pub mod response_processor; mod router; pub(crate) mod server; pub mod session; diff --git a/src-tauri/src/proxy/response_processor.rs b/src-tauri/src/proxy/response_processor.rs new file mode 100644 index 00000000..38b59fae --- /dev/null +++ b/src-tauri/src/proxy/response_processor.rs @@ -0,0 +1,411 @@ +//! 响应处理器模块 +//! +//! 统一处理流式和非流式 API 响应 + +use super::{ + handler_config::UsageParserConfig, handler_context::RequestContext, server::ProxyState, + usage::parser::TokenUsage, ProxyError, +}; +use axum::response::Response; +use bytes::Bytes; +use futures::stream::{Stream, StreamExt}; +use rust_decimal::Decimal; +use serde_json::Value; +use std::{ + str::FromStr, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; +use tokio::sync::Mutex; + +// ============================================================================ +// 公共接口 +// ============================================================================ + +/// 检测响应是否为 SSE 流式响应 +#[inline] +pub fn is_sse_response(response: &reqwest::Response) -> bool { + response + .headers() + .get("content-type") + .and_then(|v| v.to_str().ok()) + .map(|ct| ct.contains("text/event-stream")) + .unwrap_or(false) +} + +/// 处理流式响应 +pub async fn handle_streaming( + response: reqwest::Response, + ctx: &RequestContext, + state: &ProxyState, + parser_config: &UsageParserConfig, +) -> Response { + log::info!("[{}] 流式透传响应 (SSE)", ctx.tag); + + let status = response.status(); + let mut builder = axum::response::Response::builder().status(status); + + // 复制响应头 + for (key, value) in response.headers() { + builder = builder.header(key, value); + } + + // 创建字节流 + let stream = response + .bytes_stream() + .map(|chunk| chunk.map_err(|e| std::io::Error::other(e.to_string()))); + + // 创建使用量收集器 + let usage_collector = create_usage_collector(ctx, state, status.as_u16(), parser_config); + + // 创建带日志的透传流 + let logged_stream = create_logged_passthrough_stream(stream, ctx.tag, Some(usage_collector)); + + let body = axum::body::Body::from_stream(logged_stream); + builder.body(body).unwrap() +} + +/// 处理非流式响应 +pub async fn handle_non_streaming( + response: reqwest::Response, + ctx: &RequestContext, + state: &ProxyState, + parser_config: &UsageParserConfig, +) -> Result { + let response_headers = response.headers().clone(); + let status = response.status(); + + // 读取响应体 + let body_bytes = response.bytes().await.map_err(|e| { + log::error!("[{}] 读取响应失败: {e}", ctx.tag); + ProxyError::ForwardFailed(format!("Failed to read response body: {e}")) + })?; + + // 解析并记录使用量 + if let Ok(json_value) = serde_json::from_slice::(&body_bytes) { + log::info!( + "[{}] <<< 响应 JSON:\n{}", + ctx.tag, + serde_json::to_string_pretty(&json_value).unwrap_or_default() + ); + + // 解析使用量 + if let Some(usage) = (parser_config.response_parser)(&json_value) { + let model = json_value + .get("model") + .and_then(|m| m.as_str()) + .unwrap_or(&ctx.request_model); + + spawn_log_usage(state, ctx, usage, model, status.as_u16(), false); + } else { + log::debug!( + "[{}] 未能解析 usage 信息,跳过记录", + parser_config.app_type_str + ); + } + } else { + log::info!( + "[{}] <<< 响应 (非 JSON): {} bytes", + ctx.tag, + body_bytes.len() + ); + } + + log::info!("[{}] ====== 请求结束 ======", ctx.tag); + + // 构建响应 + let mut builder = axum::response::Response::builder().status(status); + for (key, value) in response_headers.iter() { + builder = builder.header(key, value); + } + + let body = axum::body::Body::from(body_bytes); + Ok(builder.body(body).unwrap()) +} + +/// 通用响应处理入口 +/// +/// 根据响应类型自动选择流式或非流式处理 +pub async fn process_response( + response: reqwest::Response, + ctx: &RequestContext, + state: &ProxyState, + parser_config: &UsageParserConfig, +) -> Result { + if is_sse_response(&response) { + Ok(handle_streaming(response, ctx, state, parser_config).await) + } else { + handle_non_streaming(response, ctx, state, parser_config).await + } +} + +// ============================================================================ +// SSE 使用量收集器 +// ============================================================================ + +type UsageCallbackWithTiming = Arc, Option) + Send + Sync + 'static>; + +/// SSE 使用量收集器 +#[derive(Clone)] +pub struct SseUsageCollector { + inner: Arc, +} + +struct SseUsageCollectorInner { + events: Mutex>, + first_event_time: Mutex>, + start_time: std::time::Instant, + on_complete: UsageCallbackWithTiming, + finished: AtomicBool, +} + +impl SseUsageCollector { + /// 创建新的使用量收集器 + pub fn new( + start_time: std::time::Instant, + callback: impl Fn(Vec, Option) + Send + Sync + 'static, + ) -> Self { + let on_complete: UsageCallbackWithTiming = Arc::new(callback); + Self { + inner: Arc::new(SseUsageCollectorInner { + events: Mutex::new(Vec::new()), + first_event_time: Mutex::new(None), + start_time, + on_complete, + finished: AtomicBool::new(false), + }), + } + } + + /// 推送 SSE 事件 + pub async fn push(&self, event: Value) { + // 记录首个事件时间 + { + let mut first_time = self.inner.first_event_time.lock().await; + if first_time.is_none() { + *first_time = Some(std::time::Instant::now()); + } + } + let mut events = self.inner.events.lock().await; + events.push(event); + } + + /// 完成收集并触发回调 + pub async fn finish(&self) { + if self.inner.finished.swap(true, Ordering::SeqCst) { + return; + } + + let events = { + let mut guard = self.inner.events.lock().await; + std::mem::take(&mut *guard) + }; + + let first_token_ms = { + let first_time = self.inner.first_event_time.lock().await; + first_time.map(|t| (t - self.inner.start_time).as_millis() as u64) + }; + + (self.inner.on_complete)(events, first_token_ms); + } +} + +// ============================================================================ +// 内部辅助函数 +// ============================================================================ + +/// 创建使用量收集器 +fn create_usage_collector( + ctx: &RequestContext, + state: &ProxyState, + status_code: u16, + parser_config: &UsageParserConfig, +) -> SseUsageCollector { + let state = state.clone(); + let provider_id = ctx.provider.id.clone(); + let request_model = ctx.request_model.clone(); + let app_type_str = parser_config.app_type_str; + let tag = ctx.tag; + let start_time = ctx.start_time; + let stream_parser = parser_config.stream_parser; + let model_extractor = parser_config.model_extractor; + + SseUsageCollector::new(start_time, move |events, first_token_ms| { + if let Some(usage) = stream_parser(&events) { + let model = model_extractor(&events, &request_model); + let latency_ms = start_time.elapsed().as_millis() as u64; + + let state = state.clone(); + let provider_id = provider_id.clone(); + + tokio::spawn(async move { + log_usage_internal( + &state, + &provider_id, + app_type_str, + &model, + usage, + latency_ms, + first_token_ms, + true, // is_streaming + status_code, + ) + .await; + }); + } else { + log::debug!("[{tag}] 流式响应缺少 usage 统计,跳过消费记录"); + } + }) +} + +/// 异步记录使用量 +fn spawn_log_usage( + state: &ProxyState, + ctx: &RequestContext, + usage: TokenUsage, + model: &str, + status_code: u16, + is_streaming: bool, +) { + let state = state.clone(); + let provider_id = ctx.provider.id.clone(); + let app_type_str = ctx.app_type_str.to_string(); + let model = model.to_string(); + let latency_ms = ctx.latency_ms(); + + tokio::spawn(async move { + log_usage_internal( + &state, + &provider_id, + &app_type_str, + &model, + usage, + latency_ms, + None, + is_streaming, + status_code, + ) + .await; + }); +} + +/// 内部使用量记录函数 +#[allow(clippy::too_many_arguments)] +async fn log_usage_internal( + state: &ProxyState, + provider_id: &str, + app_type: &str, + model: &str, + usage: TokenUsage, + latency_ms: u64, + first_token_ms: Option, + is_streaming: bool, + status_code: u16, +) { + use super::usage::logger::UsageLogger; + + let logger = UsageLogger::new(&state.db); + + // 获取 provider 的 cost_multiplier + let multiplier = match state.db.get_provider_by_id(provider_id, app_type) { + Ok(Some(p)) => { + if let Some(meta) = p.meta { + if let Some(cm) = meta.cost_multiplier { + Decimal::from_str(&cm).unwrap_or(Decimal::from(1)) + } else { + Decimal::from(1) + } + } else { + Decimal::from(1) + } + } + _ => Decimal::from(1), + }; + + let request_id = uuid::Uuid::new_v4().to_string(); + + if let Err(e) = logger.log_with_calculation( + request_id, + provider_id.to_string(), + app_type.to_string(), + model.to_string(), + usage, + multiplier, + latency_ms, + first_token_ms, + status_code, + None, + None, // provider_type + is_streaming, + ) { + log::warn!("记录使用量失败: {e}"); + } +} + +/// 创建带日志记录的透传流 +pub fn create_logged_passthrough_stream( + stream: impl Stream> + Send + 'static, + tag: &'static str, + usage_collector: Option, +) -> impl Stream> + Send { + async_stream::stream! { + let mut buffer = String::new(); + let mut collector = usage_collector; + + tokio::pin!(stream); + + while let Some(chunk) = stream.next().await { + match chunk { + Ok(bytes) => { + let text = String::from_utf8_lossy(&bytes); + buffer.push_str(&text); + + // 尝试解析并记录完整的 SSE 事件 + while let Some(pos) = buffer.find("\n\n") { + let event_text = buffer[..pos].to_string(); + buffer = buffer[pos + 2..].to_string(); + + if !event_text.trim().is_empty() { + // 提取 data 部分并尝试解析为 JSON + for line in event_text.lines() { + if let Some(data) = line.strip_prefix("data: ") { + if data.trim() != "[DONE]" { + if let Ok(json_value) = serde_json::from_str::(data) { + if let Some(c) = &collector { + c.push(json_value.clone()).await; + } + log::info!( + "[{}] <<< SSE 事件:\n{}", + tag, + serde_json::to_string_pretty(&json_value).unwrap_or_else(|_| data.to_string()) + ); + } else { + log::info!("[{tag}] <<< SSE 数据: {data}"); + } + } else { + log::info!("[{tag}] <<< SSE: [DONE]"); + } + } + } + } + } + + yield Ok(bytes); + } + Err(e) => { + log::error!("[{tag}] 流错误: {e}"); + yield Err(std::io::Error::other(e.to_string())); + break; + } + } + } + + log::info!("[{}] ====== 流结束 ======", tag); + + if let Some(c) = collector.take() { + c.finish().await; + } + } +}