Files
cc-switch/src-tauri/src/proxy/response_processor.rs
T

1009 lines
33 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! 响应处理器模块
//!
//! 统一处理流式和非流式 API 响应
use super::{
handler_config::UsageParserConfig,
handler_context::{RequestContext, StreamingTimeoutConfig},
hyper_client::ProxyResponse,
server::ProxyState,
sse::strip_sse_field,
usage::parser::TokenUsage,
ProxyError,
};
use axum::http::{header::HeaderMap, HeaderName};
use axum::response::{IntoResponse, Response};
use bytes::Bytes;
use futures::stream::{Stream, StreamExt};
use serde_json::Value;
use std::{
io::Read,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use tokio::sync::Mutex;
// ============================================================================
// 响应解压
// ============================================================================
/// 根据 content-encoding 解压响应体字节
///
/// reqwest 自动解压已禁用(为了透传 accept-encoding),需要手动解压。
fn decompress_body(content_encoding: &str, body: &[u8]) -> Result<Vec<u8>, std::io::Error> {
match content_encoding {
"gzip" | "x-gzip" => {
let mut decoder = flate2::read::GzDecoder::new(body);
let mut decompressed = Vec::new();
decoder.read_to_end(&mut decompressed)?;
Ok(decompressed)
}
"deflate" => {
let mut decoder = flate2::read::DeflateDecoder::new(body);
let mut decompressed = Vec::new();
decoder.read_to_end(&mut decompressed)?;
Ok(decompressed)
}
"br" => {
let mut decompressed = Vec::new();
brotli::BrotliDecompress(&mut std::io::Cursor::new(body), &mut decompressed)?;
Ok(decompressed)
}
_ => {
log::warn!("未知的 content-encoding: {content_encoding},跳过解压");
Ok(body.to_vec())
}
}
}
/// 从响应头提取 content-encoding(忽略 identity 和 chunked
fn get_content_encoding(headers: &HeaderMap) -> Option<String> {
headers
.get("content-encoding")
.and_then(|v| v.to_str().ok())
.map(|s| s.trim().to_lowercase())
.filter(|s| !s.is_empty() && s != "identity")
}
/// RFC 2616 / RFC 7230 中定义的不应被代理继续转发的响应头。
const HOP_BY_HOP_RESPONSE_HEADERS: &[&str] = &[
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"proxy-connection",
"te",
"trailer",
"trailers",
"transfer-encoding",
"upgrade",
];
/// 移除响应侧 hop-by-hop 头,以及 `Connection` 中点名的扩展头。
pub(crate) fn strip_hop_by_hop_response_headers(headers: &mut HeaderMap) {
let connection_listed_headers: Vec<HeaderName> = headers
.get_all(axum::http::header::CONNECTION)
.iter()
.filter_map(|value| value.to_str().ok())
.flat_map(|value| value.split(','))
.map(str::trim)
.filter(|name| !name.is_empty())
.filter_map(|name| HeaderName::from_bytes(name.as_bytes()).ok())
.collect();
for name in HOP_BY_HOP_RESPONSE_HEADERS {
headers.remove(*name);
}
for name in connection_listed_headers {
headers.remove(name);
}
}
/// 移除在重建响应体后会失真的实体头。
pub(crate) fn strip_entity_headers_for_rebuilt_body(headers: &mut HeaderMap) {
headers.remove(axum::http::header::CONTENT_ENCODING);
headers.remove(axum::http::header::CONTENT_LENGTH);
headers.remove(axum::http::header::TRANSFER_ENCODING);
}
/// 读取响应体并在需要时解压,确保 headers 与返回 body 一致。
///
/// `body_timeout`: 整包超时。当非零时用 `tokio::time::timeout` 包住 `.bytes()` 调用,
/// 防止上游发完响应头后卡住 body 导致请求永远挂住。
/// 传入 `Duration::ZERO` 表示不启用超时(故障转移关闭时)。
pub(crate) async fn read_decoded_body(
response: ProxyResponse,
tag: &str,
body_timeout: Duration,
) -> Result<(HeaderMap, http::StatusCode, Bytes), ProxyError> {
let mut headers = response.headers().clone();
let status = response.status();
let raw_bytes = if body_timeout.is_zero() {
response.bytes().await?
} else {
tokio::time::timeout(body_timeout, response.bytes())
.await
.map_err(|_| {
ProxyError::Timeout(format!(
"响应体读取超时: {}s(上游发完响应头后 body 未到达)",
body_timeout.as_secs()
))
})??
};
log::debug!(
"[{tag}] 已接收上游响应体: status={}, bytes={}, headers={}",
status.as_u16(),
raw_bytes.len(),
format_headers(&headers)
);
let mut body_bytes = raw_bytes.clone();
let mut decoded = false;
if let Some(encoding) = get_content_encoding(&headers) {
log::debug!("[{tag}] 解压非流式响应: content-encoding={encoding}");
match decompress_body(&encoding, &raw_bytes) {
Ok(decompressed) => {
body_bytes = Bytes::from(decompressed);
decoded = true;
}
Err(e) => {
log::warn!("[{tag}] 解压失败 ({encoding}): {e},使用原始数据");
}
}
}
if decoded {
strip_entity_headers_for_rebuilt_body(&mut headers);
}
Ok((headers, status, body_bytes))
}
// ============================================================================
// 公共接口
// ============================================================================
/// 检测响应是否为 SSE 流式响应
#[inline]
pub fn is_sse_response(response: &ProxyResponse) -> bool {
response.is_sse()
}
/// 处理流式响应
pub async fn handle_streaming(
response: ProxyResponse,
ctx: &RequestContext,
state: &ProxyState,
parser_config: &UsageParserConfig,
) -> Response {
let status = response.status();
log::debug!(
"[{}] 已接收上游流式响应: status={}, headers={}",
ctx.tag,
status.as_u16(),
format_headers(response.headers())
);
// 检查流式响应是否被压缩(SSE 通常不压缩,如果压缩则 SSE 解析会失败)
if let Some(encoding) = get_content_encoding(response.headers()) {
log::warn!(
"[{}] 流式响应含 content-encoding={encoding}SSE 解析可能失败。\
上游在 accept-encoding 透传后压缩了 SSE 流。",
ctx.tag
);
}
let mut response_headers = response.headers().clone();
strip_hop_by_hop_response_headers(&mut response_headers);
let mut builder = axum::response::Response::builder().status(status);
// 复制响应头
for (key, value) in &response_headers {
builder = builder.header(key, value);
}
// 创建字节流
let stream = response.bytes_stream();
// 创建使用量收集器
let usage_collector = create_usage_collector(ctx, state, status.as_u16(), parser_config);
// 获取流式超时配置
let timeout_config = ctx.streaming_timeout_config();
// 创建带日志和超时的透传流
let logged_stream =
create_logged_passthrough_stream(stream, ctx.tag, Some(usage_collector), timeout_config);
let body = axum::body::Body::from_stream(logged_stream);
match builder.body(body) {
Ok(resp) => resp,
Err(e) => {
log::error!("[{}] 构建流式响应失败: {e}", ctx.tag);
ProxyError::Internal(format!("Failed to build streaming response: {e}")).into_response()
}
}
}
/// 处理非流式响应
pub async fn handle_non_streaming(
response: ProxyResponse,
ctx: &RequestContext,
state: &ProxyState,
parser_config: &UsageParserConfig,
) -> Result<Response, ProxyError> {
// 整包超时:仅在故障转移开启且配置值非零时生效
let body_timeout =
if ctx.app_config.auto_failover_enabled && ctx.app_config.non_streaming_timeout > 0 {
Duration::from_secs(ctx.app_config.non_streaming_timeout as u64)
} else {
Duration::ZERO
};
let (mut response_headers, status, body_bytes) =
read_decoded_body(response, ctx.tag, body_timeout).await?;
strip_hop_by_hop_response_headers(&mut response_headers);
log::debug!(
"[{}] 上游响应体内容: {}",
ctx.tag,
String::from_utf8_lossy(&body_bytes)
);
// 解析并记录使用量
if let Ok(json_value) = serde_json::from_slice::<Value>(&body_bytes) {
// 解析使用量
if let Some(usage) = (parser_config.response_parser)(&json_value) {
// 优先使用 usage 中解析出的模型名称,其次使用响应中的 model 字段,最后回退到请求模型
let model = if let Some(ref m) = usage.model {
m.clone()
} else if let Some(m) = json_value.get("model").and_then(|m| m.as_str()) {
m.to_string()
} else {
ctx.request_model.clone()
};
spawn_log_usage(
state,
ctx,
usage,
&model,
&ctx.request_model,
status.as_u16(),
false,
);
} else {
let model = json_value
.get("model")
.and_then(|m| m.as_str())
.unwrap_or(&ctx.request_model)
.to_string();
spawn_log_usage(
state,
ctx,
TokenUsage::default(),
&model,
&ctx.request_model,
status.as_u16(),
false,
);
log::debug!(
"[{}] 未能解析 usage 信息,跳过记录",
parser_config.app_type_str
);
}
} else {
log::debug!(
"[{}] <<< 响应 (非 JSON): {} bytes",
ctx.tag,
body_bytes.len()
);
spawn_log_usage(
state,
ctx,
TokenUsage::default(),
&ctx.request_model,
&ctx.request_model,
status.as_u16(),
false,
);
}
// 构建响应
let mut builder = axum::response::Response::builder().status(status);
for (key, value) in response_headers.iter() {
builder = builder.header(key, value);
}
let body = axum::body::Body::from(body_bytes);
builder.body(body).map_err(|e| {
log::error!("[{}] 构建响应失败: {e}", ctx.tag);
ProxyError::Internal(format!("Failed to build response: {e}"))
})
}
/// 通用响应处理入口
///
/// 根据响应类型自动选择流式或非流式处理
pub async fn process_response(
response: ProxyResponse,
ctx: &RequestContext,
state: &ProxyState,
parser_config: &UsageParserConfig,
) -> Result<Response, ProxyError> {
if is_sse_response(&response) {
Ok(handle_streaming(response, ctx, state, parser_config).await)
} else {
handle_non_streaming(response, ctx, state, parser_config).await
}
}
// ============================================================================
// SSE 使用量收集器
// ============================================================================
type UsageCallbackWithTiming = Arc<dyn Fn(Vec<Value>, Option<u64>) + Send + Sync + 'static>;
/// SSE 使用量收集器
#[derive(Clone)]
pub struct SseUsageCollector {
inner: Arc<SseUsageCollectorInner>,
}
struct SseUsageCollectorInner {
events: Mutex<Vec<Value>>,
first_event_time: Mutex<Option<std::time::Instant>>,
start_time: std::time::Instant,
on_complete: UsageCallbackWithTiming,
finished: AtomicBool,
}
impl SseUsageCollector {
/// 创建新的使用量收集器
pub fn new(
start_time: std::time::Instant,
callback: impl Fn(Vec<Value>, Option<u64>) + Send + Sync + 'static,
) -> Self {
let on_complete: UsageCallbackWithTiming = Arc::new(callback);
Self {
inner: Arc::new(SseUsageCollectorInner {
events: Mutex::new(Vec::new()),
first_event_time: Mutex::new(None),
start_time,
on_complete,
finished: AtomicBool::new(false),
}),
}
}
/// 推送 SSE 事件
pub async fn push(&self, event: Value) {
// 记录首个事件时间
{
let mut first_time = self.inner.first_event_time.lock().await;
if first_time.is_none() {
*first_time = Some(std::time::Instant::now());
}
}
let mut events = self.inner.events.lock().await;
events.push(event);
}
/// 完成收集并触发回调
pub async fn finish(&self) {
if self.inner.finished.swap(true, Ordering::SeqCst) {
return;
}
let events = {
let mut guard = self.inner.events.lock().await;
std::mem::take(&mut *guard)
};
let first_token_ms = {
let first_time = self.inner.first_event_time.lock().await;
first_time.map(|t| (t - self.inner.start_time).as_millis() as u64)
};
(self.inner.on_complete)(events, first_token_ms);
}
}
// ============================================================================
// 内部辅助函数
// ============================================================================
/// 创建使用量收集器
fn create_usage_collector(
ctx: &RequestContext,
state: &ProxyState,
status_code: u16,
parser_config: &UsageParserConfig,
) -> SseUsageCollector {
let logging_enabled = state
.config
.try_read()
.map(|c| c.enable_logging)
.unwrap_or(true);
let state = state.clone();
let provider_id = ctx.provider.id.clone();
let request_model = ctx.request_model.clone();
let app_type_str = parser_config.app_type_str;
let tag = ctx.tag;
let start_time = ctx.start_time;
let stream_parser = parser_config.stream_parser;
let model_extractor = parser_config.model_extractor;
let session_id = ctx.session_id.clone();
SseUsageCollector::new(start_time, move |events, first_token_ms| {
if !logging_enabled {
return;
}
if let Some(usage) = stream_parser(&events) {
let model = model_extractor(&events, &request_model);
let latency_ms = start_time.elapsed().as_millis() as u64;
let state = state.clone();
let provider_id = provider_id.clone();
let session_id = session_id.clone();
let request_model = request_model.clone();
tokio::spawn(async move {
log_usage_internal(
&state,
&provider_id,
app_type_str,
&model,
&request_model,
usage,
latency_ms,
first_token_ms,
true, // is_streaming
status_code,
Some(session_id),
)
.await;
});
} else {
let model = model_extractor(&events, &request_model);
let latency_ms = start_time.elapsed().as_millis() as u64;
let state = state.clone();
let provider_id = provider_id.clone();
let session_id = session_id.clone();
let request_model = request_model.clone();
tokio::spawn(async move {
log_usage_internal(
&state,
&provider_id,
app_type_str,
&model,
&request_model,
TokenUsage::default(),
latency_ms,
first_token_ms,
true, // is_streaming
status_code,
Some(session_id),
)
.await;
});
log::debug!("[{tag}] 流式响应缺少 usage 统计,跳过消费记录");
}
})
}
/// 异步记录使用量
fn spawn_log_usage(
state: &ProxyState,
ctx: &RequestContext,
usage: TokenUsage,
model: &str,
request_model: &str,
status_code: u16,
is_streaming: bool,
) {
// Check enable_logging before spawning the log task
if let Ok(config) = state.config.try_read() {
if !config.enable_logging {
return;
}
}
let state = state.clone();
let provider_id = ctx.provider.id.clone();
let app_type_str = ctx.app_type_str.to_string();
let model = model.to_string();
let request_model = request_model.to_string();
let latency_ms = ctx.latency_ms();
let session_id = ctx.session_id.clone();
tokio::spawn(async move {
log_usage_internal(
&state,
&provider_id,
&app_type_str,
&model,
&request_model,
usage,
latency_ms,
None,
is_streaming,
status_code,
Some(session_id),
)
.await;
});
}
/// 内部使用量记录函数
#[allow(clippy::too_many_arguments)]
async fn log_usage_internal(
state: &ProxyState,
provider_id: &str,
app_type: &str,
model: &str,
request_model: &str,
usage: TokenUsage,
latency_ms: u64,
first_token_ms: Option<u64>,
is_streaming: bool,
status_code: u16,
session_id: Option<String>,
) {
use super::usage::logger::UsageLogger;
let logger = UsageLogger::new(&state.db);
let (multiplier, pricing_model_source) =
logger.resolve_pricing_config(provider_id, app_type).await;
let pricing_model = if pricing_model_source == "request" {
request_model
} else {
model
};
let request_id = usage.dedup_request_id();
log::debug!(
"[{app_type}] 记录请求日志: id={request_id}, provider={provider_id}, model={model}, streaming={is_streaming}, status={status_code}, latency_ms={latency_ms}, first_token_ms={first_token_ms:?}, session={}, input={}, output={}, cache_read={}, cache_creation={}",
session_id.as_deref().unwrap_or("none"),
usage.input_tokens,
usage.output_tokens,
usage.cache_read_tokens,
usage.cache_creation_tokens
);
if let Err(e) = logger.log_with_calculation(
request_id,
provider_id.to_string(),
app_type.to_string(),
model.to_string(),
request_model.to_string(),
pricing_model.to_string(),
usage,
multiplier,
latency_ms,
first_token_ms,
status_code,
session_id,
None, // provider_type
is_streaming,
) {
log::warn!("[USG-001] 记录使用量失败: {e}");
}
}
/// 创建带日志记录和超时控制的透传流
pub fn create_logged_passthrough_stream(
stream: impl Stream<Item = Result<Bytes, std::io::Error>> + Send + 'static,
tag: &'static str,
usage_collector: Option<SseUsageCollector>,
timeout_config: StreamingTimeoutConfig,
) -> impl Stream<Item = Result<Bytes, std::io::Error>> + Send {
async_stream::stream! {
let mut buffer = String::new();
let mut utf8_remainder: Vec<u8> = Vec::new();
let mut collector = usage_collector;
let mut is_first_chunk = true;
// 超时配置
let first_byte_timeout = if timeout_config.first_byte_timeout > 0 {
Some(Duration::from_secs(timeout_config.first_byte_timeout))
} else {
None
};
let idle_timeout = if timeout_config.idle_timeout > 0 {
Some(Duration::from_secs(timeout_config.idle_timeout))
} else {
None
};
tokio::pin!(stream);
loop {
// 选择超时时间:首字节超时或静默期超时
let timeout_duration = if is_first_chunk {
first_byte_timeout
} else {
idle_timeout
};
let chunk_result = match timeout_duration {
Some(duration) => {
match tokio::time::timeout(duration, stream.next()).await {
Ok(Some(chunk)) => Some(chunk),
Ok(None) => None, // 流结束
Err(_) => {
// 超时
let timeout_type = if is_first_chunk { "首字节" } else { "静默期" };
log::error!("[{tag}] 流式响应{}超时 ({}秒)", timeout_type, duration.as_secs());
yield Err(std::io::Error::other(format!("流式响应{timeout_type}超时")));
break;
}
}
}
None => stream.next().await, // 无超时限制
};
match chunk_result {
Some(Ok(bytes)) => {
if is_first_chunk {
log::debug!(
"[{tag}] 已接收上游流式首包: bytes={}",
bytes.len()
);
}
is_first_chunk = false;
crate::proxy::sse::append_utf8_safe(&mut buffer, &mut utf8_remainder, &bytes);
// 尝试解析并记录完整的 SSE 事件
while let Some(pos) = buffer.find("\n\n") {
let event_text = buffer[..pos].to_string();
buffer = buffer[pos + 2..].to_string();
if !event_text.trim().is_empty() {
// 提取 data 部分并尝试解析为 JSON
for line in event_text.lines() {
if let Some(data) = strip_sse_field(line, "data") {
if data.trim() != "[DONE]" {
if let Ok(json_value) = serde_json::from_str::<Value>(data) {
if let Some(c) = &collector {
c.push(json_value.clone()).await;
}
log::debug!("[{tag}] <<< SSE 事件: {data}");
} else {
log::debug!("[{tag}] <<< SSE 数据: {data}");
}
} else {
log::debug!("[{tag}] <<< SSE: [DONE]");
}
}
}
}
}
yield Ok(bytes);
}
Some(Err(e)) => {
log::error!("[{tag}] 流错误: {e}");
yield Err(std::io::Error::other(e.to_string()));
break;
}
None => {
// 流正常结束
break;
}
}
}
if let Some(c) = collector.take() {
c.finish().await;
}
}
}
fn format_headers(headers: &HeaderMap) -> String {
headers
.iter()
.map(|(key, value)| {
let value_str = value.to_str().unwrap_or("<non-utf8>");
format!("{key}={value_str}")
})
.collect::<Vec<_>>()
.join(", ")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::database::Database;
use crate::error::AppError;
use crate::provider::ProviderMeta;
use crate::proxy::failover_switch::FailoverSwitchManager;
use crate::proxy::provider_router::ProviderRouter;
use crate::proxy::types::{ProxyConfig, ProxyStatus};
use rust_decimal::Decimal;
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::RwLock;
#[test]
fn test_strip_sse_field_accepts_optional_space() {
assert_eq!(
super::strip_sse_field("data: {\"ok\":true}", "data"),
Some("{\"ok\":true}")
);
assert_eq!(
super::strip_sse_field("data:{\"ok\":true}", "data"),
Some("{\"ok\":true}")
);
assert_eq!(
super::strip_sse_field("event: message_start", "event"),
Some("message_start")
);
assert_eq!(
super::strip_sse_field("event:message_start", "event"),
Some("message_start")
);
assert_eq!(super::strip_sse_field("id:1", "data"), None);
}
#[test]
fn test_strip_hop_by_hop_response_headers_removes_standard_headers() {
let mut headers = HeaderMap::new();
headers.insert(
axum::http::header::CONNECTION,
axum::http::HeaderValue::from_static("keep-alive"),
);
headers.insert(
axum::http::header::HeaderName::from_static("keep-alive"),
axum::http::HeaderValue::from_static("timeout=5"),
);
headers.insert(
axum::http::header::TRANSFER_ENCODING,
axum::http::HeaderValue::from_static("chunked"),
);
headers.insert(
axum::http::header::HeaderName::from_static("proxy-connection"),
axum::http::HeaderValue::from_static("keep-alive"),
);
headers.insert(
axum::http::header::CONTENT_TYPE,
axum::http::HeaderValue::from_static("application/json"),
);
headers.insert(
axum::http::header::CONTENT_LENGTH,
axum::http::HeaderValue::from_static("12"),
);
strip_hop_by_hop_response_headers(&mut headers);
assert!(!headers.contains_key(axum::http::header::CONNECTION));
assert!(!headers.contains_key("keep-alive"));
assert!(!headers.contains_key(axum::http::header::TRANSFER_ENCODING));
assert!(!headers.contains_key("proxy-connection"));
assert_eq!(
headers.get(axum::http::header::CONTENT_TYPE),
Some(&axum::http::HeaderValue::from_static("application/json"))
);
assert_eq!(
headers.get(axum::http::header::CONTENT_LENGTH),
Some(&axum::http::HeaderValue::from_static("12"))
);
}
#[test]
fn test_strip_hop_by_hop_response_headers_removes_connection_listed_extensions() {
let mut headers = HeaderMap::new();
headers.append(
axum::http::header::CONNECTION,
axum::http::HeaderValue::from_static("x-trace-hop, x-debug-hop"),
);
headers.append(
axum::http::header::CONNECTION,
axum::http::HeaderValue::from_static("upgrade"),
);
headers.insert(
axum::http::header::HeaderName::from_static("x-trace-hop"),
axum::http::HeaderValue::from_static("trace"),
);
headers.insert(
axum::http::header::HeaderName::from_static("x-debug-hop"),
axum::http::HeaderValue::from_static("debug"),
);
headers.insert(
axum::http::header::UPGRADE,
axum::http::HeaderValue::from_static("websocket"),
);
headers.insert(
axum::http::header::CONTENT_TYPE,
axum::http::HeaderValue::from_static("text/event-stream"),
);
strip_hop_by_hop_response_headers(&mut headers);
assert!(!headers.contains_key(axum::http::header::CONNECTION));
assert!(!headers.contains_key("x-trace-hop"));
assert!(!headers.contains_key("x-debug-hop"));
assert!(!headers.contains_key(axum::http::header::UPGRADE));
assert_eq!(
headers.get(axum::http::header::CONTENT_TYPE),
Some(&axum::http::HeaderValue::from_static("text/event-stream"))
);
}
fn build_state(db: Arc<Database>) -> ProxyState {
ProxyState {
db: db.clone(),
config: Arc::new(RwLock::new(ProxyConfig::default())),
status: Arc::new(RwLock::new(ProxyStatus::default())),
start_time: Arc::new(RwLock::new(None)),
current_providers: Arc::new(RwLock::new(HashMap::new())),
provider_router: Arc::new(ProviderRouter::new(db.clone())),
app_handle: None,
failover_manager: Arc::new(FailoverSwitchManager::new(db)),
}
}
fn seed_pricing(db: &Database) -> Result<(), AppError> {
let conn = crate::database::lock_conn!(db.conn);
conn.execute(
"INSERT OR REPLACE INTO model_pricing (model_id, display_name, input_cost_per_million, output_cost_per_million)
VALUES (?1, ?2, ?3, ?4)",
rusqlite::params!["resp-model", "Resp Model", "1.0", "0"],
)
.map_err(|e| AppError::Database(e.to_string()))?;
conn.execute(
"INSERT OR REPLACE INTO model_pricing (model_id, display_name, input_cost_per_million, output_cost_per_million)
VALUES (?1, ?2, ?3, ?4)",
rusqlite::params!["req-model", "Req Model", "2.0", "0"],
)
.map_err(|e| AppError::Database(e.to_string()))?;
Ok(())
}
fn insert_provider(
db: &Database,
id: &str,
app_type: &str,
meta: ProviderMeta,
) -> Result<(), AppError> {
let meta_json =
serde_json::to_string(&meta).map_err(|e| AppError::Database(e.to_string()))?;
let conn = crate::database::lock_conn!(db.conn);
conn.execute(
"INSERT INTO providers (id, app_type, name, settings_config, meta)
VALUES (?1, ?2, ?3, ?4, ?5)",
rusqlite::params![id, app_type, "Test Provider", "{}", meta_json],
)
.map_err(|e| AppError::Database(e.to_string()))?;
Ok(())
}
#[tokio::test]
async fn test_log_usage_uses_provider_override_config() -> Result<(), AppError> {
let db = Arc::new(Database::memory()?);
let app_type = "claude";
db.set_default_cost_multiplier(app_type, "1.5").await?;
db.set_pricing_model_source(app_type, "response").await?;
seed_pricing(&db)?;
let mut meta = ProviderMeta::default();
meta.cost_multiplier = Some("2".to_string());
meta.pricing_model_source = Some("request".to_string());
insert_provider(&db, "provider-1", app_type, meta)?;
let state = build_state(db.clone());
let usage = TokenUsage {
input_tokens: 1_000_000,
output_tokens: 0,
cache_read_tokens: 0,
cache_creation_tokens: 0,
model: None,
message_id: None,
};
log_usage_internal(
&state,
"provider-1",
app_type,
"resp-model",
"req-model",
usage,
10,
None,
false,
200,
None,
)
.await;
let conn = crate::database::lock_conn!(db.conn);
let (model, request_model, total_cost, cost_multiplier): (String, String, String, String) =
conn.query_row(
"SELECT model, request_model, total_cost_usd, cost_multiplier
FROM proxy_request_logs WHERE provider_id = ?1",
["provider-1"],
|row| Ok((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?)),
)
.map_err(|e| AppError::Database(e.to_string()))?;
assert_eq!(model, "resp-model");
assert_eq!(request_model, "req-model");
assert_eq!(
Decimal::from_str(&cost_multiplier).unwrap(),
Decimal::from_str("2").unwrap()
);
assert_eq!(
Decimal::from_str(&total_cost).unwrap(),
Decimal::from_str("4").unwrap()
);
Ok(())
}
#[tokio::test]
async fn test_log_usage_falls_back_to_global_defaults() -> Result<(), AppError> {
let db = Arc::new(Database::memory()?);
let app_type = "claude";
db.set_default_cost_multiplier(app_type, "1.5").await?;
db.set_pricing_model_source(app_type, "response").await?;
seed_pricing(&db)?;
let meta = ProviderMeta::default();
insert_provider(&db, "provider-2", app_type, meta)?;
let state = build_state(db.clone());
let usage = TokenUsage {
input_tokens: 1_000_000,
output_tokens: 0,
cache_read_tokens: 0,
cache_creation_tokens: 0,
model: None,
message_id: None,
};
log_usage_internal(
&state,
"provider-2",
app_type,
"resp-model",
"req-model",
usage,
10,
None,
false,
200,
None,
)
.await;
let conn = crate::database::lock_conn!(db.conn);
let (total_cost, cost_multiplier): (String, String) = conn
.query_row(
"SELECT total_cost_usd, cost_multiplier
FROM proxy_request_logs WHERE provider_id = ?1",
["provider-2"],
|row| Ok((row.get(0)?, row.get(1)?)),
)
.map_err(|e| AppError::Database(e.to_string()))?;
assert_eq!(
Decimal::from_str(&cost_multiplier).unwrap(),
Decimal::from_str("1.5").unwrap()
);
assert_eq!(
Decimal::from_str(&total_cost).unwrap(),
Decimal::from_str("1.5").unwrap()
);
Ok(())
}
}