mirror of
https://github.com/farion1231/cc-switch.git
synced 2026-05-24 23:10:39 +08:00
1009 lines
33 KiB
Rust
1009 lines
33 KiB
Rust
//! 响应处理器模块
|
||
//!
|
||
//! 统一处理流式和非流式 API 响应
|
||
|
||
use super::{
|
||
handler_config::UsageParserConfig,
|
||
handler_context::{RequestContext, StreamingTimeoutConfig},
|
||
hyper_client::ProxyResponse,
|
||
server::ProxyState,
|
||
sse::strip_sse_field,
|
||
usage::parser::TokenUsage,
|
||
ProxyError,
|
||
};
|
||
use axum::http::{header::HeaderMap, HeaderName};
|
||
use axum::response::{IntoResponse, Response};
|
||
use bytes::Bytes;
|
||
use futures::stream::{Stream, StreamExt};
|
||
use serde_json::Value;
|
||
use std::{
|
||
io::Read,
|
||
sync::{
|
||
atomic::{AtomicBool, Ordering},
|
||
Arc,
|
||
},
|
||
time::Duration,
|
||
};
|
||
use tokio::sync::Mutex;
|
||
|
||
// ============================================================================
|
||
// 响应解压
|
||
// ============================================================================
|
||
|
||
/// 根据 content-encoding 解压响应体字节
|
||
///
|
||
/// reqwest 自动解压已禁用(为了透传 accept-encoding),需要手动解压。
|
||
fn decompress_body(content_encoding: &str, body: &[u8]) -> Result<Vec<u8>, std::io::Error> {
|
||
match content_encoding {
|
||
"gzip" | "x-gzip" => {
|
||
let mut decoder = flate2::read::GzDecoder::new(body);
|
||
let mut decompressed = Vec::new();
|
||
decoder.read_to_end(&mut decompressed)?;
|
||
Ok(decompressed)
|
||
}
|
||
"deflate" => {
|
||
let mut decoder = flate2::read::DeflateDecoder::new(body);
|
||
let mut decompressed = Vec::new();
|
||
decoder.read_to_end(&mut decompressed)?;
|
||
Ok(decompressed)
|
||
}
|
||
"br" => {
|
||
let mut decompressed = Vec::new();
|
||
brotli::BrotliDecompress(&mut std::io::Cursor::new(body), &mut decompressed)?;
|
||
Ok(decompressed)
|
||
}
|
||
_ => {
|
||
log::warn!("未知的 content-encoding: {content_encoding},跳过解压");
|
||
Ok(body.to_vec())
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 从响应头提取 content-encoding(忽略 identity 和 chunked)
|
||
fn get_content_encoding(headers: &HeaderMap) -> Option<String> {
|
||
headers
|
||
.get("content-encoding")
|
||
.and_then(|v| v.to_str().ok())
|
||
.map(|s| s.trim().to_lowercase())
|
||
.filter(|s| !s.is_empty() && s != "identity")
|
||
}
|
||
|
||
/// RFC 2616 / RFC 7230 中定义的不应被代理继续转发的响应头。
|
||
const HOP_BY_HOP_RESPONSE_HEADERS: &[&str] = &[
|
||
"connection",
|
||
"keep-alive",
|
||
"proxy-authenticate",
|
||
"proxy-authorization",
|
||
"proxy-connection",
|
||
"te",
|
||
"trailer",
|
||
"trailers",
|
||
"transfer-encoding",
|
||
"upgrade",
|
||
];
|
||
|
||
/// 移除响应侧 hop-by-hop 头,以及 `Connection` 中点名的扩展头。
|
||
pub(crate) fn strip_hop_by_hop_response_headers(headers: &mut HeaderMap) {
|
||
let connection_listed_headers: Vec<HeaderName> = headers
|
||
.get_all(axum::http::header::CONNECTION)
|
||
.iter()
|
||
.filter_map(|value| value.to_str().ok())
|
||
.flat_map(|value| value.split(','))
|
||
.map(str::trim)
|
||
.filter(|name| !name.is_empty())
|
||
.filter_map(|name| HeaderName::from_bytes(name.as_bytes()).ok())
|
||
.collect();
|
||
|
||
for name in HOP_BY_HOP_RESPONSE_HEADERS {
|
||
headers.remove(*name);
|
||
}
|
||
|
||
for name in connection_listed_headers {
|
||
headers.remove(name);
|
||
}
|
||
}
|
||
|
||
/// 移除在重建响应体后会失真的实体头。
|
||
pub(crate) fn strip_entity_headers_for_rebuilt_body(headers: &mut HeaderMap) {
|
||
headers.remove(axum::http::header::CONTENT_ENCODING);
|
||
headers.remove(axum::http::header::CONTENT_LENGTH);
|
||
headers.remove(axum::http::header::TRANSFER_ENCODING);
|
||
}
|
||
|
||
/// 读取响应体并在需要时解压,确保 headers 与返回 body 一致。
|
||
///
|
||
/// `body_timeout`: 整包超时。当非零时用 `tokio::time::timeout` 包住 `.bytes()` 调用,
|
||
/// 防止上游发完响应头后卡住 body 导致请求永远挂住。
|
||
/// 传入 `Duration::ZERO` 表示不启用超时(故障转移关闭时)。
|
||
pub(crate) async fn read_decoded_body(
|
||
response: ProxyResponse,
|
||
tag: &str,
|
||
body_timeout: Duration,
|
||
) -> Result<(HeaderMap, http::StatusCode, Bytes), ProxyError> {
|
||
let mut headers = response.headers().clone();
|
||
let status = response.status();
|
||
let raw_bytes = if body_timeout.is_zero() {
|
||
response.bytes().await?
|
||
} else {
|
||
tokio::time::timeout(body_timeout, response.bytes())
|
||
.await
|
||
.map_err(|_| {
|
||
ProxyError::Timeout(format!(
|
||
"响应体读取超时: {}s(上游发完响应头后 body 未到达)",
|
||
body_timeout.as_secs()
|
||
))
|
||
})??
|
||
};
|
||
|
||
log::debug!(
|
||
"[{tag}] 已接收上游响应体: status={}, bytes={}, headers={}",
|
||
status.as_u16(),
|
||
raw_bytes.len(),
|
||
format_headers(&headers)
|
||
);
|
||
|
||
let mut body_bytes = raw_bytes.clone();
|
||
let mut decoded = false;
|
||
|
||
if let Some(encoding) = get_content_encoding(&headers) {
|
||
log::debug!("[{tag}] 解压非流式响应: content-encoding={encoding}");
|
||
match decompress_body(&encoding, &raw_bytes) {
|
||
Ok(decompressed) => {
|
||
body_bytes = Bytes::from(decompressed);
|
||
decoded = true;
|
||
}
|
||
Err(e) => {
|
||
log::warn!("[{tag}] 解压失败 ({encoding}): {e},使用原始数据");
|
||
}
|
||
}
|
||
}
|
||
|
||
if decoded {
|
||
strip_entity_headers_for_rebuilt_body(&mut headers);
|
||
}
|
||
|
||
Ok((headers, status, body_bytes))
|
||
}
|
||
|
||
// ============================================================================
|
||
// 公共接口
|
||
// ============================================================================
|
||
|
||
/// 检测响应是否为 SSE 流式响应
|
||
#[inline]
|
||
pub fn is_sse_response(response: &ProxyResponse) -> bool {
|
||
response.is_sse()
|
||
}
|
||
|
||
/// 处理流式响应
|
||
pub async fn handle_streaming(
|
||
response: ProxyResponse,
|
||
ctx: &RequestContext,
|
||
state: &ProxyState,
|
||
parser_config: &UsageParserConfig,
|
||
) -> Response {
|
||
let status = response.status();
|
||
log::debug!(
|
||
"[{}] 已接收上游流式响应: status={}, headers={}",
|
||
ctx.tag,
|
||
status.as_u16(),
|
||
format_headers(response.headers())
|
||
);
|
||
// 检查流式响应是否被压缩(SSE 通常不压缩,如果压缩则 SSE 解析会失败)
|
||
if let Some(encoding) = get_content_encoding(response.headers()) {
|
||
log::warn!(
|
||
"[{}] 流式响应含 content-encoding={encoding},SSE 解析可能失败。\
|
||
上游在 accept-encoding 透传后压缩了 SSE 流。",
|
||
ctx.tag
|
||
);
|
||
}
|
||
|
||
let mut response_headers = response.headers().clone();
|
||
strip_hop_by_hop_response_headers(&mut response_headers);
|
||
|
||
let mut builder = axum::response::Response::builder().status(status);
|
||
|
||
// 复制响应头
|
||
for (key, value) in &response_headers {
|
||
builder = builder.header(key, value);
|
||
}
|
||
|
||
// 创建字节流
|
||
let stream = response.bytes_stream();
|
||
|
||
// 创建使用量收集器
|
||
let usage_collector = create_usage_collector(ctx, state, status.as_u16(), parser_config);
|
||
|
||
// 获取流式超时配置
|
||
let timeout_config = ctx.streaming_timeout_config();
|
||
|
||
// 创建带日志和超时的透传流
|
||
let logged_stream =
|
||
create_logged_passthrough_stream(stream, ctx.tag, Some(usage_collector), timeout_config);
|
||
|
||
let body = axum::body::Body::from_stream(logged_stream);
|
||
match builder.body(body) {
|
||
Ok(resp) => resp,
|
||
Err(e) => {
|
||
log::error!("[{}] 构建流式响应失败: {e}", ctx.tag);
|
||
ProxyError::Internal(format!("Failed to build streaming response: {e}")).into_response()
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 处理非流式响应
|
||
pub async fn handle_non_streaming(
|
||
response: ProxyResponse,
|
||
ctx: &RequestContext,
|
||
state: &ProxyState,
|
||
parser_config: &UsageParserConfig,
|
||
) -> Result<Response, ProxyError> {
|
||
// 整包超时:仅在故障转移开启且配置值非零时生效
|
||
let body_timeout =
|
||
if ctx.app_config.auto_failover_enabled && ctx.app_config.non_streaming_timeout > 0 {
|
||
Duration::from_secs(ctx.app_config.non_streaming_timeout as u64)
|
||
} else {
|
||
Duration::ZERO
|
||
};
|
||
let (mut response_headers, status, body_bytes) =
|
||
read_decoded_body(response, ctx.tag, body_timeout).await?;
|
||
strip_hop_by_hop_response_headers(&mut response_headers);
|
||
|
||
log::debug!(
|
||
"[{}] 上游响应体内容: {}",
|
||
ctx.tag,
|
||
String::from_utf8_lossy(&body_bytes)
|
||
);
|
||
|
||
// 解析并记录使用量
|
||
if let Ok(json_value) = serde_json::from_slice::<Value>(&body_bytes) {
|
||
// 解析使用量
|
||
if let Some(usage) = (parser_config.response_parser)(&json_value) {
|
||
// 优先使用 usage 中解析出的模型名称,其次使用响应中的 model 字段,最后回退到请求模型
|
||
let model = if let Some(ref m) = usage.model {
|
||
m.clone()
|
||
} else if let Some(m) = json_value.get("model").and_then(|m| m.as_str()) {
|
||
m.to_string()
|
||
} else {
|
||
ctx.request_model.clone()
|
||
};
|
||
|
||
spawn_log_usage(
|
||
state,
|
||
ctx,
|
||
usage,
|
||
&model,
|
||
&ctx.request_model,
|
||
status.as_u16(),
|
||
false,
|
||
);
|
||
} else {
|
||
let model = json_value
|
||
.get("model")
|
||
.and_then(|m| m.as_str())
|
||
.unwrap_or(&ctx.request_model)
|
||
.to_string();
|
||
spawn_log_usage(
|
||
state,
|
||
ctx,
|
||
TokenUsage::default(),
|
||
&model,
|
||
&ctx.request_model,
|
||
status.as_u16(),
|
||
false,
|
||
);
|
||
log::debug!(
|
||
"[{}] 未能解析 usage 信息,跳过记录",
|
||
parser_config.app_type_str
|
||
);
|
||
}
|
||
} else {
|
||
log::debug!(
|
||
"[{}] <<< 响应 (非 JSON): {} bytes",
|
||
ctx.tag,
|
||
body_bytes.len()
|
||
);
|
||
spawn_log_usage(
|
||
state,
|
||
ctx,
|
||
TokenUsage::default(),
|
||
&ctx.request_model,
|
||
&ctx.request_model,
|
||
status.as_u16(),
|
||
false,
|
||
);
|
||
}
|
||
|
||
// 构建响应
|
||
let mut builder = axum::response::Response::builder().status(status);
|
||
for (key, value) in response_headers.iter() {
|
||
builder = builder.header(key, value);
|
||
}
|
||
|
||
let body = axum::body::Body::from(body_bytes);
|
||
builder.body(body).map_err(|e| {
|
||
log::error!("[{}] 构建响应失败: {e}", ctx.tag);
|
||
ProxyError::Internal(format!("Failed to build response: {e}"))
|
||
})
|
||
}
|
||
|
||
/// 通用响应处理入口
|
||
///
|
||
/// 根据响应类型自动选择流式或非流式处理
|
||
pub async fn process_response(
|
||
response: ProxyResponse,
|
||
ctx: &RequestContext,
|
||
state: &ProxyState,
|
||
parser_config: &UsageParserConfig,
|
||
) -> Result<Response, ProxyError> {
|
||
if is_sse_response(&response) {
|
||
Ok(handle_streaming(response, ctx, state, parser_config).await)
|
||
} else {
|
||
handle_non_streaming(response, ctx, state, parser_config).await
|
||
}
|
||
}
|
||
|
||
// ============================================================================
|
||
// SSE 使用量收集器
|
||
// ============================================================================
|
||
|
||
type UsageCallbackWithTiming = Arc<dyn Fn(Vec<Value>, Option<u64>) + Send + Sync + 'static>;
|
||
|
||
/// SSE 使用量收集器
|
||
#[derive(Clone)]
|
||
pub struct SseUsageCollector {
|
||
inner: Arc<SseUsageCollectorInner>,
|
||
}
|
||
|
||
struct SseUsageCollectorInner {
|
||
events: Mutex<Vec<Value>>,
|
||
first_event_time: Mutex<Option<std::time::Instant>>,
|
||
start_time: std::time::Instant,
|
||
on_complete: UsageCallbackWithTiming,
|
||
finished: AtomicBool,
|
||
}
|
||
|
||
impl SseUsageCollector {
|
||
/// 创建新的使用量收集器
|
||
pub fn new(
|
||
start_time: std::time::Instant,
|
||
callback: impl Fn(Vec<Value>, Option<u64>) + Send + Sync + 'static,
|
||
) -> Self {
|
||
let on_complete: UsageCallbackWithTiming = Arc::new(callback);
|
||
Self {
|
||
inner: Arc::new(SseUsageCollectorInner {
|
||
events: Mutex::new(Vec::new()),
|
||
first_event_time: Mutex::new(None),
|
||
start_time,
|
||
on_complete,
|
||
finished: AtomicBool::new(false),
|
||
}),
|
||
}
|
||
}
|
||
|
||
/// 推送 SSE 事件
|
||
pub async fn push(&self, event: Value) {
|
||
// 记录首个事件时间
|
||
{
|
||
let mut first_time = self.inner.first_event_time.lock().await;
|
||
if first_time.is_none() {
|
||
*first_time = Some(std::time::Instant::now());
|
||
}
|
||
}
|
||
let mut events = self.inner.events.lock().await;
|
||
events.push(event);
|
||
}
|
||
|
||
/// 完成收集并触发回调
|
||
pub async fn finish(&self) {
|
||
if self.inner.finished.swap(true, Ordering::SeqCst) {
|
||
return;
|
||
}
|
||
|
||
let events = {
|
||
let mut guard = self.inner.events.lock().await;
|
||
std::mem::take(&mut *guard)
|
||
};
|
||
|
||
let first_token_ms = {
|
||
let first_time = self.inner.first_event_time.lock().await;
|
||
first_time.map(|t| (t - self.inner.start_time).as_millis() as u64)
|
||
};
|
||
|
||
(self.inner.on_complete)(events, first_token_ms);
|
||
}
|
||
}
|
||
|
||
// ============================================================================
|
||
// 内部辅助函数
|
||
// ============================================================================
|
||
|
||
/// 创建使用量收集器
|
||
fn create_usage_collector(
|
||
ctx: &RequestContext,
|
||
state: &ProxyState,
|
||
status_code: u16,
|
||
parser_config: &UsageParserConfig,
|
||
) -> SseUsageCollector {
|
||
let logging_enabled = state
|
||
.config
|
||
.try_read()
|
||
.map(|c| c.enable_logging)
|
||
.unwrap_or(true);
|
||
let state = state.clone();
|
||
let provider_id = ctx.provider.id.clone();
|
||
let request_model = ctx.request_model.clone();
|
||
let app_type_str = parser_config.app_type_str;
|
||
let tag = ctx.tag;
|
||
let start_time = ctx.start_time;
|
||
let stream_parser = parser_config.stream_parser;
|
||
let model_extractor = parser_config.model_extractor;
|
||
let session_id = ctx.session_id.clone();
|
||
|
||
SseUsageCollector::new(start_time, move |events, first_token_ms| {
|
||
if !logging_enabled {
|
||
return;
|
||
}
|
||
if let Some(usage) = stream_parser(&events) {
|
||
let model = model_extractor(&events, &request_model);
|
||
let latency_ms = start_time.elapsed().as_millis() as u64;
|
||
|
||
let state = state.clone();
|
||
let provider_id = provider_id.clone();
|
||
let session_id = session_id.clone();
|
||
let request_model = request_model.clone();
|
||
|
||
tokio::spawn(async move {
|
||
log_usage_internal(
|
||
&state,
|
||
&provider_id,
|
||
app_type_str,
|
||
&model,
|
||
&request_model,
|
||
usage,
|
||
latency_ms,
|
||
first_token_ms,
|
||
true, // is_streaming
|
||
status_code,
|
||
Some(session_id),
|
||
)
|
||
.await;
|
||
});
|
||
} else {
|
||
let model = model_extractor(&events, &request_model);
|
||
let latency_ms = start_time.elapsed().as_millis() as u64;
|
||
let state = state.clone();
|
||
let provider_id = provider_id.clone();
|
||
let session_id = session_id.clone();
|
||
let request_model = request_model.clone();
|
||
|
||
tokio::spawn(async move {
|
||
log_usage_internal(
|
||
&state,
|
||
&provider_id,
|
||
app_type_str,
|
||
&model,
|
||
&request_model,
|
||
TokenUsage::default(),
|
||
latency_ms,
|
||
first_token_ms,
|
||
true, // is_streaming
|
||
status_code,
|
||
Some(session_id),
|
||
)
|
||
.await;
|
||
});
|
||
log::debug!("[{tag}] 流式响应缺少 usage 统计,跳过消费记录");
|
||
}
|
||
})
|
||
}
|
||
|
||
/// 异步记录使用量
|
||
fn spawn_log_usage(
|
||
state: &ProxyState,
|
||
ctx: &RequestContext,
|
||
usage: TokenUsage,
|
||
model: &str,
|
||
request_model: &str,
|
||
status_code: u16,
|
||
is_streaming: bool,
|
||
) {
|
||
// Check enable_logging before spawning the log task
|
||
if let Ok(config) = state.config.try_read() {
|
||
if !config.enable_logging {
|
||
return;
|
||
}
|
||
}
|
||
|
||
let state = state.clone();
|
||
let provider_id = ctx.provider.id.clone();
|
||
let app_type_str = ctx.app_type_str.to_string();
|
||
let model = model.to_string();
|
||
let request_model = request_model.to_string();
|
||
let latency_ms = ctx.latency_ms();
|
||
let session_id = ctx.session_id.clone();
|
||
|
||
tokio::spawn(async move {
|
||
log_usage_internal(
|
||
&state,
|
||
&provider_id,
|
||
&app_type_str,
|
||
&model,
|
||
&request_model,
|
||
usage,
|
||
latency_ms,
|
||
None,
|
||
is_streaming,
|
||
status_code,
|
||
Some(session_id),
|
||
)
|
||
.await;
|
||
});
|
||
}
|
||
|
||
/// 内部使用量记录函数
|
||
#[allow(clippy::too_many_arguments)]
|
||
async fn log_usage_internal(
|
||
state: &ProxyState,
|
||
provider_id: &str,
|
||
app_type: &str,
|
||
model: &str,
|
||
request_model: &str,
|
||
usage: TokenUsage,
|
||
latency_ms: u64,
|
||
first_token_ms: Option<u64>,
|
||
is_streaming: bool,
|
||
status_code: u16,
|
||
session_id: Option<String>,
|
||
) {
|
||
use super::usage::logger::UsageLogger;
|
||
|
||
let logger = UsageLogger::new(&state.db);
|
||
let (multiplier, pricing_model_source) =
|
||
logger.resolve_pricing_config(provider_id, app_type).await;
|
||
let pricing_model = if pricing_model_source == "request" {
|
||
request_model
|
||
} else {
|
||
model
|
||
};
|
||
|
||
let request_id = usage.dedup_request_id();
|
||
|
||
log::debug!(
|
||
"[{app_type}] 记录请求日志: id={request_id}, provider={provider_id}, model={model}, streaming={is_streaming}, status={status_code}, latency_ms={latency_ms}, first_token_ms={first_token_ms:?}, session={}, input={}, output={}, cache_read={}, cache_creation={}",
|
||
session_id.as_deref().unwrap_or("none"),
|
||
usage.input_tokens,
|
||
usage.output_tokens,
|
||
usage.cache_read_tokens,
|
||
usage.cache_creation_tokens
|
||
);
|
||
|
||
if let Err(e) = logger.log_with_calculation(
|
||
request_id,
|
||
provider_id.to_string(),
|
||
app_type.to_string(),
|
||
model.to_string(),
|
||
request_model.to_string(),
|
||
pricing_model.to_string(),
|
||
usage,
|
||
multiplier,
|
||
latency_ms,
|
||
first_token_ms,
|
||
status_code,
|
||
session_id,
|
||
None, // provider_type
|
||
is_streaming,
|
||
) {
|
||
log::warn!("[USG-001] 记录使用量失败: {e}");
|
||
}
|
||
}
|
||
|
||
/// 创建带日志记录和超时控制的透传流
|
||
pub fn create_logged_passthrough_stream(
|
||
stream: impl Stream<Item = Result<Bytes, std::io::Error>> + Send + 'static,
|
||
tag: &'static str,
|
||
usage_collector: Option<SseUsageCollector>,
|
||
timeout_config: StreamingTimeoutConfig,
|
||
) -> impl Stream<Item = Result<Bytes, std::io::Error>> + Send {
|
||
async_stream::stream! {
|
||
let mut buffer = String::new();
|
||
let mut utf8_remainder: Vec<u8> = Vec::new();
|
||
let mut collector = usage_collector;
|
||
let mut is_first_chunk = true;
|
||
|
||
// 超时配置
|
||
let first_byte_timeout = if timeout_config.first_byte_timeout > 0 {
|
||
Some(Duration::from_secs(timeout_config.first_byte_timeout))
|
||
} else {
|
||
None
|
||
};
|
||
let idle_timeout = if timeout_config.idle_timeout > 0 {
|
||
Some(Duration::from_secs(timeout_config.idle_timeout))
|
||
} else {
|
||
None
|
||
};
|
||
|
||
tokio::pin!(stream);
|
||
|
||
loop {
|
||
// 选择超时时间:首字节超时或静默期超时
|
||
let timeout_duration = if is_first_chunk {
|
||
first_byte_timeout
|
||
} else {
|
||
idle_timeout
|
||
};
|
||
|
||
let chunk_result = match timeout_duration {
|
||
Some(duration) => {
|
||
match tokio::time::timeout(duration, stream.next()).await {
|
||
Ok(Some(chunk)) => Some(chunk),
|
||
Ok(None) => None, // 流结束
|
||
Err(_) => {
|
||
// 超时
|
||
let timeout_type = if is_first_chunk { "首字节" } else { "静默期" };
|
||
log::error!("[{tag}] 流式响应{}超时 ({}秒)", timeout_type, duration.as_secs());
|
||
yield Err(std::io::Error::other(format!("流式响应{timeout_type}超时")));
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
None => stream.next().await, // 无超时限制
|
||
};
|
||
|
||
match chunk_result {
|
||
Some(Ok(bytes)) => {
|
||
if is_first_chunk {
|
||
log::debug!(
|
||
"[{tag}] 已接收上游流式首包: bytes={}",
|
||
bytes.len()
|
||
);
|
||
}
|
||
is_first_chunk = false;
|
||
crate::proxy::sse::append_utf8_safe(&mut buffer, &mut utf8_remainder, &bytes);
|
||
|
||
// 尝试解析并记录完整的 SSE 事件
|
||
while let Some(pos) = buffer.find("\n\n") {
|
||
let event_text = buffer[..pos].to_string();
|
||
buffer = buffer[pos + 2..].to_string();
|
||
|
||
if !event_text.trim().is_empty() {
|
||
// 提取 data 部分并尝试解析为 JSON
|
||
for line in event_text.lines() {
|
||
if let Some(data) = strip_sse_field(line, "data") {
|
||
if data.trim() != "[DONE]" {
|
||
if let Ok(json_value) = serde_json::from_str::<Value>(data) {
|
||
if let Some(c) = &collector {
|
||
c.push(json_value.clone()).await;
|
||
}
|
||
log::debug!("[{tag}] <<< SSE 事件: {data}");
|
||
} else {
|
||
log::debug!("[{tag}] <<< SSE 数据: {data}");
|
||
}
|
||
} else {
|
||
log::debug!("[{tag}] <<< SSE: [DONE]");
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
yield Ok(bytes);
|
||
}
|
||
Some(Err(e)) => {
|
||
log::error!("[{tag}] 流错误: {e}");
|
||
yield Err(std::io::Error::other(e.to_string()));
|
||
break;
|
||
}
|
||
None => {
|
||
// 流正常结束
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
|
||
if let Some(c) = collector.take() {
|
||
c.finish().await;
|
||
}
|
||
}
|
||
}
|
||
|
||
fn format_headers(headers: &HeaderMap) -> String {
|
||
headers
|
||
.iter()
|
||
.map(|(key, value)| {
|
||
let value_str = value.to_str().unwrap_or("<non-utf8>");
|
||
format!("{key}={value_str}")
|
||
})
|
||
.collect::<Vec<_>>()
|
||
.join(", ")
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use crate::database::Database;
|
||
use crate::error::AppError;
|
||
use crate::provider::ProviderMeta;
|
||
use crate::proxy::failover_switch::FailoverSwitchManager;
|
||
use crate::proxy::provider_router::ProviderRouter;
|
||
use crate::proxy::types::{ProxyConfig, ProxyStatus};
|
||
use rust_decimal::Decimal;
|
||
use std::collections::HashMap;
|
||
use std::str::FromStr;
|
||
use std::sync::Arc;
|
||
use tokio::sync::RwLock;
|
||
|
||
#[test]
|
||
fn test_strip_sse_field_accepts_optional_space() {
|
||
assert_eq!(
|
||
super::strip_sse_field("data: {\"ok\":true}", "data"),
|
||
Some("{\"ok\":true}")
|
||
);
|
||
assert_eq!(
|
||
super::strip_sse_field("data:{\"ok\":true}", "data"),
|
||
Some("{\"ok\":true}")
|
||
);
|
||
assert_eq!(
|
||
super::strip_sse_field("event: message_start", "event"),
|
||
Some("message_start")
|
||
);
|
||
assert_eq!(
|
||
super::strip_sse_field("event:message_start", "event"),
|
||
Some("message_start")
|
||
);
|
||
assert_eq!(super::strip_sse_field("id:1", "data"), None);
|
||
}
|
||
|
||
#[test]
|
||
fn test_strip_hop_by_hop_response_headers_removes_standard_headers() {
|
||
let mut headers = HeaderMap::new();
|
||
headers.insert(
|
||
axum::http::header::CONNECTION,
|
||
axum::http::HeaderValue::from_static("keep-alive"),
|
||
);
|
||
headers.insert(
|
||
axum::http::header::HeaderName::from_static("keep-alive"),
|
||
axum::http::HeaderValue::from_static("timeout=5"),
|
||
);
|
||
headers.insert(
|
||
axum::http::header::TRANSFER_ENCODING,
|
||
axum::http::HeaderValue::from_static("chunked"),
|
||
);
|
||
headers.insert(
|
||
axum::http::header::HeaderName::from_static("proxy-connection"),
|
||
axum::http::HeaderValue::from_static("keep-alive"),
|
||
);
|
||
headers.insert(
|
||
axum::http::header::CONTENT_TYPE,
|
||
axum::http::HeaderValue::from_static("application/json"),
|
||
);
|
||
headers.insert(
|
||
axum::http::header::CONTENT_LENGTH,
|
||
axum::http::HeaderValue::from_static("12"),
|
||
);
|
||
|
||
strip_hop_by_hop_response_headers(&mut headers);
|
||
|
||
assert!(!headers.contains_key(axum::http::header::CONNECTION));
|
||
assert!(!headers.contains_key("keep-alive"));
|
||
assert!(!headers.contains_key(axum::http::header::TRANSFER_ENCODING));
|
||
assert!(!headers.contains_key("proxy-connection"));
|
||
assert_eq!(
|
||
headers.get(axum::http::header::CONTENT_TYPE),
|
||
Some(&axum::http::HeaderValue::from_static("application/json"))
|
||
);
|
||
assert_eq!(
|
||
headers.get(axum::http::header::CONTENT_LENGTH),
|
||
Some(&axum::http::HeaderValue::from_static("12"))
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_strip_hop_by_hop_response_headers_removes_connection_listed_extensions() {
|
||
let mut headers = HeaderMap::new();
|
||
headers.append(
|
||
axum::http::header::CONNECTION,
|
||
axum::http::HeaderValue::from_static("x-trace-hop, x-debug-hop"),
|
||
);
|
||
headers.append(
|
||
axum::http::header::CONNECTION,
|
||
axum::http::HeaderValue::from_static("upgrade"),
|
||
);
|
||
headers.insert(
|
||
axum::http::header::HeaderName::from_static("x-trace-hop"),
|
||
axum::http::HeaderValue::from_static("trace"),
|
||
);
|
||
headers.insert(
|
||
axum::http::header::HeaderName::from_static("x-debug-hop"),
|
||
axum::http::HeaderValue::from_static("debug"),
|
||
);
|
||
headers.insert(
|
||
axum::http::header::UPGRADE,
|
||
axum::http::HeaderValue::from_static("websocket"),
|
||
);
|
||
headers.insert(
|
||
axum::http::header::CONTENT_TYPE,
|
||
axum::http::HeaderValue::from_static("text/event-stream"),
|
||
);
|
||
|
||
strip_hop_by_hop_response_headers(&mut headers);
|
||
|
||
assert!(!headers.contains_key(axum::http::header::CONNECTION));
|
||
assert!(!headers.contains_key("x-trace-hop"));
|
||
assert!(!headers.contains_key("x-debug-hop"));
|
||
assert!(!headers.contains_key(axum::http::header::UPGRADE));
|
||
assert_eq!(
|
||
headers.get(axum::http::header::CONTENT_TYPE),
|
||
Some(&axum::http::HeaderValue::from_static("text/event-stream"))
|
||
);
|
||
}
|
||
|
||
fn build_state(db: Arc<Database>) -> ProxyState {
|
||
ProxyState {
|
||
db: db.clone(),
|
||
config: Arc::new(RwLock::new(ProxyConfig::default())),
|
||
status: Arc::new(RwLock::new(ProxyStatus::default())),
|
||
start_time: Arc::new(RwLock::new(None)),
|
||
current_providers: Arc::new(RwLock::new(HashMap::new())),
|
||
provider_router: Arc::new(ProviderRouter::new(db.clone())),
|
||
app_handle: None,
|
||
failover_manager: Arc::new(FailoverSwitchManager::new(db)),
|
||
}
|
||
}
|
||
|
||
fn seed_pricing(db: &Database) -> Result<(), AppError> {
|
||
let conn = crate::database::lock_conn!(db.conn);
|
||
conn.execute(
|
||
"INSERT OR REPLACE INTO model_pricing (model_id, display_name, input_cost_per_million, output_cost_per_million)
|
||
VALUES (?1, ?2, ?3, ?4)",
|
||
rusqlite::params!["resp-model", "Resp Model", "1.0", "0"],
|
||
)
|
||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||
conn.execute(
|
||
"INSERT OR REPLACE INTO model_pricing (model_id, display_name, input_cost_per_million, output_cost_per_million)
|
||
VALUES (?1, ?2, ?3, ?4)",
|
||
rusqlite::params!["req-model", "Req Model", "2.0", "0"],
|
||
)
|
||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||
Ok(())
|
||
}
|
||
|
||
fn insert_provider(
|
||
db: &Database,
|
||
id: &str,
|
||
app_type: &str,
|
||
meta: ProviderMeta,
|
||
) -> Result<(), AppError> {
|
||
let meta_json =
|
||
serde_json::to_string(&meta).map_err(|e| AppError::Database(e.to_string()))?;
|
||
let conn = crate::database::lock_conn!(db.conn);
|
||
conn.execute(
|
||
"INSERT INTO providers (id, app_type, name, settings_config, meta)
|
||
VALUES (?1, ?2, ?3, ?4, ?5)",
|
||
rusqlite::params![id, app_type, "Test Provider", "{}", meta_json],
|
||
)
|
||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||
Ok(())
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_log_usage_uses_provider_override_config() -> Result<(), AppError> {
|
||
let db = Arc::new(Database::memory()?);
|
||
let app_type = "claude";
|
||
|
||
db.set_default_cost_multiplier(app_type, "1.5").await?;
|
||
db.set_pricing_model_source(app_type, "response").await?;
|
||
seed_pricing(&db)?;
|
||
|
||
let mut meta = ProviderMeta::default();
|
||
meta.cost_multiplier = Some("2".to_string());
|
||
meta.pricing_model_source = Some("request".to_string());
|
||
insert_provider(&db, "provider-1", app_type, meta)?;
|
||
|
||
let state = build_state(db.clone());
|
||
let usage = TokenUsage {
|
||
input_tokens: 1_000_000,
|
||
output_tokens: 0,
|
||
cache_read_tokens: 0,
|
||
cache_creation_tokens: 0,
|
||
model: None,
|
||
message_id: None,
|
||
};
|
||
|
||
log_usage_internal(
|
||
&state,
|
||
"provider-1",
|
||
app_type,
|
||
"resp-model",
|
||
"req-model",
|
||
usage,
|
||
10,
|
||
None,
|
||
false,
|
||
200,
|
||
None,
|
||
)
|
||
.await;
|
||
|
||
let conn = crate::database::lock_conn!(db.conn);
|
||
let (model, request_model, total_cost, cost_multiplier): (String, String, String, String) =
|
||
conn.query_row(
|
||
"SELECT model, request_model, total_cost_usd, cost_multiplier
|
||
FROM proxy_request_logs WHERE provider_id = ?1",
|
||
["provider-1"],
|
||
|row| Ok((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?)),
|
||
)
|
||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||
|
||
assert_eq!(model, "resp-model");
|
||
assert_eq!(request_model, "req-model");
|
||
assert_eq!(
|
||
Decimal::from_str(&cost_multiplier).unwrap(),
|
||
Decimal::from_str("2").unwrap()
|
||
);
|
||
assert_eq!(
|
||
Decimal::from_str(&total_cost).unwrap(),
|
||
Decimal::from_str("4").unwrap()
|
||
);
|
||
Ok(())
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_log_usage_falls_back_to_global_defaults() -> Result<(), AppError> {
|
||
let db = Arc::new(Database::memory()?);
|
||
let app_type = "claude";
|
||
|
||
db.set_default_cost_multiplier(app_type, "1.5").await?;
|
||
db.set_pricing_model_source(app_type, "response").await?;
|
||
seed_pricing(&db)?;
|
||
|
||
let meta = ProviderMeta::default();
|
||
insert_provider(&db, "provider-2", app_type, meta)?;
|
||
|
||
let state = build_state(db.clone());
|
||
let usage = TokenUsage {
|
||
input_tokens: 1_000_000,
|
||
output_tokens: 0,
|
||
cache_read_tokens: 0,
|
||
cache_creation_tokens: 0,
|
||
model: None,
|
||
message_id: None,
|
||
};
|
||
|
||
log_usage_internal(
|
||
&state,
|
||
"provider-2",
|
||
app_type,
|
||
"resp-model",
|
||
"req-model",
|
||
usage,
|
||
10,
|
||
None,
|
||
false,
|
||
200,
|
||
None,
|
||
)
|
||
.await;
|
||
|
||
let conn = crate::database::lock_conn!(db.conn);
|
||
let (total_cost, cost_multiplier): (String, String) = conn
|
||
.query_row(
|
||
"SELECT total_cost_usd, cost_multiplier
|
||
FROM proxy_request_logs WHERE provider_id = ?1",
|
||
["provider-2"],
|
||
|row| Ok((row.get(0)?, row.get(1)?)),
|
||
)
|
||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||
|
||
assert_eq!(
|
||
Decimal::from_str(&cost_multiplier).unwrap(),
|
||
Decimal::from_str("1.5").unwrap()
|
||
);
|
||
assert_eq!(
|
||
Decimal::from_str(&total_cost).unwrap(),
|
||
Decimal::from_str("1.5").unwrap()
|
||
);
|
||
Ok(())
|
||
}
|
||
}
|