fix(proxy): apply cost multiplier to total cost only

- Move multiplier calculation from per-item to total cost
- Add resolve_pricing_config for provider-level override
- Include request_model and cost_multiplier in usage logs
- Return new fields in get_request_logs API
This commit is contained in:
YoVinchen
2026-01-26 01:40:18 +08:00
parent bdef49fe0f
commit 63b874aff1
5 changed files with 395 additions and 103 deletions
+13 -22
View File
@@ -22,9 +22,7 @@ use super::{
};
use crate::app_config::AppType;
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
use rust_decimal::Decimal;
use serde_json::{json, Value};
use std::str::FromStr;
// ============================================================================
// 健康检查和状态查询(简单端点)
@@ -145,6 +143,7 @@ async fn handle_claude_transform(
&provider_id,
"claude",
&model,
&model,
usage,
latency_ms,
first_token_ms,
@@ -215,6 +214,7 @@ async fn handle_claude_transform(
.unwrap_or("unknown");
let latency_ms = ctx.latency_ms();
let request_model = ctx.request_model.clone();
tokio::spawn({
let state = state.clone();
let provider_id = ctx.provider.id.clone();
@@ -225,6 +225,7 @@ async fn handle_claude_transform(
&provider_id,
"claude",
&model,
&request_model,
usage,
latency_ms,
None,
@@ -441,6 +442,7 @@ async fn log_usage(
provider_id: &str,
app_type: &str,
model: &str,
request_model: &str,
usage: TokenUsage,
latency_ms: u64,
first_token_ms: Option<u64>,
@@ -450,26 +452,13 @@ async fn log_usage(
use super::usage::logger::UsageLogger;
let logger = UsageLogger::new(&state.db);
// 获取 provider 的 cost_multiplier
let multiplier = match state.db.get_provider_by_id(provider_id, app_type) {
Ok(Some(p)) => {
if let Some(meta) = p.meta {
if let Some(cm) = meta.cost_multiplier {
Decimal::from_str(&cm).unwrap_or_else(|e| {
log::warn!(
"cost_multiplier 解析失败 (provider_id={provider_id}): {cm} - {e}"
);
Decimal::from(1)
})
} else {
Decimal::from(1)
}
} else {
Decimal::from(1)
}
}
_ => Decimal::from(1),
let (multiplier, pricing_model_source) = logger
.resolve_pricing_config(provider_id, app_type)
.await;
let pricing_model = if pricing_model_source == "request" {
request_model
} else {
model
};
let request_id = uuid::Uuid::new_v4().to_string();
@@ -479,6 +468,8 @@ async fn log_usage(
provider_id.to_string(),
app_type.to_string(),
model.to_string(),
request_model.to_string(),
pricing_model.to_string(),
usage,
multiplier,
latency_ms,
+210 -23
View File
@@ -13,10 +13,8 @@ use axum::response::{IntoResponse, Response};
use bytes::Bytes;
use futures::stream::{Stream, StreamExt};
use reqwest::header::HeaderMap;
use rust_decimal::Decimal;
use serde_json::Value;
use std::{
str::FromStr,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
@@ -128,7 +126,15 @@ pub async fn handle_non_streaming(
ctx.request_model.clone()
};
spawn_log_usage(state, ctx, usage, &model, status.as_u16(), false);
spawn_log_usage(
state,
ctx,
usage,
&model,
&ctx.request_model,
status.as_u16(),
false,
);
} else {
let model = json_value
.get("model")
@@ -140,6 +146,7 @@ pub async fn handle_non_streaming(
ctx,
TokenUsage::default(),
&model,
&ctx.request_model,
status.as_u16(),
false,
);
@@ -159,6 +166,7 @@ pub async fn handle_non_streaming(
ctx,
TokenUsage::default(),
&ctx.request_model,
&ctx.request_model,
status.as_u16(),
false,
);
@@ -293,6 +301,7 @@ fn create_usage_collector(
let state = state.clone();
let provider_id = provider_id.clone();
let session_id = session_id.clone();
let request_model = request_model.clone();
tokio::spawn(async move {
log_usage_internal(
@@ -300,6 +309,7 @@ fn create_usage_collector(
&provider_id,
app_type_str,
&model,
&request_model,
usage,
latency_ms,
first_token_ms,
@@ -315,6 +325,7 @@ fn create_usage_collector(
let state = state.clone();
let provider_id = provider_id.clone();
let session_id = session_id.clone();
let request_model = request_model.clone();
tokio::spawn(async move {
log_usage_internal(
@@ -322,6 +333,7 @@ fn create_usage_collector(
&provider_id,
app_type_str,
&model,
&request_model,
TokenUsage::default(),
latency_ms,
first_token_ms,
@@ -342,6 +354,7 @@ fn spawn_log_usage(
ctx: &RequestContext,
usage: TokenUsage,
model: &str,
request_model: &str,
status_code: u16,
is_streaming: bool,
) {
@@ -349,6 +362,7 @@ fn spawn_log_usage(
let provider_id = ctx.provider.id.clone();
let app_type_str = ctx.app_type_str.to_string();
let model = model.to_string();
let request_model = request_model.to_string();
let latency_ms = ctx.latency_ms();
let session_id = ctx.session_id.clone();
@@ -358,6 +372,7 @@ fn spawn_log_usage(
&provider_id,
&app_type_str,
&model,
&request_model,
usage,
latency_ms,
None,
@@ -376,6 +391,7 @@ async fn log_usage_internal(
provider_id: &str,
app_type: &str,
model: &str,
request_model: &str,
usage: TokenUsage,
latency_ms: u64,
first_token_ms: Option<u64>,
@@ -386,26 +402,13 @@ async fn log_usage_internal(
use super::usage::logger::UsageLogger;
let logger = UsageLogger::new(&state.db);
// 获取 provider 的 cost_multiplier
let multiplier = match state.db.get_provider_by_id(provider_id, app_type) {
Ok(Some(p)) => {
if let Some(meta) = p.meta {
if let Some(cm) = meta.cost_multiplier {
Decimal::from_str(&cm).unwrap_or_else(|e| {
log::warn!(
"cost_multiplier 解析失败 (provider_id={provider_id}): {cm} - {e}"
);
Decimal::from(1)
})
} else {
Decimal::from(1)
}
} else {
Decimal::from(1)
}
}
_ => Decimal::from(1),
let (multiplier, pricing_model_source) = logger
.resolve_pricing_config(provider_id, app_type)
.await;
let pricing_model = if pricing_model_source == "request" {
request_model
} else {
model
};
let request_id = uuid::Uuid::new_v4().to_string();
@@ -424,6 +427,8 @@ async fn log_usage_internal(
provider_id.to_string(),
app_type.to_string(),
model.to_string(),
request_model.to_string(),
pricing_model.to_string(),
usage,
multiplier,
latency_ms,
@@ -556,3 +561,185 @@ fn format_headers(headers: &HeaderMap) -> String {
.collect::<Vec<_>>()
.join(", ")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::database::Database;
use crate::error::AppError;
use crate::provider::ProviderMeta;
use crate::proxy::failover_switch::FailoverSwitchManager;
use crate::proxy::provider_router::ProviderRouter;
use crate::proxy::types::{ProxyConfig, ProxyStatus};
use rust_decimal::Decimal;
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::RwLock;
fn build_state(db: Arc<Database>) -> ProxyState {
ProxyState {
db: db.clone(),
config: Arc::new(RwLock::new(ProxyConfig::default())),
status: Arc::new(RwLock::new(ProxyStatus::default())),
start_time: Arc::new(RwLock::new(None)),
current_providers: Arc::new(RwLock::new(HashMap::new())),
provider_router: Arc::new(ProviderRouter::new(db.clone())),
app_handle: None,
failover_manager: Arc::new(FailoverSwitchManager::new(db)),
}
}
fn seed_pricing(db: &Database) -> Result<(), AppError> {
let conn = crate::database::lock_conn!(db.conn);
conn.execute(
"INSERT OR REPLACE INTO model_pricing (model_id, display_name, input_cost_per_million, output_cost_per_million)
VALUES (?1, ?2, ?3, ?4)",
rusqlite::params!["resp-model", "Resp Model", "1.0", "0"],
)
.map_err(|e| AppError::Database(e.to_string()))?;
conn.execute(
"INSERT OR REPLACE INTO model_pricing (model_id, display_name, input_cost_per_million, output_cost_per_million)
VALUES (?1, ?2, ?3, ?4)",
rusqlite::params!["req-model", "Req Model", "2.0", "0"],
)
.map_err(|e| AppError::Database(e.to_string()))?;
Ok(())
}
fn insert_provider(
db: &Database,
id: &str,
app_type: &str,
meta: ProviderMeta,
) -> Result<(), AppError> {
let meta_json =
serde_json::to_string(&meta).map_err(|e| AppError::Database(e.to_string()))?;
let conn = crate::database::lock_conn!(db.conn);
conn.execute(
"INSERT INTO providers (id, app_type, name, settings_config, meta)
VALUES (?1, ?2, ?3, ?4, ?5)",
rusqlite::params![id, app_type, "Test Provider", "{}", meta_json],
)
.map_err(|e| AppError::Database(e.to_string()))?;
Ok(())
}
#[tokio::test]
async fn test_log_usage_uses_provider_override_config() -> Result<(), AppError> {
let db = Arc::new(Database::memory()?);
let app_type = "claude";
db.set_default_cost_multiplier(app_type, "1.5").await?;
db.set_pricing_model_source(app_type, "response").await?;
seed_pricing(&db)?;
let mut meta = ProviderMeta::default();
meta.cost_multiplier = Some("2".to_string());
meta.pricing_model_source = Some("request".to_string());
insert_provider(&db, "provider-1", app_type, meta)?;
let state = build_state(db.clone());
let usage = TokenUsage {
input_tokens: 1_000_000,
output_tokens: 0,
cache_read_tokens: 0,
cache_creation_tokens: 0,
model: None,
};
log_usage_internal(
&state,
"provider-1",
app_type,
"resp-model",
"req-model",
usage,
10,
None,
false,
200,
None,
)
.await;
let conn = crate::database::lock_conn!(db.conn);
let (model, request_model, total_cost, cost_multiplier): (String, String, String, String) =
conn.query_row(
"SELECT model, request_model, total_cost_usd, cost_multiplier
FROM proxy_request_logs WHERE provider_id = ?1",
["provider-1"],
|row| Ok((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?)),
)
.map_err(|e| AppError::Database(e.to_string()))?;
assert_eq!(model, "resp-model");
assert_eq!(request_model, "req-model");
assert_eq!(
Decimal::from_str(&cost_multiplier).unwrap(),
Decimal::from_str("2").unwrap()
);
assert_eq!(
Decimal::from_str(&total_cost).unwrap(),
Decimal::from_str("4").unwrap()
);
Ok(())
}
#[tokio::test]
async fn test_log_usage_falls_back_to_global_defaults() -> Result<(), AppError> {
let db = Arc::new(Database::memory()?);
let app_type = "claude";
db.set_default_cost_multiplier(app_type, "1.5").await?;
db.set_pricing_model_source(app_type, "response").await?;
seed_pricing(&db)?;
let meta = ProviderMeta::default();
insert_provider(&db, "provider-2", app_type, meta)?;
let state = build_state(db.clone());
let usage = TokenUsage {
input_tokens: 1_000_000,
output_tokens: 0,
cache_read_tokens: 0,
cache_creation_tokens: 0,
model: None,
};
log_usage_internal(
&state,
"provider-2",
app_type,
"resp-model",
"req-model",
usage,
10,
None,
false,
200,
None,
)
.await;
let conn = crate::database::lock_conn!(db.conn);
let (total_cost, cost_multiplier): (String, String) = conn
.query_row(
"SELECT total_cost_usd, cost_multiplier
FROM proxy_request_logs WHERE provider_id = ?1",
["provider-2"],
|row| Ok((row.get(0)?, row.get(1)?)),
)
.map_err(|e| AppError::Database(e.to_string()))?;
assert_eq!(
Decimal::from_str(&cost_multiplier).unwrap(),
Decimal::from_str("1.5").unwrap()
);
assert_eq!(
Decimal::from_str(&total_cost).unwrap(),
Decimal::from_str("1.5").unwrap()
);
Ok(())
}
}
+14 -13
View File
@@ -40,6 +40,7 @@ impl CostCalculator {
/// - input_cost: (input_tokens - cache_read_tokens) × 输入价格
/// - cache_read_cost: cache_read_tokens × 缓存读取价格
/// - 这样避免缓存部分被重复计费
/// - total_cost: 各项成本之和 × 倍率(倍率只作用于最终总价)
pub fn calculate(
usage: &TokenUsage,
pricing: &ModelPricing,
@@ -50,21 +51,20 @@ impl CostCalculator {
// 计算实际需要按输入价格计费的 token 数(减去缓存命中部分)
let billable_input_tokens = usage.input_tokens.saturating_sub(usage.cache_read_tokens);
let input_cost = Decimal::from(billable_input_tokens) * pricing.input_cost_per_million
/ million
* cost_multiplier;
let output_cost = Decimal::from(usage.output_tokens) * pricing.output_cost_per_million
/ million
* cost_multiplier;
// 各项基础成本(不含倍率)
let input_cost =
Decimal::from(billable_input_tokens) * pricing.input_cost_per_million / million;
let output_cost =
Decimal::from(usage.output_tokens) * pricing.output_cost_per_million / million;
let cache_read_cost =
Decimal::from(usage.cache_read_tokens) * pricing.cache_read_cost_per_million / million
* cost_multiplier;
Decimal::from(usage.cache_read_tokens) * pricing.cache_read_cost_per_million / million;
let cache_creation_cost = Decimal::from(usage.cache_creation_tokens)
* pricing.cache_creation_cost_per_million
/ million
* cost_multiplier;
/ million;
let total_cost = input_cost + output_cost + cache_read_cost + cache_creation_cost;
// 总成本 = 各项基础成本之和 × 倍率
let base_total = input_cost + output_cost + cache_read_cost + cache_creation_cost;
let total_cost = base_total * cost_multiplier;
CostBreakdown {
input_cost,
@@ -151,8 +151,9 @@ mod tests {
let cost = CostCalculator::calculate(&usage, &pricing, multiplier);
// input: 1000 * 3.0 / 1M * 1.5 = 0.0045
assert_eq!(cost.input_cost, Decimal::from_str("0.0045").unwrap());
// input_cost: 基础价格(不含倍率)= 1000 * 3.0 / 1M = 0.003
assert_eq!(cost.input_cost, Decimal::from_str("0.003").unwrap());
// total_cost: 基础价格 × 倍率 = 0.003 * 1.5 = 0.0045
assert_eq!(cost.total_cost, Decimal::from_str("0.0045").unwrap());
}
+112 -8
View File
@@ -6,7 +6,7 @@ use crate::database::Database;
use crate::error::AppError;
use crate::services::usage_stats::find_model_pricing_row;
use rust_decimal::Decimal;
use std::time::SystemTime;
use std::{str::FromStr, time::SystemTime};
/// 请求日志
#[derive(Debug, Clone)]
@@ -15,6 +15,7 @@ pub struct RequestLog {
pub provider_id: String,
pub app_type: String,
pub model: String,
pub request_model: String,
pub usage: TokenUsage,
pub cost: Option<CostBreakdown>,
pub latency_ms: u64,
@@ -73,17 +74,18 @@ impl<'a> UsageLogger<'a> {
conn.execute(
"INSERT INTO proxy_request_logs (
request_id, provider_id, app_type, model,
request_id, provider_id, app_type, model, request_model,
input_tokens, output_tokens, cache_read_tokens, cache_creation_tokens,
input_cost_usd, output_cost_usd, cache_read_cost_usd, cache_creation_cost_usd, total_cost_usd,
latency_ms, first_token_ms, status_code, error_message, session_id,
provider_type, is_streaming, cost_multiplier, created_at
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18, ?19, ?20, ?21, ?22)",
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18, ?19, ?20, ?21, ?22, ?23)",
rusqlite::params![
log.request_id,
log.provider_id,
log.app_type,
log.model,
log.request_model,
log.usage.input_tokens,
log.usage.output_tokens,
log.usage.cache_read_tokens,
@@ -123,11 +125,13 @@ impl<'a> UsageLogger<'a> {
error_message: String,
latency_ms: u64,
) -> Result<(), AppError> {
let request_model = model.clone();
let log = RequestLog {
request_id,
provider_id,
app_type,
model,
request_model,
usage: TokenUsage::default(),
cost: None,
latency_ms,
@@ -160,11 +164,13 @@ impl<'a> UsageLogger<'a> {
session_id: Option<String>,
provider_type: Option<String>,
) -> Result<(), AppError> {
let request_model = model.clone();
let log = RequestLog {
request_id,
provider_id,
app_type,
model,
request_model,
usage: TokenUsage::default(),
cost: None,
latency_ms,
@@ -194,6 +200,96 @@ impl<'a> UsageLogger<'a> {
}
}
/// 获取有效的倍率与计费模式来源(供应商优先,未配置则回退全局默认)
pub async fn resolve_pricing_config(
&self,
provider_id: &str,
app_type: &str,
) -> (Decimal, String) {
let default_multiplier_raw = match self.db.get_default_cost_multiplier(app_type).await {
Ok(value) => value,
Err(e) => {
log::warn!(
"[USG-003] 获取默认倍率失败 (app_type={app_type}): {e}"
);
"1".to_string()
}
};
let default_multiplier = match Decimal::from_str(&default_multiplier_raw) {
Ok(value) => value,
Err(e) => {
log::warn!(
"[USG-003] 默认倍率解析失败 (app_type={app_type}): {default_multiplier_raw} - {e}"
);
Decimal::from(1)
}
};
let default_pricing_source_raw = match self.db.get_pricing_model_source(app_type).await {
Ok(value) => value,
Err(e) => {
log::warn!(
"[USG-003] 获取默认计费模式失败 (app_type={app_type}): {e}"
);
"response".to_string()
}
};
let default_pricing_source = if matches!(
default_pricing_source_raw.as_str(),
"response" | "request"
) {
default_pricing_source_raw
} else {
log::warn!(
"[USG-003] 默认计费模式无效 (app_type={app_type}): {default_pricing_source_raw}"
);
"response".to_string()
};
let provider = self
.db
.get_provider_by_id(provider_id, app_type)
.ok()
.flatten();
let (provider_multiplier, provider_pricing_source) = provider
.as_ref()
.and_then(|p| p.meta.as_ref())
.map(|meta| {
(
meta.cost_multiplier.as_deref(),
meta.pricing_model_source.as_deref(),
)
})
.unwrap_or((None, None));
let cost_multiplier = match provider_multiplier {
Some(value) => match Decimal::from_str(value) {
Ok(parsed) => parsed,
Err(e) => {
log::warn!(
"[USG-003] 供应商倍率解析失败 (provider_id={provider_id}): {value} - {e}"
);
default_multiplier
}
},
None => default_multiplier,
};
let pricing_model_source = match provider_pricing_source {
Some(value) if matches!(value, "response" | "request") => value.to_string(),
Some(value) => {
log::warn!(
"[USG-003] 供应商计费模式无效 (provider_id={provider_id}): {value}"
);
default_pricing_source.clone()
}
None => default_pricing_source.clone(),
};
(cost_multiplier, pricing_model_source)
}
/// 计算并记录请求
#[allow(clippy::too_many_arguments)]
pub fn log_with_calculation(
@@ -202,6 +298,8 @@ impl<'a> UsageLogger<'a> {
provider_id: String,
app_type: String,
model: String,
request_model: String,
pricing_model: String,
usage: TokenUsage,
cost_multiplier: Decimal,
latency_ms: u64,
@@ -211,10 +309,12 @@ impl<'a> UsageLogger<'a> {
provider_type: Option<String>,
is_streaming: bool,
) -> Result<(), AppError> {
let pricing = self.get_model_pricing(&model)?;
let pricing = self.get_model_pricing(&pricing_model)?;
if pricing.is_none() {
log::warn!("[USG-002] 模型定价未找到,成本将记录为 0");
log::warn!(
"[USG-002] 模型定价未找到,成本将记录为 0: {pricing_model}"
);
}
let cost = CostCalculator::try_calculate(&usage, pricing.as_ref(), cost_multiplier);
@@ -224,6 +324,7 @@ impl<'a> UsageLogger<'a> {
provider_id,
app_type,
model,
request_model,
usage,
cost,
latency_ms,
@@ -274,6 +375,8 @@ mod tests {
"provider-1".to_string(),
"claude".to_string(),
"test-model".to_string(),
"req-model".to_string(),
"test-model".to_string(),
usage,
Decimal::from(1),
100,
@@ -286,14 +389,15 @@ mod tests {
// 验证记录已插入
let conn = crate::database::lock_conn!(db.conn);
let count: i64 = conn
let (count, request_model): (i64, String) = conn
.query_row(
"SELECT COUNT(*) FROM proxy_request_logs WHERE request_id = 'req-123'",
"SELECT COUNT(*), request_model FROM proxy_request_logs WHERE request_id = 'req-123'",
[],
|row| row.get(0),
|row| Ok((row.get(0)?, row.get(1)?)),
)
.unwrap();
assert_eq!(count, 1);
assert_eq!(request_model, "req-model");
Ok(())
}
+46 -37
View File
@@ -94,6 +94,9 @@ pub struct RequestLogDetail {
pub provider_name: Option<String>,
pub app_type: String,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub request_model: Option<String>,
pub cost_multiplier: String,
pub input_tokens: u32,
pub output_tokens: u32,
pub cache_read_tokens: u32,
@@ -140,7 +143,7 @@ impl Database {
};
let sql = format!(
"SELECT
"SELECT
COUNT(*) as total_requests,
COALESCE(SUM(CAST(total_cost_usd AS REAL)), 0) as total_cost,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
@@ -218,7 +221,7 @@ impl Database {
}
let sql = "
SELECT
SELECT
CAST((created_at - ?1) / ?3 AS INTEGER) as bucket_idx,
COUNT(*) as request_count,
COALESCE(SUM(CAST(total_cost_usd AS REAL)), 0) as total_cost,
@@ -295,7 +298,7 @@ impl Database {
pub fn get_provider_stats(&self) -> Result<Vec<ProviderStats>, AppError> {
let conn = lock_conn!(self.conn);
let sql = "SELECT
let sql = "SELECT
l.provider_id,
p.name as provider_name,
COUNT(*) as request_count,
@@ -343,7 +346,7 @@ impl Database {
pub fn get_model_stats(&self) -> Result<Vec<ModelStats>, AppError> {
let conn = lock_conn!(self.conn);
let sql = "SELECT
let sql = "SELECT
model,
COUNT(*) as request_count,
COALESCE(SUM(input_tokens + output_tokens), 0) as total_tokens,
@@ -424,7 +427,7 @@ impl Database {
// 获取总数
let count_sql = format!(
"SELECT COUNT(*) FROM proxy_request_logs l
"SELECT COUNT(*) FROM proxy_request_logs l
LEFT JOIN providers p ON l.provider_id = p.id AND l.app_type = p.app_type
{where_clause}"
);
@@ -440,6 +443,7 @@ impl Database {
let sql = format!(
"SELECT l.request_id, l.provider_id, p.name as provider_name, l.app_type, l.model,
l.request_model, l.cost_multiplier,
l.input_tokens, l.output_tokens, l.cache_read_tokens, l.cache_creation_tokens,
l.input_cost_usd, l.output_cost_usd, l.cache_read_cost_usd, l.cache_creation_cost_usd, l.total_cost_usd,
l.is_streaming, l.latency_ms, l.first_token_ms, l.duration_ms,
@@ -460,22 +464,24 @@ impl Database {
provider_name: row.get(2)?,
app_type: row.get(3)?,
model: row.get(4)?,
input_tokens: row.get::<_, i64>(5)? as u32,
output_tokens: row.get::<_, i64>(6)? as u32,
cache_read_tokens: row.get::<_, i64>(7)? as u32,
cache_creation_tokens: row.get::<_, i64>(8)? as u32,
input_cost_usd: row.get(9)?,
output_cost_usd: row.get(10)?,
cache_read_cost_usd: row.get(11)?,
cache_creation_cost_usd: row.get(12)?,
total_cost_usd: row.get(13)?,
is_streaming: row.get::<_, i64>(14)? != 0,
latency_ms: row.get::<_, i64>(15)? as u64,
first_token_ms: row.get::<_, Option<i64>>(16)?.map(|v| v as u64),
duration_ms: row.get::<_, Option<i64>>(17)?.map(|v| v as u64),
status_code: row.get::<_, i64>(18)? as u16,
error_message: row.get(19)?,
created_at: row.get(20)?,
request_model: row.get(5)?,
cost_multiplier: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "1".to_string()),
input_tokens: row.get::<_, i64>(7)? as u32,
output_tokens: row.get::<_, i64>(8)? as u32,
cache_read_tokens: row.get::<_, i64>(9)? as u32,
cache_creation_tokens: row.get::<_, i64>(10)? as u32,
input_cost_usd: row.get(11)?,
output_cost_usd: row.get(12)?,
cache_read_cost_usd: row.get(13)?,
cache_creation_cost_usd: row.get(14)?,
total_cost_usd: row.get(15)?,
is_streaming: row.get::<_, i64>(16)? != 0,
latency_ms: row.get::<_, i64>(17)? as u64,
first_token_ms: row.get::<_, Option<i64>>(18)?.map(|v| v as u64),
duration_ms: row.get::<_, Option<i64>>(19)?.map(|v| v as u64),
status_code: row.get::<_, i64>(20)? as u16,
error_message: row.get(21)?,
created_at: row.get(22)?,
})
})?;
@@ -511,6 +517,7 @@ impl Database {
let result = conn.query_row(
"SELECT l.request_id, l.provider_id, p.name as provider_name, l.app_type, l.model,
l.request_model, l.cost_multiplier,
input_tokens, output_tokens, cache_read_tokens, cache_creation_tokens,
input_cost_usd, output_cost_usd, cache_read_cost_usd, cache_creation_cost_usd, total_cost_usd,
is_streaming, latency_ms, first_token_ms, duration_ms,
@@ -526,22 +533,24 @@ impl Database {
provider_name: row.get(2)?,
app_type: row.get(3)?,
model: row.get(4)?,
input_tokens: row.get::<_, i64>(5)? as u32,
output_tokens: row.get::<_, i64>(6)? as u32,
cache_read_tokens: row.get::<_, i64>(7)? as u32,
cache_creation_tokens: row.get::<_, i64>(8)? as u32,
input_cost_usd: row.get(9)?,
output_cost_usd: row.get(10)?,
cache_read_cost_usd: row.get(11)?,
cache_creation_cost_usd: row.get(12)?,
total_cost_usd: row.get(13)?,
is_streaming: row.get::<_, i64>(14)? != 0,
latency_ms: row.get::<_, i64>(15)? as u64,
first_token_ms: row.get::<_, Option<i64>>(16)?.map(|v| v as u64),
duration_ms: row.get::<_, Option<i64>>(17)?.map(|v| v as u64),
status_code: row.get::<_, i64>(18)? as u16,
error_message: row.get(19)?,
created_at: row.get(20)?,
request_model: row.get(5)?,
cost_multiplier: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "1".to_string()),
input_tokens: row.get::<_, i64>(7)? as u32,
output_tokens: row.get::<_, i64>(8)? as u32,
cache_read_tokens: row.get::<_, i64>(9)? as u32,
cache_creation_tokens: row.get::<_, i64>(10)? as u32,
input_cost_usd: row.get(11)?,
output_cost_usd: row.get(12)?,
cache_read_cost_usd: row.get(13)?,
cache_creation_cost_usd: row.get(14)?,
total_cost_usd: row.get(15)?,
is_streaming: row.get::<_, i64>(16)? != 0,
latency_ms: row.get::<_, i64>(17)? as u64,
first_token_ms: row.get::<_, Option<i64>>(18)?.map(|v| v as u64),
duration_ms: row.get::<_, Option<i64>>(19)?.map(|v| v as u64),
status_code: row.get::<_, i64>(20)? as u16,
error_message: row.get(21)?,
created_at: row.get(22)?,
})
},
);