From 63b874aff1b434c23e7ae4fcfe6a9aed279f8e57 Mon Sep 17 00:00:00 2001 From: YoVinchen Date: Mon, 26 Jan 2026 01:40:18 +0800 Subject: [PATCH] fix(proxy): apply cost multiplier to total cost only - Move multiplier calculation from per-item to total cost - Add resolve_pricing_config for provider-level override - Include request_model and cost_multiplier in usage logs - Return new fields in get_request_logs API --- src-tauri/src/proxy/handlers.rs | 35 ++-- src-tauri/src/proxy/response_processor.rs | 233 +++++++++++++++++++--- src-tauri/src/proxy/usage/calculator.rs | 27 +-- src-tauri/src/proxy/usage/logger.rs | 120 ++++++++++- src-tauri/src/services/usage_stats.rs | 83 ++++---- 5 files changed, 395 insertions(+), 103 deletions(-) diff --git a/src-tauri/src/proxy/handlers.rs b/src-tauri/src/proxy/handlers.rs index 664b849b..e5a4cfd7 100644 --- a/src-tauri/src/proxy/handlers.rs +++ b/src-tauri/src/proxy/handlers.rs @@ -22,9 +22,7 @@ use super::{ }; use crate::app_config::AppType; use axum::{extract::State, http::StatusCode, response::IntoResponse, Json}; -use rust_decimal::Decimal; use serde_json::{json, Value}; -use std::str::FromStr; // ============================================================================ // 健康检查和状态查询(简单端点) @@ -145,6 +143,7 @@ async fn handle_claude_transform( &provider_id, "claude", &model, + &model, usage, latency_ms, first_token_ms, @@ -215,6 +214,7 @@ async fn handle_claude_transform( .unwrap_or("unknown"); let latency_ms = ctx.latency_ms(); + let request_model = ctx.request_model.clone(); tokio::spawn({ let state = state.clone(); let provider_id = ctx.provider.id.clone(); @@ -225,6 +225,7 @@ async fn handle_claude_transform( &provider_id, "claude", &model, + &request_model, usage, latency_ms, None, @@ -441,6 +442,7 @@ async fn log_usage( provider_id: &str, app_type: &str, model: &str, + request_model: &str, usage: TokenUsage, latency_ms: u64, first_token_ms: Option, @@ -450,26 +452,13 @@ async fn log_usage( use super::usage::logger::UsageLogger; let logger = UsageLogger::new(&state.db); - - // 获取 provider 的 cost_multiplier - let multiplier = match state.db.get_provider_by_id(provider_id, app_type) { - Ok(Some(p)) => { - if let Some(meta) = p.meta { - if let Some(cm) = meta.cost_multiplier { - Decimal::from_str(&cm).unwrap_or_else(|e| { - log::warn!( - "cost_multiplier 解析失败 (provider_id={provider_id}): {cm} - {e}" - ); - Decimal::from(1) - }) - } else { - Decimal::from(1) - } - } else { - Decimal::from(1) - } - } - _ => Decimal::from(1), + let (multiplier, pricing_model_source) = logger + .resolve_pricing_config(provider_id, app_type) + .await; + let pricing_model = if pricing_model_source == "request" { + request_model + } else { + model }; let request_id = uuid::Uuid::new_v4().to_string(); @@ -479,6 +468,8 @@ async fn log_usage( provider_id.to_string(), app_type.to_string(), model.to_string(), + request_model.to_string(), + pricing_model.to_string(), usage, multiplier, latency_ms, diff --git a/src-tauri/src/proxy/response_processor.rs b/src-tauri/src/proxy/response_processor.rs index b7af4352..10608ec7 100644 --- a/src-tauri/src/proxy/response_processor.rs +++ b/src-tauri/src/proxy/response_processor.rs @@ -13,10 +13,8 @@ use axum::response::{IntoResponse, Response}; use bytes::Bytes; use futures::stream::{Stream, StreamExt}; use reqwest::header::HeaderMap; -use rust_decimal::Decimal; use serde_json::Value; use std::{ - str::FromStr, sync::{ atomic::{AtomicBool, Ordering}, Arc, @@ -128,7 +126,15 @@ pub async fn handle_non_streaming( ctx.request_model.clone() }; - spawn_log_usage(state, ctx, usage, &model, status.as_u16(), false); + spawn_log_usage( + state, + ctx, + usage, + &model, + &ctx.request_model, + status.as_u16(), + false, + ); } else { let model = json_value .get("model") @@ -140,6 +146,7 @@ pub async fn handle_non_streaming( ctx, TokenUsage::default(), &model, + &ctx.request_model, status.as_u16(), false, ); @@ -159,6 +166,7 @@ pub async fn handle_non_streaming( ctx, TokenUsage::default(), &ctx.request_model, + &ctx.request_model, status.as_u16(), false, ); @@ -293,6 +301,7 @@ fn create_usage_collector( let state = state.clone(); let provider_id = provider_id.clone(); let session_id = session_id.clone(); + let request_model = request_model.clone(); tokio::spawn(async move { log_usage_internal( @@ -300,6 +309,7 @@ fn create_usage_collector( &provider_id, app_type_str, &model, + &request_model, usage, latency_ms, first_token_ms, @@ -315,6 +325,7 @@ fn create_usage_collector( let state = state.clone(); let provider_id = provider_id.clone(); let session_id = session_id.clone(); + let request_model = request_model.clone(); tokio::spawn(async move { log_usage_internal( @@ -322,6 +333,7 @@ fn create_usage_collector( &provider_id, app_type_str, &model, + &request_model, TokenUsage::default(), latency_ms, first_token_ms, @@ -342,6 +354,7 @@ fn spawn_log_usage( ctx: &RequestContext, usage: TokenUsage, model: &str, + request_model: &str, status_code: u16, is_streaming: bool, ) { @@ -349,6 +362,7 @@ fn spawn_log_usage( let provider_id = ctx.provider.id.clone(); let app_type_str = ctx.app_type_str.to_string(); let model = model.to_string(); + let request_model = request_model.to_string(); let latency_ms = ctx.latency_ms(); let session_id = ctx.session_id.clone(); @@ -358,6 +372,7 @@ fn spawn_log_usage( &provider_id, &app_type_str, &model, + &request_model, usage, latency_ms, None, @@ -376,6 +391,7 @@ async fn log_usage_internal( provider_id: &str, app_type: &str, model: &str, + request_model: &str, usage: TokenUsage, latency_ms: u64, first_token_ms: Option, @@ -386,26 +402,13 @@ async fn log_usage_internal( use super::usage::logger::UsageLogger; let logger = UsageLogger::new(&state.db); - - // 获取 provider 的 cost_multiplier - let multiplier = match state.db.get_provider_by_id(provider_id, app_type) { - Ok(Some(p)) => { - if let Some(meta) = p.meta { - if let Some(cm) = meta.cost_multiplier { - Decimal::from_str(&cm).unwrap_or_else(|e| { - log::warn!( - "cost_multiplier 解析失败 (provider_id={provider_id}): {cm} - {e}" - ); - Decimal::from(1) - }) - } else { - Decimal::from(1) - } - } else { - Decimal::from(1) - } - } - _ => Decimal::from(1), + let (multiplier, pricing_model_source) = logger + .resolve_pricing_config(provider_id, app_type) + .await; + let pricing_model = if pricing_model_source == "request" { + request_model + } else { + model }; let request_id = uuid::Uuid::new_v4().to_string(); @@ -424,6 +427,8 @@ async fn log_usage_internal( provider_id.to_string(), app_type.to_string(), model.to_string(), + request_model.to_string(), + pricing_model.to_string(), usage, multiplier, latency_ms, @@ -556,3 +561,185 @@ fn format_headers(headers: &HeaderMap) -> String { .collect::>() .join(", ") } + +#[cfg(test)] +mod tests { + use super::*; + use crate::database::Database; + use crate::error::AppError; + use crate::provider::ProviderMeta; + use crate::proxy::failover_switch::FailoverSwitchManager; + use crate::proxy::provider_router::ProviderRouter; + use crate::proxy::types::{ProxyConfig, ProxyStatus}; + use rust_decimal::Decimal; + use std::collections::HashMap; + use std::str::FromStr; + use std::sync::Arc; + use tokio::sync::RwLock; + + fn build_state(db: Arc) -> ProxyState { + ProxyState { + db: db.clone(), + config: Arc::new(RwLock::new(ProxyConfig::default())), + status: Arc::new(RwLock::new(ProxyStatus::default())), + start_time: Arc::new(RwLock::new(None)), + current_providers: Arc::new(RwLock::new(HashMap::new())), + provider_router: Arc::new(ProviderRouter::new(db.clone())), + app_handle: None, + failover_manager: Arc::new(FailoverSwitchManager::new(db)), + } + } + + fn seed_pricing(db: &Database) -> Result<(), AppError> { + let conn = crate::database::lock_conn!(db.conn); + conn.execute( + "INSERT OR REPLACE INTO model_pricing (model_id, display_name, input_cost_per_million, output_cost_per_million) + VALUES (?1, ?2, ?3, ?4)", + rusqlite::params!["resp-model", "Resp Model", "1.0", "0"], + ) + .map_err(|e| AppError::Database(e.to_string()))?; + conn.execute( + "INSERT OR REPLACE INTO model_pricing (model_id, display_name, input_cost_per_million, output_cost_per_million) + VALUES (?1, ?2, ?3, ?4)", + rusqlite::params!["req-model", "Req Model", "2.0", "0"], + ) + .map_err(|e| AppError::Database(e.to_string()))?; + Ok(()) + } + + fn insert_provider( + db: &Database, + id: &str, + app_type: &str, + meta: ProviderMeta, + ) -> Result<(), AppError> { + let meta_json = + serde_json::to_string(&meta).map_err(|e| AppError::Database(e.to_string()))?; + let conn = crate::database::lock_conn!(db.conn); + conn.execute( + "INSERT INTO providers (id, app_type, name, settings_config, meta) + VALUES (?1, ?2, ?3, ?4, ?5)", + rusqlite::params![id, app_type, "Test Provider", "{}", meta_json], + ) + .map_err(|e| AppError::Database(e.to_string()))?; + Ok(()) + } + + #[tokio::test] + async fn test_log_usage_uses_provider_override_config() -> Result<(), AppError> { + let db = Arc::new(Database::memory()?); + let app_type = "claude"; + + db.set_default_cost_multiplier(app_type, "1.5").await?; + db.set_pricing_model_source(app_type, "response").await?; + seed_pricing(&db)?; + + let mut meta = ProviderMeta::default(); + meta.cost_multiplier = Some("2".to_string()); + meta.pricing_model_source = Some("request".to_string()); + insert_provider(&db, "provider-1", app_type, meta)?; + + let state = build_state(db.clone()); + let usage = TokenUsage { + input_tokens: 1_000_000, + output_tokens: 0, + cache_read_tokens: 0, + cache_creation_tokens: 0, + model: None, + }; + + log_usage_internal( + &state, + "provider-1", + app_type, + "resp-model", + "req-model", + usage, + 10, + None, + false, + 200, + None, + ) + .await; + + let conn = crate::database::lock_conn!(db.conn); + let (model, request_model, total_cost, cost_multiplier): (String, String, String, String) = + conn.query_row( + "SELECT model, request_model, total_cost_usd, cost_multiplier + FROM proxy_request_logs WHERE provider_id = ?1", + ["provider-1"], + |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?)), + ) + .map_err(|e| AppError::Database(e.to_string()))?; + + assert_eq!(model, "resp-model"); + assert_eq!(request_model, "req-model"); + assert_eq!( + Decimal::from_str(&cost_multiplier).unwrap(), + Decimal::from_str("2").unwrap() + ); + assert_eq!( + Decimal::from_str(&total_cost).unwrap(), + Decimal::from_str("4").unwrap() + ); + Ok(()) + } + + #[tokio::test] + async fn test_log_usage_falls_back_to_global_defaults() -> Result<(), AppError> { + let db = Arc::new(Database::memory()?); + let app_type = "claude"; + + db.set_default_cost_multiplier(app_type, "1.5").await?; + db.set_pricing_model_source(app_type, "response").await?; + seed_pricing(&db)?; + + let meta = ProviderMeta::default(); + insert_provider(&db, "provider-2", app_type, meta)?; + + let state = build_state(db.clone()); + let usage = TokenUsage { + input_tokens: 1_000_000, + output_tokens: 0, + cache_read_tokens: 0, + cache_creation_tokens: 0, + model: None, + }; + + log_usage_internal( + &state, + "provider-2", + app_type, + "resp-model", + "req-model", + usage, + 10, + None, + false, + 200, + None, + ) + .await; + + let conn = crate::database::lock_conn!(db.conn); + let (total_cost, cost_multiplier): (String, String) = conn + .query_row( + "SELECT total_cost_usd, cost_multiplier + FROM proxy_request_logs WHERE provider_id = ?1", + ["provider-2"], + |row| Ok((row.get(0)?, row.get(1)?)), + ) + .map_err(|e| AppError::Database(e.to_string()))?; + + assert_eq!( + Decimal::from_str(&cost_multiplier).unwrap(), + Decimal::from_str("1.5").unwrap() + ); + assert_eq!( + Decimal::from_str(&total_cost).unwrap(), + Decimal::from_str("1.5").unwrap() + ); + Ok(()) + } +} diff --git a/src-tauri/src/proxy/usage/calculator.rs b/src-tauri/src/proxy/usage/calculator.rs index 80fd2c67..34015023 100644 --- a/src-tauri/src/proxy/usage/calculator.rs +++ b/src-tauri/src/proxy/usage/calculator.rs @@ -40,6 +40,7 @@ impl CostCalculator { /// - input_cost: (input_tokens - cache_read_tokens) × 输入价格 /// - cache_read_cost: cache_read_tokens × 缓存读取价格 /// - 这样避免缓存部分被重复计费 + /// - total_cost: 各项成本之和 × 倍率(倍率只作用于最终总价) pub fn calculate( usage: &TokenUsage, pricing: &ModelPricing, @@ -50,21 +51,20 @@ impl CostCalculator { // 计算实际需要按输入价格计费的 token 数(减去缓存命中部分) let billable_input_tokens = usage.input_tokens.saturating_sub(usage.cache_read_tokens); - let input_cost = Decimal::from(billable_input_tokens) * pricing.input_cost_per_million - / million - * cost_multiplier; - let output_cost = Decimal::from(usage.output_tokens) * pricing.output_cost_per_million - / million - * cost_multiplier; + // 各项基础成本(不含倍率) + let input_cost = + Decimal::from(billable_input_tokens) * pricing.input_cost_per_million / million; + let output_cost = + Decimal::from(usage.output_tokens) * pricing.output_cost_per_million / million; let cache_read_cost = - Decimal::from(usage.cache_read_tokens) * pricing.cache_read_cost_per_million / million - * cost_multiplier; + Decimal::from(usage.cache_read_tokens) * pricing.cache_read_cost_per_million / million; let cache_creation_cost = Decimal::from(usage.cache_creation_tokens) * pricing.cache_creation_cost_per_million - / million - * cost_multiplier; + / million; - let total_cost = input_cost + output_cost + cache_read_cost + cache_creation_cost; + // 总成本 = 各项基础成本之和 × 倍率 + let base_total = input_cost + output_cost + cache_read_cost + cache_creation_cost; + let total_cost = base_total * cost_multiplier; CostBreakdown { input_cost, @@ -151,8 +151,9 @@ mod tests { let cost = CostCalculator::calculate(&usage, &pricing, multiplier); - // input: 1000 * 3.0 / 1M * 1.5 = 0.0045 - assert_eq!(cost.input_cost, Decimal::from_str("0.0045").unwrap()); + // input_cost: 基础价格(不含倍率)= 1000 * 3.0 / 1M = 0.003 + assert_eq!(cost.input_cost, Decimal::from_str("0.003").unwrap()); + // total_cost: 基础价格 × 倍率 = 0.003 * 1.5 = 0.0045 assert_eq!(cost.total_cost, Decimal::from_str("0.0045").unwrap()); } diff --git a/src-tauri/src/proxy/usage/logger.rs b/src-tauri/src/proxy/usage/logger.rs index ea1470fa..f6a511a7 100644 --- a/src-tauri/src/proxy/usage/logger.rs +++ b/src-tauri/src/proxy/usage/logger.rs @@ -6,7 +6,7 @@ use crate::database::Database; use crate::error::AppError; use crate::services::usage_stats::find_model_pricing_row; use rust_decimal::Decimal; -use std::time::SystemTime; +use std::{str::FromStr, time::SystemTime}; /// 请求日志 #[derive(Debug, Clone)] @@ -15,6 +15,7 @@ pub struct RequestLog { pub provider_id: String, pub app_type: String, pub model: String, + pub request_model: String, pub usage: TokenUsage, pub cost: Option, pub latency_ms: u64, @@ -73,17 +74,18 @@ impl<'a> UsageLogger<'a> { conn.execute( "INSERT INTO proxy_request_logs ( - request_id, provider_id, app_type, model, + request_id, provider_id, app_type, model, request_model, input_tokens, output_tokens, cache_read_tokens, cache_creation_tokens, input_cost_usd, output_cost_usd, cache_read_cost_usd, cache_creation_cost_usd, total_cost_usd, latency_ms, first_token_ms, status_code, error_message, session_id, provider_type, is_streaming, cost_multiplier, created_at - ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18, ?19, ?20, ?21, ?22)", + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18, ?19, ?20, ?21, ?22, ?23)", rusqlite::params![ log.request_id, log.provider_id, log.app_type, log.model, + log.request_model, log.usage.input_tokens, log.usage.output_tokens, log.usage.cache_read_tokens, @@ -123,11 +125,13 @@ impl<'a> UsageLogger<'a> { error_message: String, latency_ms: u64, ) -> Result<(), AppError> { + let request_model = model.clone(); let log = RequestLog { request_id, provider_id, app_type, model, + request_model, usage: TokenUsage::default(), cost: None, latency_ms, @@ -160,11 +164,13 @@ impl<'a> UsageLogger<'a> { session_id: Option, provider_type: Option, ) -> Result<(), AppError> { + let request_model = model.clone(); let log = RequestLog { request_id, provider_id, app_type, model, + request_model, usage: TokenUsage::default(), cost: None, latency_ms, @@ -194,6 +200,96 @@ impl<'a> UsageLogger<'a> { } } + /// 获取有效的倍率与计费模式来源(供应商优先,未配置则回退全局默认) + pub async fn resolve_pricing_config( + &self, + provider_id: &str, + app_type: &str, + ) -> (Decimal, String) { + let default_multiplier_raw = match self.db.get_default_cost_multiplier(app_type).await { + Ok(value) => value, + Err(e) => { + log::warn!( + "[USG-003] 获取默认倍率失败 (app_type={app_type}): {e}" + ); + "1".to_string() + } + }; + let default_multiplier = match Decimal::from_str(&default_multiplier_raw) { + Ok(value) => value, + Err(e) => { + log::warn!( + "[USG-003] 默认倍率解析失败 (app_type={app_type}): {default_multiplier_raw} - {e}" + ); + Decimal::from(1) + } + }; + + let default_pricing_source_raw = match self.db.get_pricing_model_source(app_type).await { + Ok(value) => value, + Err(e) => { + log::warn!( + "[USG-003] 获取默认计费模式失败 (app_type={app_type}): {e}" + ); + "response".to_string() + } + }; + let default_pricing_source = if matches!( + default_pricing_source_raw.as_str(), + "response" | "request" + ) { + default_pricing_source_raw + } else { + log::warn!( + "[USG-003] 默认计费模式无效 (app_type={app_type}): {default_pricing_source_raw}" + ); + "response".to_string() + }; + + let provider = self + .db + .get_provider_by_id(provider_id, app_type) + .ok() + .flatten(); + + let (provider_multiplier, provider_pricing_source) = provider + .as_ref() + .and_then(|p| p.meta.as_ref()) + .map(|meta| { + ( + meta.cost_multiplier.as_deref(), + meta.pricing_model_source.as_deref(), + ) + }) + .unwrap_or((None, None)); + + let cost_multiplier = match provider_multiplier { + Some(value) => match Decimal::from_str(value) { + Ok(parsed) => parsed, + Err(e) => { + log::warn!( + "[USG-003] 供应商倍率解析失败 (provider_id={provider_id}): {value} - {e}" + ); + default_multiplier + } + }, + None => default_multiplier, + }; + + let pricing_model_source = match provider_pricing_source { + Some(value) if matches!(value, "response" | "request") => value.to_string(), + Some(value) => { + log::warn!( + "[USG-003] 供应商计费模式无效 (provider_id={provider_id}): {value}" + ); + default_pricing_source.clone() + } + None => default_pricing_source.clone(), + }; + + (cost_multiplier, pricing_model_source) + } + /// 计算并记录请求 #[allow(clippy::too_many_arguments)] pub fn log_with_calculation( @@ -202,6 +298,8 @@ impl<'a> UsageLogger<'a> { provider_id: String, app_type: String, model: String, + request_model: String, + pricing_model: String, usage: TokenUsage, cost_multiplier: Decimal, latency_ms: u64, @@ -211,10 +309,12 @@ impl<'a> UsageLogger<'a> { provider_type: Option, is_streaming: bool, ) -> Result<(), AppError> { - let pricing = self.get_model_pricing(&model)?; + let pricing = self.get_model_pricing(&pricing_model)?; if pricing.is_none() { - log::warn!("[USG-002] 模型定价未找到,成本将记录为 0"); + log::warn!( + "[USG-002] 模型定价未找到,成本将记录为 0: {pricing_model}" + ); } let cost = CostCalculator::try_calculate(&usage, pricing.as_ref(), cost_multiplier); @@ -224,6 +324,7 @@ impl<'a> UsageLogger<'a> { provider_id, app_type, model, + request_model, usage, cost, latency_ms, @@ -274,6 +375,8 @@ mod tests { "provider-1".to_string(), "claude".to_string(), "test-model".to_string(), + "req-model".to_string(), + "test-model".to_string(), usage, Decimal::from(1), 100, @@ -286,14 +389,15 @@ mod tests { // 验证记录已插入 let conn = crate::database::lock_conn!(db.conn); - let count: i64 = conn + let (count, request_model): (i64, String) = conn .query_row( - "SELECT COUNT(*) FROM proxy_request_logs WHERE request_id = 'req-123'", + "SELECT COUNT(*), request_model FROM proxy_request_logs WHERE request_id = 'req-123'", [], - |row| row.get(0), + |row| Ok((row.get(0)?, row.get(1)?)), ) .unwrap(); assert_eq!(count, 1); + assert_eq!(request_model, "req-model"); Ok(()) } diff --git a/src-tauri/src/services/usage_stats.rs b/src-tauri/src/services/usage_stats.rs index 81316aa4..2389ed17 100644 --- a/src-tauri/src/services/usage_stats.rs +++ b/src-tauri/src/services/usage_stats.rs @@ -94,6 +94,9 @@ pub struct RequestLogDetail { pub provider_name: Option, pub app_type: String, pub model: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub request_model: Option, + pub cost_multiplier: String, pub input_tokens: u32, pub output_tokens: u32, pub cache_read_tokens: u32, @@ -140,7 +143,7 @@ impl Database { }; let sql = format!( - "SELECT + "SELECT COUNT(*) as total_requests, COALESCE(SUM(CAST(total_cost_usd AS REAL)), 0) as total_cost, COALESCE(SUM(input_tokens), 0) as total_input_tokens, @@ -218,7 +221,7 @@ impl Database { } let sql = " - SELECT + SELECT CAST((created_at - ?1) / ?3 AS INTEGER) as bucket_idx, COUNT(*) as request_count, COALESCE(SUM(CAST(total_cost_usd AS REAL)), 0) as total_cost, @@ -295,7 +298,7 @@ impl Database { pub fn get_provider_stats(&self) -> Result, AppError> { let conn = lock_conn!(self.conn); - let sql = "SELECT + let sql = "SELECT l.provider_id, p.name as provider_name, COUNT(*) as request_count, @@ -343,7 +346,7 @@ impl Database { pub fn get_model_stats(&self) -> Result, AppError> { let conn = lock_conn!(self.conn); - let sql = "SELECT + let sql = "SELECT model, COUNT(*) as request_count, COALESCE(SUM(input_tokens + output_tokens), 0) as total_tokens, @@ -424,7 +427,7 @@ impl Database { // 获取总数 let count_sql = format!( - "SELECT COUNT(*) FROM proxy_request_logs l + "SELECT COUNT(*) FROM proxy_request_logs l LEFT JOIN providers p ON l.provider_id = p.id AND l.app_type = p.app_type {where_clause}" ); @@ -440,6 +443,7 @@ impl Database { let sql = format!( "SELECT l.request_id, l.provider_id, p.name as provider_name, l.app_type, l.model, + l.request_model, l.cost_multiplier, l.input_tokens, l.output_tokens, l.cache_read_tokens, l.cache_creation_tokens, l.input_cost_usd, l.output_cost_usd, l.cache_read_cost_usd, l.cache_creation_cost_usd, l.total_cost_usd, l.is_streaming, l.latency_ms, l.first_token_ms, l.duration_ms, @@ -460,22 +464,24 @@ impl Database { provider_name: row.get(2)?, app_type: row.get(3)?, model: row.get(4)?, - input_tokens: row.get::<_, i64>(5)? as u32, - output_tokens: row.get::<_, i64>(6)? as u32, - cache_read_tokens: row.get::<_, i64>(7)? as u32, - cache_creation_tokens: row.get::<_, i64>(8)? as u32, - input_cost_usd: row.get(9)?, - output_cost_usd: row.get(10)?, - cache_read_cost_usd: row.get(11)?, - cache_creation_cost_usd: row.get(12)?, - total_cost_usd: row.get(13)?, - is_streaming: row.get::<_, i64>(14)? != 0, - latency_ms: row.get::<_, i64>(15)? as u64, - first_token_ms: row.get::<_, Option>(16)?.map(|v| v as u64), - duration_ms: row.get::<_, Option>(17)?.map(|v| v as u64), - status_code: row.get::<_, i64>(18)? as u16, - error_message: row.get(19)?, - created_at: row.get(20)?, + request_model: row.get(5)?, + cost_multiplier: row.get::<_, Option>(6)?.unwrap_or_else(|| "1".to_string()), + input_tokens: row.get::<_, i64>(7)? as u32, + output_tokens: row.get::<_, i64>(8)? as u32, + cache_read_tokens: row.get::<_, i64>(9)? as u32, + cache_creation_tokens: row.get::<_, i64>(10)? as u32, + input_cost_usd: row.get(11)?, + output_cost_usd: row.get(12)?, + cache_read_cost_usd: row.get(13)?, + cache_creation_cost_usd: row.get(14)?, + total_cost_usd: row.get(15)?, + is_streaming: row.get::<_, i64>(16)? != 0, + latency_ms: row.get::<_, i64>(17)? as u64, + first_token_ms: row.get::<_, Option>(18)?.map(|v| v as u64), + duration_ms: row.get::<_, Option>(19)?.map(|v| v as u64), + status_code: row.get::<_, i64>(20)? as u16, + error_message: row.get(21)?, + created_at: row.get(22)?, }) })?; @@ -511,6 +517,7 @@ impl Database { let result = conn.query_row( "SELECT l.request_id, l.provider_id, p.name as provider_name, l.app_type, l.model, + l.request_model, l.cost_multiplier, input_tokens, output_tokens, cache_read_tokens, cache_creation_tokens, input_cost_usd, output_cost_usd, cache_read_cost_usd, cache_creation_cost_usd, total_cost_usd, is_streaming, latency_ms, first_token_ms, duration_ms, @@ -526,22 +533,24 @@ impl Database { provider_name: row.get(2)?, app_type: row.get(3)?, model: row.get(4)?, - input_tokens: row.get::<_, i64>(5)? as u32, - output_tokens: row.get::<_, i64>(6)? as u32, - cache_read_tokens: row.get::<_, i64>(7)? as u32, - cache_creation_tokens: row.get::<_, i64>(8)? as u32, - input_cost_usd: row.get(9)?, - output_cost_usd: row.get(10)?, - cache_read_cost_usd: row.get(11)?, - cache_creation_cost_usd: row.get(12)?, - total_cost_usd: row.get(13)?, - is_streaming: row.get::<_, i64>(14)? != 0, - latency_ms: row.get::<_, i64>(15)? as u64, - first_token_ms: row.get::<_, Option>(16)?.map(|v| v as u64), - duration_ms: row.get::<_, Option>(17)?.map(|v| v as u64), - status_code: row.get::<_, i64>(18)? as u16, - error_message: row.get(19)?, - created_at: row.get(20)?, + request_model: row.get(5)?, + cost_multiplier: row.get::<_, Option>(6)?.unwrap_or_else(|| "1".to_string()), + input_tokens: row.get::<_, i64>(7)? as u32, + output_tokens: row.get::<_, i64>(8)? as u32, + cache_read_tokens: row.get::<_, i64>(9)? as u32, + cache_creation_tokens: row.get::<_, i64>(10)? as u32, + input_cost_usd: row.get(11)?, + output_cost_usd: row.get(12)?, + cache_read_cost_usd: row.get(13)?, + cache_creation_cost_usd: row.get(14)?, + total_cost_usd: row.get(15)?, + is_streaming: row.get::<_, i64>(16)? != 0, + latency_ms: row.get::<_, i64>(17)? as u64, + first_token_ms: row.get::<_, Option>(18)?.map(|v| v as u64), + duration_ms: row.get::<_, Option>(19)?.map(|v| v as u64), + status_code: row.get::<_, i64>(20)? as u16, + error_message: row.get(21)?, + created_at: row.get(22)?, }) }, );