mirror of
https://github.com/farion1231/cc-switch.git
synced 2026-04-23 17:45:28 +08:00
Compare commits
9 Commits
feat/proxy
...
feat/prici
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eec05e704f | ||
|
|
8c20c57b29 | ||
|
|
5ef6f78fd0 | ||
|
|
44fba676d5 | ||
|
|
e6f91541a3 | ||
|
|
828f839083 | ||
|
|
63b874aff1 | ||
|
|
bdef49fe0f | ||
|
|
bba6524979 |
@@ -2,6 +2,7 @@
|
||||
//!
|
||||
//! 提供前端调用的 API 接口
|
||||
|
||||
use crate::error::AppError;
|
||||
use crate::proxy::types::*;
|
||||
use crate::proxy::{CircuitBreakerConfig, CircuitBreakerStats};
|
||||
use crate::store::AppState;
|
||||
@@ -119,6 +120,120 @@ pub async fn update_proxy_config_for_app(
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
async fn get_default_cost_multiplier_internal(
|
||||
state: &AppState,
|
||||
app_type: &str,
|
||||
) -> Result<String, AppError> {
|
||||
let db = &state.db;
|
||||
db.get_default_cost_multiplier(app_type).await
|
||||
}
|
||||
|
||||
#[cfg_attr(not(feature = "test-hooks"), doc(hidden))]
|
||||
pub async fn get_default_cost_multiplier_test_hook(
|
||||
state: &AppState,
|
||||
app_type: &str,
|
||||
) -> Result<String, AppError> {
|
||||
get_default_cost_multiplier_internal(state, app_type).await
|
||||
}
|
||||
|
||||
/// 获取默认成本倍率
|
||||
#[tauri::command]
|
||||
pub async fn get_default_cost_multiplier(
|
||||
state: tauri::State<'_, AppState>,
|
||||
app_type: String,
|
||||
) -> Result<String, String> {
|
||||
get_default_cost_multiplier_internal(&state, &app_type)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
async fn set_default_cost_multiplier_internal(
|
||||
state: &AppState,
|
||||
app_type: &str,
|
||||
value: &str,
|
||||
) -> Result<(), AppError> {
|
||||
let db = &state.db;
|
||||
db.set_default_cost_multiplier(app_type, value).await
|
||||
}
|
||||
|
||||
#[cfg_attr(not(feature = "test-hooks"), doc(hidden))]
|
||||
pub async fn set_default_cost_multiplier_test_hook(
|
||||
state: &AppState,
|
||||
app_type: &str,
|
||||
value: &str,
|
||||
) -> Result<(), AppError> {
|
||||
set_default_cost_multiplier_internal(state, app_type, value).await
|
||||
}
|
||||
|
||||
/// 设置默认成本倍率
|
||||
#[tauri::command]
|
||||
pub async fn set_default_cost_multiplier(
|
||||
state: tauri::State<'_, AppState>,
|
||||
app_type: String,
|
||||
value: String,
|
||||
) -> Result<(), String> {
|
||||
set_default_cost_multiplier_internal(&state, &app_type, &value)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
async fn get_pricing_model_source_internal(
|
||||
state: &AppState,
|
||||
app_type: &str,
|
||||
) -> Result<String, AppError> {
|
||||
let db = &state.db;
|
||||
db.get_pricing_model_source(app_type).await
|
||||
}
|
||||
|
||||
#[cfg_attr(not(feature = "test-hooks"), doc(hidden))]
|
||||
pub async fn get_pricing_model_source_test_hook(
|
||||
state: &AppState,
|
||||
app_type: &str,
|
||||
) -> Result<String, AppError> {
|
||||
get_pricing_model_source_internal(state, app_type).await
|
||||
}
|
||||
|
||||
/// 获取计费模式来源
|
||||
#[tauri::command]
|
||||
pub async fn get_pricing_model_source(
|
||||
state: tauri::State<'_, AppState>,
|
||||
app_type: String,
|
||||
) -> Result<String, String> {
|
||||
get_pricing_model_source_internal(&state, &app_type)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
async fn set_pricing_model_source_internal(
|
||||
state: &AppState,
|
||||
app_type: &str,
|
||||
value: &str,
|
||||
) -> Result<(), AppError> {
|
||||
let db = &state.db;
|
||||
db.set_pricing_model_source(app_type, value).await
|
||||
}
|
||||
|
||||
#[cfg_attr(not(feature = "test-hooks"), doc(hidden))]
|
||||
pub async fn set_pricing_model_source_test_hook(
|
||||
state: &AppState,
|
||||
app_type: &str,
|
||||
value: &str,
|
||||
) -> Result<(), AppError> {
|
||||
set_pricing_model_source_internal(state, app_type, value).await
|
||||
}
|
||||
|
||||
/// 设置计费模式来源
|
||||
#[tauri::command]
|
||||
pub async fn set_pricing_model_source(
|
||||
state: tauri::State<'_, AppState>,
|
||||
app_type: String,
|
||||
value: String,
|
||||
) -> Result<(), String> {
|
||||
set_pricing_model_source_internal(&state, &app_type, &value)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
/// 检查代理服务器是否正在运行
|
||||
#[tauri::command]
|
||||
pub async fn is_proxy_running(state: tauri::State<'_, AppState>) -> Result<bool, String> {
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
use crate::error::AppError;
|
||||
use crate::proxy::types::*;
|
||||
use rust_decimal::Decimal;
|
||||
|
||||
use super::super::{lock_conn, Database};
|
||||
|
||||
@@ -75,6 +76,117 @@ impl Database {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取默认成本倍率
|
||||
pub async fn get_default_cost_multiplier(&self, app_type: &str) -> Result<String, AppError> {
|
||||
let result = {
|
||||
let conn = lock_conn!(self.conn);
|
||||
conn.query_row(
|
||||
"SELECT default_cost_multiplier FROM proxy_config WHERE app_type = ?1",
|
||||
[app_type],
|
||||
|row| row.get(0),
|
||||
)
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(value) => Ok(value),
|
||||
Err(rusqlite::Error::QueryReturnedNoRows) => {
|
||||
self.init_proxy_config_rows().await?;
|
||||
Ok("1".to_string())
|
||||
}
|
||||
Err(e) => Err(AppError::Database(e.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
/// 设置默认成本倍率
|
||||
pub async fn set_default_cost_multiplier(
|
||||
&self,
|
||||
app_type: &str,
|
||||
value: &str,
|
||||
) -> Result<(), AppError> {
|
||||
let trimmed = value.trim();
|
||||
if trimmed.is_empty() {
|
||||
return Err(AppError::localized(
|
||||
"error.multiplierEmpty",
|
||||
"倍率不能为空",
|
||||
"Multiplier cannot be empty",
|
||||
));
|
||||
}
|
||||
trimmed.parse::<Decimal>().map_err(|e| {
|
||||
AppError::localized(
|
||||
"error.invalidMultiplier",
|
||||
format!("无效倍率: {value} - {e}"),
|
||||
format!("Invalid multiplier: {value} - {e}"),
|
||||
)
|
||||
})?;
|
||||
|
||||
// 确保行存在
|
||||
self.ensure_proxy_config_row_exists(app_type)?;
|
||||
|
||||
let conn = lock_conn!(self.conn);
|
||||
conn.execute(
|
||||
"UPDATE proxy_config SET
|
||||
default_cost_multiplier = ?2,
|
||||
updated_at = datetime('now')
|
||||
WHERE app_type = ?1",
|
||||
rusqlite::params![app_type, trimmed],
|
||||
)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取计费模式来源
|
||||
pub async fn get_pricing_model_source(&self, app_type: &str) -> Result<String, AppError> {
|
||||
let result = {
|
||||
let conn = lock_conn!(self.conn);
|
||||
conn.query_row(
|
||||
"SELECT pricing_model_source FROM proxy_config WHERE app_type = ?1",
|
||||
[app_type],
|
||||
|row| row.get(0),
|
||||
)
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(value) => Ok(value),
|
||||
Err(rusqlite::Error::QueryReturnedNoRows) => {
|
||||
self.init_proxy_config_rows().await?;
|
||||
Ok("response".to_string())
|
||||
}
|
||||
Err(e) => Err(AppError::Database(e.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
/// 设置计费模式来源
|
||||
pub async fn set_pricing_model_source(
|
||||
&self,
|
||||
app_type: &str,
|
||||
value: &str,
|
||||
) -> Result<(), AppError> {
|
||||
let trimmed = value.trim();
|
||||
if !matches!(trimmed, "response" | "request") {
|
||||
return Err(AppError::localized(
|
||||
"error.invalidPricingMode",
|
||||
format!("无效计费模式: {value}"),
|
||||
format!("Invalid pricing mode: {value}"),
|
||||
));
|
||||
}
|
||||
|
||||
// 确保行存在
|
||||
self.ensure_proxy_config_row_exists(app_type)?;
|
||||
|
||||
let conn = lock_conn!(self.conn);
|
||||
conn.execute(
|
||||
"UPDATE proxy_config SET
|
||||
pricing_model_source = ?2,
|
||||
updated_at = datetime('now')
|
||||
WHERE app_type = ?1",
|
||||
rusqlite::params![app_type, trimmed],
|
||||
)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取应用级代理配置
|
||||
pub async fn get_proxy_config_for_app(
|
||||
&self,
|
||||
@@ -177,17 +289,90 @@ impl Database {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 确保指定 app_type 的 proxy_config 行存在(同步版本,用于 set_* 函数)
|
||||
///
|
||||
/// 使用与 schema.rs seed 相同的 per-app 默认值
|
||||
fn ensure_proxy_config_row_exists(&self, app_type: &str) -> Result<(), AppError> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| AppError::Lock(e.to_string()))?;
|
||||
|
||||
// 根据 app_type 使用不同的默认值(与 schema.rs seed 保持一致)
|
||||
let (retries, fb_timeout, idle_timeout, cb_fail, cb_succ, cb_timeout, cb_rate, cb_min) =
|
||||
match app_type {
|
||||
"claude" => (6, 90, 180, 8, 3, 90, 0.7, 15),
|
||||
"codex" => (3, 60, 120, 4, 2, 60, 0.6, 10),
|
||||
"gemini" => (5, 60, 120, 4, 2, 60, 0.6, 10),
|
||||
_ => (3, 60, 120, 4, 2, 60, 0.6, 10), // 默认值
|
||||
};
|
||||
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO proxy_config (
|
||||
app_type, max_retries,
|
||||
streaming_first_byte_timeout, streaming_idle_timeout, non_streaming_timeout,
|
||||
circuit_failure_threshold, circuit_success_threshold, circuit_timeout_seconds,
|
||||
circuit_error_rate_threshold, circuit_min_requests
|
||||
) VALUES (?1, ?2, ?3, ?4, 600, ?5, ?6, ?7, ?8, ?9)",
|
||||
rusqlite::params![
|
||||
app_type,
|
||||
retries,
|
||||
fb_timeout,
|
||||
idle_timeout,
|
||||
cb_fail,
|
||||
cb_succ,
|
||||
cb_timeout,
|
||||
cb_rate,
|
||||
cb_min
|
||||
],
|
||||
)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 初始化 proxy_config 表的三行数据
|
||||
///
|
||||
/// 使用与 schema.rs seed 相同的 per-app 默认值
|
||||
async fn init_proxy_config_rows(&self) -> Result<(), AppError> {
|
||||
let conn = lock_conn!(self.conn);
|
||||
|
||||
for app_type in &["claude", "codex", "gemini"] {
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO proxy_config (app_type) VALUES (?1)",
|
||||
[app_type],
|
||||
)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
}
|
||||
// 使用与 schema.rs seed 相同的 per-app 默认值
|
||||
// claude: 更激进的重试和超时配置
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO proxy_config (
|
||||
app_type, max_retries,
|
||||
streaming_first_byte_timeout, streaming_idle_timeout, non_streaming_timeout,
|
||||
circuit_failure_threshold, circuit_success_threshold, circuit_timeout_seconds,
|
||||
circuit_error_rate_threshold, circuit_min_requests
|
||||
) VALUES ('claude', 6, 90, 180, 600, 8, 3, 90, 0.7, 15)",
|
||||
[],
|
||||
)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
|
||||
// codex: 默认配置
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO proxy_config (
|
||||
app_type, max_retries,
|
||||
streaming_first_byte_timeout, streaming_idle_timeout, non_streaming_timeout,
|
||||
circuit_failure_threshold, circuit_success_threshold, circuit_timeout_seconds,
|
||||
circuit_error_rate_threshold, circuit_min_requests
|
||||
) VALUES ('codex', 3, 60, 120, 600, 4, 2, 60, 0.6, 10)",
|
||||
[],
|
||||
)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
|
||||
// gemini: 稍高的重试次数
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO proxy_config (
|
||||
app_type, max_retries,
|
||||
streaming_first_byte_timeout, streaming_idle_timeout, non_streaming_timeout,
|
||||
circuit_failure_threshold, circuit_success_threshold, circuit_timeout_seconds,
|
||||
circuit_error_rate_threshold, circuit_min_requests
|
||||
) VALUES ('gemini', 5, 60, 120, 600, 4, 2, 60, 0.6, 10)",
|
||||
[],
|
||||
)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -662,3 +847,58 @@ impl Database {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::database::Database;
|
||||
use crate::error::AppError;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_default_cost_multiplier_round_trip() -> Result<(), AppError> {
|
||||
let db = Database::memory()?;
|
||||
|
||||
let default = db.get_default_cost_multiplier("claude").await?;
|
||||
assert_eq!(default, "1");
|
||||
|
||||
db.set_default_cost_multiplier("claude", "1.5").await?;
|
||||
let updated = db.get_default_cost_multiplier("claude").await?;
|
||||
assert_eq!(updated, "1.5");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_default_cost_multiplier_validation() -> Result<(), AppError> {
|
||||
let db = Database::memory()?;
|
||||
|
||||
let err = db
|
||||
.set_default_cost_multiplier("claude", "not-a-number")
|
||||
.await
|
||||
.unwrap_err();
|
||||
// AppError::localized returns AppError::Localized variant
|
||||
assert!(matches!(err, AppError::Localized { key: "error.invalidMultiplier", .. }));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pricing_model_source_round_trip_and_validation() -> Result<(), AppError> {
|
||||
let db = Database::memory()?;
|
||||
|
||||
let default = db.get_pricing_model_source("claude").await?;
|
||||
assert_eq!(default, "response");
|
||||
|
||||
db.set_pricing_model_source("claude", "request").await?;
|
||||
let updated = db.get_pricing_model_source("claude").await?;
|
||||
assert_eq!(updated, "request");
|
||||
|
||||
let err = db
|
||||
.set_pricing_model_source("claude", "invalid")
|
||||
.await
|
||||
.unwrap_err();
|
||||
// AppError::localized returns AppError::Localized variant
|
||||
assert!(matches!(err, AppError::Localized { key: "error.invalidPricingMode", .. }));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ const DB_BACKUP_RETAIN: usize = 10;
|
||||
|
||||
/// 当前 Schema 版本号
|
||||
/// 每次修改表结构时递增,并在 schema.rs 中添加相应的迁移逻辑
|
||||
pub(crate) const SCHEMA_VERSION: i32 = 4;
|
||||
pub(crate) const SCHEMA_VERSION: i32 = 5;
|
||||
|
||||
/// 安全地序列化 JSON,避免 unwrap panic
|
||||
pub(crate) fn to_json_string<T: Serialize>(value: &T) -> Result<String, AppError> {
|
||||
|
||||
@@ -120,6 +120,8 @@ impl Database {
|
||||
circuit_failure_threshold INTEGER NOT NULL DEFAULT 4, circuit_success_threshold INTEGER NOT NULL DEFAULT 2,
|
||||
circuit_timeout_seconds INTEGER NOT NULL DEFAULT 60, circuit_error_rate_threshold REAL NOT NULL DEFAULT 0.6,
|
||||
circuit_min_requests INTEGER NOT NULL DEFAULT 10,
|
||||
default_cost_multiplier TEXT NOT NULL DEFAULT '1',
|
||||
pricing_model_source TEXT NOT NULL DEFAULT 'response',
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')), updated_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)", []).map_err(|e| AppError::Database(e.to_string()))?;
|
||||
|
||||
@@ -170,6 +172,7 @@ impl Database {
|
||||
// 10. Proxy Request Logs 表
|
||||
conn.execute("CREATE TABLE IF NOT EXISTS proxy_request_logs (
|
||||
request_id TEXT PRIMARY KEY, provider_id TEXT NOT NULL, app_type TEXT NOT NULL, model TEXT NOT NULL,
|
||||
request_model TEXT,
|
||||
input_tokens INTEGER NOT NULL DEFAULT 0, output_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
cache_read_tokens INTEGER NOT NULL DEFAULT 0, cache_creation_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
input_cost_usd TEXT NOT NULL DEFAULT '0', output_cost_usd TEXT NOT NULL DEFAULT '0',
|
||||
@@ -352,6 +355,11 @@ impl Database {
|
||||
Self::migrate_v3_to_v4(conn)?;
|
||||
Self::set_user_version(conn, 4)?;
|
||||
}
|
||||
4 => {
|
||||
log::info!("迁移数据库从 v4 到 v5(计费模式支持)");
|
||||
Self::migrate_v4_to_v5(conn)?;
|
||||
Self::set_user_version(conn, 5)?;
|
||||
}
|
||||
_ => {
|
||||
return Err(AppError::Database(format!(
|
||||
"未知的数据库版本 {version},无法迁移到 {SCHEMA_VERSION}"
|
||||
@@ -521,6 +529,7 @@ impl Database {
|
||||
// proxy_request_logs 表
|
||||
conn.execute("CREATE TABLE IF NOT EXISTS proxy_request_logs (
|
||||
request_id TEXT PRIMARY KEY, provider_id TEXT NOT NULL, app_type TEXT NOT NULL, model TEXT NOT NULL,
|
||||
request_model TEXT,
|
||||
input_tokens INTEGER NOT NULL DEFAULT 0, output_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
cache_read_tokens INTEGER NOT NULL DEFAULT 0, cache_creation_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
input_cost_usd TEXT NOT NULL DEFAULT '0', output_cost_usd TEXT NOT NULL DEFAULT '0',
|
||||
@@ -677,6 +686,8 @@ impl Database {
|
||||
circuit_failure_threshold INTEGER NOT NULL DEFAULT 4, circuit_success_threshold INTEGER NOT NULL DEFAULT 2,
|
||||
circuit_timeout_seconds INTEGER NOT NULL DEFAULT 60, circuit_error_rate_threshold REAL NOT NULL DEFAULT 0.6,
|
||||
circuit_min_requests INTEGER NOT NULL DEFAULT 10,
|
||||
default_cost_multiplier TEXT NOT NULL DEFAULT '1',
|
||||
pricing_model_source TEXT NOT NULL DEFAULT 'response',
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')), updated_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)", [])?;
|
||||
|
||||
@@ -879,6 +890,30 @@ impl Database {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// v4 -> v5 迁移:新增计费模式配置与请求模型字段
|
||||
fn migrate_v4_to_v5(conn: &Connection) -> Result<(), AppError> {
|
||||
if Self::table_exists(conn, "proxy_config")? {
|
||||
Self::add_column_if_missing(
|
||||
conn,
|
||||
"proxy_config",
|
||||
"default_cost_multiplier",
|
||||
"TEXT NOT NULL DEFAULT '1'",
|
||||
)?;
|
||||
Self::add_column_if_missing(
|
||||
conn,
|
||||
"proxy_config",
|
||||
"pricing_model_source",
|
||||
"TEXT NOT NULL DEFAULT 'response'",
|
||||
)?;
|
||||
}
|
||||
if Self::table_exists(conn, "proxy_request_logs")? {
|
||||
Self::add_column_if_missing(conn, "proxy_request_logs", "request_model", "TEXT")?;
|
||||
}
|
||||
|
||||
log::info!("v4 -> v5 迁移完成:已添加计费模式与请求模型字段");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 插入默认模型定价数据
|
||||
/// 格式: (model_id, display_name, input, output, cache_read, cache_creation)
|
||||
/// 注意: model_id 使用短横线格式(如 claude-haiku-4-5),与 API 返回的模型名称标准化后一致
|
||||
|
||||
@@ -151,7 +151,7 @@ fn normalize_default(default: &Option<String>) -> Option<String> {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn migration_sets_user_version_when_missing() {
|
||||
fn schema_migration_sets_user_version_when_missing() {
|
||||
let conn = Connection::open_in_memory().expect("open memory db");
|
||||
|
||||
Database::create_tables_on_conn(&conn).expect("create tables");
|
||||
@@ -169,7 +169,7 @@ fn migration_sets_user_version_when_missing() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn migration_rejects_future_version() {
|
||||
fn schema_migration_rejects_future_version() {
|
||||
let conn = Connection::open_in_memory().expect("open memory db");
|
||||
Database::create_tables_on_conn(&conn).expect("create tables");
|
||||
Database::set_user_version(&conn, SCHEMA_VERSION + 1).expect("set future version");
|
||||
@@ -183,7 +183,7 @@ fn migration_rejects_future_version() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn migration_adds_missing_columns_for_providers() {
|
||||
fn schema_migration_adds_missing_columns_for_providers() {
|
||||
let conn = Connection::open_in_memory().expect("open memory db");
|
||||
|
||||
// 创建旧版 providers 表,缺少新增列
|
||||
@@ -224,7 +224,7 @@ fn migration_adds_missing_columns_for_providers() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn migration_aligns_column_defaults_and_types() {
|
||||
fn schema_migration_aligns_column_defaults_and_types() {
|
||||
let conn = Connection::open_in_memory().expect("open memory db");
|
||||
conn.execute_batch(LEGACY_SCHEMA_SQL)
|
||||
.expect("seed old schema");
|
||||
@@ -268,7 +268,67 @@ fn migration_aligns_column_defaults_and_types() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_tables_repairs_legacy_proxy_config_singleton_to_per_app() {
|
||||
fn schema_create_tables_include_pricing_model_columns() {
|
||||
let conn = Connection::open_in_memory().expect("open memory db");
|
||||
Database::create_tables_on_conn(&conn).expect("create tables");
|
||||
|
||||
let multiplier = get_column_info(&conn, "proxy_config", "default_cost_multiplier");
|
||||
assert_eq!(multiplier.r#type, "TEXT");
|
||||
assert_eq!(multiplier.notnull, 1);
|
||||
assert_eq!(normalize_default(&multiplier.default).as_deref(), Some("1"));
|
||||
|
||||
let pricing_source = get_column_info(&conn, "proxy_config", "pricing_model_source");
|
||||
assert_eq!(pricing_source.r#type, "TEXT");
|
||||
assert_eq!(pricing_source.notnull, 1);
|
||||
assert_eq!(
|
||||
normalize_default(&pricing_source.default).as_deref(),
|
||||
Some("response")
|
||||
);
|
||||
|
||||
let request_model = get_column_info(&conn, "proxy_request_logs", "request_model");
|
||||
assert_eq!(request_model.r#type, "TEXT");
|
||||
assert_eq!(request_model.notnull, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn schema_migration_v4_adds_pricing_model_columns() {
|
||||
let conn = Connection::open_in_memory().expect("open memory db");
|
||||
conn.execute_batch(
|
||||
r#"
|
||||
CREATE TABLE proxy_config (app_type TEXT PRIMARY KEY);
|
||||
CREATE TABLE proxy_request_logs (request_id TEXT PRIMARY KEY, model TEXT NOT NULL);
|
||||
"#,
|
||||
)
|
||||
.expect("seed v4 schema");
|
||||
|
||||
Database::set_user_version(&conn, 4).expect("set user_version=4");
|
||||
Database::apply_schema_migrations_on_conn(&conn).expect("apply migrations");
|
||||
|
||||
let multiplier = get_column_info(&conn, "proxy_config", "default_cost_multiplier");
|
||||
assert_eq!(multiplier.r#type, "TEXT");
|
||||
assert_eq!(multiplier.notnull, 1);
|
||||
assert_eq!(normalize_default(&multiplier.default).as_deref(), Some("1"));
|
||||
|
||||
let pricing_source = get_column_info(&conn, "proxy_config", "pricing_model_source");
|
||||
assert_eq!(pricing_source.r#type, "TEXT");
|
||||
assert_eq!(pricing_source.notnull, 1);
|
||||
assert_eq!(
|
||||
normalize_default(&pricing_source.default).as_deref(),
|
||||
Some("response")
|
||||
);
|
||||
|
||||
let request_model = get_column_info(&conn, "proxy_request_logs", "request_model");
|
||||
assert_eq!(request_model.r#type, "TEXT");
|
||||
assert_eq!(request_model.notnull, 0);
|
||||
|
||||
assert_eq!(
|
||||
Database::get_user_version(&conn).expect("version after migration"),
|
||||
SCHEMA_VERSION
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn schema_create_tables_repairs_legacy_proxy_config_singleton_to_per_app() {
|
||||
let conn = Connection::open_in_memory().expect("open memory db");
|
||||
|
||||
// 模拟测试版 v2:user_version=2,但 proxy_config 仍是单例结构(无 app_type)
|
||||
@@ -433,7 +493,7 @@ fn migration_from_v3_8_schema_v1_to_current_schema_v3() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dry_run_does_not_write_to_disk() {
|
||||
fn schema_dry_run_does_not_write_to_disk() {
|
||||
// Create minimal valid config for migration
|
||||
let mut apps = HashMap::new();
|
||||
apps.insert("claude".to_string(), ProviderManager::default());
|
||||
@@ -507,7 +567,7 @@ fn dry_run_validates_schema_compatibility() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn model_pricing_is_seeded_on_init() {
|
||||
fn schema_model_pricing_is_seeded_on_init() {
|
||||
let db = Database::memory().expect("create memory db");
|
||||
|
||||
let conn = db.conn.lock().expect("lock conn");
|
||||
|
||||
@@ -895,6 +895,10 @@ pub fn run() {
|
||||
commands::update_global_proxy_config,
|
||||
commands::get_proxy_config_for_app,
|
||||
commands::update_proxy_config_for_app,
|
||||
commands::get_default_cost_multiplier,
|
||||
commands::set_default_cost_multiplier,
|
||||
commands::get_pricing_model_source,
|
||||
commands::set_pricing_model_source,
|
||||
commands::is_proxy_running,
|
||||
commands::is_live_takeover_active,
|
||||
commands::switch_proxy_provider,
|
||||
|
||||
@@ -215,6 +215,9 @@ pub struct ProviderMeta {
|
||||
/// 成本倍数(用于计算实际成本)
|
||||
#[serde(rename = "costMultiplier", skip_serializing_if = "Option::is_none")]
|
||||
pub cost_multiplier: Option<String>,
|
||||
/// 计费模式来源(response/request)
|
||||
#[serde(rename = "pricingModelSource", skip_serializing_if = "Option::is_none")]
|
||||
pub pricing_model_source: Option<String>,
|
||||
/// 每日消费限额(USD)
|
||||
#[serde(rename = "limitDailyUsd", skip_serializing_if = "Option::is_none")]
|
||||
pub limit_daily_usd: Option<String>,
|
||||
@@ -614,3 +617,267 @@ pub struct OpenCodeModelLimit {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub output: Option<u64>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
ClaudeModelConfig, CodexModelConfig, GeminiModelConfig, OpenCodeProviderConfig, Provider,
|
||||
ProviderManager, ProviderMeta, UniversalProvider,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn provider_meta_serializes_pricing_model_source() {
|
||||
let mut meta = ProviderMeta::default();
|
||||
meta.pricing_model_source = Some("response".to_string());
|
||||
|
||||
let value = serde_json::to_value(&meta).expect("serialize ProviderMeta");
|
||||
|
||||
assert_eq!(
|
||||
value
|
||||
.get("pricingModelSource")
|
||||
.and_then(|item| item.as_str()),
|
||||
Some("response")
|
||||
);
|
||||
assert!(value.get("pricing_model_source").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_meta_omits_pricing_model_source_when_none() {
|
||||
let meta = ProviderMeta::default();
|
||||
let value = serde_json::to_value(&meta).expect("serialize ProviderMeta");
|
||||
|
||||
assert!(value.get("pricingModelSource").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_with_id_populates_defaults() {
|
||||
let settings_config = json!({
|
||||
"env": { "API_KEY": "test" }
|
||||
});
|
||||
let provider = Provider::with_id(
|
||||
"provider-1".to_string(),
|
||||
"Provider".to_string(),
|
||||
settings_config.clone(),
|
||||
Some("https://example.com".to_string()),
|
||||
);
|
||||
|
||||
assert_eq!(provider.id, "provider-1");
|
||||
assert_eq!(provider.name, "Provider");
|
||||
assert_eq!(provider.settings_config, settings_config);
|
||||
assert_eq!(provider.website_url.as_deref(), Some("https://example.com"));
|
||||
assert!(provider.category.is_none());
|
||||
assert!(provider.created_at.is_none());
|
||||
assert!(provider.sort_index.is_none());
|
||||
assert!(provider.notes.is_none());
|
||||
assert!(provider.meta.is_none());
|
||||
assert!(provider.icon.is_none());
|
||||
assert!(provider.icon_color.is_none());
|
||||
assert!(!provider.in_failover_queue);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_manager_get_all_providers_returns_map() {
|
||||
let mut manager = ProviderManager::default();
|
||||
let provider = Provider::with_id(
|
||||
"provider-1".to_string(),
|
||||
"Provider".to_string(),
|
||||
json!({ "env": {} }),
|
||||
None,
|
||||
);
|
||||
manager.providers.insert("provider-1".to_string(), provider);
|
||||
|
||||
assert_eq!(manager.get_all_providers().len(), 1);
|
||||
assert!(manager.get_all_providers().contains_key("provider-1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn universal_provider_to_claude_provider_uses_models() {
|
||||
let mut universal = UniversalProvider::new(
|
||||
"u1".to_string(),
|
||||
"Universal".to_string(),
|
||||
"newapi".to_string(),
|
||||
"https://api.example.com".to_string(),
|
||||
"api-key".to_string(),
|
||||
);
|
||||
universal.apps.claude = true;
|
||||
universal.models.claude = Some(ClaudeModelConfig {
|
||||
model: Some("claude-main".to_string()),
|
||||
haiku_model: Some("claude-haiku".to_string()),
|
||||
sonnet_model: Some("claude-sonnet".to_string()),
|
||||
opus_model: Some("claude-opus".to_string()),
|
||||
});
|
||||
|
||||
let provider = universal.to_claude_provider().expect("claude provider");
|
||||
|
||||
assert_eq!(provider.id, "universal-claude-u1");
|
||||
assert_eq!(provider.name, "Universal");
|
||||
assert_eq!(provider.category.as_deref(), Some("aggregator"));
|
||||
assert_eq!(
|
||||
provider
|
||||
.settings_config
|
||||
.pointer("/env/ANTHROPIC_MODEL")
|
||||
.and_then(|item| item.as_str()),
|
||||
Some("claude-main")
|
||||
);
|
||||
assert_eq!(
|
||||
provider
|
||||
.settings_config
|
||||
.pointer("/env/ANTHROPIC_DEFAULT_HAIKU_MODEL")
|
||||
.and_then(|item| item.as_str()),
|
||||
Some("claude-haiku")
|
||||
);
|
||||
assert_eq!(
|
||||
provider
|
||||
.settings_config
|
||||
.pointer("/env/ANTHROPIC_DEFAULT_SONNET_MODEL")
|
||||
.and_then(|item| item.as_str()),
|
||||
Some("claude-sonnet")
|
||||
);
|
||||
assert_eq!(
|
||||
provider
|
||||
.settings_config
|
||||
.pointer("/env/ANTHROPIC_DEFAULT_OPUS_MODEL")
|
||||
.and_then(|item| item.as_str()),
|
||||
Some("claude-opus")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn universal_provider_to_claude_provider_disabled_returns_none() {
|
||||
let universal = UniversalProvider::new(
|
||||
"u1".to_string(),
|
||||
"Universal".to_string(),
|
||||
"newapi".to_string(),
|
||||
"https://api.example.com".to_string(),
|
||||
"api-key".to_string(),
|
||||
);
|
||||
|
||||
assert!(universal.to_claude_provider().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn universal_provider_to_codex_provider_appends_v1() {
|
||||
let mut universal = UniversalProvider::new(
|
||||
"u1".to_string(),
|
||||
"Universal".to_string(),
|
||||
"newapi".to_string(),
|
||||
"https://api.example.com".to_string(),
|
||||
"api-key".to_string(),
|
||||
);
|
||||
universal.apps.codex = true;
|
||||
universal.models.codex = Some(CodexModelConfig {
|
||||
model: Some("gpt-4o-mini".to_string()),
|
||||
reasoning_effort: Some("low".to_string()),
|
||||
});
|
||||
|
||||
let provider = universal.to_codex_provider().expect("codex provider");
|
||||
let config = provider
|
||||
.settings_config
|
||||
.get("config")
|
||||
.and_then(|item| item.as_str())
|
||||
.expect("config toml");
|
||||
|
||||
assert!(config.contains("base_url = \"https://api.example.com/v1\""));
|
||||
assert_eq!(
|
||||
provider
|
||||
.settings_config
|
||||
.pointer("/auth/OPENAI_API_KEY")
|
||||
.and_then(|item| item.as_str()),
|
||||
Some("api-key")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn universal_provider_to_codex_provider_keeps_v1_suffix() {
|
||||
let mut universal = UniversalProvider::new(
|
||||
"u1".to_string(),
|
||||
"Universal".to_string(),
|
||||
"newapi".to_string(),
|
||||
"https://api.example.com/v1".to_string(),
|
||||
"api-key".to_string(),
|
||||
);
|
||||
universal.apps.codex = true;
|
||||
|
||||
let provider = universal.to_codex_provider().expect("codex provider");
|
||||
let config = provider
|
||||
.settings_config
|
||||
.get("config")
|
||||
.and_then(|item| item.as_str())
|
||||
.expect("config toml");
|
||||
|
||||
assert!(config.contains("base_url = \"https://api.example.com/v1\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn universal_provider_to_codex_provider_disabled_returns_none() {
|
||||
let universal = UniversalProvider::new(
|
||||
"u1".to_string(),
|
||||
"Universal".to_string(),
|
||||
"newapi".to_string(),
|
||||
"https://api.example.com".to_string(),
|
||||
"api-key".to_string(),
|
||||
);
|
||||
|
||||
assert!(universal.to_codex_provider().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn universal_provider_to_gemini_provider_defaults_model() {
|
||||
let mut universal = UniversalProvider::new(
|
||||
"u1".to_string(),
|
||||
"Universal".to_string(),
|
||||
"newapi".to_string(),
|
||||
"https://api.example.com".to_string(),
|
||||
"api-key".to_string(),
|
||||
);
|
||||
universal.apps.gemini = true;
|
||||
|
||||
let provider = universal.to_gemini_provider().expect("gemini provider");
|
||||
|
||||
assert_eq!(
|
||||
provider
|
||||
.settings_config
|
||||
.pointer("/env/GEMINI_MODEL")
|
||||
.and_then(|item| item.as_str()),
|
||||
Some("gemini-2.5-pro")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn universal_provider_to_gemini_provider_uses_model() {
|
||||
let mut universal = UniversalProvider::new(
|
||||
"u1".to_string(),
|
||||
"Universal".to_string(),
|
||||
"newapi".to_string(),
|
||||
"https://api.example.com".to_string(),
|
||||
"api-key".to_string(),
|
||||
);
|
||||
universal.apps.gemini = true;
|
||||
universal.models.gemini = Some(GeminiModelConfig {
|
||||
model: Some("gemini-custom".to_string()),
|
||||
});
|
||||
|
||||
let provider = universal.to_gemini_provider().expect("gemini provider");
|
||||
|
||||
assert_eq!(
|
||||
provider
|
||||
.settings_config
|
||||
.pointer("/env/GEMINI_MODEL")
|
||||
.and_then(|item| item.as_str()),
|
||||
Some("gemini-custom")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn opencode_provider_config_defaults() {
|
||||
let config = OpenCodeProviderConfig::default();
|
||||
assert_eq!(config.npm, "@ai-sdk/openai-compatible");
|
||||
assert!(config.name.is_none());
|
||||
assert!(config.models.is_empty());
|
||||
assert!(config.options.base_url.is_none());
|
||||
assert!(config.options.api_key.is_none());
|
||||
assert!(config.options.headers.is_none());
|
||||
assert!(config.options.extra.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,9 +22,7 @@ use super::{
|
||||
};
|
||||
use crate::app_config::AppType;
|
||||
use axum::{extract::State, http::StatusCode, response::IntoResponse, Json};
|
||||
use rust_decimal::Decimal;
|
||||
use serde_json::{json, Value};
|
||||
use std::str::FromStr;
|
||||
|
||||
// ============================================================================
|
||||
// 健康检查和状态查询(简单端点)
|
||||
@@ -145,6 +143,7 @@ async fn handle_claude_transform(
|
||||
&provider_id,
|
||||
"claude",
|
||||
&model,
|
||||
&model,
|
||||
usage,
|
||||
latency_ms,
|
||||
first_token_ms,
|
||||
@@ -215,6 +214,7 @@ async fn handle_claude_transform(
|
||||
.unwrap_or("unknown");
|
||||
let latency_ms = ctx.latency_ms();
|
||||
|
||||
let request_model = ctx.request_model.clone();
|
||||
tokio::spawn({
|
||||
let state = state.clone();
|
||||
let provider_id = ctx.provider.id.clone();
|
||||
@@ -225,6 +225,7 @@ async fn handle_claude_transform(
|
||||
&provider_id,
|
||||
"claude",
|
||||
&model,
|
||||
&request_model,
|
||||
usage,
|
||||
latency_ms,
|
||||
None,
|
||||
@@ -441,6 +442,7 @@ async fn log_usage(
|
||||
provider_id: &str,
|
||||
app_type: &str,
|
||||
model: &str,
|
||||
request_model: &str,
|
||||
usage: TokenUsage,
|
||||
latency_ms: u64,
|
||||
first_token_ms: Option<u64>,
|
||||
@@ -451,25 +453,12 @@ async fn log_usage(
|
||||
|
||||
let logger = UsageLogger::new(&state.db);
|
||||
|
||||
// 获取 provider 的 cost_multiplier
|
||||
let multiplier = match state.db.get_provider_by_id(provider_id, app_type) {
|
||||
Ok(Some(p)) => {
|
||||
if let Some(meta) = p.meta {
|
||||
if let Some(cm) = meta.cost_multiplier {
|
||||
Decimal::from_str(&cm).unwrap_or_else(|e| {
|
||||
log::warn!(
|
||||
"cost_multiplier 解析失败 (provider_id={provider_id}): {cm} - {e}"
|
||||
);
|
||||
Decimal::from(1)
|
||||
})
|
||||
} else {
|
||||
Decimal::from(1)
|
||||
}
|
||||
} else {
|
||||
Decimal::from(1)
|
||||
}
|
||||
}
|
||||
_ => Decimal::from(1),
|
||||
let (multiplier, pricing_model_source) =
|
||||
logger.resolve_pricing_config(provider_id, app_type).await;
|
||||
let pricing_model = if pricing_model_source == "request" {
|
||||
request_model
|
||||
} else {
|
||||
model
|
||||
};
|
||||
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
@@ -479,6 +468,8 @@ async fn log_usage(
|
||||
provider_id.to_string(),
|
||||
app_type.to_string(),
|
||||
model.to_string(),
|
||||
request_model.to_string(),
|
||||
pricing_model.to_string(),
|
||||
usage,
|
||||
multiplier,
|
||||
latency_ms,
|
||||
|
||||
@@ -13,10 +13,8 @@ use axum::response::{IntoResponse, Response};
|
||||
use bytes::Bytes;
|
||||
use futures::stream::{Stream, StreamExt};
|
||||
use reqwest::header::HeaderMap;
|
||||
use rust_decimal::Decimal;
|
||||
use serde_json::Value;
|
||||
use std::{
|
||||
str::FromStr,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
@@ -128,7 +126,15 @@ pub async fn handle_non_streaming(
|
||||
ctx.request_model.clone()
|
||||
};
|
||||
|
||||
spawn_log_usage(state, ctx, usage, &model, status.as_u16(), false);
|
||||
spawn_log_usage(
|
||||
state,
|
||||
ctx,
|
||||
usage,
|
||||
&model,
|
||||
&ctx.request_model,
|
||||
status.as_u16(),
|
||||
false,
|
||||
);
|
||||
} else {
|
||||
let model = json_value
|
||||
.get("model")
|
||||
@@ -140,6 +146,7 @@ pub async fn handle_non_streaming(
|
||||
ctx,
|
||||
TokenUsage::default(),
|
||||
&model,
|
||||
&ctx.request_model,
|
||||
status.as_u16(),
|
||||
false,
|
||||
);
|
||||
@@ -159,6 +166,7 @@ pub async fn handle_non_streaming(
|
||||
ctx,
|
||||
TokenUsage::default(),
|
||||
&ctx.request_model,
|
||||
&ctx.request_model,
|
||||
status.as_u16(),
|
||||
false,
|
||||
);
|
||||
@@ -293,6 +301,7 @@ fn create_usage_collector(
|
||||
let state = state.clone();
|
||||
let provider_id = provider_id.clone();
|
||||
let session_id = session_id.clone();
|
||||
let request_model = request_model.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
log_usage_internal(
|
||||
@@ -300,6 +309,7 @@ fn create_usage_collector(
|
||||
&provider_id,
|
||||
app_type_str,
|
||||
&model,
|
||||
&request_model,
|
||||
usage,
|
||||
latency_ms,
|
||||
first_token_ms,
|
||||
@@ -315,6 +325,7 @@ fn create_usage_collector(
|
||||
let state = state.clone();
|
||||
let provider_id = provider_id.clone();
|
||||
let session_id = session_id.clone();
|
||||
let request_model = request_model.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
log_usage_internal(
|
||||
@@ -322,6 +333,7 @@ fn create_usage_collector(
|
||||
&provider_id,
|
||||
app_type_str,
|
||||
&model,
|
||||
&request_model,
|
||||
TokenUsage::default(),
|
||||
latency_ms,
|
||||
first_token_ms,
|
||||
@@ -342,6 +354,7 @@ fn spawn_log_usage(
|
||||
ctx: &RequestContext,
|
||||
usage: TokenUsage,
|
||||
model: &str,
|
||||
request_model: &str,
|
||||
status_code: u16,
|
||||
is_streaming: bool,
|
||||
) {
|
||||
@@ -349,6 +362,7 @@ fn spawn_log_usage(
|
||||
let provider_id = ctx.provider.id.clone();
|
||||
let app_type_str = ctx.app_type_str.to_string();
|
||||
let model = model.to_string();
|
||||
let request_model = request_model.to_string();
|
||||
let latency_ms = ctx.latency_ms();
|
||||
let session_id = ctx.session_id.clone();
|
||||
|
||||
@@ -358,6 +372,7 @@ fn spawn_log_usage(
|
||||
&provider_id,
|
||||
&app_type_str,
|
||||
&model,
|
||||
&request_model,
|
||||
usage,
|
||||
latency_ms,
|
||||
None,
|
||||
@@ -376,6 +391,7 @@ async fn log_usage_internal(
|
||||
provider_id: &str,
|
||||
app_type: &str,
|
||||
model: &str,
|
||||
request_model: &str,
|
||||
usage: TokenUsage,
|
||||
latency_ms: u64,
|
||||
first_token_ms: Option<u64>,
|
||||
@@ -386,26 +402,12 @@ async fn log_usage_internal(
|
||||
use super::usage::logger::UsageLogger;
|
||||
|
||||
let logger = UsageLogger::new(&state.db);
|
||||
|
||||
// 获取 provider 的 cost_multiplier
|
||||
let multiplier = match state.db.get_provider_by_id(provider_id, app_type) {
|
||||
Ok(Some(p)) => {
|
||||
if let Some(meta) = p.meta {
|
||||
if let Some(cm) = meta.cost_multiplier {
|
||||
Decimal::from_str(&cm).unwrap_or_else(|e| {
|
||||
log::warn!(
|
||||
"cost_multiplier 解析失败 (provider_id={provider_id}): {cm} - {e}"
|
||||
);
|
||||
Decimal::from(1)
|
||||
})
|
||||
} else {
|
||||
Decimal::from(1)
|
||||
}
|
||||
} else {
|
||||
Decimal::from(1)
|
||||
}
|
||||
}
|
||||
_ => Decimal::from(1),
|
||||
let (multiplier, pricing_model_source) =
|
||||
logger.resolve_pricing_config(provider_id, app_type).await;
|
||||
let pricing_model = if pricing_model_source == "request" {
|
||||
request_model
|
||||
} else {
|
||||
model
|
||||
};
|
||||
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
@@ -424,6 +426,8 @@ async fn log_usage_internal(
|
||||
provider_id.to_string(),
|
||||
app_type.to_string(),
|
||||
model.to_string(),
|
||||
request_model.to_string(),
|
||||
pricing_model.to_string(),
|
||||
usage,
|
||||
multiplier,
|
||||
latency_ms,
|
||||
@@ -556,3 +560,185 @@ fn format_headers(headers: &HeaderMap) -> String {
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::database::Database;
|
||||
use crate::error::AppError;
|
||||
use crate::provider::ProviderMeta;
|
||||
use crate::proxy::failover_switch::FailoverSwitchManager;
|
||||
use crate::proxy::provider_router::ProviderRouter;
|
||||
use crate::proxy::types::{ProxyConfig, ProxyStatus};
|
||||
use rust_decimal::Decimal;
|
||||
use std::collections::HashMap;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
fn build_state(db: Arc<Database>) -> ProxyState {
|
||||
ProxyState {
|
||||
db: db.clone(),
|
||||
config: Arc::new(RwLock::new(ProxyConfig::default())),
|
||||
status: Arc::new(RwLock::new(ProxyStatus::default())),
|
||||
start_time: Arc::new(RwLock::new(None)),
|
||||
current_providers: Arc::new(RwLock::new(HashMap::new())),
|
||||
provider_router: Arc::new(ProviderRouter::new(db.clone())),
|
||||
app_handle: None,
|
||||
failover_manager: Arc::new(FailoverSwitchManager::new(db)),
|
||||
}
|
||||
}
|
||||
|
||||
fn seed_pricing(db: &Database) -> Result<(), AppError> {
|
||||
let conn = crate::database::lock_conn!(db.conn);
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO model_pricing (model_id, display_name, input_cost_per_million, output_cost_per_million)
|
||||
VALUES (?1, ?2, ?3, ?4)",
|
||||
rusqlite::params!["resp-model", "Resp Model", "1.0", "0"],
|
||||
)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO model_pricing (model_id, display_name, input_cost_per_million, output_cost_per_million)
|
||||
VALUES (?1, ?2, ?3, ?4)",
|
||||
rusqlite::params!["req-model", "Req Model", "2.0", "0"],
|
||||
)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn insert_provider(
|
||||
db: &Database,
|
||||
id: &str,
|
||||
app_type: &str,
|
||||
meta: ProviderMeta,
|
||||
) -> Result<(), AppError> {
|
||||
let meta_json =
|
||||
serde_json::to_string(&meta).map_err(|e| AppError::Database(e.to_string()))?;
|
||||
let conn = crate::database::lock_conn!(db.conn);
|
||||
conn.execute(
|
||||
"INSERT INTO providers (id, app_type, name, settings_config, meta)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5)",
|
||||
rusqlite::params![id, app_type, "Test Provider", "{}", meta_json],
|
||||
)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_log_usage_uses_provider_override_config() -> Result<(), AppError> {
|
||||
let db = Arc::new(Database::memory()?);
|
||||
let app_type = "claude";
|
||||
|
||||
db.set_default_cost_multiplier(app_type, "1.5").await?;
|
||||
db.set_pricing_model_source(app_type, "response").await?;
|
||||
seed_pricing(&db)?;
|
||||
|
||||
let mut meta = ProviderMeta::default();
|
||||
meta.cost_multiplier = Some("2".to_string());
|
||||
meta.pricing_model_source = Some("request".to_string());
|
||||
insert_provider(&db, "provider-1", app_type, meta)?;
|
||||
|
||||
let state = build_state(db.clone());
|
||||
let usage = TokenUsage {
|
||||
input_tokens: 1_000_000,
|
||||
output_tokens: 0,
|
||||
cache_read_tokens: 0,
|
||||
cache_creation_tokens: 0,
|
||||
model: None,
|
||||
};
|
||||
|
||||
log_usage_internal(
|
||||
&state,
|
||||
"provider-1",
|
||||
app_type,
|
||||
"resp-model",
|
||||
"req-model",
|
||||
usage,
|
||||
10,
|
||||
None,
|
||||
false,
|
||||
200,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
let conn = crate::database::lock_conn!(db.conn);
|
||||
let (model, request_model, total_cost, cost_multiplier): (String, String, String, String) =
|
||||
conn.query_row(
|
||||
"SELECT model, request_model, total_cost_usd, cost_multiplier
|
||||
FROM proxy_request_logs WHERE provider_id = ?1",
|
||||
["provider-1"],
|
||||
|row| Ok((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?)),
|
||||
)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
|
||||
assert_eq!(model, "resp-model");
|
||||
assert_eq!(request_model, "req-model");
|
||||
assert_eq!(
|
||||
Decimal::from_str(&cost_multiplier).unwrap(),
|
||||
Decimal::from_str("2").unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
Decimal::from_str(&total_cost).unwrap(),
|
||||
Decimal::from_str("4").unwrap()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_log_usage_falls_back_to_global_defaults() -> Result<(), AppError> {
|
||||
let db = Arc::new(Database::memory()?);
|
||||
let app_type = "claude";
|
||||
|
||||
db.set_default_cost_multiplier(app_type, "1.5").await?;
|
||||
db.set_pricing_model_source(app_type, "response").await?;
|
||||
seed_pricing(&db)?;
|
||||
|
||||
let meta = ProviderMeta::default();
|
||||
insert_provider(&db, "provider-2", app_type, meta)?;
|
||||
|
||||
let state = build_state(db.clone());
|
||||
let usage = TokenUsage {
|
||||
input_tokens: 1_000_000,
|
||||
output_tokens: 0,
|
||||
cache_read_tokens: 0,
|
||||
cache_creation_tokens: 0,
|
||||
model: None,
|
||||
};
|
||||
|
||||
log_usage_internal(
|
||||
&state,
|
||||
"provider-2",
|
||||
app_type,
|
||||
"resp-model",
|
||||
"req-model",
|
||||
usage,
|
||||
10,
|
||||
None,
|
||||
false,
|
||||
200,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
let conn = crate::database::lock_conn!(db.conn);
|
||||
let (total_cost, cost_multiplier): (String, String) = conn
|
||||
.query_row(
|
||||
"SELECT total_cost_usd, cost_multiplier
|
||||
FROM proxy_request_logs WHERE provider_id = ?1",
|
||||
["provider-2"],
|
||||
|row| Ok((row.get(0)?, row.get(1)?)),
|
||||
)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
|
||||
assert_eq!(
|
||||
Decimal::from_str(&cost_multiplier).unwrap(),
|
||||
Decimal::from_str("1.5").unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
Decimal::from_str(&total_cost).unwrap(),
|
||||
Decimal::from_str("1.5").unwrap()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,6 +40,7 @@ impl CostCalculator {
|
||||
/// - input_cost: (input_tokens - cache_read_tokens) × 输入价格
|
||||
/// - cache_read_cost: cache_read_tokens × 缓存读取价格
|
||||
/// - 这样避免缓存部分被重复计费
|
||||
/// - total_cost: 各项成本之和 × 倍率(倍率只作用于最终总价)
|
||||
pub fn calculate(
|
||||
usage: &TokenUsage,
|
||||
pricing: &ModelPricing,
|
||||
@@ -50,21 +51,20 @@ impl CostCalculator {
|
||||
// 计算实际需要按输入价格计费的 token 数(减去缓存命中部分)
|
||||
let billable_input_tokens = usage.input_tokens.saturating_sub(usage.cache_read_tokens);
|
||||
|
||||
let input_cost = Decimal::from(billable_input_tokens) * pricing.input_cost_per_million
|
||||
/ million
|
||||
* cost_multiplier;
|
||||
let output_cost = Decimal::from(usage.output_tokens) * pricing.output_cost_per_million
|
||||
/ million
|
||||
* cost_multiplier;
|
||||
// 各项基础成本(不含倍率)
|
||||
let input_cost =
|
||||
Decimal::from(billable_input_tokens) * pricing.input_cost_per_million / million;
|
||||
let output_cost =
|
||||
Decimal::from(usage.output_tokens) * pricing.output_cost_per_million / million;
|
||||
let cache_read_cost =
|
||||
Decimal::from(usage.cache_read_tokens) * pricing.cache_read_cost_per_million / million
|
||||
* cost_multiplier;
|
||||
Decimal::from(usage.cache_read_tokens) * pricing.cache_read_cost_per_million / million;
|
||||
let cache_creation_cost = Decimal::from(usage.cache_creation_tokens)
|
||||
* pricing.cache_creation_cost_per_million
|
||||
/ million
|
||||
* cost_multiplier;
|
||||
/ million;
|
||||
|
||||
let total_cost = input_cost + output_cost + cache_read_cost + cache_creation_cost;
|
||||
// 总成本 = 各项基础成本之和 × 倍率
|
||||
let base_total = input_cost + output_cost + cache_read_cost + cache_creation_cost;
|
||||
let total_cost = base_total * cost_multiplier;
|
||||
|
||||
CostBreakdown {
|
||||
input_cost,
|
||||
@@ -151,8 +151,9 @@ mod tests {
|
||||
|
||||
let cost = CostCalculator::calculate(&usage, &pricing, multiplier);
|
||||
|
||||
// input: 1000 * 3.0 / 1M * 1.5 = 0.0045
|
||||
assert_eq!(cost.input_cost, Decimal::from_str("0.0045").unwrap());
|
||||
// input_cost: 基础价格(不含倍率)= 1000 * 3.0 / 1M = 0.003
|
||||
assert_eq!(cost.input_cost, Decimal::from_str("0.003").unwrap());
|
||||
// total_cost: 基础价格 × 倍率 = 0.003 * 1.5 = 0.0045
|
||||
assert_eq!(cost.total_cost, Decimal::from_str("0.0045").unwrap());
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::database::Database;
|
||||
use crate::error::AppError;
|
||||
use crate::services::usage_stats::find_model_pricing_row;
|
||||
use rust_decimal::Decimal;
|
||||
use std::time::SystemTime;
|
||||
use std::{str::FromStr, time::SystemTime};
|
||||
|
||||
/// 请求日志
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -15,6 +15,7 @@ pub struct RequestLog {
|
||||
pub provider_id: String,
|
||||
pub app_type: String,
|
||||
pub model: String,
|
||||
pub request_model: String,
|
||||
pub usage: TokenUsage,
|
||||
pub cost: Option<CostBreakdown>,
|
||||
pub latency_ms: u64,
|
||||
@@ -73,17 +74,18 @@ impl<'a> UsageLogger<'a> {
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO proxy_request_logs (
|
||||
request_id, provider_id, app_type, model,
|
||||
request_id, provider_id, app_type, model, request_model,
|
||||
input_tokens, output_tokens, cache_read_tokens, cache_creation_tokens,
|
||||
input_cost_usd, output_cost_usd, cache_read_cost_usd, cache_creation_cost_usd, total_cost_usd,
|
||||
latency_ms, first_token_ms, status_code, error_message, session_id,
|
||||
provider_type, is_streaming, cost_multiplier, created_at
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18, ?19, ?20, ?21, ?22)",
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18, ?19, ?20, ?21, ?22, ?23)",
|
||||
rusqlite::params![
|
||||
log.request_id,
|
||||
log.provider_id,
|
||||
log.app_type,
|
||||
log.model,
|
||||
log.request_model,
|
||||
log.usage.input_tokens,
|
||||
log.usage.output_tokens,
|
||||
log.usage.cache_read_tokens,
|
||||
@@ -123,11 +125,13 @@ impl<'a> UsageLogger<'a> {
|
||||
error_message: String,
|
||||
latency_ms: u64,
|
||||
) -> Result<(), AppError> {
|
||||
let request_model = model.clone();
|
||||
let log = RequestLog {
|
||||
request_id,
|
||||
provider_id,
|
||||
app_type,
|
||||
model,
|
||||
request_model,
|
||||
usage: TokenUsage::default(),
|
||||
cost: None,
|
||||
latency_ms,
|
||||
@@ -160,11 +164,13 @@ impl<'a> UsageLogger<'a> {
|
||||
session_id: Option<String>,
|
||||
provider_type: Option<String>,
|
||||
) -> Result<(), AppError> {
|
||||
let request_model = model.clone();
|
||||
let log = RequestLog {
|
||||
request_id,
|
||||
provider_id,
|
||||
app_type,
|
||||
model,
|
||||
request_model,
|
||||
usage: TokenUsage::default(),
|
||||
cost: None,
|
||||
latency_ms,
|
||||
@@ -194,6 +200,88 @@ impl<'a> UsageLogger<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取有效的倍率与计费模式来源(供应商优先,未配置则回退全局默认)
|
||||
pub async fn resolve_pricing_config(
|
||||
&self,
|
||||
provider_id: &str,
|
||||
app_type: &str,
|
||||
) -> (Decimal, String) {
|
||||
let default_multiplier_raw = match self.db.get_default_cost_multiplier(app_type).await {
|
||||
Ok(value) => value,
|
||||
Err(e) => {
|
||||
log::warn!("[USG-003] 获取默认倍率失败 (app_type={app_type}): {e}");
|
||||
"1".to_string()
|
||||
}
|
||||
};
|
||||
let default_multiplier = match Decimal::from_str(&default_multiplier_raw) {
|
||||
Ok(value) => value,
|
||||
Err(e) => {
|
||||
log::warn!(
|
||||
"[USG-003] 默认倍率解析失败 (app_type={app_type}): {default_multiplier_raw} - {e}"
|
||||
);
|
||||
Decimal::from(1)
|
||||
}
|
||||
};
|
||||
|
||||
let default_pricing_source_raw = match self.db.get_pricing_model_source(app_type).await {
|
||||
Ok(value) => value,
|
||||
Err(e) => {
|
||||
log::warn!("[USG-003] 获取默认计费模式失败 (app_type={app_type}): {e}");
|
||||
"response".to_string()
|
||||
}
|
||||
};
|
||||
let default_pricing_source =
|
||||
if matches!(default_pricing_source_raw.as_str(), "response" | "request") {
|
||||
default_pricing_source_raw
|
||||
} else {
|
||||
log::warn!(
|
||||
"[USG-003] 默认计费模式无效 (app_type={app_type}): {default_pricing_source_raw}"
|
||||
);
|
||||
"response".to_string()
|
||||
};
|
||||
|
||||
let provider = self
|
||||
.db
|
||||
.get_provider_by_id(provider_id, app_type)
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let (provider_multiplier, provider_pricing_source) = provider
|
||||
.as_ref()
|
||||
.and_then(|p| p.meta.as_ref())
|
||||
.map(|meta| {
|
||||
(
|
||||
meta.cost_multiplier.as_deref(),
|
||||
meta.pricing_model_source.as_deref(),
|
||||
)
|
||||
})
|
||||
.unwrap_or((None, None));
|
||||
|
||||
let cost_multiplier = match provider_multiplier {
|
||||
Some(value) => match Decimal::from_str(value) {
|
||||
Ok(parsed) => parsed,
|
||||
Err(e) => {
|
||||
log::warn!(
|
||||
"[USG-003] 供应商倍率解析失败 (provider_id={provider_id}): {value} - {e}"
|
||||
);
|
||||
default_multiplier
|
||||
}
|
||||
},
|
||||
None => default_multiplier,
|
||||
};
|
||||
|
||||
let pricing_model_source = match provider_pricing_source {
|
||||
Some(value) if matches!(value, "response" | "request") => value.to_string(),
|
||||
Some(value) => {
|
||||
log::warn!("[USG-003] 供应商计费模式无效 (provider_id={provider_id}): {value}");
|
||||
default_pricing_source.clone()
|
||||
}
|
||||
None => default_pricing_source.clone(),
|
||||
};
|
||||
|
||||
(cost_multiplier, pricing_model_source)
|
||||
}
|
||||
|
||||
/// 计算并记录请求
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn log_with_calculation(
|
||||
@@ -202,6 +290,8 @@ impl<'a> UsageLogger<'a> {
|
||||
provider_id: String,
|
||||
app_type: String,
|
||||
model: String,
|
||||
request_model: String,
|
||||
pricing_model: String,
|
||||
usage: TokenUsage,
|
||||
cost_multiplier: Decimal,
|
||||
latency_ms: u64,
|
||||
@@ -211,10 +301,10 @@ impl<'a> UsageLogger<'a> {
|
||||
provider_type: Option<String>,
|
||||
is_streaming: bool,
|
||||
) -> Result<(), AppError> {
|
||||
let pricing = self.get_model_pricing(&model)?;
|
||||
let pricing = self.get_model_pricing(&pricing_model)?;
|
||||
|
||||
if pricing.is_none() {
|
||||
log::warn!("[USG-002] 模型定价未找到,成本将记录为 0");
|
||||
log::warn!("[USG-002] 模型定价未找到,成本将记录为 0: {pricing_model}");
|
||||
}
|
||||
|
||||
let cost = CostCalculator::try_calculate(&usage, pricing.as_ref(), cost_multiplier);
|
||||
@@ -224,6 +314,7 @@ impl<'a> UsageLogger<'a> {
|
||||
provider_id,
|
||||
app_type,
|
||||
model,
|
||||
request_model,
|
||||
usage,
|
||||
cost,
|
||||
latency_ms,
|
||||
@@ -274,6 +365,8 @@ mod tests {
|
||||
"provider-1".to_string(),
|
||||
"claude".to_string(),
|
||||
"test-model".to_string(),
|
||||
"req-model".to_string(),
|
||||
"test-model".to_string(),
|
||||
usage,
|
||||
Decimal::from(1),
|
||||
100,
|
||||
@@ -286,14 +379,15 @@ mod tests {
|
||||
|
||||
// 验证记录已插入
|
||||
let conn = crate::database::lock_conn!(db.conn);
|
||||
let count: i64 = conn
|
||||
let (count, request_model): (i64, String) = conn
|
||||
.query_row(
|
||||
"SELECT COUNT(*) FROM proxy_request_logs WHERE request_id = 'req-123'",
|
||||
"SELECT COUNT(*), request_model FROM proxy_request_logs WHERE request_id = 'req-123'",
|
||||
[],
|
||||
|row| row.get(0),
|
||||
|row| Ok((row.get(0)?, row.get(1)?)),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(count, 1);
|
||||
assert_eq!(request_model, "req-model");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -94,6 +94,9 @@ pub struct RequestLogDetail {
|
||||
pub provider_name: Option<String>,
|
||||
pub app_type: String,
|
||||
pub model: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub request_model: Option<String>,
|
||||
pub cost_multiplier: String,
|
||||
pub input_tokens: u32,
|
||||
pub output_tokens: u32,
|
||||
pub cache_read_tokens: u32,
|
||||
@@ -140,7 +143,7 @@ impl Database {
|
||||
};
|
||||
|
||||
let sql = format!(
|
||||
"SELECT
|
||||
"SELECT
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(CAST(total_cost_usd AS REAL)), 0) as total_cost,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
@@ -218,7 +221,7 @@ impl Database {
|
||||
}
|
||||
|
||||
let sql = "
|
||||
SELECT
|
||||
SELECT
|
||||
CAST((created_at - ?1) / ?3 AS INTEGER) as bucket_idx,
|
||||
COUNT(*) as request_count,
|
||||
COALESCE(SUM(CAST(total_cost_usd AS REAL)), 0) as total_cost,
|
||||
@@ -295,7 +298,7 @@ impl Database {
|
||||
pub fn get_provider_stats(&self) -> Result<Vec<ProviderStats>, AppError> {
|
||||
let conn = lock_conn!(self.conn);
|
||||
|
||||
let sql = "SELECT
|
||||
let sql = "SELECT
|
||||
l.provider_id,
|
||||
p.name as provider_name,
|
||||
COUNT(*) as request_count,
|
||||
@@ -343,7 +346,7 @@ impl Database {
|
||||
pub fn get_model_stats(&self) -> Result<Vec<ModelStats>, AppError> {
|
||||
let conn = lock_conn!(self.conn);
|
||||
|
||||
let sql = "SELECT
|
||||
let sql = "SELECT
|
||||
model,
|
||||
COUNT(*) as request_count,
|
||||
COALESCE(SUM(input_tokens + output_tokens), 0) as total_tokens,
|
||||
@@ -424,7 +427,7 @@ impl Database {
|
||||
|
||||
// 获取总数
|
||||
let count_sql = format!(
|
||||
"SELECT COUNT(*) FROM proxy_request_logs l
|
||||
"SELECT COUNT(*) FROM proxy_request_logs l
|
||||
LEFT JOIN providers p ON l.provider_id = p.id AND l.app_type = p.app_type
|
||||
{where_clause}"
|
||||
);
|
||||
@@ -440,6 +443,7 @@ impl Database {
|
||||
|
||||
let sql = format!(
|
||||
"SELECT l.request_id, l.provider_id, p.name as provider_name, l.app_type, l.model,
|
||||
l.request_model, l.cost_multiplier,
|
||||
l.input_tokens, l.output_tokens, l.cache_read_tokens, l.cache_creation_tokens,
|
||||
l.input_cost_usd, l.output_cost_usd, l.cache_read_cost_usd, l.cache_creation_cost_usd, l.total_cost_usd,
|
||||
l.is_streaming, l.latency_ms, l.first_token_ms, l.duration_ms,
|
||||
@@ -460,22 +464,26 @@ impl Database {
|
||||
provider_name: row.get(2)?,
|
||||
app_type: row.get(3)?,
|
||||
model: row.get(4)?,
|
||||
input_tokens: row.get::<_, i64>(5)? as u32,
|
||||
output_tokens: row.get::<_, i64>(6)? as u32,
|
||||
cache_read_tokens: row.get::<_, i64>(7)? as u32,
|
||||
cache_creation_tokens: row.get::<_, i64>(8)? as u32,
|
||||
input_cost_usd: row.get(9)?,
|
||||
output_cost_usd: row.get(10)?,
|
||||
cache_read_cost_usd: row.get(11)?,
|
||||
cache_creation_cost_usd: row.get(12)?,
|
||||
total_cost_usd: row.get(13)?,
|
||||
is_streaming: row.get::<_, i64>(14)? != 0,
|
||||
latency_ms: row.get::<_, i64>(15)? as u64,
|
||||
first_token_ms: row.get::<_, Option<i64>>(16)?.map(|v| v as u64),
|
||||
duration_ms: row.get::<_, Option<i64>>(17)?.map(|v| v as u64),
|
||||
status_code: row.get::<_, i64>(18)? as u16,
|
||||
error_message: row.get(19)?,
|
||||
created_at: row.get(20)?,
|
||||
request_model: row.get(5)?,
|
||||
cost_multiplier: row
|
||||
.get::<_, Option<String>>(6)?
|
||||
.unwrap_or_else(|| "1".to_string()),
|
||||
input_tokens: row.get::<_, i64>(7)? as u32,
|
||||
output_tokens: row.get::<_, i64>(8)? as u32,
|
||||
cache_read_tokens: row.get::<_, i64>(9)? as u32,
|
||||
cache_creation_tokens: row.get::<_, i64>(10)? as u32,
|
||||
input_cost_usd: row.get(11)?,
|
||||
output_cost_usd: row.get(12)?,
|
||||
cache_read_cost_usd: row.get(13)?,
|
||||
cache_creation_cost_usd: row.get(14)?,
|
||||
total_cost_usd: row.get(15)?,
|
||||
is_streaming: row.get::<_, i64>(16)? != 0,
|
||||
latency_ms: row.get::<_, i64>(17)? as u64,
|
||||
first_token_ms: row.get::<_, Option<i64>>(18)?.map(|v| v as u64),
|
||||
duration_ms: row.get::<_, Option<i64>>(19)?.map(|v| v as u64),
|
||||
status_code: row.get::<_, i64>(20)? as u16,
|
||||
error_message: row.get(21)?,
|
||||
created_at: row.get(22)?,
|
||||
})
|
||||
})?;
|
||||
|
||||
@@ -511,6 +519,7 @@ impl Database {
|
||||
|
||||
let result = conn.query_row(
|
||||
"SELECT l.request_id, l.provider_id, p.name as provider_name, l.app_type, l.model,
|
||||
l.request_model, l.cost_multiplier,
|
||||
input_tokens, output_tokens, cache_read_tokens, cache_creation_tokens,
|
||||
input_cost_usd, output_cost_usd, cache_read_cost_usd, cache_creation_cost_usd, total_cost_usd,
|
||||
is_streaming, latency_ms, first_token_ms, duration_ms,
|
||||
@@ -526,22 +535,24 @@ impl Database {
|
||||
provider_name: row.get(2)?,
|
||||
app_type: row.get(3)?,
|
||||
model: row.get(4)?,
|
||||
input_tokens: row.get::<_, i64>(5)? as u32,
|
||||
output_tokens: row.get::<_, i64>(6)? as u32,
|
||||
cache_read_tokens: row.get::<_, i64>(7)? as u32,
|
||||
cache_creation_tokens: row.get::<_, i64>(8)? as u32,
|
||||
input_cost_usd: row.get(9)?,
|
||||
output_cost_usd: row.get(10)?,
|
||||
cache_read_cost_usd: row.get(11)?,
|
||||
cache_creation_cost_usd: row.get(12)?,
|
||||
total_cost_usd: row.get(13)?,
|
||||
is_streaming: row.get::<_, i64>(14)? != 0,
|
||||
latency_ms: row.get::<_, i64>(15)? as u64,
|
||||
first_token_ms: row.get::<_, Option<i64>>(16)?.map(|v| v as u64),
|
||||
duration_ms: row.get::<_, Option<i64>>(17)?.map(|v| v as u64),
|
||||
status_code: row.get::<_, i64>(18)? as u16,
|
||||
error_message: row.get(19)?,
|
||||
created_at: row.get(20)?,
|
||||
request_model: row.get(5)?,
|
||||
cost_multiplier: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "1".to_string()),
|
||||
input_tokens: row.get::<_, i64>(7)? as u32,
|
||||
output_tokens: row.get::<_, i64>(8)? as u32,
|
||||
cache_read_tokens: row.get::<_, i64>(9)? as u32,
|
||||
cache_creation_tokens: row.get::<_, i64>(10)? as u32,
|
||||
input_cost_usd: row.get(11)?,
|
||||
output_cost_usd: row.get(12)?,
|
||||
cache_read_cost_usd: row.get(13)?,
|
||||
cache_creation_cost_usd: row.get(14)?,
|
||||
total_cost_usd: row.get(15)?,
|
||||
is_streaming: row.get::<_, i64>(16)? != 0,
|
||||
latency_ms: row.get::<_, i64>(17)? as u64,
|
||||
first_token_ms: row.get::<_, Option<i64>>(18)?.map(|v| v as u64),
|
||||
duration_ms: row.get::<_, Option<i64>>(19)?.map(|v| v as u64),
|
||||
status_code: row.get::<_, i64>(20)? as u16,
|
||||
error_message: row.get(21)?,
|
||||
created_at: row.get(22)?,
|
||||
})
|
||||
},
|
||||
);
|
||||
@@ -691,21 +702,26 @@ impl Database {
|
||||
)?;
|
||||
|
||||
let million = rust_decimal::Decimal::from(1_000_000u64);
|
||||
let input_cost = rust_decimal::Decimal::from(log.input_tokens as u64) * pricing.input
|
||||
/ million
|
||||
* multiplier;
|
||||
let output_cost = rust_decimal::Decimal::from(log.output_tokens as u64) * pricing.output
|
||||
/ million
|
||||
* multiplier;
|
||||
|
||||
// 与 CostCalculator::calculate 保持一致的计算逻辑:
|
||||
// 1. input_cost 需要扣除 cache_read_tokens(避免缓存部分被重复计费)
|
||||
// 2. 各项成本是基础成本(不含倍率)
|
||||
// 3. 倍率只作用于最终总价
|
||||
let billable_input_tokens =
|
||||
(log.input_tokens as u64).saturating_sub(log.cache_read_tokens as u64);
|
||||
let input_cost =
|
||||
rust_decimal::Decimal::from(billable_input_tokens) * pricing.input / million;
|
||||
let output_cost =
|
||||
rust_decimal::Decimal::from(log.output_tokens as u64) * pricing.output / million;
|
||||
let cache_read_cost = rust_decimal::Decimal::from(log.cache_read_tokens as u64)
|
||||
* pricing.cache_read
|
||||
/ million
|
||||
* multiplier;
|
||||
/ million;
|
||||
let cache_creation_cost = rust_decimal::Decimal::from(log.cache_creation_tokens as u64)
|
||||
* pricing.cache_creation
|
||||
/ million
|
||||
* multiplier;
|
||||
let total_cost = input_cost + output_cost + cache_read_cost + cache_creation_cost;
|
||||
/ million;
|
||||
// 总成本 = 基础成本之和 × 倍率
|
||||
let base_total = input_cost + output_cost + cache_read_cost + cache_creation_cost;
|
||||
let total_cost = base_total * multiplier;
|
||||
|
||||
log.input_cost_usd = format!("{input_cost:.6}");
|
||||
log.output_cost_usd = format!("{output_cost:.6}");
|
||||
|
||||
78
src-tauri/tests/proxy_commands.rs
Normal file
78
src-tauri/tests/proxy_commands.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
use cc_switch_lib::{
|
||||
get_default_cost_multiplier_test_hook, get_pricing_model_source_test_hook,
|
||||
set_default_cost_multiplier_test_hook, set_pricing_model_source_test_hook, AppError,
|
||||
};
|
||||
|
||||
#[path = "support.rs"]
|
||||
mod support;
|
||||
use support::{create_test_state, ensure_test_home, reset_test_fs, test_mutex};
|
||||
|
||||
// 测试使用 Mutex 进行串行化,跨 await 持锁是预期行为
|
||||
#[allow(clippy::await_holding_lock)]
|
||||
#[tokio::test]
|
||||
async fn default_cost_multiplier_commands_round_trip() {
|
||||
let _guard = test_mutex().lock().expect("acquire test mutex");
|
||||
reset_test_fs();
|
||||
let _home = ensure_test_home();
|
||||
|
||||
let state = create_test_state().expect("create test state");
|
||||
|
||||
let default = get_default_cost_multiplier_test_hook(&state, "claude")
|
||||
.await
|
||||
.expect("read default multiplier");
|
||||
assert_eq!(default, "1");
|
||||
|
||||
set_default_cost_multiplier_test_hook(&state, "claude", "1.5")
|
||||
.await
|
||||
.expect("set multiplier");
|
||||
let updated = get_default_cost_multiplier_test_hook(&state, "claude")
|
||||
.await
|
||||
.expect("read updated multiplier");
|
||||
assert_eq!(updated, "1.5");
|
||||
|
||||
let err = set_default_cost_multiplier_test_hook(&state, "claude", "not-a-number")
|
||||
.await
|
||||
.expect_err("invalid multiplier should error");
|
||||
// 错误已改为 Localized 类型(支持 i18n)
|
||||
match err {
|
||||
AppError::Localized { key, .. } => {
|
||||
assert_eq!(key, "error.invalidMultiplier");
|
||||
}
|
||||
other => panic!("expected localized error, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
// 测试使用 Mutex 进行串行化,跨 await 持锁是预期行为
|
||||
#[allow(clippy::await_holding_lock)]
|
||||
#[tokio::test]
|
||||
async fn pricing_model_source_commands_round_trip() {
|
||||
let _guard = test_mutex().lock().expect("acquire test mutex");
|
||||
reset_test_fs();
|
||||
let _home = ensure_test_home();
|
||||
|
||||
let state = create_test_state().expect("create test state");
|
||||
|
||||
let default = get_pricing_model_source_test_hook(&state, "claude")
|
||||
.await
|
||||
.expect("read default pricing model source");
|
||||
assert_eq!(default, "response");
|
||||
|
||||
set_pricing_model_source_test_hook(&state, "claude", "request")
|
||||
.await
|
||||
.expect("set pricing model source");
|
||||
let updated = get_pricing_model_source_test_hook(&state, "claude")
|
||||
.await
|
||||
.expect("read updated pricing model source");
|
||||
assert_eq!(updated, "request");
|
||||
|
||||
let err = set_pricing_model_source_test_hook(&state, "claude", "invalid")
|
||||
.await
|
||||
.expect_err("invalid pricing model source should error");
|
||||
// 错误已改为 Localized 类型(支持 i18n)
|
||||
match err {
|
||||
AppError::Localized { key, .. } => {
|
||||
assert_eq!(key, "error.invalidPricingMode");
|
||||
}
|
||||
other => panic!("expected localized error, got {other:?}"),
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
ChevronRight,
|
||||
FlaskConical,
|
||||
Globe,
|
||||
Coins,
|
||||
Eye,
|
||||
EyeOff,
|
||||
X,
|
||||
@@ -13,14 +14,31 @@ import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { cn } from "@/lib/utils";
|
||||
import type { ProviderTestConfig, ProviderProxyConfig } from "@/types";
|
||||
|
||||
export type PricingModelSourceOption = "inherit" | "request" | "response";
|
||||
|
||||
interface ProviderPricingConfig {
|
||||
enabled: boolean;
|
||||
costMultiplier?: string;
|
||||
pricingModelSource: PricingModelSourceOption;
|
||||
}
|
||||
|
||||
interface ProviderAdvancedConfigProps {
|
||||
testConfig: ProviderTestConfig;
|
||||
proxyConfig: ProviderProxyConfig;
|
||||
pricingConfig: ProviderPricingConfig;
|
||||
onTestConfigChange: (config: ProviderTestConfig) => void;
|
||||
onProxyConfigChange: (config: ProviderProxyConfig) => void;
|
||||
onPricingConfigChange: (config: ProviderPricingConfig) => void;
|
||||
}
|
||||
|
||||
/** 从 ProviderProxyConfig 构建完整 URL */
|
||||
@@ -71,14 +89,19 @@ function parseProxyUrl(url: string): Partial<ProviderProxyConfig> {
|
||||
export function ProviderAdvancedConfig({
|
||||
testConfig,
|
||||
proxyConfig,
|
||||
pricingConfig,
|
||||
onTestConfigChange,
|
||||
onProxyConfigChange,
|
||||
onPricingConfigChange,
|
||||
}: ProviderAdvancedConfigProps) {
|
||||
const { t } = useTranslation();
|
||||
const [isTestConfigOpen, setIsTestConfigOpen] = useState(testConfig.enabled);
|
||||
const [isProxyConfigOpen, setIsProxyConfigOpen] = useState(
|
||||
proxyConfig.enabled,
|
||||
);
|
||||
const [isPricingConfigOpen, setIsPricingConfigOpen] = useState(
|
||||
pricingConfig.enabled,
|
||||
);
|
||||
const [showPassword, setShowPassword] = useState(false);
|
||||
|
||||
// 代理 URL 输入状态(仅在初始化时从 proxyConfig 构建)
|
||||
@@ -97,6 +120,11 @@ export function ProviderAdvancedConfig({
|
||||
setIsProxyConfigOpen(proxyConfig.enabled);
|
||||
}, [proxyConfig.enabled]);
|
||||
|
||||
// 同步外部 pricingConfig.enabled 变化到展开状态
|
||||
useEffect(() => {
|
||||
setIsPricingConfigOpen(pricingConfig.enabled);
|
||||
}, [pricingConfig.enabled]);
|
||||
|
||||
// 仅在外部 proxyConfig 变化且非用户输入时同步(如:重置表单、加载数据)
|
||||
useEffect(() => {
|
||||
if (!isUserTyping) {
|
||||
@@ -450,6 +478,143 @@ export function ProviderAdvancedConfig({
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* 计费配置 */}
|
||||
<div className="rounded-lg border border-border/50 bg-muted/20">
|
||||
<button
|
||||
type="button"
|
||||
className="flex w-full items-center justify-between p-4 hover:bg-muted/30 transition-colors"
|
||||
onClick={() => setIsPricingConfigOpen(!isPricingConfigOpen)}
|
||||
>
|
||||
<div className="flex items-center gap-3">
|
||||
<Coins className="h-4 w-4 text-muted-foreground" />
|
||||
<span className="font-medium">
|
||||
{t("providerAdvanced.pricingConfig", {
|
||||
defaultValue: "计费配置",
|
||||
})}
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-3">
|
||||
<div
|
||||
className="flex items-center gap-2"
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<Label
|
||||
htmlFor="pricing-config-enabled"
|
||||
className="text-sm text-muted-foreground"
|
||||
>
|
||||
{t("providerAdvanced.useCustomPricing", {
|
||||
defaultValue: "使用单独配置",
|
||||
})}
|
||||
</Label>
|
||||
<Switch
|
||||
id="pricing-config-enabled"
|
||||
checked={pricingConfig.enabled}
|
||||
onCheckedChange={(checked) => {
|
||||
onPricingConfigChange({ ...pricingConfig, enabled: checked });
|
||||
if (checked) setIsPricingConfigOpen(true);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
{isPricingConfigOpen ? (
|
||||
<ChevronDown className="h-4 w-4 text-muted-foreground" />
|
||||
) : (
|
||||
<ChevronRight className="h-4 w-4 text-muted-foreground" />
|
||||
)}
|
||||
</div>
|
||||
</button>
|
||||
<div
|
||||
className={cn(
|
||||
"overflow-hidden transition-all duration-200",
|
||||
isPricingConfigOpen
|
||||
? "max-h-[500px] opacity-100"
|
||||
: "max-h-0 opacity-0",
|
||||
)}
|
||||
>
|
||||
<div className="border-t border-border/50 p-4 space-y-4">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
{t("providerAdvanced.pricingConfigDesc", {
|
||||
defaultValue:
|
||||
"为此供应商配置单独的计费参数,不启用时使用全局默认配置。",
|
||||
})}
|
||||
</p>
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="cost-multiplier">
|
||||
{t("providerAdvanced.costMultiplier", {
|
||||
defaultValue: "成本倍率",
|
||||
})}
|
||||
</Label>
|
||||
<Input
|
||||
id="cost-multiplier"
|
||||
type="number"
|
||||
step="0.01"
|
||||
inputMode="decimal"
|
||||
value={pricingConfig.costMultiplier || ""}
|
||||
onChange={(e) =>
|
||||
onPricingConfigChange({
|
||||
...pricingConfig,
|
||||
costMultiplier: e.target.value || undefined,
|
||||
})
|
||||
}
|
||||
placeholder={t("providerAdvanced.costMultiplierPlaceholder", {
|
||||
defaultValue: "留空使用全局默认(1)",
|
||||
})}
|
||||
disabled={!pricingConfig.enabled}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
{t("providerAdvanced.costMultiplierHint", {
|
||||
defaultValue: "实际成本 = 基础成本 × 倍率,支持小数如 1.5",
|
||||
})}
|
||||
</p>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="pricing-model-source">
|
||||
{t("providerAdvanced.pricingModelSourceLabel", {
|
||||
defaultValue: "计费模式",
|
||||
})}
|
||||
</Label>
|
||||
<Select
|
||||
value={pricingConfig.pricingModelSource}
|
||||
onValueChange={(value) =>
|
||||
onPricingConfigChange({
|
||||
...pricingConfig,
|
||||
pricingModelSource: value as PricingModelSourceOption,
|
||||
})
|
||||
}
|
||||
disabled={!pricingConfig.enabled}
|
||||
>
|
||||
<SelectTrigger id="pricing-model-source">
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="inherit">
|
||||
{t("providerAdvanced.pricingModelSourceInherit", {
|
||||
defaultValue: "继承全局默认",
|
||||
})}
|
||||
</SelectItem>
|
||||
<SelectItem value="request">
|
||||
{t("providerAdvanced.pricingModelSourceRequest", {
|
||||
defaultValue: "请求模型",
|
||||
})}
|
||||
</SelectItem>
|
||||
<SelectItem value="response">
|
||||
{t("providerAdvanced.pricingModelSourceResponse", {
|
||||
defaultValue: "返回模型",
|
||||
})}
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
{t("providerAdvanced.pricingModelSourceHint", {
|
||||
defaultValue: "选择按请求模型还是返回模型进行定价匹配",
|
||||
})}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -46,7 +46,10 @@ import { BasicFormFields } from "./BasicFormFields";
|
||||
import { ClaudeFormFields } from "./ClaudeFormFields";
|
||||
import { CodexFormFields } from "./CodexFormFields";
|
||||
import { GeminiFormFields } from "./GeminiFormFields";
|
||||
import { ProviderAdvancedConfig } from "./ProviderAdvancedConfig";
|
||||
import {
|
||||
ProviderAdvancedConfig,
|
||||
type PricingModelSourceOption,
|
||||
} from "./ProviderAdvancedConfig";
|
||||
import {
|
||||
useProviderCategory,
|
||||
useApiKeyState,
|
||||
@@ -121,6 +124,9 @@ interface ProviderFormProps {
|
||||
showButtons?: boolean;
|
||||
}
|
||||
|
||||
const normalizePricingSource = (value?: string): PricingModelSourceOption =>
|
||||
value === "request" || value === "response" ? value : "inherit";
|
||||
|
||||
export function ProviderForm({
|
||||
appId,
|
||||
providerId,
|
||||
@@ -168,6 +174,19 @@ export function ProviderForm({
|
||||
const [proxyConfig, setProxyConfig] = useState<ProviderProxyConfig>(
|
||||
() => initialData?.meta?.proxyConfig ?? { enabled: false },
|
||||
);
|
||||
const [pricingConfig, setPricingConfig] = useState<{
|
||||
enabled: boolean;
|
||||
costMultiplier?: string;
|
||||
pricingModelSource: PricingModelSourceOption;
|
||||
}>(() => ({
|
||||
enabled:
|
||||
initialData?.meta?.costMultiplier !== undefined ||
|
||||
initialData?.meta?.pricingModelSource !== undefined,
|
||||
costMultiplier: initialData?.meta?.costMultiplier,
|
||||
pricingModelSource: normalizePricingSource(
|
||||
initialData?.meta?.pricingModelSource,
|
||||
),
|
||||
}));
|
||||
|
||||
// 使用 category hook
|
||||
const { category } = useProviderCategory({
|
||||
@@ -188,6 +207,15 @@ export function ProviderForm({
|
||||
setEndpointAutoSelect(initialData?.meta?.endpointAutoSelect ?? true);
|
||||
setTestConfig(initialData?.meta?.testConfig ?? { enabled: false });
|
||||
setProxyConfig(initialData?.meta?.proxyConfig ?? { enabled: false });
|
||||
setPricingConfig({
|
||||
enabled:
|
||||
initialData?.meta?.costMultiplier !== undefined ||
|
||||
initialData?.meta?.pricingModelSource !== undefined,
|
||||
costMultiplier: initialData?.meta?.costMultiplier,
|
||||
pricingModelSource: normalizePricingSource(
|
||||
initialData?.meta?.pricingModelSource,
|
||||
),
|
||||
});
|
||||
}, [appId, initialData]);
|
||||
|
||||
const defaultValues: ProviderFormData = useMemo(
|
||||
@@ -940,6 +968,13 @@ export function ProviderForm({
|
||||
// 添加高级配置
|
||||
testConfig: testConfig.enabled ? testConfig : undefined,
|
||||
proxyConfig: proxyConfig.enabled ? proxyConfig : undefined,
|
||||
costMultiplier: pricingConfig.enabled
|
||||
? pricingConfig.costMultiplier
|
||||
: undefined,
|
||||
pricingModelSource:
|
||||
pricingConfig.enabled && pricingConfig.pricingModelSource !== "inherit"
|
||||
? pricingConfig.pricingModelSource
|
||||
: undefined,
|
||||
};
|
||||
|
||||
onSubmit(payload);
|
||||
@@ -1464,8 +1499,10 @@ export function ProviderForm({
|
||||
<ProviderAdvancedConfig
|
||||
testConfig={testConfig}
|
||||
proxyConfig={proxyConfig}
|
||||
pricingConfig={pricingConfig}
|
||||
onTestConfigChange={setTestConfig}
|
||||
onProxyConfigChange={setProxyConfig}
|
||||
onPricingConfigChange={setPricingConfig}
|
||||
/>
|
||||
|
||||
{showButtons && (
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { useState } from "react";
|
||||
import { useState, useEffect } from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
@@ -19,10 +18,31 @@ import {
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { useModelPricing, useDeleteModelPricing } from "@/lib/query/usage";
|
||||
import { PricingEditModal } from "./PricingEditModal";
|
||||
import type { ModelPricing } from "@/types/usage";
|
||||
import { Plus, Pencil, Trash2, ChevronDown, ChevronRight } from "lucide-react";
|
||||
import { Plus, Pencil, Trash2, Loader2 } from "lucide-react";
|
||||
import { toast } from "sonner";
|
||||
import { proxyApi } from "@/lib/api/proxy";
|
||||
|
||||
const PRICING_APPS = ["claude", "codex", "gemini"] as const;
|
||||
type PricingApp = (typeof PRICING_APPS)[number];
|
||||
type PricingModelSource = "request" | "response";
|
||||
|
||||
interface AppConfig {
|
||||
multiplier: string;
|
||||
source: PricingModelSource;
|
||||
}
|
||||
|
||||
type AppConfigState = Record<PricingApp, AppConfig>;
|
||||
|
||||
export function PricingConfigPanel() {
|
||||
const { t } = useTranslation();
|
||||
@@ -31,13 +51,137 @@ export function PricingConfigPanel() {
|
||||
const [editingModel, setEditingModel] = useState<ModelPricing | null>(null);
|
||||
const [isAddingNew, setIsAddingNew] = useState(false);
|
||||
const [deleteConfirm, setDeleteConfirm] = useState<string | null>(null);
|
||||
const [isExpanded, setIsExpanded] = useState(false);
|
||||
|
||||
// 三个应用的配置状态
|
||||
const [appConfigs, setAppConfigs] = useState<AppConfigState>({
|
||||
claude: { multiplier: "1", source: "response" },
|
||||
codex: { multiplier: "1", source: "response" },
|
||||
gemini: { multiplier: "1", source: "response" },
|
||||
});
|
||||
const [originalConfigs, setOriginalConfigs] = useState<AppConfigState | null>(
|
||||
null,
|
||||
);
|
||||
const [isConfigLoading, setIsConfigLoading] = useState(true);
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
|
||||
// 检查是否有改动
|
||||
const isDirty =
|
||||
originalConfigs !== null &&
|
||||
PRICING_APPS.some(
|
||||
(app) =>
|
||||
appConfigs[app].multiplier !== originalConfigs[app].multiplier ||
|
||||
appConfigs[app].source !== originalConfigs[app].source,
|
||||
);
|
||||
|
||||
// 加载所有应用的配置
|
||||
useEffect(() => {
|
||||
let isMounted = true;
|
||||
|
||||
const loadAllConfigs = async () => {
|
||||
setIsConfigLoading(true);
|
||||
try {
|
||||
const results = await Promise.all(
|
||||
PRICING_APPS.map(async (app) => {
|
||||
const [multiplier, source] = await Promise.all([
|
||||
proxyApi.getDefaultCostMultiplier(app),
|
||||
proxyApi.getPricingModelSource(app),
|
||||
]);
|
||||
return {
|
||||
app,
|
||||
multiplier,
|
||||
source: (source === "request"
|
||||
? "request"
|
||||
: "response") as PricingModelSource,
|
||||
};
|
||||
}),
|
||||
);
|
||||
|
||||
if (!isMounted) return;
|
||||
|
||||
const newState: AppConfigState = {
|
||||
claude: { multiplier: "1", source: "response" },
|
||||
codex: { multiplier: "1", source: "response" },
|
||||
gemini: { multiplier: "1", source: "response" },
|
||||
};
|
||||
for (const result of results) {
|
||||
newState[result.app] = {
|
||||
multiplier: result.multiplier,
|
||||
source: result.source,
|
||||
};
|
||||
}
|
||||
setAppConfigs(newState);
|
||||
setOriginalConfigs(newState);
|
||||
} catch (error) {
|
||||
const message =
|
||||
error instanceof Error
|
||||
? error.message
|
||||
: typeof error === "string"
|
||||
? error
|
||||
: "Unknown error";
|
||||
toast.error(
|
||||
t("settings.globalProxy.pricingLoadFailed", { error: message }),
|
||||
);
|
||||
} finally {
|
||||
if (isMounted) setIsConfigLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
loadAllConfigs();
|
||||
return () => {
|
||||
isMounted = false;
|
||||
};
|
||||
}, [t]);
|
||||
|
||||
// 保存所有配置
|
||||
const handleSaveAll = async () => {
|
||||
// 验证所有倍率
|
||||
for (const app of PRICING_APPS) {
|
||||
const trimmed = appConfigs[app].multiplier.trim();
|
||||
if (!trimmed) {
|
||||
toast.error(
|
||||
`${t(`apps.${app}`)}: ${t("settings.globalProxy.defaultCostMultiplierRequired")}`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
if (!/^-?\d+(?:\.\d+)?$/.test(trimmed)) {
|
||||
toast.error(
|
||||
`${t(`apps.${app}`)}: ${t("settings.globalProxy.defaultCostMultiplierInvalid")}`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
setIsSaving(true);
|
||||
try {
|
||||
await Promise.all(
|
||||
PRICING_APPS.flatMap((app) => [
|
||||
proxyApi.setDefaultCostMultiplier(
|
||||
app,
|
||||
appConfigs[app].multiplier.trim(),
|
||||
),
|
||||
proxyApi.setPricingModelSource(app, appConfigs[app].source),
|
||||
]),
|
||||
);
|
||||
toast.success(t("settings.globalProxy.pricingSaved"));
|
||||
setOriginalConfigs({ ...appConfigs });
|
||||
} catch (error) {
|
||||
const message =
|
||||
error instanceof Error
|
||||
? error.message
|
||||
: typeof error === "string"
|
||||
? error
|
||||
: "Unknown error";
|
||||
toast.error(
|
||||
t("settings.globalProxy.pricingSaveFailed", { error: message }),
|
||||
);
|
||||
} finally {
|
||||
setIsSaving(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDelete = (modelId: string) => {
|
||||
deleteMutation.mutate(modelId, {
|
||||
onSuccess: () => {
|
||||
setDeleteConfirm(null);
|
||||
},
|
||||
onSuccess: () => setDeleteConfirm(null),
|
||||
});
|
||||
};
|
||||
|
||||
@@ -55,151 +199,242 @@ export function PricingConfigPanel() {
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<Card className="border rounded-lg">
|
||||
<CardHeader
|
||||
className="cursor-pointer"
|
||||
onClick={() => setIsExpanded(!isExpanded)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<ChevronRight className="h-4 w-4" />
|
||||
<CardTitle className="text-base">
|
||||
{t("usage.modelPricing")}
|
||||
</CardTitle>
|
||||
</div>
|
||||
</CardHeader>
|
||||
</Card>
|
||||
<div className="flex items-center justify-center p-4">
|
||||
<Loader2 className="h-5 w-5 animate-spin text-muted-foreground" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<Card className="border rounded-lg">
|
||||
<CardHeader
|
||||
className="cursor-pointer"
|
||||
onClick={() => setIsExpanded(!isExpanded)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
{isExpanded ? (
|
||||
<ChevronDown className="h-4 w-4" />
|
||||
) : (
|
||||
<ChevronRight className="h-4 w-4" />
|
||||
)}
|
||||
<CardTitle className="text-base">
|
||||
{t("usage.modelPricing")}
|
||||
</CardTitle>
|
||||
</div>
|
||||
</CardHeader>
|
||||
{isExpanded && (
|
||||
<CardContent>
|
||||
<Alert variant="destructive">
|
||||
<AlertDescription>
|
||||
{t("usage.loadPricingError")}: {String(error)}
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
</CardContent>
|
||||
)}
|
||||
</Card>
|
||||
<Alert variant="destructive">
|
||||
<AlertDescription>
|
||||
{t("usage.loadPricingError")}: {String(error)}
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<div className="flex items-center justify-between mb-4">
|
||||
<h4 className="text-sm font-medium text-muted-foreground">
|
||||
{t("usage.modelPricingDesc")} {t("usage.perMillion")}
|
||||
</h4>
|
||||
<Button
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleAddNew();
|
||||
}}
|
||||
size="sm"
|
||||
>
|
||||
<Plus className="mr-1 h-4 w-4" />
|
||||
{t("common.add")}
|
||||
</Button>
|
||||
</div>
|
||||
<div className="space-y-6">
|
||||
{/* 全局计费默认配置 - 紧凑表格布局 */}
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center justify-between">
|
||||
<div>
|
||||
<h4 className="text-sm font-medium">
|
||||
{t("settings.globalProxy.pricingDefaultsTitle")}
|
||||
</h4>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
{t("settings.globalProxy.pricingDefaultsDescription")}
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
onClick={handleSaveAll}
|
||||
disabled={isConfigLoading || isSaving || !isDirty}
|
||||
size="sm"
|
||||
>
|
||||
{isSaving ? (
|
||||
<>
|
||||
<Loader2 className="mr-1.5 h-3.5 w-3.5 animate-spin" />
|
||||
{t("common.saving")}
|
||||
</>
|
||||
) : (
|
||||
t("common.save")
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className="space-y-4">
|
||||
{!pricing || pricing.length === 0 ? (
|
||||
<Alert>
|
||||
<AlertDescription>{t("usage.noPricingData")}</AlertDescription>
|
||||
</Alert>
|
||||
{isConfigLoading ? (
|
||||
<div className="flex items-center justify-center py-4">
|
||||
<Loader2 className="h-4 w-4 animate-spin text-muted-foreground" />
|
||||
</div>
|
||||
) : (
|
||||
<div className="rounded-md bg-card/60 shadow-sm">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>{t("usage.model")}</TableHead>
|
||||
<TableHead>{t("usage.displayName")}</TableHead>
|
||||
<TableHead className="text-right">
|
||||
{t("usage.inputCost")}
|
||||
</TableHead>
|
||||
<TableHead className="text-right">
|
||||
{t("usage.outputCost")}
|
||||
</TableHead>
|
||||
<TableHead className="text-right">
|
||||
{t("usage.cacheReadCost")}
|
||||
</TableHead>
|
||||
<TableHead className="text-right">
|
||||
{t("usage.cacheWriteCost")}
|
||||
</TableHead>
|
||||
<TableHead className="text-right">
|
||||
{t("common.actions")}
|
||||
</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{pricing.map((model) => (
|
||||
<TableRow key={model.modelId}>
|
||||
<TableCell className="font-mono text-sm">
|
||||
{model.modelId}
|
||||
</TableCell>
|
||||
<TableCell>{model.displayName}</TableCell>
|
||||
<TableCell className="text-right font-mono text-sm">
|
||||
${model.inputCostPerMillion}
|
||||
</TableCell>
|
||||
<TableCell className="text-right font-mono text-sm">
|
||||
${model.outputCostPerMillion}
|
||||
</TableCell>
|
||||
<TableCell className="text-right font-mono text-sm">
|
||||
${model.cacheReadCostPerMillion}
|
||||
</TableCell>
|
||||
<TableCell className="text-right font-mono text-sm">
|
||||
${model.cacheCreationCostPerMillion}
|
||||
</TableCell>
|
||||
<TableCell className="text-right">
|
||||
<div className="flex justify-end gap-1">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={() => {
|
||||
setIsAddingNew(false);
|
||||
setEditingModel(model);
|
||||
}}
|
||||
title={t("common.edit")}
|
||||
>
|
||||
<Pencil className="h-4 w-4" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={() => setDeleteConfirm(model.modelId)}
|
||||
title={t("common.delete")}
|
||||
className="text-destructive hover:text-destructive"
|
||||
>
|
||||
<Trash2 className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
<div className="rounded-md border border-border/50 overflow-hidden">
|
||||
<table className="w-full text-sm">
|
||||
<thead>
|
||||
<tr className="border-b border-border/50 bg-muted/30">
|
||||
<th className="px-3 py-2 text-left font-medium text-muted-foreground w-24">
|
||||
{t("settings.globalProxy.pricingAppLabel")}
|
||||
</th>
|
||||
<th className="px-3 py-2 text-left font-medium text-muted-foreground">
|
||||
{t("settings.globalProxy.defaultCostMultiplierLabel")}
|
||||
</th>
|
||||
<th className="px-3 py-2 text-left font-medium text-muted-foreground">
|
||||
{t("settings.globalProxy.pricingModelSourceLabel")}
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{PRICING_APPS.map((app, idx) => (
|
||||
<tr
|
||||
key={app}
|
||||
className={
|
||||
idx < PRICING_APPS.length - 1
|
||||
? "border-b border-border/30"
|
||||
: ""
|
||||
}
|
||||
>
|
||||
<td className="px-3 py-1.5 font-medium">
|
||||
{t(`apps.${app}`)}
|
||||
</td>
|
||||
<td className="px-3 py-1.5">
|
||||
<Input
|
||||
type="number"
|
||||
step="0.01"
|
||||
inputMode="decimal"
|
||||
value={appConfigs[app].multiplier}
|
||||
onChange={(e) =>
|
||||
setAppConfigs((prev) => ({
|
||||
...prev,
|
||||
[app]: { ...prev[app], multiplier: e.target.value },
|
||||
}))
|
||||
}
|
||||
disabled={isSaving}
|
||||
placeholder="1"
|
||||
className="h-7 w-24"
|
||||
/>
|
||||
</td>
|
||||
<td className="px-3 py-1.5">
|
||||
<Select
|
||||
value={appConfigs[app].source}
|
||||
onValueChange={(value) =>
|
||||
setAppConfigs((prev) => ({
|
||||
...prev,
|
||||
[app]: {
|
||||
...prev[app],
|
||||
source: value as PricingModelSource,
|
||||
},
|
||||
}))
|
||||
}
|
||||
disabled={isSaving}
|
||||
>
|
||||
<SelectTrigger className="h-7 w-28">
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="response">
|
||||
{t(
|
||||
"settings.globalProxy.pricingModelSourceResponse",
|
||||
)}
|
||||
</SelectItem>
|
||||
<SelectItem value="request">
|
||||
{t(
|
||||
"settings.globalProxy.pricingModelSourceRequest",
|
||||
)}
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* 分隔线 */}
|
||||
<div className="border-t border-border/50" />
|
||||
|
||||
{/* 模型定价配置 */}
|
||||
<div className="space-y-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<h4 className="text-sm font-medium text-muted-foreground">
|
||||
{t("usage.modelPricingDesc")} {t("usage.perMillion")}
|
||||
</h4>
|
||||
<Button
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleAddNew();
|
||||
}}
|
||||
size="sm"
|
||||
>
|
||||
<Plus className="mr-1 h-4 w-4" />
|
||||
{t("common.add")}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className="space-y-4">
|
||||
{!pricing || pricing.length === 0 ? (
|
||||
<Alert>
|
||||
<AlertDescription>{t("usage.noPricingData")}</AlertDescription>
|
||||
</Alert>
|
||||
) : (
|
||||
<div className="rounded-md bg-card/60 shadow-sm">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>{t("usage.model")}</TableHead>
|
||||
<TableHead>{t("usage.displayName")}</TableHead>
|
||||
<TableHead className="text-right">
|
||||
{t("usage.inputCost")}
|
||||
</TableHead>
|
||||
<TableHead className="text-right">
|
||||
{t("usage.outputCost")}
|
||||
</TableHead>
|
||||
<TableHead className="text-right">
|
||||
{t("usage.cacheReadCost")}
|
||||
</TableHead>
|
||||
<TableHead className="text-right">
|
||||
{t("usage.cacheWriteCost")}
|
||||
</TableHead>
|
||||
<TableHead className="text-right">
|
||||
{t("common.actions")}
|
||||
</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{pricing.map((model) => (
|
||||
<TableRow key={model.modelId}>
|
||||
<TableCell className="font-mono text-sm">
|
||||
{model.modelId}
|
||||
</TableCell>
|
||||
<TableCell>{model.displayName}</TableCell>
|
||||
<TableCell className="text-right font-mono text-sm">
|
||||
${model.inputCostPerMillion}
|
||||
</TableCell>
|
||||
<TableCell className="text-right font-mono text-sm">
|
||||
${model.outputCostPerMillion}
|
||||
</TableCell>
|
||||
<TableCell className="text-right font-mono text-sm">
|
||||
${model.cacheReadCostPerMillion}
|
||||
</TableCell>
|
||||
<TableCell className="text-right font-mono text-sm">
|
||||
${model.cacheCreationCostPerMillion}
|
||||
</TableCell>
|
||||
<TableCell className="text-right">
|
||||
<div className="flex justify-end gap-1">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={() => {
|
||||
setIsAddingNew(false);
|
||||
setEditingModel(model);
|
||||
}}
|
||||
title={t("common.edit")}
|
||||
>
|
||||
<Pencil className="h-4 w-4" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={() => setDeleteConfirm(model.modelId)}
|
||||
title={t("common.delete")}
|
||||
className="text-destructive hover:text-destructive"
|
||||
>
|
||||
<Trash2 className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{editingModel && (
|
||||
<PricingEditModal
|
||||
open={!!editingModel}
|
||||
|
||||
@@ -184,6 +184,9 @@ export function RequestDetailPanel({
|
||||
<div>
|
||||
<dt className="text-muted-foreground">
|
||||
{t("usage.inputCost", "输入成本")}
|
||||
<span className="ml-1 text-xs">
|
||||
({t("usage.baseCost", "基础")})
|
||||
</span>
|
||||
</dt>
|
||||
<dd className="font-mono">
|
||||
${parseFloat(request.inputCostUsd).toFixed(6)}
|
||||
@@ -192,6 +195,9 @@ export function RequestDetailPanel({
|
||||
<div>
|
||||
<dt className="text-muted-foreground">
|
||||
{t("usage.outputCost", "输出成本")}
|
||||
<span className="ml-1 text-xs">
|
||||
({t("usage.baseCost", "基础")})
|
||||
</span>
|
||||
</dt>
|
||||
<dd className="font-mono">
|
||||
${parseFloat(request.outputCostUsd).toFixed(6)}
|
||||
@@ -200,6 +206,9 @@ export function RequestDetailPanel({
|
||||
<div>
|
||||
<dt className="text-muted-foreground">
|
||||
{t("usage.cacheReadCost", "缓存读取成本")}
|
||||
<span className="ml-1 text-xs">
|
||||
({t("usage.baseCost", "基础")})
|
||||
</span>
|
||||
</dt>
|
||||
<dd className="font-mono">
|
||||
${parseFloat(request.cacheReadCostUsd).toFixed(6)}
|
||||
@@ -208,14 +217,35 @@ export function RequestDetailPanel({
|
||||
<div>
|
||||
<dt className="text-muted-foreground">
|
||||
{t("usage.cacheCreationCost", "缓存写入成本")}
|
||||
<span className="ml-1 text-xs">
|
||||
({t("usage.baseCost", "基础")})
|
||||
</span>
|
||||
</dt>
|
||||
<dd className="font-mono">
|
||||
${parseFloat(request.cacheCreationCostUsd).toFixed(6)}
|
||||
</dd>
|
||||
</div>
|
||||
<div className="col-span-2 border-t pt-3">
|
||||
{/* 显示成本倍率(如果不等于1) */}
|
||||
{request.costMultiplier &&
|
||||
parseFloat(request.costMultiplier) !== 1 && (
|
||||
<div className="col-span-2 border-t pt-3">
|
||||
<dt className="text-muted-foreground">
|
||||
{t("usage.costMultiplier", "成本倍率")}
|
||||
</dt>
|
||||
<dd className="font-mono">×{request.costMultiplier}</dd>
|
||||
</div>
|
||||
)}
|
||||
<div
|
||||
className={`col-span-2 ${request.costMultiplier && parseFloat(request.costMultiplier) !== 1 ? "" : "border-t"} pt-3`}
|
||||
>
|
||||
<dt className="text-muted-foreground">
|
||||
{t("usage.totalCost", "总成本")}
|
||||
{request.costMultiplier &&
|
||||
parseFloat(request.costMultiplier) !== 1 && (
|
||||
<span className="ml-1 text-xs">
|
||||
({t("usage.withMultiplier", "含倍率")})
|
||||
</span>
|
||||
)}
|
||||
</dt>
|
||||
<dd className="text-lg font-semibold text-primary">
|
||||
${parseFloat(request.totalCostUsd).toFixed(6)}
|
||||
|
||||
@@ -250,7 +250,7 @@ export function RequestLogTable() {
|
||||
<TableHead className="whitespace-nowrap">
|
||||
{t("usage.provider")}
|
||||
</TableHead>
|
||||
<TableHead className="min-w-[280px] whitespace-nowrap">
|
||||
<TableHead className="min-w-[200px] whitespace-nowrap">
|
||||
{t("usage.billingModel")}
|
||||
</TableHead>
|
||||
<TableHead className="text-right whitespace-nowrap">
|
||||
@@ -265,6 +265,9 @@ export function RequestLogTable() {
|
||||
<TableHead className="text-right min-w-[90px] whitespace-nowrap">
|
||||
{t("usage.cacheCreationTokens")}
|
||||
</TableHead>
|
||||
<TableHead className="text-right whitespace-nowrap">
|
||||
{t("usage.multiplier")}
|
||||
</TableHead>
|
||||
<TableHead className="text-right whitespace-nowrap">
|
||||
{t("usage.totalCost")}
|
||||
</TableHead>
|
||||
@@ -280,7 +283,7 @@ export function RequestLogTable() {
|
||||
{logs.length === 0 ? (
|
||||
<TableRow>
|
||||
<TableCell
|
||||
colSpan={10}
|
||||
colSpan={11}
|
||||
className="text-center text-muted-foreground"
|
||||
>
|
||||
{t("usage.noData")}
|
||||
@@ -297,11 +300,25 @@ export function RequestLogTable() {
|
||||
<TableCell>
|
||||
{log.providerName || t("usage.unknownProvider")}
|
||||
</TableCell>
|
||||
<TableCell
|
||||
className="font-mono text-sm max-w-[280px] truncate"
|
||||
title={log.model}
|
||||
>
|
||||
{log.model}
|
||||
<TableCell className="font-mono text-xs max-w-[200px]">
|
||||
<div
|
||||
className="truncate"
|
||||
title={
|
||||
log.requestModel && log.requestModel !== log.model
|
||||
? `${t("usage.requestModel")}: ${log.requestModel}\n${t("usage.responseModel")}: ${log.model}`
|
||||
: log.model
|
||||
}
|
||||
>
|
||||
{log.model}
|
||||
</div>
|
||||
{log.requestModel && log.requestModel !== log.model && (
|
||||
<div
|
||||
className="truncate text-muted-foreground text-[10px]"
|
||||
title={log.requestModel}
|
||||
>
|
||||
← {log.requestModel}
|
||||
</div>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell className="text-right">
|
||||
{log.inputTokens.toLocaleString()}
|
||||
@@ -315,6 +332,15 @@ export function RequestLogTable() {
|
||||
<TableCell className="text-right">
|
||||
{log.cacheCreationTokens.toLocaleString()}
|
||||
</TableCell>
|
||||
<TableCell className="text-right font-mono text-xs">
|
||||
{parseFloat(log.costMultiplier) !== 1 ? (
|
||||
<span className="text-orange-600">
|
||||
×{log.costMultiplier}
|
||||
</span>
|
||||
) : (
|
||||
<span className="text-muted-foreground">×1</span>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell className="text-right">
|
||||
${parseFloat(log.totalCostUsd).toFixed(6)}
|
||||
</TableCell>
|
||||
|
||||
@@ -376,7 +376,21 @@
|
||||
"saved": "Proxy settings saved",
|
||||
"saveFailed": "Save failed: {{error}}",
|
||||
"testSuccess": "Connected! Latency {{latency}}ms",
|
||||
"testFailed": "Connection failed: {{error}}"
|
||||
"testFailed": "Connection failed: {{error}}",
|
||||
"pricingDefaultsTitle": "Pricing Defaults",
|
||||
"pricingDefaultsDescription": "Set the default multiplier and pricing model source per app.",
|
||||
"pricingAppLabel": "App",
|
||||
"defaultCostMultiplierLabel": "Default Multiplier",
|
||||
"defaultCostMultiplierHint": "Multiplier for cost calculation, decimals supported.",
|
||||
"pricingModelSourceLabel": "Pricing Model Source",
|
||||
"pricingModelSourceRequest": "Request model",
|
||||
"pricingModelSourceResponse": "Response model",
|
||||
"pricingSave": "Save Pricing Defaults",
|
||||
"pricingSaved": "Pricing defaults saved",
|
||||
"pricingSaveFailed": "Failed to save pricing defaults: {{error}}",
|
||||
"pricingLoadFailed": "Failed to load pricing defaults: {{error}}",
|
||||
"defaultCostMultiplierRequired": "Default multiplier is required",
|
||||
"defaultCostMultiplierInvalid": "Invalid multiplier format"
|
||||
}
|
||||
},
|
||||
"apps": {
|
||||
@@ -512,7 +526,18 @@
|
||||
"useCustomProxy": "Use separate proxy",
|
||||
"proxyConfigDesc": "Configure separate network proxy for this provider. Uses system proxy or global settings when disabled.",
|
||||
"proxyUsername": "Username (optional)",
|
||||
"proxyPassword": "Password (optional)"
|
||||
"proxyPassword": "Password (optional)",
|
||||
"pricingConfig": "Pricing Config",
|
||||
"useCustomPricing": "Use separate config",
|
||||
"pricingConfigDesc": "Configure separate pricing parameters for this provider. Uses global defaults when disabled.",
|
||||
"costMultiplier": "Cost Multiplier",
|
||||
"costMultiplierPlaceholder": "Leave empty to use global default (1)",
|
||||
"costMultiplierHint": "Actual cost = Base cost × Multiplier, supports decimals like 1.5",
|
||||
"pricingModelSourceLabel": "Pricing Mode",
|
||||
"pricingModelSourceInherit": "Inherit global default",
|
||||
"pricingModelSourceRequest": "Request model",
|
||||
"pricingModelSourceResponse": "Response model",
|
||||
"pricingModelSourceHint": "Choose whether to match pricing by request model or response model"
|
||||
},
|
||||
"codexConfig": {
|
||||
"authJson": "auth.json (JSON) *",
|
||||
@@ -615,6 +640,9 @@
|
||||
"cacheCreationTokens": "Cache Creation",
|
||||
"timingInfo": "Duration/TTFT",
|
||||
"status": "Status",
|
||||
"multiplier": "Multiplier",
|
||||
"requestModel": "Request Model",
|
||||
"responseModel": "Response Model",
|
||||
"noData": "No data",
|
||||
"unknownProvider": "Unknown Provider",
|
||||
"stream": "Stream",
|
||||
@@ -657,7 +685,19 @@
|
||||
"input": "Input",
|
||||
"output": "Output",
|
||||
"cacheWrite": "Creation",
|
||||
"cacheRead": "Hit"
|
||||
"cacheRead": "Hit",
|
||||
"baseCost": "Base",
|
||||
"costMultiplier": "Cost Multiplier",
|
||||
"withMultiplier": "with multiplier",
|
||||
"requestDetail": "Request Detail",
|
||||
"requestNotFound": "Request not found",
|
||||
"basicInfo": "Basic Info",
|
||||
"tokenUsage": "Token Usage",
|
||||
"cacheCreationCost": "Cache Creation Cost",
|
||||
"costBreakdown": "Cost Breakdown",
|
||||
"performance": "Performance",
|
||||
"latency": "Latency",
|
||||
"errorMessage": "Error Message"
|
||||
},
|
||||
"usageScript": {
|
||||
"title": "Configure Usage Query",
|
||||
|
||||
@@ -376,7 +376,21 @@
|
||||
"saved": "プロキシ設定を保存しました",
|
||||
"saveFailed": "保存に失敗しました: {{error}}",
|
||||
"testSuccess": "接続成功!遅延 {{latency}}ms",
|
||||
"testFailed": "接続に失敗しました: {{error}}"
|
||||
"testFailed": "接続に失敗しました: {{error}}",
|
||||
"pricingDefaultsTitle": "課金のデフォルト設定",
|
||||
"pricingDefaultsDescription": "アプリごとのデフォルト倍率と課金モードを設定します。",
|
||||
"pricingAppLabel": "アプリ",
|
||||
"defaultCostMultiplierLabel": "デフォルト倍率",
|
||||
"defaultCostMultiplierHint": "コスト計算用の倍率(小数対応)。",
|
||||
"pricingModelSourceLabel": "課金モード",
|
||||
"pricingModelSourceRequest": "リクエストモデル",
|
||||
"pricingModelSourceResponse": "レスポンスモデル",
|
||||
"pricingSave": "課金設定を保存",
|
||||
"pricingSaved": "課金設定を保存しました",
|
||||
"pricingSaveFailed": "課金設定の保存に失敗しました: {{error}}",
|
||||
"pricingLoadFailed": "課金設定の読み込みに失敗しました: {{error}}",
|
||||
"defaultCostMultiplierRequired": "デフォルト倍率は必須です",
|
||||
"defaultCostMultiplierInvalid": "デフォルト倍率の形式が正しくありません"
|
||||
}
|
||||
},
|
||||
"apps": {
|
||||
@@ -512,7 +526,18 @@
|
||||
"useCustomProxy": "個別プロキシを使用",
|
||||
"proxyConfigDesc": "このプロバイダーに個別のネットワークプロキシを設定します。無効の場合はシステムプロキシまたはグローバル設定を使用します。",
|
||||
"proxyUsername": "ユーザー名(任意)",
|
||||
"proxyPassword": "パスワード(任意)"
|
||||
"proxyPassword": "パスワード(任意)",
|
||||
"pricingConfig": "課金設定",
|
||||
"useCustomPricing": "個別設定を使用",
|
||||
"pricingConfigDesc": "このプロバイダーに個別の課金パラメータを設定します。無効の場合はグローバル設定を使用します。",
|
||||
"costMultiplier": "コスト倍率",
|
||||
"costMultiplierPlaceholder": "空白の場合はグローバル設定を使用(1)",
|
||||
"costMultiplierHint": "実際のコスト = 基本コスト × 倍率、1.5 などの小数をサポート",
|
||||
"pricingModelSourceLabel": "課金モード",
|
||||
"pricingModelSourceInherit": "グローバル設定を継承",
|
||||
"pricingModelSourceRequest": "リクエストモデル",
|
||||
"pricingModelSourceResponse": "レスポンスモデル",
|
||||
"pricingModelSourceHint": "リクエストモデルまたはレスポンスモデルで価格を照合するかを選択"
|
||||
},
|
||||
"codexConfig": {
|
||||
"authJson": "auth.json (JSON) *",
|
||||
@@ -615,6 +640,9 @@
|
||||
"cacheCreationTokens": "キャッシュ作成",
|
||||
"timingInfo": "応答時間/TTFT",
|
||||
"status": "ステータス",
|
||||
"multiplier": "倍率",
|
||||
"requestModel": "リクエストモデル",
|
||||
"responseModel": "レスポンスモデル",
|
||||
"noData": "データなし",
|
||||
"unknownProvider": "不明なプロバイダー",
|
||||
"stream": "ストリーム",
|
||||
@@ -657,7 +685,19 @@
|
||||
"input": "Input",
|
||||
"output": "Output",
|
||||
"cacheWrite": "作成",
|
||||
"cacheRead": "ヒット"
|
||||
"cacheRead": "ヒット",
|
||||
"baseCost": "基本",
|
||||
"costMultiplier": "コスト倍率",
|
||||
"withMultiplier": "倍率込み",
|
||||
"requestDetail": "リクエスト詳細",
|
||||
"requestNotFound": "リクエストが見つかりません",
|
||||
"basicInfo": "基本情報",
|
||||
"tokenUsage": "Token 使用量",
|
||||
"cacheCreationCost": "キャッシュ作成コスト",
|
||||
"costBreakdown": "コスト明細",
|
||||
"performance": "パフォーマンス",
|
||||
"latency": "レイテンシー",
|
||||
"errorMessage": "エラーメッセージ"
|
||||
},
|
||||
"usageScript": {
|
||||
"title": "利用状況を設定",
|
||||
|
||||
@@ -376,7 +376,21 @@
|
||||
"saved": "代理设置已保存",
|
||||
"saveFailed": "保存失败:{{error}}",
|
||||
"testSuccess": "连接成功!延迟 {{latency}}ms",
|
||||
"testFailed": "连接失败:{{error}}"
|
||||
"testFailed": "连接失败:{{error}}",
|
||||
"pricingDefaultsTitle": "计费默认配置",
|
||||
"pricingDefaultsDescription": "设置各应用的默认倍率与计费模式来源。",
|
||||
"pricingAppLabel": "应用",
|
||||
"defaultCostMultiplierLabel": "默认倍率",
|
||||
"defaultCostMultiplierHint": "用于成本计算的倍率,支持小数。",
|
||||
"pricingModelSourceLabel": "计费模式",
|
||||
"pricingModelSourceRequest": "请求模型",
|
||||
"pricingModelSourceResponse": "返回模型",
|
||||
"pricingSave": "保存计费配置",
|
||||
"pricingSaved": "计费配置已保存",
|
||||
"pricingSaveFailed": "保存计费配置失败:{{error}}",
|
||||
"pricingLoadFailed": "加载计费配置失败:{{error}}",
|
||||
"defaultCostMultiplierRequired": "默认倍率不能为空",
|
||||
"defaultCostMultiplierInvalid": "默认倍率格式不正确"
|
||||
}
|
||||
},
|
||||
"apps": {
|
||||
@@ -512,7 +526,18 @@
|
||||
"useCustomProxy": "使用单独代理",
|
||||
"proxyConfigDesc": "为此供应商配置单独的网络代理,不启用时使用系统代理或全局设置。",
|
||||
"proxyUsername": "用户名(可选)",
|
||||
"proxyPassword": "密码(可选)"
|
||||
"proxyPassword": "密码(可选)",
|
||||
"pricingConfig": "计费配置",
|
||||
"useCustomPricing": "使用单独配置",
|
||||
"pricingConfigDesc": "为此供应商配置单独的计费参数,不启用时使用全局默认配置。",
|
||||
"costMultiplier": "成本倍率",
|
||||
"costMultiplierPlaceholder": "留空使用全局默认(1)",
|
||||
"costMultiplierHint": "实际成本 = 基础成本 × 倍率,支持小数如 1.5",
|
||||
"pricingModelSourceLabel": "计费模式",
|
||||
"pricingModelSourceInherit": "继承全局默认",
|
||||
"pricingModelSourceRequest": "请求模型",
|
||||
"pricingModelSourceResponse": "返回模型",
|
||||
"pricingModelSourceHint": "选择按请求模型还是返回模型进行定价匹配"
|
||||
},
|
||||
"codexConfig": {
|
||||
"authJson": "auth.json (JSON) *",
|
||||
@@ -615,6 +640,9 @@
|
||||
"cacheCreationTokens": "缓存创建",
|
||||
"timingInfo": "用时/首字",
|
||||
"status": "状态",
|
||||
"multiplier": "倍率",
|
||||
"requestModel": "请求模型",
|
||||
"responseModel": "返回模型",
|
||||
"noData": "暂无数据",
|
||||
"unknownProvider": "未知供应商",
|
||||
"stream": "流",
|
||||
@@ -657,7 +685,19 @@
|
||||
"input": "Input",
|
||||
"output": "Output",
|
||||
"cacheWrite": "创建",
|
||||
"cacheRead": "命中"
|
||||
"cacheRead": "命中",
|
||||
"baseCost": "基础",
|
||||
"costMultiplier": "成本倍率",
|
||||
"withMultiplier": "含倍率",
|
||||
"requestDetail": "请求详情",
|
||||
"requestNotFound": "请求未找到",
|
||||
"basicInfo": "基本信息",
|
||||
"tokenUsage": "Token 使用量",
|
||||
"cacheCreationCost": "缓存写入成本",
|
||||
"costBreakdown": "成本明细",
|
||||
"performance": "性能信息",
|
||||
"latency": "延迟",
|
||||
"errorMessage": "错误信息"
|
||||
},
|
||||
"usageScript": {
|
||||
"title": "配置用量查询",
|
||||
|
||||
@@ -92,4 +92,29 @@ export const proxyApi = {
|
||||
async updateProxyConfigForApp(config: AppProxyConfig): Promise<void> {
|
||||
return invoke("update_proxy_config_for_app", { config });
|
||||
},
|
||||
|
||||
// ========== 计费默认配置 API ==========
|
||||
|
||||
// 获取默认成本倍率
|
||||
async getDefaultCostMultiplier(appType: string): Promise<string> {
|
||||
return invoke("get_default_cost_multiplier", { appType });
|
||||
},
|
||||
|
||||
// 设置默认成本倍率
|
||||
async setDefaultCostMultiplier(
|
||||
appType: string,
|
||||
value: string,
|
||||
): Promise<void> {
|
||||
return invoke("set_default_cost_multiplier", { appType, value });
|
||||
},
|
||||
|
||||
// 获取计费模式来源
|
||||
async getPricingModelSource(appType: string): Promise<string> {
|
||||
return invoke("get_pricing_model_source", { appType });
|
||||
},
|
||||
|
||||
// 设置计费模式来源
|
||||
async setPricingModelSource(appType: string, value: string): Promise<void> {
|
||||
return invoke("set_pricing_model_source", { appType, value });
|
||||
},
|
||||
};
|
||||
|
||||
@@ -135,6 +135,10 @@ export interface ProviderMeta {
|
||||
testConfig?: ProviderTestConfig;
|
||||
// 供应商单独的代理配置
|
||||
proxyConfig?: ProviderProxyConfig;
|
||||
// 供应商成本倍率
|
||||
costMultiplier?: string;
|
||||
// 供应商计费模式来源
|
||||
pricingModelSource?: string;
|
||||
}
|
||||
|
||||
// Skill 同步方式
|
||||
|
||||
@@ -13,6 +13,8 @@ export interface RequestLog {
|
||||
providerName?: string;
|
||||
appType: string;
|
||||
model: string;
|
||||
requestModel?: string;
|
||||
costMultiplier: string;
|
||||
inputTokens: number;
|
||||
outputTokens: number;
|
||||
cacheReadTokens: number;
|
||||
|
||||
83
tests/components/GlobalProxySettings.test.tsx
Normal file
83
tests/components/GlobalProxySettings.test.tsx
Normal file
@@ -0,0 +1,83 @@
|
||||
import { render, screen, fireEvent, waitFor } from "@testing-library/react";
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||
import { GlobalProxySettings } from "@/components/settings/GlobalProxySettings";
|
||||
|
||||
vi.mock("react-i18next", () => ({
|
||||
useTranslation: () => ({ t: (key: string) => key }),
|
||||
}));
|
||||
|
||||
const mutateAsyncMock = vi.fn();
|
||||
const testMutateAsyncMock = vi.fn();
|
||||
const scanMutateAsyncMock = vi.fn();
|
||||
|
||||
vi.mock("@/hooks/useGlobalProxy", () => ({
|
||||
useGlobalProxyUrl: () => ({ data: "http://127.0.0.1:7890", isLoading: false }),
|
||||
useSetGlobalProxyUrl: () => ({
|
||||
mutateAsync: mutateAsyncMock,
|
||||
isPending: false,
|
||||
}),
|
||||
useTestProxy: () => ({
|
||||
mutateAsync: testMutateAsyncMock,
|
||||
isPending: false,
|
||||
}),
|
||||
useScanProxies: () => ({
|
||||
mutateAsync: scanMutateAsyncMock,
|
||||
isPending: false,
|
||||
}),
|
||||
}));
|
||||
|
||||
describe("GlobalProxySettings", () => {
|
||||
beforeEach(() => {
|
||||
mutateAsyncMock.mockReset();
|
||||
testMutateAsyncMock.mockReset();
|
||||
scanMutateAsyncMock.mockReset();
|
||||
});
|
||||
|
||||
it("renders proxy URL input with saved value", async () => {
|
||||
render(<GlobalProxySettings />);
|
||||
|
||||
const urlInput = screen.getByPlaceholderText(
|
||||
"http://127.0.0.1:7890 / socks5://127.0.0.1:1080",
|
||||
);
|
||||
// URL 对象会在末尾添加斜杠
|
||||
await waitFor(() =>
|
||||
expect(urlInput).toHaveValue("http://127.0.0.1:7890/"),
|
||||
);
|
||||
});
|
||||
|
||||
it("saves proxy URL when save button is clicked", async () => {
|
||||
render(<GlobalProxySettings />);
|
||||
|
||||
const urlInput = screen.getByPlaceholderText(
|
||||
"http://127.0.0.1:7890 / socks5://127.0.0.1:1080",
|
||||
);
|
||||
|
||||
fireEvent.change(urlInput, { target: { value: "http://localhost:8080" } });
|
||||
|
||||
const saveButton = screen.getByRole("button", { name: "common.save" });
|
||||
fireEvent.click(saveButton);
|
||||
|
||||
await waitFor(() => expect(mutateAsyncMock).toHaveBeenCalled());
|
||||
// 没有用户名时,URL 不经过 URL 对象解析,所以没有尾部斜杠
|
||||
expect(mutateAsyncMock).toHaveBeenCalledWith("http://localhost:8080");
|
||||
});
|
||||
|
||||
it("clears proxy URL when clear button is clicked", async () => {
|
||||
render(<GlobalProxySettings />);
|
||||
|
||||
const urlInput = screen.getByPlaceholderText(
|
||||
"http://127.0.0.1:7890 / socks5://127.0.0.1:1080",
|
||||
);
|
||||
|
||||
// Wait for initial value to load
|
||||
await waitFor(() =>
|
||||
expect(urlInput).toHaveValue("http://127.0.0.1:7890/"),
|
||||
);
|
||||
|
||||
// Click clear button
|
||||
const clearButton = screen.getByTitle("settings.globalProxy.clear");
|
||||
fireEvent.click(clearButton);
|
||||
|
||||
expect(urlInput).toHaveValue("");
|
||||
});
|
||||
});
|
||||
26
tests/setupGlobals.ts
Normal file
26
tests/setupGlobals.ts
Normal file
@@ -0,0 +1,26 @@
|
||||
const storage = new Map<string, string>();
|
||||
|
||||
if (
|
||||
typeof globalThis.localStorage === "undefined" ||
|
||||
typeof globalThis.localStorage?.getItem !== "function"
|
||||
) {
|
||||
Object.defineProperty(globalThis, "localStorage", {
|
||||
value: {
|
||||
getItem: (key: string) => storage.get(key) ?? null,
|
||||
setItem: (key: string, value: string) => {
|
||||
storage.set(key, String(value));
|
||||
},
|
||||
removeItem: (key: string) => {
|
||||
storage.delete(key);
|
||||
},
|
||||
clear: () => {
|
||||
storage.clear();
|
||||
},
|
||||
key: (index: number) => Array.from(storage.keys())[index] ?? null,
|
||||
get length() {
|
||||
return storage.size;
|
||||
},
|
||||
},
|
||||
configurable: true,
|
||||
});
|
||||
}
|
||||
@@ -11,7 +11,7 @@ export default defineConfig({
|
||||
},
|
||||
test: {
|
||||
environment: "jsdom",
|
||||
setupFiles: ["./tests/setupTests.ts"],
|
||||
setupFiles: ["./tests/setupGlobals.ts", "./tests/setupTests.ts"],
|
||||
globals: true,
|
||||
coverage: {
|
||||
reporter: ["text", "lcov"],
|
||||
|
||||
Reference in New Issue
Block a user