mirror of
https://github.com/farion1231/cc-switch.git
synced 2026-04-07 05:19:57 +08:00
refactor(proxy): modularize handlers.rs to reduce code duplication
Extract common request handling logic into dedicated modules: - handler_config.rs: Usage parser configurations for each API type - handler_context.rs: Request lifecycle context management - response_processor.rs: Unified streaming/non-streaming response handling Reduces handlers.rs from ~1130 lines to ~418 lines (-63%), eliminating repeated initialization and response processing patterns across the four API handlers (Claude, Codex Chat, Codex Responses, Gemini).
This commit is contained in:
164
src-tauri/src/proxy/handler_config.rs
Normal file
164
src-tauri/src/proxy/handler_config.rs
Normal file
@@ -0,0 +1,164 @@
|
||||
//! Handler 配置模块
|
||||
//!
|
||||
//! 定义各 API 处理器的配置结构和使用量解析器
|
||||
|
||||
use crate::app_config::AppType;
|
||||
use crate::proxy::usage::parser::TokenUsage;
|
||||
use serde_json::Value;
|
||||
|
||||
/// 使用量解析器类型别名
|
||||
pub type StreamUsageParser = fn(&[Value]) -> Option<TokenUsage>;
|
||||
pub type ResponseUsageParser = fn(&Value) -> Option<TokenUsage>;
|
||||
|
||||
/// 模型提取器类型别名
|
||||
/// 参数: (流式事件列表, 请求中的模型名称) -> 最终使用的模型名称
|
||||
pub type StreamModelExtractor = fn(&[Value], &str) -> String;
|
||||
|
||||
/// 各 API 的使用量解析配置
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct UsageParserConfig {
|
||||
/// 流式响应解析器
|
||||
pub stream_parser: StreamUsageParser,
|
||||
/// 非流式响应解析器
|
||||
pub response_parser: ResponseUsageParser,
|
||||
/// 流式响应中的模型提取器
|
||||
pub model_extractor: StreamModelExtractor,
|
||||
/// 应用类型字符串(用于日志记录)
|
||||
pub app_type_str: &'static str,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 模型提取器实现
|
||||
// ============================================================================
|
||||
|
||||
/// Claude 流式响应模型提取(直接使用请求模型)
|
||||
fn claude_model_extractor(_events: &[Value], request_model: &str) -> String {
|
||||
request_model.to_string()
|
||||
}
|
||||
|
||||
/// OpenAI Chat Completions 流式响应模型提取
|
||||
fn openai_model_extractor(events: &[Value], request_model: &str) -> String {
|
||||
events
|
||||
.iter()
|
||||
.find_map(|e| e.get("model")?.as_str())
|
||||
.unwrap_or(request_model)
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Codex Responses API 流式响应模型提取
|
||||
fn codex_model_extractor(events: &[Value], request_model: &str) -> String {
|
||||
events
|
||||
.iter()
|
||||
.find_map(|e| {
|
||||
if e.get("type")?.as_str()? == "response.completed" {
|
||||
e.get("response")?.get("model")?.as_str()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or(request_model)
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Gemini 流式响应模型提取(优先使用 usage.model)
|
||||
fn gemini_model_extractor(events: &[Value], request_model: &str) -> String {
|
||||
// 首先尝试从解析的 usage 中获取模型
|
||||
if let Some(usage) = TokenUsage::from_gemini_stream_chunks(events) {
|
||||
if let Some(model) = usage.model {
|
||||
return model;
|
||||
}
|
||||
}
|
||||
request_model.to_string()
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 预定义配置
|
||||
// ============================================================================
|
||||
|
||||
/// Claude API 解析配置
|
||||
pub const CLAUDE_PARSER_CONFIG: UsageParserConfig = UsageParserConfig {
|
||||
stream_parser: TokenUsage::from_claude_stream_events,
|
||||
response_parser: TokenUsage::from_claude_response,
|
||||
model_extractor: claude_model_extractor,
|
||||
app_type_str: "claude",
|
||||
};
|
||||
|
||||
/// OpenAI Chat Completions API 解析配置(用于 Codex /v1/chat/completions)
|
||||
pub const OPENAI_PARSER_CONFIG: UsageParserConfig = UsageParserConfig {
|
||||
stream_parser: TokenUsage::from_openai_stream_events,
|
||||
response_parser: TokenUsage::from_openai_response,
|
||||
model_extractor: openai_model_extractor,
|
||||
app_type_str: "codex",
|
||||
};
|
||||
|
||||
/// Codex Responses API 解析配置(用于 /v1/responses)
|
||||
pub const CODEX_PARSER_CONFIG: UsageParserConfig = UsageParserConfig {
|
||||
stream_parser: TokenUsage::from_codex_stream_events,
|
||||
response_parser: TokenUsage::from_codex_response,
|
||||
model_extractor: codex_model_extractor,
|
||||
app_type_str: "codex",
|
||||
};
|
||||
|
||||
/// Gemini API 解析配置
|
||||
pub const GEMINI_PARSER_CONFIG: UsageParserConfig = UsageParserConfig {
|
||||
stream_parser: TokenUsage::from_gemini_stream_chunks,
|
||||
response_parser: TokenUsage::from_gemini_response,
|
||||
model_extractor: gemini_model_extractor,
|
||||
app_type_str: "gemini",
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Handler 配置(预留,用于进一步简化)
|
||||
// ============================================================================
|
||||
|
||||
/// Handler 基础配置
|
||||
///
|
||||
/// 预留结构,可用于进一步统一各 handler 的配置
|
||||
#[allow(dead_code)]
|
||||
#[derive(Clone)]
|
||||
pub struct HandlerConfig {
|
||||
/// 应用类型
|
||||
pub app_type: AppType,
|
||||
/// 日志标签
|
||||
pub tag: &'static str,
|
||||
/// 应用类型字符串
|
||||
pub app_type_str: &'static str,
|
||||
/// 使用量解析配置
|
||||
pub parser_config: &'static UsageParserConfig,
|
||||
}
|
||||
|
||||
/// Claude Handler 配置
|
||||
#[allow(dead_code)]
|
||||
pub const CLAUDE_HANDLER_CONFIG: HandlerConfig = HandlerConfig {
|
||||
app_type: AppType::Claude,
|
||||
tag: "Claude",
|
||||
app_type_str: "claude",
|
||||
parser_config: &CLAUDE_PARSER_CONFIG,
|
||||
};
|
||||
|
||||
/// Codex Chat Completions Handler 配置
|
||||
#[allow(dead_code)]
|
||||
pub const CODEX_CHAT_HANDLER_CONFIG: HandlerConfig = HandlerConfig {
|
||||
app_type: AppType::Codex,
|
||||
tag: "Codex",
|
||||
app_type_str: "codex",
|
||||
parser_config: &OPENAI_PARSER_CONFIG,
|
||||
};
|
||||
|
||||
/// Codex Responses Handler 配置
|
||||
#[allow(dead_code)]
|
||||
pub const CODEX_RESPONSES_HANDLER_CONFIG: HandlerConfig = HandlerConfig {
|
||||
app_type: AppType::Codex,
|
||||
tag: "Codex",
|
||||
app_type_str: "codex",
|
||||
parser_config: &CODEX_PARSER_CONFIG,
|
||||
};
|
||||
|
||||
/// Gemini Handler 配置
|
||||
#[allow(dead_code)]
|
||||
pub const GEMINI_HANDLER_CONFIG: HandlerConfig = HandlerConfig {
|
||||
app_type: AppType::Gemini,
|
||||
tag: "Gemini",
|
||||
app_type_str: "gemini",
|
||||
parser_config: &GEMINI_PARSER_CONFIG,
|
||||
};
|
||||
128
src-tauri/src/proxy/handler_context.rs
Normal file
128
src-tauri/src/proxy/handler_context.rs
Normal file
@@ -0,0 +1,128 @@
|
||||
//! 请求上下文模块
|
||||
//!
|
||||
//! 提供请求生命周期的上下文管理,封装通用初始化逻辑
|
||||
|
||||
use crate::app_config::AppType;
|
||||
use crate::provider::Provider;
|
||||
use crate::proxy::{
|
||||
forwarder::RequestForwarder, router::ProviderRouter, server::ProxyState, types::ProxyConfig,
|
||||
ProxyError,
|
||||
};
|
||||
use std::time::Instant;
|
||||
|
||||
/// 请求上下文
|
||||
///
|
||||
/// 贯穿整个请求生命周期,包含:
|
||||
/// - 计时信息
|
||||
/// - 代理配置
|
||||
/// - 选中的 Provider
|
||||
/// - 请求模型名称
|
||||
/// - 日志标签
|
||||
pub struct RequestContext {
|
||||
/// 请求开始时间
|
||||
pub start_time: Instant,
|
||||
/// 代理配置快照
|
||||
pub config: ProxyConfig,
|
||||
/// 选中的 Provider
|
||||
pub provider: Provider,
|
||||
/// 请求中的模型名称
|
||||
pub request_model: String,
|
||||
/// 日志标签(如 "Claude"、"Codex"、"Gemini")
|
||||
pub tag: &'static str,
|
||||
/// 应用类型字符串(如 "claude"、"codex"、"gemini")
|
||||
pub app_type_str: &'static str,
|
||||
/// 应用类型(预留,目前通过 app_type_str 使用)
|
||||
#[allow(dead_code)]
|
||||
pub app_type: AppType,
|
||||
}
|
||||
|
||||
impl RequestContext {
|
||||
/// 创建请求上下文
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `state` - 代理服务器状态
|
||||
/// * `body` - 请求体 JSON
|
||||
/// * `app_type` - 应用类型
|
||||
/// * `tag` - 日志标签
|
||||
/// * `app_type_str` - 应用类型字符串
|
||||
///
|
||||
/// # Errors
|
||||
/// 返回 `ProxyError` 如果 Provider 选择失败
|
||||
pub async fn new(
|
||||
state: &ProxyState,
|
||||
body: &serde_json::Value,
|
||||
app_type: AppType,
|
||||
tag: &'static str,
|
||||
app_type_str: &'static str,
|
||||
) -> Result<Self, ProxyError> {
|
||||
let start_time = Instant::now();
|
||||
let config = state.config.read().await.clone();
|
||||
|
||||
// 从请求体提取模型名称
|
||||
let request_model = body
|
||||
.get("model")
|
||||
.and_then(|m| m.as_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
|
||||
// Provider 选择
|
||||
let router = ProviderRouter::new(state.db.clone());
|
||||
let provider = router.select_provider(&app_type, &[]).await?;
|
||||
|
||||
log::info!(
|
||||
"[{}] Provider: {}, model: {}",
|
||||
tag,
|
||||
provider.name,
|
||||
request_model
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
start_time,
|
||||
config,
|
||||
provider,
|
||||
request_model,
|
||||
tag,
|
||||
app_type_str,
|
||||
app_type,
|
||||
})
|
||||
}
|
||||
|
||||
/// 从 URI 提取模型名称(Gemini 专用)
|
||||
///
|
||||
/// Gemini API 的模型名称在 URI 中,格式如:
|
||||
/// `/v1beta/models/gemini-pro:generateContent`
|
||||
pub fn with_model_from_uri(mut self, uri: &axum::http::Uri) -> Self {
|
||||
let endpoint = uri
|
||||
.path_and_query()
|
||||
.map(|pq| pq.as_str())
|
||||
.unwrap_or(uri.path());
|
||||
|
||||
self.request_model = endpoint
|
||||
.split('/')
|
||||
.find(|s| s.starts_with("models/"))
|
||||
.and_then(|s| s.strip_prefix("models/"))
|
||||
.map(|s| s.split(':').next().unwrap_or(s))
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
|
||||
log::info!("[{}] 从 URI 提取模型: {}", self.tag, self.request_model);
|
||||
self
|
||||
}
|
||||
|
||||
/// 创建 RequestForwarder
|
||||
pub fn create_forwarder(&self, state: &ProxyState) -> RequestForwarder {
|
||||
RequestForwarder::new(
|
||||
state.db.clone(),
|
||||
self.config.request_timeout,
|
||||
self.config.max_retries,
|
||||
state.status.clone(),
|
||||
state.current_providers.clone(),
|
||||
)
|
||||
}
|
||||
|
||||
/// 计算请求延迟(毫秒)
|
||||
#[inline]
|
||||
pub fn latency_ms(&self) -> u64 {
|
||||
self.start_time.elapsed().as_millis() as u64
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,11 +5,14 @@
|
||||
pub mod circuit_breaker;
|
||||
pub mod error;
|
||||
mod forwarder;
|
||||
pub mod handler_config;
|
||||
pub mod handler_context;
|
||||
mod handlers;
|
||||
mod health;
|
||||
pub mod provider_router;
|
||||
pub mod providers;
|
||||
pub mod response_handler;
|
||||
pub mod response_processor;
|
||||
mod router;
|
||||
pub(crate) mod server;
|
||||
pub mod session;
|
||||
|
||||
411
src-tauri/src/proxy/response_processor.rs
Normal file
411
src-tauri/src/proxy/response_processor.rs
Normal file
@@ -0,0 +1,411 @@
|
||||
//! 响应处理器模块
|
||||
//!
|
||||
//! 统一处理流式和非流式 API 响应
|
||||
|
||||
use super::{
|
||||
handler_config::UsageParserConfig, handler_context::RequestContext, server::ProxyState,
|
||||
usage::parser::TokenUsage, ProxyError,
|
||||
};
|
||||
use axum::response::Response;
|
||||
use bytes::Bytes;
|
||||
use futures::stream::{Stream, StreamExt};
|
||||
use rust_decimal::Decimal;
|
||||
use serde_json::Value;
|
||||
use std::{
|
||||
str::FromStr,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
// ============================================================================
|
||||
// 公共接口
|
||||
// ============================================================================
|
||||
|
||||
/// 检测响应是否为 SSE 流式响应
|
||||
#[inline]
|
||||
pub fn is_sse_response(response: &reqwest::Response) -> bool {
|
||||
response
|
||||
.headers()
|
||||
.get("content-type")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|ct| ct.contains("text/event-stream"))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// 处理流式响应
|
||||
pub async fn handle_streaming(
|
||||
response: reqwest::Response,
|
||||
ctx: &RequestContext,
|
||||
state: &ProxyState,
|
||||
parser_config: &UsageParserConfig,
|
||||
) -> Response {
|
||||
log::info!("[{}] 流式透传响应 (SSE)", ctx.tag);
|
||||
|
||||
let status = response.status();
|
||||
let mut builder = axum::response::Response::builder().status(status);
|
||||
|
||||
// 复制响应头
|
||||
for (key, value) in response.headers() {
|
||||
builder = builder.header(key, value);
|
||||
}
|
||||
|
||||
// 创建字节流
|
||||
let stream = response
|
||||
.bytes_stream()
|
||||
.map(|chunk| chunk.map_err(|e| std::io::Error::other(e.to_string())));
|
||||
|
||||
// 创建使用量收集器
|
||||
let usage_collector = create_usage_collector(ctx, state, status.as_u16(), parser_config);
|
||||
|
||||
// 创建带日志的透传流
|
||||
let logged_stream = create_logged_passthrough_stream(stream, ctx.tag, Some(usage_collector));
|
||||
|
||||
let body = axum::body::Body::from_stream(logged_stream);
|
||||
builder.body(body).unwrap()
|
||||
}
|
||||
|
||||
/// 处理非流式响应
|
||||
pub async fn handle_non_streaming(
|
||||
response: reqwest::Response,
|
||||
ctx: &RequestContext,
|
||||
state: &ProxyState,
|
||||
parser_config: &UsageParserConfig,
|
||||
) -> Result<Response, ProxyError> {
|
||||
let response_headers = response.headers().clone();
|
||||
let status = response.status();
|
||||
|
||||
// 读取响应体
|
||||
let body_bytes = response.bytes().await.map_err(|e| {
|
||||
log::error!("[{}] 读取响应失败: {e}", ctx.tag);
|
||||
ProxyError::ForwardFailed(format!("Failed to read response body: {e}"))
|
||||
})?;
|
||||
|
||||
// 解析并记录使用量
|
||||
if let Ok(json_value) = serde_json::from_slice::<Value>(&body_bytes) {
|
||||
log::info!(
|
||||
"[{}] <<< 响应 JSON:\n{}",
|
||||
ctx.tag,
|
||||
serde_json::to_string_pretty(&json_value).unwrap_or_default()
|
||||
);
|
||||
|
||||
// 解析使用量
|
||||
if let Some(usage) = (parser_config.response_parser)(&json_value) {
|
||||
let model = json_value
|
||||
.get("model")
|
||||
.and_then(|m| m.as_str())
|
||||
.unwrap_or(&ctx.request_model);
|
||||
|
||||
spawn_log_usage(state, ctx, usage, model, status.as_u16(), false);
|
||||
} else {
|
||||
log::debug!(
|
||||
"[{}] 未能解析 usage 信息,跳过记录",
|
||||
parser_config.app_type_str
|
||||
);
|
||||
}
|
||||
} else {
|
||||
log::info!(
|
||||
"[{}] <<< 响应 (非 JSON): {} bytes",
|
||||
ctx.tag,
|
||||
body_bytes.len()
|
||||
);
|
||||
}
|
||||
|
||||
log::info!("[{}] ====== 请求结束 ======", ctx.tag);
|
||||
|
||||
// 构建响应
|
||||
let mut builder = axum::response::Response::builder().status(status);
|
||||
for (key, value) in response_headers.iter() {
|
||||
builder = builder.header(key, value);
|
||||
}
|
||||
|
||||
let body = axum::body::Body::from(body_bytes);
|
||||
Ok(builder.body(body).unwrap())
|
||||
}
|
||||
|
||||
/// 通用响应处理入口
|
||||
///
|
||||
/// 根据响应类型自动选择流式或非流式处理
|
||||
pub async fn process_response(
|
||||
response: reqwest::Response,
|
||||
ctx: &RequestContext,
|
||||
state: &ProxyState,
|
||||
parser_config: &UsageParserConfig,
|
||||
) -> Result<Response, ProxyError> {
|
||||
if is_sse_response(&response) {
|
||||
Ok(handle_streaming(response, ctx, state, parser_config).await)
|
||||
} else {
|
||||
handle_non_streaming(response, ctx, state, parser_config).await
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SSE 使用量收集器
|
||||
// ============================================================================
|
||||
|
||||
type UsageCallbackWithTiming = Arc<dyn Fn(Vec<Value>, Option<u64>) + Send + Sync + 'static>;
|
||||
|
||||
/// SSE 使用量收集器
|
||||
#[derive(Clone)]
|
||||
pub struct SseUsageCollector {
|
||||
inner: Arc<SseUsageCollectorInner>,
|
||||
}
|
||||
|
||||
struct SseUsageCollectorInner {
|
||||
events: Mutex<Vec<Value>>,
|
||||
first_event_time: Mutex<Option<std::time::Instant>>,
|
||||
start_time: std::time::Instant,
|
||||
on_complete: UsageCallbackWithTiming,
|
||||
finished: AtomicBool,
|
||||
}
|
||||
|
||||
impl SseUsageCollector {
|
||||
/// 创建新的使用量收集器
|
||||
pub fn new(
|
||||
start_time: std::time::Instant,
|
||||
callback: impl Fn(Vec<Value>, Option<u64>) + Send + Sync + 'static,
|
||||
) -> Self {
|
||||
let on_complete: UsageCallbackWithTiming = Arc::new(callback);
|
||||
Self {
|
||||
inner: Arc::new(SseUsageCollectorInner {
|
||||
events: Mutex::new(Vec::new()),
|
||||
first_event_time: Mutex::new(None),
|
||||
start_time,
|
||||
on_complete,
|
||||
finished: AtomicBool::new(false),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// 推送 SSE 事件
|
||||
pub async fn push(&self, event: Value) {
|
||||
// 记录首个事件时间
|
||||
{
|
||||
let mut first_time = self.inner.first_event_time.lock().await;
|
||||
if first_time.is_none() {
|
||||
*first_time = Some(std::time::Instant::now());
|
||||
}
|
||||
}
|
||||
let mut events = self.inner.events.lock().await;
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
/// 完成收集并触发回调
|
||||
pub async fn finish(&self) {
|
||||
if self.inner.finished.swap(true, Ordering::SeqCst) {
|
||||
return;
|
||||
}
|
||||
|
||||
let events = {
|
||||
let mut guard = self.inner.events.lock().await;
|
||||
std::mem::take(&mut *guard)
|
||||
};
|
||||
|
||||
let first_token_ms = {
|
||||
let first_time = self.inner.first_event_time.lock().await;
|
||||
first_time.map(|t| (t - self.inner.start_time).as_millis() as u64)
|
||||
};
|
||||
|
||||
(self.inner.on_complete)(events, first_token_ms);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 内部辅助函数
|
||||
// ============================================================================
|
||||
|
||||
/// 创建使用量收集器
|
||||
fn create_usage_collector(
|
||||
ctx: &RequestContext,
|
||||
state: &ProxyState,
|
||||
status_code: u16,
|
||||
parser_config: &UsageParserConfig,
|
||||
) -> SseUsageCollector {
|
||||
let state = state.clone();
|
||||
let provider_id = ctx.provider.id.clone();
|
||||
let request_model = ctx.request_model.clone();
|
||||
let app_type_str = parser_config.app_type_str;
|
||||
let tag = ctx.tag;
|
||||
let start_time = ctx.start_time;
|
||||
let stream_parser = parser_config.stream_parser;
|
||||
let model_extractor = parser_config.model_extractor;
|
||||
|
||||
SseUsageCollector::new(start_time, move |events, first_token_ms| {
|
||||
if let Some(usage) = stream_parser(&events) {
|
||||
let model = model_extractor(&events, &request_model);
|
||||
let latency_ms = start_time.elapsed().as_millis() as u64;
|
||||
|
||||
let state = state.clone();
|
||||
let provider_id = provider_id.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
log_usage_internal(
|
||||
&state,
|
||||
&provider_id,
|
||||
app_type_str,
|
||||
&model,
|
||||
usage,
|
||||
latency_ms,
|
||||
first_token_ms,
|
||||
true, // is_streaming
|
||||
status_code,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
} else {
|
||||
log::debug!("[{tag}] 流式响应缺少 usage 统计,跳过消费记录");
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// 异步记录使用量
|
||||
fn spawn_log_usage(
|
||||
state: &ProxyState,
|
||||
ctx: &RequestContext,
|
||||
usage: TokenUsage,
|
||||
model: &str,
|
||||
status_code: u16,
|
||||
is_streaming: bool,
|
||||
) {
|
||||
let state = state.clone();
|
||||
let provider_id = ctx.provider.id.clone();
|
||||
let app_type_str = ctx.app_type_str.to_string();
|
||||
let model = model.to_string();
|
||||
let latency_ms = ctx.latency_ms();
|
||||
|
||||
tokio::spawn(async move {
|
||||
log_usage_internal(
|
||||
&state,
|
||||
&provider_id,
|
||||
&app_type_str,
|
||||
&model,
|
||||
usage,
|
||||
latency_ms,
|
||||
None,
|
||||
is_streaming,
|
||||
status_code,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
/// 内部使用量记录函数
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn log_usage_internal(
|
||||
state: &ProxyState,
|
||||
provider_id: &str,
|
||||
app_type: &str,
|
||||
model: &str,
|
||||
usage: TokenUsage,
|
||||
latency_ms: u64,
|
||||
first_token_ms: Option<u64>,
|
||||
is_streaming: bool,
|
||||
status_code: u16,
|
||||
) {
|
||||
use super::usage::logger::UsageLogger;
|
||||
|
||||
let logger = UsageLogger::new(&state.db);
|
||||
|
||||
// 获取 provider 的 cost_multiplier
|
||||
let multiplier = match state.db.get_provider_by_id(provider_id, app_type) {
|
||||
Ok(Some(p)) => {
|
||||
if let Some(meta) = p.meta {
|
||||
if let Some(cm) = meta.cost_multiplier {
|
||||
Decimal::from_str(&cm).unwrap_or(Decimal::from(1))
|
||||
} else {
|
||||
Decimal::from(1)
|
||||
}
|
||||
} else {
|
||||
Decimal::from(1)
|
||||
}
|
||||
}
|
||||
_ => Decimal::from(1),
|
||||
};
|
||||
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
if let Err(e) = logger.log_with_calculation(
|
||||
request_id,
|
||||
provider_id.to_string(),
|
||||
app_type.to_string(),
|
||||
model.to_string(),
|
||||
usage,
|
||||
multiplier,
|
||||
latency_ms,
|
||||
first_token_ms,
|
||||
status_code,
|
||||
None,
|
||||
None, // provider_type
|
||||
is_streaming,
|
||||
) {
|
||||
log::warn!("记录使用量失败: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
/// 创建带日志记录的透传流
|
||||
pub fn create_logged_passthrough_stream(
|
||||
stream: impl Stream<Item = Result<Bytes, std::io::Error>> + Send + 'static,
|
||||
tag: &'static str,
|
||||
usage_collector: Option<SseUsageCollector>,
|
||||
) -> impl Stream<Item = Result<Bytes, std::io::Error>> + Send {
|
||||
async_stream::stream! {
|
||||
let mut buffer = String::new();
|
||||
let mut collector = usage_collector;
|
||||
|
||||
tokio::pin!(stream);
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
match chunk {
|
||||
Ok(bytes) => {
|
||||
let text = String::from_utf8_lossy(&bytes);
|
||||
buffer.push_str(&text);
|
||||
|
||||
// 尝试解析并记录完整的 SSE 事件
|
||||
while let Some(pos) = buffer.find("\n\n") {
|
||||
let event_text = buffer[..pos].to_string();
|
||||
buffer = buffer[pos + 2..].to_string();
|
||||
|
||||
if !event_text.trim().is_empty() {
|
||||
// 提取 data 部分并尝试解析为 JSON
|
||||
for line in event_text.lines() {
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
if data.trim() != "[DONE]" {
|
||||
if let Ok(json_value) = serde_json::from_str::<Value>(data) {
|
||||
if let Some(c) = &collector {
|
||||
c.push(json_value.clone()).await;
|
||||
}
|
||||
log::info!(
|
||||
"[{}] <<< SSE 事件:\n{}",
|
||||
tag,
|
||||
serde_json::to_string_pretty(&json_value).unwrap_or_else(|_| data.to_string())
|
||||
);
|
||||
} else {
|
||||
log::info!("[{tag}] <<< SSE 数据: {data}");
|
||||
}
|
||||
} else {
|
||||
log::info!("[{tag}] <<< SSE: [DONE]");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
yield Ok(bytes);
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("[{tag}] 流错误: {e}");
|
||||
yield Err(std::io::Error::other(e.to_string()));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("[{}] ====== 流结束 ======", tag);
|
||||
|
||||
if let Some(c) = collector.take() {
|
||||
c.finish().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user