mirror of
https://github.com/farion1231/cc-switch.git
synced 2026-05-18 02:49:23 +08:00
fix(proxy): apply cost multiplier to total cost only
- Move multiplier calculation from per-item to total cost - Add resolve_pricing_config for provider-level override - Include request_model and cost_multiplier in usage logs - Return new fields in get_request_logs API
This commit is contained in:
@@ -22,9 +22,7 @@ use super::{
|
||||
};
|
||||
use crate::app_config::AppType;
|
||||
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
|
||||
use rust_decimal::Decimal;
|
||||
use serde_json::{json, Value};
|
||||
use std::str::FromStr;
|
||||
|
||||
// ============================================================================
|
||||
// 健康检查和状态查询(简单端点)
|
||||
@@ -145,6 +143,7 @@ async fn handle_claude_transform(
|
||||
&provider_id,
|
||||
"claude",
|
||||
&model,
|
||||
&model,
|
||||
usage,
|
||||
latency_ms,
|
||||
first_token_ms,
|
||||
@@ -215,6 +214,7 @@ async fn handle_claude_transform(
|
||||
.unwrap_or("unknown");
|
||||
let latency_ms = ctx.latency_ms();
|
||||
|
||||
let request_model = ctx.request_model.clone();
|
||||
tokio::spawn({
|
||||
let state = state.clone();
|
||||
let provider_id = ctx.provider.id.clone();
|
||||
@@ -225,6 +225,7 @@ async fn handle_claude_transform(
|
||||
&provider_id,
|
||||
"claude",
|
||||
&model,
|
||||
&request_model,
|
||||
usage,
|
||||
latency_ms,
|
||||
None,
|
||||
@@ -441,6 +442,7 @@ async fn log_usage(
|
||||
provider_id: &str,
|
||||
app_type: &str,
|
||||
model: &str,
|
||||
request_model: &str,
|
||||
usage: TokenUsage,
|
||||
latency_ms: u64,
|
||||
first_token_ms: Option<u64>,
|
||||
@@ -450,26 +452,13 @@ async fn log_usage(
|
||||
use super::usage::logger::UsageLogger;
|
||||
|
||||
let logger = UsageLogger::new(&state.db);
|
||||
|
||||
// 获取 provider 的 cost_multiplier
|
||||
let multiplier = match state.db.get_provider_by_id(provider_id, app_type) {
|
||||
Ok(Some(p)) => {
|
||||
if let Some(meta) = p.meta {
|
||||
if let Some(cm) = meta.cost_multiplier {
|
||||
Decimal::from_str(&cm).unwrap_or_else(|e| {
|
||||
log::warn!(
|
||||
"cost_multiplier 解析失败 (provider_id={provider_id}): {cm} - {e}"
|
||||
);
|
||||
Decimal::from(1)
|
||||
})
|
||||
} else {
|
||||
Decimal::from(1)
|
||||
}
|
||||
} else {
|
||||
Decimal::from(1)
|
||||
}
|
||||
}
|
||||
_ => Decimal::from(1),
|
||||
let (multiplier, pricing_model_source) = logger
|
||||
.resolve_pricing_config(provider_id, app_type)
|
||||
.await;
|
||||
let pricing_model = if pricing_model_source == "request" {
|
||||
request_model
|
||||
} else {
|
||||
model
|
||||
};
|
||||
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
@@ -479,6 +468,8 @@ async fn log_usage(
|
||||
provider_id.to_string(),
|
||||
app_type.to_string(),
|
||||
model.to_string(),
|
||||
request_model.to_string(),
|
||||
pricing_model.to_string(),
|
||||
usage,
|
||||
multiplier,
|
||||
latency_ms,
|
||||
|
||||
@@ -13,10 +13,8 @@ use axum::response::{IntoResponse, Response};
|
||||
use bytes::Bytes;
|
||||
use futures::stream::{Stream, StreamExt};
|
||||
use reqwest::header::HeaderMap;
|
||||
use rust_decimal::Decimal;
|
||||
use serde_json::Value;
|
||||
use std::{
|
||||
str::FromStr,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
@@ -128,7 +126,15 @@ pub async fn handle_non_streaming(
|
||||
ctx.request_model.clone()
|
||||
};
|
||||
|
||||
spawn_log_usage(state, ctx, usage, &model, status.as_u16(), false);
|
||||
spawn_log_usage(
|
||||
state,
|
||||
ctx,
|
||||
usage,
|
||||
&model,
|
||||
&ctx.request_model,
|
||||
status.as_u16(),
|
||||
false,
|
||||
);
|
||||
} else {
|
||||
let model = json_value
|
||||
.get("model")
|
||||
@@ -140,6 +146,7 @@ pub async fn handle_non_streaming(
|
||||
ctx,
|
||||
TokenUsage::default(),
|
||||
&model,
|
||||
&ctx.request_model,
|
||||
status.as_u16(),
|
||||
false,
|
||||
);
|
||||
@@ -159,6 +166,7 @@ pub async fn handle_non_streaming(
|
||||
ctx,
|
||||
TokenUsage::default(),
|
||||
&ctx.request_model,
|
||||
&ctx.request_model,
|
||||
status.as_u16(),
|
||||
false,
|
||||
);
|
||||
@@ -293,6 +301,7 @@ fn create_usage_collector(
|
||||
let state = state.clone();
|
||||
let provider_id = provider_id.clone();
|
||||
let session_id = session_id.clone();
|
||||
let request_model = request_model.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
log_usage_internal(
|
||||
@@ -300,6 +309,7 @@ fn create_usage_collector(
|
||||
&provider_id,
|
||||
app_type_str,
|
||||
&model,
|
||||
&request_model,
|
||||
usage,
|
||||
latency_ms,
|
||||
first_token_ms,
|
||||
@@ -315,6 +325,7 @@ fn create_usage_collector(
|
||||
let state = state.clone();
|
||||
let provider_id = provider_id.clone();
|
||||
let session_id = session_id.clone();
|
||||
let request_model = request_model.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
log_usage_internal(
|
||||
@@ -322,6 +333,7 @@ fn create_usage_collector(
|
||||
&provider_id,
|
||||
app_type_str,
|
||||
&model,
|
||||
&request_model,
|
||||
TokenUsage::default(),
|
||||
latency_ms,
|
||||
first_token_ms,
|
||||
@@ -342,6 +354,7 @@ fn spawn_log_usage(
|
||||
ctx: &RequestContext,
|
||||
usage: TokenUsage,
|
||||
model: &str,
|
||||
request_model: &str,
|
||||
status_code: u16,
|
||||
is_streaming: bool,
|
||||
) {
|
||||
@@ -349,6 +362,7 @@ fn spawn_log_usage(
|
||||
let provider_id = ctx.provider.id.clone();
|
||||
let app_type_str = ctx.app_type_str.to_string();
|
||||
let model = model.to_string();
|
||||
let request_model = request_model.to_string();
|
||||
let latency_ms = ctx.latency_ms();
|
||||
let session_id = ctx.session_id.clone();
|
||||
|
||||
@@ -358,6 +372,7 @@ fn spawn_log_usage(
|
||||
&provider_id,
|
||||
&app_type_str,
|
||||
&model,
|
||||
&request_model,
|
||||
usage,
|
||||
latency_ms,
|
||||
None,
|
||||
@@ -376,6 +391,7 @@ async fn log_usage_internal(
|
||||
provider_id: &str,
|
||||
app_type: &str,
|
||||
model: &str,
|
||||
request_model: &str,
|
||||
usage: TokenUsage,
|
||||
latency_ms: u64,
|
||||
first_token_ms: Option<u64>,
|
||||
@@ -386,26 +402,13 @@ async fn log_usage_internal(
|
||||
use super::usage::logger::UsageLogger;
|
||||
|
||||
let logger = UsageLogger::new(&state.db);
|
||||
|
||||
// 获取 provider 的 cost_multiplier
|
||||
let multiplier = match state.db.get_provider_by_id(provider_id, app_type) {
|
||||
Ok(Some(p)) => {
|
||||
if let Some(meta) = p.meta {
|
||||
if let Some(cm) = meta.cost_multiplier {
|
||||
Decimal::from_str(&cm).unwrap_or_else(|e| {
|
||||
log::warn!(
|
||||
"cost_multiplier 解析失败 (provider_id={provider_id}): {cm} - {e}"
|
||||
);
|
||||
Decimal::from(1)
|
||||
})
|
||||
} else {
|
||||
Decimal::from(1)
|
||||
}
|
||||
} else {
|
||||
Decimal::from(1)
|
||||
}
|
||||
}
|
||||
_ => Decimal::from(1),
|
||||
let (multiplier, pricing_model_source) = logger
|
||||
.resolve_pricing_config(provider_id, app_type)
|
||||
.await;
|
||||
let pricing_model = if pricing_model_source == "request" {
|
||||
request_model
|
||||
} else {
|
||||
model
|
||||
};
|
||||
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
@@ -424,6 +427,8 @@ async fn log_usage_internal(
|
||||
provider_id.to_string(),
|
||||
app_type.to_string(),
|
||||
model.to_string(),
|
||||
request_model.to_string(),
|
||||
pricing_model.to_string(),
|
||||
usage,
|
||||
multiplier,
|
||||
latency_ms,
|
||||
@@ -556,3 +561,185 @@ fn format_headers(headers: &HeaderMap) -> String {
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::database::Database;
|
||||
use crate::error::AppError;
|
||||
use crate::provider::ProviderMeta;
|
||||
use crate::proxy::failover_switch::FailoverSwitchManager;
|
||||
use crate::proxy::provider_router::ProviderRouter;
|
||||
use crate::proxy::types::{ProxyConfig, ProxyStatus};
|
||||
use rust_decimal::Decimal;
|
||||
use std::collections::HashMap;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
fn build_state(db: Arc<Database>) -> ProxyState {
|
||||
ProxyState {
|
||||
db: db.clone(),
|
||||
config: Arc::new(RwLock::new(ProxyConfig::default())),
|
||||
status: Arc::new(RwLock::new(ProxyStatus::default())),
|
||||
start_time: Arc::new(RwLock::new(None)),
|
||||
current_providers: Arc::new(RwLock::new(HashMap::new())),
|
||||
provider_router: Arc::new(ProviderRouter::new(db.clone())),
|
||||
app_handle: None,
|
||||
failover_manager: Arc::new(FailoverSwitchManager::new(db)),
|
||||
}
|
||||
}
|
||||
|
||||
fn seed_pricing(db: &Database) -> Result<(), AppError> {
|
||||
let conn = crate::database::lock_conn!(db.conn);
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO model_pricing (model_id, display_name, input_cost_per_million, output_cost_per_million)
|
||||
VALUES (?1, ?2, ?3, ?4)",
|
||||
rusqlite::params!["resp-model", "Resp Model", "1.0", "0"],
|
||||
)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO model_pricing (model_id, display_name, input_cost_per_million, output_cost_per_million)
|
||||
VALUES (?1, ?2, ?3, ?4)",
|
||||
rusqlite::params!["req-model", "Req Model", "2.0", "0"],
|
||||
)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn insert_provider(
|
||||
db: &Database,
|
||||
id: &str,
|
||||
app_type: &str,
|
||||
meta: ProviderMeta,
|
||||
) -> Result<(), AppError> {
|
||||
let meta_json =
|
||||
serde_json::to_string(&meta).map_err(|e| AppError::Database(e.to_string()))?;
|
||||
let conn = crate::database::lock_conn!(db.conn);
|
||||
conn.execute(
|
||||
"INSERT INTO providers (id, app_type, name, settings_config, meta)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5)",
|
||||
rusqlite::params![id, app_type, "Test Provider", "{}", meta_json],
|
||||
)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_log_usage_uses_provider_override_config() -> Result<(), AppError> {
|
||||
let db = Arc::new(Database::memory()?);
|
||||
let app_type = "claude";
|
||||
|
||||
db.set_default_cost_multiplier(app_type, "1.5").await?;
|
||||
db.set_pricing_model_source(app_type, "response").await?;
|
||||
seed_pricing(&db)?;
|
||||
|
||||
let mut meta = ProviderMeta::default();
|
||||
meta.cost_multiplier = Some("2".to_string());
|
||||
meta.pricing_model_source = Some("request".to_string());
|
||||
insert_provider(&db, "provider-1", app_type, meta)?;
|
||||
|
||||
let state = build_state(db.clone());
|
||||
let usage = TokenUsage {
|
||||
input_tokens: 1_000_000,
|
||||
output_tokens: 0,
|
||||
cache_read_tokens: 0,
|
||||
cache_creation_tokens: 0,
|
||||
model: None,
|
||||
};
|
||||
|
||||
log_usage_internal(
|
||||
&state,
|
||||
"provider-1",
|
||||
app_type,
|
||||
"resp-model",
|
||||
"req-model",
|
||||
usage,
|
||||
10,
|
||||
None,
|
||||
false,
|
||||
200,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
let conn = crate::database::lock_conn!(db.conn);
|
||||
let (model, request_model, total_cost, cost_multiplier): (String, String, String, String) =
|
||||
conn.query_row(
|
||||
"SELECT model, request_model, total_cost_usd, cost_multiplier
|
||||
FROM proxy_request_logs WHERE provider_id = ?1",
|
||||
["provider-1"],
|
||||
|row| Ok((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?)),
|
||||
)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
|
||||
assert_eq!(model, "resp-model");
|
||||
assert_eq!(request_model, "req-model");
|
||||
assert_eq!(
|
||||
Decimal::from_str(&cost_multiplier).unwrap(),
|
||||
Decimal::from_str("2").unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
Decimal::from_str(&total_cost).unwrap(),
|
||||
Decimal::from_str("4").unwrap()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_log_usage_falls_back_to_global_defaults() -> Result<(), AppError> {
|
||||
let db = Arc::new(Database::memory()?);
|
||||
let app_type = "claude";
|
||||
|
||||
db.set_default_cost_multiplier(app_type, "1.5").await?;
|
||||
db.set_pricing_model_source(app_type, "response").await?;
|
||||
seed_pricing(&db)?;
|
||||
|
||||
let meta = ProviderMeta::default();
|
||||
insert_provider(&db, "provider-2", app_type, meta)?;
|
||||
|
||||
let state = build_state(db.clone());
|
||||
let usage = TokenUsage {
|
||||
input_tokens: 1_000_000,
|
||||
output_tokens: 0,
|
||||
cache_read_tokens: 0,
|
||||
cache_creation_tokens: 0,
|
||||
model: None,
|
||||
};
|
||||
|
||||
log_usage_internal(
|
||||
&state,
|
||||
"provider-2",
|
||||
app_type,
|
||||
"resp-model",
|
||||
"req-model",
|
||||
usage,
|
||||
10,
|
||||
None,
|
||||
false,
|
||||
200,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
let conn = crate::database::lock_conn!(db.conn);
|
||||
let (total_cost, cost_multiplier): (String, String) = conn
|
||||
.query_row(
|
||||
"SELECT total_cost_usd, cost_multiplier
|
||||
FROM proxy_request_logs WHERE provider_id = ?1",
|
||||
["provider-2"],
|
||||
|row| Ok((row.get(0)?, row.get(1)?)),
|
||||
)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
|
||||
assert_eq!(
|
||||
Decimal::from_str(&cost_multiplier).unwrap(),
|
||||
Decimal::from_str("1.5").unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
Decimal::from_str(&total_cost).unwrap(),
|
||||
Decimal::from_str("1.5").unwrap()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,6 +40,7 @@ impl CostCalculator {
|
||||
/// - input_cost: (input_tokens - cache_read_tokens) × 输入价格
|
||||
/// - cache_read_cost: cache_read_tokens × 缓存读取价格
|
||||
/// - 这样避免缓存部分被重复计费
|
||||
/// - total_cost: 各项成本之和 × 倍率(倍率只作用于最终总价)
|
||||
pub fn calculate(
|
||||
usage: &TokenUsage,
|
||||
pricing: &ModelPricing,
|
||||
@@ -50,21 +51,20 @@ impl CostCalculator {
|
||||
// 计算实际需要按输入价格计费的 token 数(减去缓存命中部分)
|
||||
let billable_input_tokens = usage.input_tokens.saturating_sub(usage.cache_read_tokens);
|
||||
|
||||
let input_cost = Decimal::from(billable_input_tokens) * pricing.input_cost_per_million
|
||||
/ million
|
||||
* cost_multiplier;
|
||||
let output_cost = Decimal::from(usage.output_tokens) * pricing.output_cost_per_million
|
||||
/ million
|
||||
* cost_multiplier;
|
||||
// 各项基础成本(不含倍率)
|
||||
let input_cost =
|
||||
Decimal::from(billable_input_tokens) * pricing.input_cost_per_million / million;
|
||||
let output_cost =
|
||||
Decimal::from(usage.output_tokens) * pricing.output_cost_per_million / million;
|
||||
let cache_read_cost =
|
||||
Decimal::from(usage.cache_read_tokens) * pricing.cache_read_cost_per_million / million
|
||||
* cost_multiplier;
|
||||
Decimal::from(usage.cache_read_tokens) * pricing.cache_read_cost_per_million / million;
|
||||
let cache_creation_cost = Decimal::from(usage.cache_creation_tokens)
|
||||
* pricing.cache_creation_cost_per_million
|
||||
/ million
|
||||
* cost_multiplier;
|
||||
/ million;
|
||||
|
||||
let total_cost = input_cost + output_cost + cache_read_cost + cache_creation_cost;
|
||||
// 总成本 = 各项基础成本之和 × 倍率
|
||||
let base_total = input_cost + output_cost + cache_read_cost + cache_creation_cost;
|
||||
let total_cost = base_total * cost_multiplier;
|
||||
|
||||
CostBreakdown {
|
||||
input_cost,
|
||||
@@ -151,8 +151,9 @@ mod tests {
|
||||
|
||||
let cost = CostCalculator::calculate(&usage, &pricing, multiplier);
|
||||
|
||||
// input: 1000 * 3.0 / 1M * 1.5 = 0.0045
|
||||
assert_eq!(cost.input_cost, Decimal::from_str("0.0045").unwrap());
|
||||
// input_cost: 基础价格(不含倍率)= 1000 * 3.0 / 1M = 0.003
|
||||
assert_eq!(cost.input_cost, Decimal::from_str("0.003").unwrap());
|
||||
// total_cost: 基础价格 × 倍率 = 0.003 * 1.5 = 0.0045
|
||||
assert_eq!(cost.total_cost, Decimal::from_str("0.0045").unwrap());
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::database::Database;
|
||||
use crate::error::AppError;
|
||||
use crate::services::usage_stats::find_model_pricing_row;
|
||||
use rust_decimal::Decimal;
|
||||
use std::time::SystemTime;
|
||||
use std::{str::FromStr, time::SystemTime};
|
||||
|
||||
/// 请求日志
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -15,6 +15,7 @@ pub struct RequestLog {
|
||||
pub provider_id: String,
|
||||
pub app_type: String,
|
||||
pub model: String,
|
||||
pub request_model: String,
|
||||
pub usage: TokenUsage,
|
||||
pub cost: Option<CostBreakdown>,
|
||||
pub latency_ms: u64,
|
||||
@@ -73,17 +74,18 @@ impl<'a> UsageLogger<'a> {
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO proxy_request_logs (
|
||||
request_id, provider_id, app_type, model,
|
||||
request_id, provider_id, app_type, model, request_model,
|
||||
input_tokens, output_tokens, cache_read_tokens, cache_creation_tokens,
|
||||
input_cost_usd, output_cost_usd, cache_read_cost_usd, cache_creation_cost_usd, total_cost_usd,
|
||||
latency_ms, first_token_ms, status_code, error_message, session_id,
|
||||
provider_type, is_streaming, cost_multiplier, created_at
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18, ?19, ?20, ?21, ?22)",
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18, ?19, ?20, ?21, ?22, ?23)",
|
||||
rusqlite::params![
|
||||
log.request_id,
|
||||
log.provider_id,
|
||||
log.app_type,
|
||||
log.model,
|
||||
log.request_model,
|
||||
log.usage.input_tokens,
|
||||
log.usage.output_tokens,
|
||||
log.usage.cache_read_tokens,
|
||||
@@ -123,11 +125,13 @@ impl<'a> UsageLogger<'a> {
|
||||
error_message: String,
|
||||
latency_ms: u64,
|
||||
) -> Result<(), AppError> {
|
||||
let request_model = model.clone();
|
||||
let log = RequestLog {
|
||||
request_id,
|
||||
provider_id,
|
||||
app_type,
|
||||
model,
|
||||
request_model,
|
||||
usage: TokenUsage::default(),
|
||||
cost: None,
|
||||
latency_ms,
|
||||
@@ -160,11 +164,13 @@ impl<'a> UsageLogger<'a> {
|
||||
session_id: Option<String>,
|
||||
provider_type: Option<String>,
|
||||
) -> Result<(), AppError> {
|
||||
let request_model = model.clone();
|
||||
let log = RequestLog {
|
||||
request_id,
|
||||
provider_id,
|
||||
app_type,
|
||||
model,
|
||||
request_model,
|
||||
usage: TokenUsage::default(),
|
||||
cost: None,
|
||||
latency_ms,
|
||||
@@ -194,6 +200,96 @@ impl<'a> UsageLogger<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取有效的倍率与计费模式来源(供应商优先,未配置则回退全局默认)
|
||||
pub async fn resolve_pricing_config(
|
||||
&self,
|
||||
provider_id: &str,
|
||||
app_type: &str,
|
||||
) -> (Decimal, String) {
|
||||
let default_multiplier_raw = match self.db.get_default_cost_multiplier(app_type).await {
|
||||
Ok(value) => value,
|
||||
Err(e) => {
|
||||
log::warn!(
|
||||
"[USG-003] 获取默认倍率失败 (app_type={app_type}): {e}"
|
||||
);
|
||||
"1".to_string()
|
||||
}
|
||||
};
|
||||
let default_multiplier = match Decimal::from_str(&default_multiplier_raw) {
|
||||
Ok(value) => value,
|
||||
Err(e) => {
|
||||
log::warn!(
|
||||
"[USG-003] 默认倍率解析失败 (app_type={app_type}): {default_multiplier_raw} - {e}"
|
||||
);
|
||||
Decimal::from(1)
|
||||
}
|
||||
};
|
||||
|
||||
let default_pricing_source_raw = match self.db.get_pricing_model_source(app_type).await {
|
||||
Ok(value) => value,
|
||||
Err(e) => {
|
||||
log::warn!(
|
||||
"[USG-003] 获取默认计费模式失败 (app_type={app_type}): {e}"
|
||||
);
|
||||
"response".to_string()
|
||||
}
|
||||
};
|
||||
let default_pricing_source = if matches!(
|
||||
default_pricing_source_raw.as_str(),
|
||||
"response" | "request"
|
||||
) {
|
||||
default_pricing_source_raw
|
||||
} else {
|
||||
log::warn!(
|
||||
"[USG-003] 默认计费模式无效 (app_type={app_type}): {default_pricing_source_raw}"
|
||||
);
|
||||
"response".to_string()
|
||||
};
|
||||
|
||||
let provider = self
|
||||
.db
|
||||
.get_provider_by_id(provider_id, app_type)
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let (provider_multiplier, provider_pricing_source) = provider
|
||||
.as_ref()
|
||||
.and_then(|p| p.meta.as_ref())
|
||||
.map(|meta| {
|
||||
(
|
||||
meta.cost_multiplier.as_deref(),
|
||||
meta.pricing_model_source.as_deref(),
|
||||
)
|
||||
})
|
||||
.unwrap_or((None, None));
|
||||
|
||||
let cost_multiplier = match provider_multiplier {
|
||||
Some(value) => match Decimal::from_str(value) {
|
||||
Ok(parsed) => parsed,
|
||||
Err(e) => {
|
||||
log::warn!(
|
||||
"[USG-003] 供应商倍率解析失败 (provider_id={provider_id}): {value} - {e}"
|
||||
);
|
||||
default_multiplier
|
||||
}
|
||||
},
|
||||
None => default_multiplier,
|
||||
};
|
||||
|
||||
let pricing_model_source = match provider_pricing_source {
|
||||
Some(value) if matches!(value, "response" | "request") => value.to_string(),
|
||||
Some(value) => {
|
||||
log::warn!(
|
||||
"[USG-003] 供应商计费模式无效 (provider_id={provider_id}): {value}"
|
||||
);
|
||||
default_pricing_source.clone()
|
||||
}
|
||||
None => default_pricing_source.clone(),
|
||||
};
|
||||
|
||||
(cost_multiplier, pricing_model_source)
|
||||
}
|
||||
|
||||
/// 计算并记录请求
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn log_with_calculation(
|
||||
@@ -202,6 +298,8 @@ impl<'a> UsageLogger<'a> {
|
||||
provider_id: String,
|
||||
app_type: String,
|
||||
model: String,
|
||||
request_model: String,
|
||||
pricing_model: String,
|
||||
usage: TokenUsage,
|
||||
cost_multiplier: Decimal,
|
||||
latency_ms: u64,
|
||||
@@ -211,10 +309,12 @@ impl<'a> UsageLogger<'a> {
|
||||
provider_type: Option<String>,
|
||||
is_streaming: bool,
|
||||
) -> Result<(), AppError> {
|
||||
let pricing = self.get_model_pricing(&model)?;
|
||||
let pricing = self.get_model_pricing(&pricing_model)?;
|
||||
|
||||
if pricing.is_none() {
|
||||
log::warn!("[USG-002] 模型定价未找到,成本将记录为 0");
|
||||
log::warn!(
|
||||
"[USG-002] 模型定价未找到,成本将记录为 0: {pricing_model}"
|
||||
);
|
||||
}
|
||||
|
||||
let cost = CostCalculator::try_calculate(&usage, pricing.as_ref(), cost_multiplier);
|
||||
@@ -224,6 +324,7 @@ impl<'a> UsageLogger<'a> {
|
||||
provider_id,
|
||||
app_type,
|
||||
model,
|
||||
request_model,
|
||||
usage,
|
||||
cost,
|
||||
latency_ms,
|
||||
@@ -274,6 +375,8 @@ mod tests {
|
||||
"provider-1".to_string(),
|
||||
"claude".to_string(),
|
||||
"test-model".to_string(),
|
||||
"req-model".to_string(),
|
||||
"test-model".to_string(),
|
||||
usage,
|
||||
Decimal::from(1),
|
||||
100,
|
||||
@@ -286,14 +389,15 @@ mod tests {
|
||||
|
||||
// 验证记录已插入
|
||||
let conn = crate::database::lock_conn!(db.conn);
|
||||
let count: i64 = conn
|
||||
let (count, request_model): (i64, String) = conn
|
||||
.query_row(
|
||||
"SELECT COUNT(*) FROM proxy_request_logs WHERE request_id = 'req-123'",
|
||||
"SELECT COUNT(*), request_model FROM proxy_request_logs WHERE request_id = 'req-123'",
|
||||
[],
|
||||
|row| row.get(0),
|
||||
|row| Ok((row.get(0)?, row.get(1)?)),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(count, 1);
|
||||
assert_eq!(request_model, "req-model");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -94,6 +94,9 @@ pub struct RequestLogDetail {
|
||||
pub provider_name: Option<String>,
|
||||
pub app_type: String,
|
||||
pub model: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub request_model: Option<String>,
|
||||
pub cost_multiplier: String,
|
||||
pub input_tokens: u32,
|
||||
pub output_tokens: u32,
|
||||
pub cache_read_tokens: u32,
|
||||
@@ -140,7 +143,7 @@ impl Database {
|
||||
};
|
||||
|
||||
let sql = format!(
|
||||
"SELECT
|
||||
"SELECT
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(CAST(total_cost_usd AS REAL)), 0) as total_cost,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
@@ -218,7 +221,7 @@ impl Database {
|
||||
}
|
||||
|
||||
let sql = "
|
||||
SELECT
|
||||
SELECT
|
||||
CAST((created_at - ?1) / ?3 AS INTEGER) as bucket_idx,
|
||||
COUNT(*) as request_count,
|
||||
COALESCE(SUM(CAST(total_cost_usd AS REAL)), 0) as total_cost,
|
||||
@@ -295,7 +298,7 @@ impl Database {
|
||||
pub fn get_provider_stats(&self) -> Result<Vec<ProviderStats>, AppError> {
|
||||
let conn = lock_conn!(self.conn);
|
||||
|
||||
let sql = "SELECT
|
||||
let sql = "SELECT
|
||||
l.provider_id,
|
||||
p.name as provider_name,
|
||||
COUNT(*) as request_count,
|
||||
@@ -343,7 +346,7 @@ impl Database {
|
||||
pub fn get_model_stats(&self) -> Result<Vec<ModelStats>, AppError> {
|
||||
let conn = lock_conn!(self.conn);
|
||||
|
||||
let sql = "SELECT
|
||||
let sql = "SELECT
|
||||
model,
|
||||
COUNT(*) as request_count,
|
||||
COALESCE(SUM(input_tokens + output_tokens), 0) as total_tokens,
|
||||
@@ -424,7 +427,7 @@ impl Database {
|
||||
|
||||
// 获取总数
|
||||
let count_sql = format!(
|
||||
"SELECT COUNT(*) FROM proxy_request_logs l
|
||||
"SELECT COUNT(*) FROM proxy_request_logs l
|
||||
LEFT JOIN providers p ON l.provider_id = p.id AND l.app_type = p.app_type
|
||||
{where_clause}"
|
||||
);
|
||||
@@ -440,6 +443,7 @@ impl Database {
|
||||
|
||||
let sql = format!(
|
||||
"SELECT l.request_id, l.provider_id, p.name as provider_name, l.app_type, l.model,
|
||||
l.request_model, l.cost_multiplier,
|
||||
l.input_tokens, l.output_tokens, l.cache_read_tokens, l.cache_creation_tokens,
|
||||
l.input_cost_usd, l.output_cost_usd, l.cache_read_cost_usd, l.cache_creation_cost_usd, l.total_cost_usd,
|
||||
l.is_streaming, l.latency_ms, l.first_token_ms, l.duration_ms,
|
||||
@@ -460,22 +464,24 @@ impl Database {
|
||||
provider_name: row.get(2)?,
|
||||
app_type: row.get(3)?,
|
||||
model: row.get(4)?,
|
||||
input_tokens: row.get::<_, i64>(5)? as u32,
|
||||
output_tokens: row.get::<_, i64>(6)? as u32,
|
||||
cache_read_tokens: row.get::<_, i64>(7)? as u32,
|
||||
cache_creation_tokens: row.get::<_, i64>(8)? as u32,
|
||||
input_cost_usd: row.get(9)?,
|
||||
output_cost_usd: row.get(10)?,
|
||||
cache_read_cost_usd: row.get(11)?,
|
||||
cache_creation_cost_usd: row.get(12)?,
|
||||
total_cost_usd: row.get(13)?,
|
||||
is_streaming: row.get::<_, i64>(14)? != 0,
|
||||
latency_ms: row.get::<_, i64>(15)? as u64,
|
||||
first_token_ms: row.get::<_, Option<i64>>(16)?.map(|v| v as u64),
|
||||
duration_ms: row.get::<_, Option<i64>>(17)?.map(|v| v as u64),
|
||||
status_code: row.get::<_, i64>(18)? as u16,
|
||||
error_message: row.get(19)?,
|
||||
created_at: row.get(20)?,
|
||||
request_model: row.get(5)?,
|
||||
cost_multiplier: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "1".to_string()),
|
||||
input_tokens: row.get::<_, i64>(7)? as u32,
|
||||
output_tokens: row.get::<_, i64>(8)? as u32,
|
||||
cache_read_tokens: row.get::<_, i64>(9)? as u32,
|
||||
cache_creation_tokens: row.get::<_, i64>(10)? as u32,
|
||||
input_cost_usd: row.get(11)?,
|
||||
output_cost_usd: row.get(12)?,
|
||||
cache_read_cost_usd: row.get(13)?,
|
||||
cache_creation_cost_usd: row.get(14)?,
|
||||
total_cost_usd: row.get(15)?,
|
||||
is_streaming: row.get::<_, i64>(16)? != 0,
|
||||
latency_ms: row.get::<_, i64>(17)? as u64,
|
||||
first_token_ms: row.get::<_, Option<i64>>(18)?.map(|v| v as u64),
|
||||
duration_ms: row.get::<_, Option<i64>>(19)?.map(|v| v as u64),
|
||||
status_code: row.get::<_, i64>(20)? as u16,
|
||||
error_message: row.get(21)?,
|
||||
created_at: row.get(22)?,
|
||||
})
|
||||
})?;
|
||||
|
||||
@@ -511,6 +517,7 @@ impl Database {
|
||||
|
||||
let result = conn.query_row(
|
||||
"SELECT l.request_id, l.provider_id, p.name as provider_name, l.app_type, l.model,
|
||||
l.request_model, l.cost_multiplier,
|
||||
input_tokens, output_tokens, cache_read_tokens, cache_creation_tokens,
|
||||
input_cost_usd, output_cost_usd, cache_read_cost_usd, cache_creation_cost_usd, total_cost_usd,
|
||||
is_streaming, latency_ms, first_token_ms, duration_ms,
|
||||
@@ -526,22 +533,24 @@ impl Database {
|
||||
provider_name: row.get(2)?,
|
||||
app_type: row.get(3)?,
|
||||
model: row.get(4)?,
|
||||
input_tokens: row.get::<_, i64>(5)? as u32,
|
||||
output_tokens: row.get::<_, i64>(6)? as u32,
|
||||
cache_read_tokens: row.get::<_, i64>(7)? as u32,
|
||||
cache_creation_tokens: row.get::<_, i64>(8)? as u32,
|
||||
input_cost_usd: row.get(9)?,
|
||||
output_cost_usd: row.get(10)?,
|
||||
cache_read_cost_usd: row.get(11)?,
|
||||
cache_creation_cost_usd: row.get(12)?,
|
||||
total_cost_usd: row.get(13)?,
|
||||
is_streaming: row.get::<_, i64>(14)? != 0,
|
||||
latency_ms: row.get::<_, i64>(15)? as u64,
|
||||
first_token_ms: row.get::<_, Option<i64>>(16)?.map(|v| v as u64),
|
||||
duration_ms: row.get::<_, Option<i64>>(17)?.map(|v| v as u64),
|
||||
status_code: row.get::<_, i64>(18)? as u16,
|
||||
error_message: row.get(19)?,
|
||||
created_at: row.get(20)?,
|
||||
request_model: row.get(5)?,
|
||||
cost_multiplier: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "1".to_string()),
|
||||
input_tokens: row.get::<_, i64>(7)? as u32,
|
||||
output_tokens: row.get::<_, i64>(8)? as u32,
|
||||
cache_read_tokens: row.get::<_, i64>(9)? as u32,
|
||||
cache_creation_tokens: row.get::<_, i64>(10)? as u32,
|
||||
input_cost_usd: row.get(11)?,
|
||||
output_cost_usd: row.get(12)?,
|
||||
cache_read_cost_usd: row.get(13)?,
|
||||
cache_creation_cost_usd: row.get(14)?,
|
||||
total_cost_usd: row.get(15)?,
|
||||
is_streaming: row.get::<_, i64>(16)? != 0,
|
||||
latency_ms: row.get::<_, i64>(17)? as u64,
|
||||
first_token_ms: row.get::<_, Option<i64>>(18)?.map(|v| v as u64),
|
||||
duration_ms: row.get::<_, Option<i64>>(19)?.map(|v| v as u64),
|
||||
status_code: row.get::<_, i64>(20)? as u16,
|
||||
error_message: row.get(21)?,
|
||||
created_at: row.get(22)?,
|
||||
})
|
||||
},
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user