mirror of
https://github.com/farion1231/cc-switch.git
synced 2026-05-08 20:21:41 +08:00
Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 766ca23eca | |||
| 48a8aabfb8 | |||
| cf79b09597 | |||
| 6004084644 | |||
| 52199f39c1 | |||
| c69d36d457 | |||
| a2aa969096 | |||
| 02fd924119 | |||
| fa4e7bcd82 | |||
| 43ccbdad47 | |||
| b46fb6e6da | |||
| 0efe0594e9 | |||
| a80b2b98b7 | |||
| f4f0590fd6 | |||
| bfa5c3c526 | |||
| 9f212115c4 | |||
| da75f22a12 | |||
| 3a692c84fb | |||
| fe08f69cac | |||
| 4a1ee98784 | |||
| 2901ead814 | |||
| e99d599ad8 |
@@ -184,17 +184,16 @@ pub async fn reset_circuit_breaker(
|
||||
.await?;
|
||||
|
||||
// 3. 检查是否应该切回优先级更高的供应商(从 proxy_config 表读取)
|
||||
let auto_failover_enabled = match db.get_proxy_config_for_app(&app_type).await {
|
||||
Ok(config) => config.auto_failover_enabled,
|
||||
// 只有当该应用已被代理接管(enabled=true)且开启了自动故障转移时才执行
|
||||
let (app_enabled, auto_failover_enabled) = match db.get_proxy_config_for_app(&app_type).await {
|
||||
Ok(config) => (config.enabled, config.auto_failover_enabled),
|
||||
Err(e) => {
|
||||
log::error!(
|
||||
"[{app_type}] Failed to read proxy_config for auto_failover_enabled: {e}, defaulting to disabled"
|
||||
);
|
||||
false
|
||||
log::error!("[{app_type}] Failed to read proxy_config: {e}, defaulting to disabled");
|
||||
(false, false)
|
||||
}
|
||||
};
|
||||
|
||||
if auto_failover_enabled && state.proxy_service.is_running().await {
|
||||
if app_enabled && auto_failover_enabled && state.proxy_service.is_running().await {
|
||||
// 获取当前供应商 ID
|
||||
let current_id = db
|
||||
.get_current_provider(&app_type)
|
||||
|
||||
@@ -19,9 +19,10 @@ pub fn get_usage_summary(
|
||||
#[tauri::command]
|
||||
pub fn get_usage_trends(
|
||||
state: State<'_, AppState>,
|
||||
days: u32,
|
||||
start_date: Option<i64>,
|
||||
end_date: Option<i64>,
|
||||
) -> Result<Vec<DailyStats>, AppError> {
|
||||
state.db.get_daily_trends(days)
|
||||
state.db.get_daily_trends(start_date, end_date)
|
||||
}
|
||||
|
||||
/// 获取 Provider 统计
|
||||
|
||||
@@ -765,9 +765,9 @@ impl Database {
|
||||
/// 注意: model_id 使用短横线格式(如 claude-haiku-4-5),与 API 返回的模型名称标准化后一致
|
||||
fn seed_model_pricing(conn: &Connection) -> Result<(), AppError> {
|
||||
let pricing_data = [
|
||||
// Claude 4.5 系列
|
||||
// Claude 4.5 系列 (Latest Models)
|
||||
(
|
||||
"claude-opus-4-5",
|
||||
"claude-opus-4-5-20251101",
|
||||
"Claude Opus 4.5",
|
||||
"5",
|
||||
"25",
|
||||
@@ -775,7 +775,7 @@ impl Database {
|
||||
"6.25",
|
||||
),
|
||||
(
|
||||
"claude-sonnet-4-5",
|
||||
"claude-sonnet-4-5-20250929",
|
||||
"Claude Sonnet 4.5",
|
||||
"3",
|
||||
"15",
|
||||
@@ -783,16 +783,24 @@ impl Database {
|
||||
"3.75",
|
||||
),
|
||||
(
|
||||
"claude-haiku-4-5",
|
||||
"claude-haiku-4-5-20251001",
|
||||
"Claude Haiku 4.5",
|
||||
"1",
|
||||
"5",
|
||||
"0.10",
|
||||
"1.25",
|
||||
),
|
||||
// Claude 4.1 系列
|
||||
// Claude 4 系列 (Legacy Models)
|
||||
(
|
||||
"claude-opus-4-1",
|
||||
"claude-opus-4-20250514",
|
||||
"Claude Opus 4",
|
||||
"15",
|
||||
"75",
|
||||
"1.50",
|
||||
"18.75",
|
||||
),
|
||||
(
|
||||
"claude-opus-4-1-20250805",
|
||||
"Claude Opus 4.1",
|
||||
"15",
|
||||
"75",
|
||||
@@ -800,17 +808,8 @@ impl Database {
|
||||
"18.75",
|
||||
),
|
||||
(
|
||||
"claude-sonnet-4-1",
|
||||
"Claude Sonnet 4.1",
|
||||
"3",
|
||||
"15",
|
||||
"0.30",
|
||||
"3.75",
|
||||
),
|
||||
// Claude 3.7 系列
|
||||
(
|
||||
"claude-sonnet-3-7",
|
||||
"Claude Sonnet 3.7",
|
||||
"claude-sonnet-4-20250514",
|
||||
"Claude Sonnet 4",
|
||||
"3",
|
||||
"15",
|
||||
"0.30",
|
||||
@@ -818,38 +817,167 @@ impl Database {
|
||||
),
|
||||
// Claude 3.5 系列
|
||||
(
|
||||
"claude-sonnet-3-5",
|
||||
"Claude Sonnet 3.5",
|
||||
"3",
|
||||
"15",
|
||||
"0.30",
|
||||
"3.75",
|
||||
),
|
||||
(
|
||||
"claude-haiku-3-5",
|
||||
"Claude Haiku 3.5",
|
||||
"claude-3-5-haiku-20241022",
|
||||
"Claude 3.5 Haiku",
|
||||
"0.80",
|
||||
"4",
|
||||
"0.08",
|
||||
"1",
|
||||
),
|
||||
// GPT-5 系列(model_id 使用短横线格式)
|
||||
(
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"Claude 3.5 Sonnet",
|
||||
"3",
|
||||
"15",
|
||||
"0.30",
|
||||
"3.75",
|
||||
),
|
||||
// GPT-5.2 系列
|
||||
("gpt-5.2", "GPT-5.2", "1.75", "14", "0.175", "0"),
|
||||
("gpt-5.2-low", "GPT-5.2", "1.75", "14", "0.175", "0"),
|
||||
("gpt-5.2-medium", "GPT-5.2", "1.75", "14", "0.175", "0"),
|
||||
("gpt-5.2-high", "GPT-5.2", "1.75", "14", "0.175", "0"),
|
||||
("gpt-5.2-xhigh", "GPT-5.2", "1.75", "14", "0.175", "0"),
|
||||
("gpt-5.2-codex", "GPT-5.2 Codex", "1.75", "14", "0.175", "0"),
|
||||
(
|
||||
"gpt-5.2-codex-low",
|
||||
"GPT-5.2 Codex",
|
||||
"1.75",
|
||||
"14",
|
||||
"0.175",
|
||||
"0",
|
||||
),
|
||||
(
|
||||
"gpt-5.2-codex-medium",
|
||||
"GPT-5.2 Codex",
|
||||
"1.75",
|
||||
"14",
|
||||
"0.175",
|
||||
"0",
|
||||
),
|
||||
(
|
||||
"gpt-5.2-codex-high",
|
||||
"GPT-5.2 Codex",
|
||||
"1.75",
|
||||
"14",
|
||||
"0.175",
|
||||
"0",
|
||||
),
|
||||
(
|
||||
"gpt-5.2-codex-xhigh",
|
||||
"GPT-5.2 Codex",
|
||||
"1.75",
|
||||
"14",
|
||||
"0.175",
|
||||
"0",
|
||||
),
|
||||
// GPT-5.1 系列
|
||||
("gpt-5.1", "GPT-5.1", "1.25", "10", "0.125", "0"),
|
||||
("gpt-5.1-low", "GPT-5.1", "1.25", "10", "0.125", "0"),
|
||||
("gpt-5.1-medium", "GPT-5.1", "1.25", "10", "0.125", "0"),
|
||||
("gpt-5.1-high", "GPT-5.1", "1.25", "10", "0.125", "0"),
|
||||
("gpt-5.1-minimal", "GPT-5.1", "1.25", "10", "0.125", "0"),
|
||||
("gpt-5.1-codex", "GPT-5.1 Codex", "1.25", "10", "0.125", "0"),
|
||||
(
|
||||
"gpt-5.1-codex-mini",
|
||||
"GPT-5.1 Codex",
|
||||
"1.25",
|
||||
"10",
|
||||
"0.125",
|
||||
"0",
|
||||
),
|
||||
(
|
||||
"gpt-5.1-codex-max",
|
||||
"GPT-5.1 Codex",
|
||||
"1.25",
|
||||
"10",
|
||||
"0.125",
|
||||
"0",
|
||||
),
|
||||
(
|
||||
"gpt-5.1-codex-max-high",
|
||||
"GPT-5.1 Codex",
|
||||
"1.25",
|
||||
"10",
|
||||
"0.125",
|
||||
"0",
|
||||
),
|
||||
(
|
||||
"gpt-5.1-codex-max-xhigh",
|
||||
"GPT-5.1 Codex",
|
||||
"1.25",
|
||||
"10",
|
||||
"0.125",
|
||||
"0",
|
||||
),
|
||||
// GPT-5 系列
|
||||
("gpt-5", "GPT-5", "1.25", "10", "0.125", "0"),
|
||||
("gpt-5-1", "GPT-5.1", "1.25", "10", "0.125", "0"),
|
||||
("gpt-5-low", "GPT-5", "1.25", "10", "0.125", "0"),
|
||||
("gpt-5-medium", "GPT-5", "1.25", "10", "0.125", "0"),
|
||||
("gpt-5-high", "GPT-5", "1.25", "10", "0.125", "0"),
|
||||
("gpt-5-minimal", "GPT-5", "1.25", "10", "0.125", "0"),
|
||||
("gpt-5-codex", "GPT-5 Codex", "1.25", "10", "0.125", "0"),
|
||||
("gpt-5-1-codex", "GPT-5.1 Codex", "1.25", "10", "0.125", "0"),
|
||||
("gpt-5-codex-low", "GPT-5 Codex", "1.25", "10", "0.125", "0"),
|
||||
(
|
||||
"gpt-5-codex-medium",
|
||||
"GPT-5 Codex",
|
||||
"1.25",
|
||||
"10",
|
||||
"0.125",
|
||||
"0",
|
||||
),
|
||||
(
|
||||
"gpt-5-codex-high",
|
||||
"GPT-5 Codex",
|
||||
"1.25",
|
||||
"10",
|
||||
"0.125",
|
||||
"0",
|
||||
),
|
||||
(
|
||||
"gpt-5-codex-mini",
|
||||
"GPT-5 Codex",
|
||||
"1.25",
|
||||
"10",
|
||||
"0.125",
|
||||
"0",
|
||||
),
|
||||
(
|
||||
"gpt-5-codex-mini-medium",
|
||||
"GPT-5 Codex",
|
||||
"1.25",
|
||||
"10",
|
||||
"0.125",
|
||||
"0",
|
||||
),
|
||||
(
|
||||
"gpt-5-codex-mini-high",
|
||||
"GPT-5 Codex",
|
||||
"1.25",
|
||||
"10",
|
||||
"0.125",
|
||||
"0",
|
||||
),
|
||||
// Gemini 3 系列
|
||||
(
|
||||
"gemini-3-pro-preview",
|
||||
"Gemini 3 Pro Preview",
|
||||
"2",
|
||||
"12",
|
||||
"0",
|
||||
"0.2",
|
||||
"0",
|
||||
),
|
||||
// Gemini 2.5 系列(model_id 使用短横线格式)
|
||||
(
|
||||
"gemini-2-5-pro",
|
||||
"gemini-3-flash-preview",
|
||||
"Gemini 3 Flash Preview",
|
||||
"0.5",
|
||||
"3",
|
||||
"0.05",
|
||||
"0",
|
||||
),
|
||||
// Gemini 2.5 系列
|
||||
(
|
||||
"gemini-2.5-pro",
|
||||
"Gemini 2.5 Pro",
|
||||
"1.25",
|
||||
"10",
|
||||
@@ -857,13 +985,75 @@ impl Database {
|
||||
"0",
|
||||
),
|
||||
(
|
||||
"gemini-2-5-flash",
|
||||
"gemini-2.5-flash",
|
||||
"Gemini 2.5 Flash",
|
||||
"0.3",
|
||||
"2.5",
|
||||
"0.03",
|
||||
"0",
|
||||
),
|
||||
// ====== 国产模型 (CNY/1M tokens) ======
|
||||
// Doubao (字节跳动)
|
||||
(
|
||||
"doubao-seed-code",
|
||||
"Doubao Seed Code",
|
||||
"1.20",
|
||||
"8.00",
|
||||
"0.24",
|
||||
"0",
|
||||
),
|
||||
// DeepSeek 系列
|
||||
(
|
||||
"deepseek-v3.2",
|
||||
"DeepSeek V3.2",
|
||||
"2.00",
|
||||
"3.00",
|
||||
"0.40",
|
||||
"0",
|
||||
),
|
||||
(
|
||||
"deepseek-v3.1",
|
||||
"DeepSeek V3.1",
|
||||
"4.00",
|
||||
"12.00",
|
||||
"0.80",
|
||||
"0",
|
||||
),
|
||||
("deepseek-v3", "DeepSeek V3", "2.00", "8.00", "0.40", "0"),
|
||||
// Kimi (月之暗面)
|
||||
(
|
||||
"kimi-k2-thinking",
|
||||
"Kimi K2 Thinking",
|
||||
"4.00",
|
||||
"16.00",
|
||||
"1.00",
|
||||
"0",
|
||||
),
|
||||
("kimi-k2-0905", "Kimi K2", "4.00", "16.00", "1.00", "0"),
|
||||
(
|
||||
"kimi-k2-turbo",
|
||||
"Kimi K2 Turbo",
|
||||
"8.00",
|
||||
"58.00",
|
||||
"1.00",
|
||||
"0",
|
||||
),
|
||||
// MiniMax 系列
|
||||
("minimax-m2.1", "MiniMax M2.1", "2.10", "8.40", "0.21", "0"),
|
||||
(
|
||||
"minimax-m2.1-lightning",
|
||||
"MiniMax M2.1 Lightning",
|
||||
"2.10",
|
||||
"16.80",
|
||||
"0.21",
|
||||
"0",
|
||||
),
|
||||
("minimax-m2", "MiniMax M2", "2.10", "8.40", "0.21", "0"),
|
||||
// GLM (智谱)
|
||||
("glm-4.7", "GLM-4.7", "2.00", "8.00", "0.40", "0"),
|
||||
("glm-4.6", "GLM-4.6", "2.00", "8.00", "0.40", "0"),
|
||||
// Mimo (小米)
|
||||
("mimo-v2-flash", "Mimo V2 Flash", "0", "0", "0", "0"),
|
||||
];
|
||||
|
||||
for (model_id, display_name, input, output, cache_read, cache_creation) in pricing_data {
|
||||
|
||||
@@ -52,6 +52,10 @@ pub enum AppError {
|
||||
},
|
||||
#[error("数据库错误: {0}")]
|
||||
Database(String),
|
||||
#[error("所有供应商已熔断,无可用渠道")]
|
||||
AllProvidersCircuitOpen,
|
||||
#[error("未配置供应商")]
|
||||
NoProvidersConfigured,
|
||||
}
|
||||
|
||||
impl AppError {
|
||||
|
||||
@@ -386,16 +386,15 @@ impl UniversalProvider {
|
||||
// 生成 Codex 的 config.toml 内容
|
||||
let config_toml = format!(
|
||||
r#"model_provider = "newapi"
|
||||
model = "{}"
|
||||
model_reasoning_effort = "{}"
|
||||
model = "{model}"
|
||||
model_reasoning_effort = "{reasoning_effort}"
|
||||
disable_response_storage = true
|
||||
|
||||
[model_providers.newapi]
|
||||
name = "NewAPI"
|
||||
base_url = "{}"
|
||||
base_url = "{codex_base_url}"
|
||||
wire_api = "responses"
|
||||
requires_openai_auth = true"#,
|
||||
model, reasoning_effort, codex_base_url
|
||||
requires_openai_auth = true"#
|
||||
);
|
||||
|
||||
let settings_config = serde_json::json!({
|
||||
|
||||
@@ -0,0 +1,303 @@
|
||||
//! 请求体过滤模块
|
||||
//!
|
||||
//! 过滤不应透传到上游的私有参数,防止内部信息泄露。
|
||||
//!
|
||||
//! ## 过滤规则
|
||||
//! - 以 `_` 开头的字段被视为私有参数,会被递归过滤
|
||||
//! - 支持白名单机制,允许透传特定的 `_` 前缀字段
|
||||
//! - 支持嵌套对象和数组的深度过滤
|
||||
//!
|
||||
//! ## 使用场景
|
||||
//! - `_internal_id`: 内部追踪 ID
|
||||
//! - `_debug_mode`: 调试标记
|
||||
//! - `_session_token`: 会话令牌
|
||||
//! - `_client_version`: 客户端版本
|
||||
|
||||
use serde_json::Value;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// 过滤私有参数(以 `_` 开头的字段)
|
||||
///
|
||||
/// 递归遍历 JSON 结构,移除所有以下划线开头的字段。
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `body` - 原始请求体
|
||||
///
|
||||
/// # Returns
|
||||
/// 过滤后的请求体
|
||||
///
|
||||
/// # Example
|
||||
/// ```ignore
|
||||
/// let input = json!({
|
||||
/// "model": "claude-3",
|
||||
/// "_internal_id": "abc123",
|
||||
/// "messages": [{"role": "user", "content": "hello", "_token": "secret"}]
|
||||
/// });
|
||||
/// let output = filter_private_params(input);
|
||||
/// // output 中不包含 _internal_id 和 _token
|
||||
/// ```
|
||||
#[cfg(test)]
|
||||
pub fn filter_private_params(body: Value) -> Value {
|
||||
filter_private_params_with_whitelist(body, &[])
|
||||
}
|
||||
|
||||
/// 过滤私有参数(支持白名单)
|
||||
///
|
||||
/// 递归遍历 JSON 结构,移除所有以下划线开头的字段,
|
||||
/// 但保留白名单中指定的字段。
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `body` - 原始请求体
|
||||
/// * `whitelist` - 白名单字段列表(不过滤这些字段)
|
||||
///
|
||||
/// # Returns
|
||||
/// 过滤后的请求体
|
||||
///
|
||||
/// # Example
|
||||
/// ```ignore
|
||||
/// let input = json!({
|
||||
/// "model": "claude-3",
|
||||
/// "_metadata": {"key": "value"}, // 白名单中,保留
|
||||
/// "_internal_id": "abc123" // 不在白名单中,过滤
|
||||
/// });
|
||||
/// let output = filter_private_params_with_whitelist(input, &["_metadata"]);
|
||||
/// // output 包含 _metadata,不包含 _internal_id
|
||||
/// ```
|
||||
pub fn filter_private_params_with_whitelist(body: Value, whitelist: &[String]) -> Value {
|
||||
let whitelist_set: HashSet<&str> = whitelist.iter().map(|s| s.as_str()).collect();
|
||||
filter_recursive_with_whitelist(body, &mut Vec::new(), &whitelist_set)
|
||||
}
|
||||
|
||||
/// 递归过滤实现
|
||||
#[cfg(test)]
|
||||
fn filter_recursive(value: Value, removed_keys: &mut Vec<String>) -> Value {
|
||||
filter_recursive_with_whitelist(value, removed_keys, &HashSet::new())
|
||||
}
|
||||
|
||||
/// 递归过滤实现(支持白名单)
|
||||
fn filter_recursive_with_whitelist(
|
||||
value: Value,
|
||||
removed_keys: &mut Vec<String>,
|
||||
whitelist: &HashSet<&str>,
|
||||
) -> Value {
|
||||
match value {
|
||||
Value::Object(map) => {
|
||||
let filtered: serde_json::Map<String, Value> = map
|
||||
.into_iter()
|
||||
.filter_map(|(key, val)| {
|
||||
// 以 _ 开头且不在白名单中的字段被过滤
|
||||
if key.starts_with('_') && !whitelist.contains(key.as_str()) {
|
||||
removed_keys.push(key);
|
||||
None
|
||||
} else {
|
||||
Some((
|
||||
key,
|
||||
filter_recursive_with_whitelist(val, removed_keys, whitelist),
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// 仅在有过滤时记录日志(避免每次请求都打印)
|
||||
if !removed_keys.is_empty() {
|
||||
log::debug!("[BodyFilter] 过滤私有参数: {removed_keys:?}");
|
||||
removed_keys.clear();
|
||||
}
|
||||
|
||||
Value::Object(filtered)
|
||||
}
|
||||
Value::Array(arr) => Value::Array(
|
||||
arr.into_iter()
|
||||
.map(|v| filter_recursive_with_whitelist(v, removed_keys, whitelist))
|
||||
.collect(),
|
||||
),
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_filter_top_level_private_params() {
|
||||
let input = json!({
|
||||
"model": "claude-3",
|
||||
"_internal_id": "abc123",
|
||||
"_debug": true,
|
||||
"max_tokens": 1024
|
||||
});
|
||||
|
||||
let output = filter_private_params(input);
|
||||
|
||||
assert!(output.get("model").is_some());
|
||||
assert!(output.get("max_tokens").is_some());
|
||||
assert!(output.get("_internal_id").is_none());
|
||||
assert!(output.get("_debug").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_filter_nested_private_params() {
|
||||
let input = json!({
|
||||
"model": "claude-3",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello",
|
||||
"_session_token": "secret"
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"user_id": "user-1",
|
||||
"_tracking_id": "track-1"
|
||||
}
|
||||
});
|
||||
|
||||
let output = filter_private_params(input);
|
||||
|
||||
// 顶级字段保留
|
||||
assert!(output.get("model").is_some());
|
||||
assert!(output.get("messages").is_some());
|
||||
assert!(output.get("metadata").is_some());
|
||||
|
||||
// messages 数组中的私有参数被过滤
|
||||
let messages = output.get("messages").unwrap().as_array().unwrap();
|
||||
assert!(messages[0].get("role").is_some());
|
||||
assert!(messages[0].get("content").is_some());
|
||||
assert!(messages[0].get("_session_token").is_none());
|
||||
|
||||
// metadata 对象中的私有参数被过滤
|
||||
let metadata = output.get("metadata").unwrap();
|
||||
assert!(metadata.get("user_id").is_some());
|
||||
assert!(metadata.get("_tracking_id").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_filter_deeply_nested() {
|
||||
let input = json!({
|
||||
"level1": {
|
||||
"level2": {
|
||||
"level3": {
|
||||
"keep": "value",
|
||||
"_remove": "secret"
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let output = filter_private_params(input);
|
||||
|
||||
let level3 = output
|
||||
.get("level1")
|
||||
.unwrap()
|
||||
.get("level2")
|
||||
.unwrap()
|
||||
.get("level3")
|
||||
.unwrap();
|
||||
|
||||
assert!(level3.get("keep").is_some());
|
||||
assert!(level3.get("_remove").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_filter_array_of_objects() {
|
||||
let input = json!({
|
||||
"items": [
|
||||
{"id": 1, "_secret": "a"},
|
||||
{"id": 2, "_secret": "b"},
|
||||
{"id": 3, "_secret": "c"}
|
||||
]
|
||||
});
|
||||
|
||||
let output = filter_private_params(input);
|
||||
let items = output.get("items").unwrap().as_array().unwrap();
|
||||
|
||||
for item in items {
|
||||
assert!(item.get("id").is_some());
|
||||
assert!(item.get("_secret").is_none());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_private_params() {
|
||||
let input = json!({
|
||||
"model": "claude-3",
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
});
|
||||
|
||||
let output = filter_private_params(input.clone());
|
||||
|
||||
// 无私有参数时,输出应与输入相同
|
||||
assert_eq!(input, output);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_object() {
|
||||
let input = json!({});
|
||||
let output = filter_private_params(input);
|
||||
assert_eq!(output, json!({}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_primitive_values() {
|
||||
// 原始值不应被修改
|
||||
assert_eq!(filter_private_params(json!(42)), json!(42));
|
||||
assert_eq!(filter_private_params(json!("string")), json!("string"));
|
||||
assert_eq!(filter_private_params(json!(true)), json!(true));
|
||||
assert_eq!(filter_private_params(json!(null)), json!(null));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_whitelist_preserves_private_params() {
|
||||
let input = json!({
|
||||
"model": "claude-3",
|
||||
"_metadata": {"key": "value"},
|
||||
"_internal_id": "abc123",
|
||||
"_stream_options": {"include_usage": true}
|
||||
});
|
||||
|
||||
let whitelist = vec!["_metadata".to_string(), "_stream_options".to_string()];
|
||||
let output = filter_private_params_with_whitelist(input, &whitelist);
|
||||
|
||||
// 白名单中的字段保留
|
||||
assert!(output.get("_metadata").is_some());
|
||||
assert!(output.get("_stream_options").is_some());
|
||||
// 不在白名单中的私有字段被过滤
|
||||
assert!(output.get("_internal_id").is_none());
|
||||
// 普通字段保留
|
||||
assert!(output.get("model").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_whitelist_nested() {
|
||||
let input = json!({
|
||||
"data": {
|
||||
"_allowed": "keep",
|
||||
"_forbidden": "remove",
|
||||
"normal": "value"
|
||||
}
|
||||
});
|
||||
|
||||
let whitelist = vec!["_allowed".to_string()];
|
||||
let output = filter_private_params_with_whitelist(input, &whitelist);
|
||||
|
||||
let data = output.get("data").unwrap();
|
||||
assert!(data.get("_allowed").is_some());
|
||||
assert!(data.get("_forbidden").is_none());
|
||||
assert!(data.get("normal").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_whitelist_same_as_default() {
|
||||
let input = json!({
|
||||
"model": "claude-3",
|
||||
"_internal_id": "abc123"
|
||||
});
|
||||
|
||||
let output1 = filter_private_params(input.clone());
|
||||
let output2 = filter_private_params_with_whitelist(input, &[]);
|
||||
|
||||
assert_eq!(output1, output2);
|
||||
}
|
||||
}
|
||||
@@ -23,6 +23,12 @@ pub enum ProxyError {
|
||||
#[error("无可用的Provider")]
|
||||
NoAvailableProvider,
|
||||
|
||||
#[error("所有供应商已熔断,无可用渠道")]
|
||||
AllProvidersCircuitOpen,
|
||||
|
||||
#[error("未配置供应商")]
|
||||
NoProvidersConfigured,
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[error("Provider不健康: {0}")]
|
||||
ProviderUnhealthy(String),
|
||||
@@ -111,6 +117,12 @@ impl IntoResponse for ProxyError {
|
||||
ProxyError::NoAvailableProvider => {
|
||||
(StatusCode::SERVICE_UNAVAILABLE, self.to_string())
|
||||
}
|
||||
ProxyError::AllProvidersCircuitOpen => {
|
||||
(StatusCode::SERVICE_UNAVAILABLE, self.to_string())
|
||||
}
|
||||
ProxyError::NoProvidersConfigured => {
|
||||
(StatusCode::SERVICE_UNAVAILABLE, self.to_string())
|
||||
}
|
||||
ProxyError::ProviderUnhealthy(_) => {
|
||||
(StatusCode::SERVICE_UNAVAILABLE, self.to_string())
|
||||
}
|
||||
|
||||
@@ -27,6 +27,12 @@ pub fn map_proxy_error_to_status(error: &ProxyError) -> u16 {
|
||||
// 无可用 Provider:503 Service Unavailable
|
||||
ProxyError::NoAvailableProvider => 503,
|
||||
|
||||
// 所有供应商已熔断:503 Service Unavailable
|
||||
ProxyError::AllProvidersCircuitOpen => 503,
|
||||
|
||||
// 未配置供应商:503 Service Unavailable
|
||||
ProxyError::NoProvidersConfigured => 503,
|
||||
|
||||
// 重试耗尽:503 Service Unavailable
|
||||
ProxyError::MaxRetriesExceeded => 503,
|
||||
|
||||
@@ -57,6 +63,8 @@ pub fn get_error_message(error: &ProxyError) -> String {
|
||||
ProxyError::Timeout(msg) => format!("请求超时: {msg}"),
|
||||
ProxyError::ForwardFailed(msg) => format!("转发失败: {msg}"),
|
||||
ProxyError::NoAvailableProvider => "无可用 Provider".to_string(),
|
||||
ProxyError::AllProvidersCircuitOpen => "所有供应商已熔断,无可用渠道".to_string(),
|
||||
ProxyError::NoProvidersConfigured => "未配置供应商".to_string(),
|
||||
ProxyError::MaxRetriesExceeded => "所有 Provider 都失败,重试耗尽".to_string(),
|
||||
ProxyError::ProviderUnhealthy(msg) => format!("Provider 不健康: {msg}"),
|
||||
ProxyError::DatabaseError(msg) => format!("数据库错误: {msg}"),
|
||||
|
||||
@@ -81,6 +81,21 @@ impl FailoverSwitchManager {
|
||||
provider_id: &str,
|
||||
provider_name: &str,
|
||||
) -> Result<bool, AppError> {
|
||||
// 检查该应用是否已被代理接管(enabled=true)
|
||||
// 只有被接管的应用才允许执行故障转移切换
|
||||
let app_enabled = match self.db.get_proxy_config_for_app(app_type).await {
|
||||
Ok(config) => config.enabled,
|
||||
Err(e) => {
|
||||
log::warn!("[Failover] 无法读取 {app_type} 配置: {e},跳过切换");
|
||||
return Ok(false);
|
||||
}
|
||||
};
|
||||
|
||||
if !app_enabled {
|
||||
log::info!("[Failover] {app_type} 未被代理接管(enabled=false),跳过切换");
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
log::info!("[Failover] 开始切换供应商: {app_type} -> {provider_name} ({provider_id})");
|
||||
|
||||
// 1. 更新数据库 is_current
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
//! 请求转发器
|
||||
//!
|
||||
//! 负责将请求转发到上游Provider,支持重试和故障转移
|
||||
//! 负责将请求转发到上游Provider,支持故障转移
|
||||
|
||||
use super::{
|
||||
body_filter::filter_private_params_with_whitelist,
|
||||
error::*,
|
||||
failover_switch::FailoverSwitchManager,
|
||||
provider_router::ProviderRouter,
|
||||
@@ -17,6 +18,71 @@ use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Headers 黑名单 - 不透传到上游的 Headers
|
||||
///
|
||||
/// 参考 Claude Code Hub 设计,过滤以下类别:
|
||||
/// 1. 认证类(会被覆盖)
|
||||
/// 2. 连接类(由 HTTP 客户端管理)
|
||||
/// 3. 代理转发类
|
||||
/// 4. CDN/云服务商特定头
|
||||
/// 5. 请求追踪类
|
||||
/// 6. 浏览器特定头(可能被上游检测)
|
||||
///
|
||||
/// 注意:客户端 IP 类(x-forwarded-for, x-real-ip)默认透传
|
||||
const HEADER_BLACKLIST: &[&str] = &[
|
||||
// 认证类(会被覆盖)
|
||||
"authorization",
|
||||
"x-api-key",
|
||||
// 连接类
|
||||
"host",
|
||||
"content-length",
|
||||
"connection",
|
||||
"transfer-encoding",
|
||||
// 编码类(会被覆盖为 identity)
|
||||
"accept-encoding",
|
||||
// 代理转发类(保留 x-forwarded-for 和 x-real-ip)
|
||||
"x-forwarded-host",
|
||||
"x-forwarded-port",
|
||||
"x-forwarded-proto",
|
||||
"forwarded",
|
||||
// CDN/云服务商特定头
|
||||
"cf-connecting-ip",
|
||||
"cf-ipcountry",
|
||||
"cf-ray",
|
||||
"cf-visitor",
|
||||
"true-client-ip",
|
||||
"fastly-client-ip",
|
||||
"x-azure-clientip",
|
||||
"x-azure-fdid",
|
||||
"x-azure-ref",
|
||||
"akamai-origin-hop",
|
||||
"x-akamai-config-log-detail",
|
||||
// 请求追踪类
|
||||
"x-request-id",
|
||||
"x-correlation-id",
|
||||
"x-trace-id",
|
||||
"x-amzn-trace-id",
|
||||
"x-b3-traceid",
|
||||
"x-b3-spanid",
|
||||
"x-b3-parentspanid",
|
||||
"x-b3-sampled",
|
||||
"traceparent",
|
||||
"tracestate",
|
||||
// 浏览器特定头(可能被上游检测为非 CLI 请求)
|
||||
"sec-fetch-mode",
|
||||
"sec-fetch-site",
|
||||
"sec-fetch-dest",
|
||||
"sec-ch-ua",
|
||||
"sec-ch-ua-mobile",
|
||||
"sec-ch-ua-platform",
|
||||
"accept-language",
|
||||
// anthropic-beta 单独处理,避免重复
|
||||
"anthropic-beta",
|
||||
// 客户端 IP 单独处理(默认透传)
|
||||
"x-forwarded-for",
|
||||
"x-real-ip",
|
||||
];
|
||||
|
||||
pub struct ForwardResult {
|
||||
pub response: Response,
|
||||
pub provider: Provider,
|
||||
@@ -31,8 +97,6 @@ pub struct RequestForwarder {
|
||||
client: Client,
|
||||
/// 共享的 ProviderRouter(持有熔断器状态)
|
||||
router: Arc<ProviderRouter>,
|
||||
/// 单个 Provider 内的最大重试次数
|
||||
max_retries: u8,
|
||||
status: Arc<RwLock<ProxyStatus>>,
|
||||
current_providers: Arc<RwLock<std::collections::HashMap<String, (String, String)>>>,
|
||||
/// 故障转移切换管理器
|
||||
@@ -48,7 +112,6 @@ impl RequestForwarder {
|
||||
pub fn new(
|
||||
router: Arc<ProviderRouter>,
|
||||
non_streaming_timeout: u64,
|
||||
max_retries: u8,
|
||||
status: Arc<RwLock<ProxyStatus>>,
|
||||
current_providers: Arc<RwLock<std::collections::HashMap<String, (String, String)>>>,
|
||||
failover_manager: Arc<FailoverSwitchManager>,
|
||||
@@ -77,7 +140,6 @@ impl RequestForwarder {
|
||||
Self {
|
||||
client,
|
||||
router,
|
||||
max_retries,
|
||||
status,
|
||||
current_providers,
|
||||
failover_manager,
|
||||
@@ -86,59 +148,6 @@ impl RequestForwarder {
|
||||
}
|
||||
}
|
||||
|
||||
/// 对单个 Provider 执行请求(带重试)
|
||||
///
|
||||
/// 在同一个 Provider 上最多重试 max_retries 次,使用指数退避
|
||||
async fn forward_with_provider_retry(
|
||||
&self,
|
||||
provider: &Provider,
|
||||
endpoint: &str,
|
||||
body: &Value,
|
||||
headers: &axum::http::HeaderMap,
|
||||
adapter: &dyn ProviderAdapter,
|
||||
) -> Result<Response, ProxyError> {
|
||||
let mut last_error = None;
|
||||
|
||||
for attempt in 0..=self.max_retries {
|
||||
if attempt > 0 {
|
||||
// 指数退避:100ms, 200ms, 400ms, ...
|
||||
let delay_ms = 100 * 2u64.pow(attempt as u32 - 1);
|
||||
log::info!(
|
||||
"[{}] 重试第 {}/{} 次(等待 {}ms)",
|
||||
adapter.name(),
|
||||
attempt,
|
||||
self.max_retries,
|
||||
delay_ms
|
||||
);
|
||||
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
|
||||
}
|
||||
|
||||
match self
|
||||
.forward(provider, endpoint, body, headers, adapter)
|
||||
.await
|
||||
{
|
||||
Ok(response) => return Ok(response),
|
||||
Err(e) => {
|
||||
// 只有“同一 Provider 内可重试”的错误才继续重试
|
||||
if !self.should_retry_same_provider(&e) {
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
log::debug!(
|
||||
"[{}] Provider {} 第 {} 次请求失败: {}",
|
||||
adapter.name(),
|
||||
provider.name,
|
||||
attempt + 1,
|
||||
e
|
||||
);
|
||||
last_error = Some(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_error.unwrap_or(ProxyError::MaxRetriesExceeded))
|
||||
}
|
||||
|
||||
/// 转发请求(带故障转移)
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -224,9 +233,9 @@ impl RequestForwarder {
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
// 转发请求(带单 Provider 内重试)
|
||||
// 转发请求(每个 Provider 只尝试一次,重试由客户端控制)
|
||||
match self
|
||||
.forward_with_provider_retry(provider, endpoint, &body, &headers, adapter.as_ref())
|
||||
.forward(provider, endpoint, &body, &headers, adapter.as_ref())
|
||||
.await
|
||||
{
|
||||
Ok(response) => {
|
||||
@@ -477,6 +486,28 @@ impl RequestForwarder {
|
||||
mapped_body
|
||||
};
|
||||
|
||||
// 过滤私有参数(以 `_` 开头的字段),防止内部信息泄露到上游
|
||||
// 默认使用空白名单,过滤所有 _ 前缀字段
|
||||
let filtered_body = filter_private_params_with_whitelist(request_body, &[]);
|
||||
|
||||
// ========== 请求体日志(截断显示) ==========
|
||||
let body_str = serde_json::to_string_pretty(&filtered_body)
|
||||
.unwrap_or_else(|_| filtered_body.to_string());
|
||||
let body_preview = if body_str.len() > 2000 {
|
||||
format!(
|
||||
"{}...\n[截断,总长度: {} 字符]",
|
||||
&body_str[..2000],
|
||||
body_str.len()
|
||||
)
|
||||
} else {
|
||||
body_str
|
||||
};
|
||||
log::info!(
|
||||
"[{}] ====== 最终请求体 ======\n{}",
|
||||
adapter.name(),
|
||||
body_preview
|
||||
);
|
||||
|
||||
log::info!(
|
||||
"[{}] 转发请求: {} -> {}",
|
||||
adapter.name(),
|
||||
@@ -487,28 +518,73 @@ impl RequestForwarder {
|
||||
// 构建请求
|
||||
let mut request = self.client.post(&url);
|
||||
|
||||
// 只透传必要的 Headers(白名单模式)
|
||||
let allowed_headers = [
|
||||
"accept",
|
||||
"user-agent",
|
||||
"x-request-id",
|
||||
"x-stainless-arch",
|
||||
"x-stainless-lang",
|
||||
"x-stainless-os",
|
||||
"x-stainless-package-version",
|
||||
"x-stainless-runtime",
|
||||
"x-stainless-runtime-version",
|
||||
];
|
||||
// ========== 详细 Headers 日志 ==========
|
||||
log::info!("[{}] ====== 客户端原始 Headers ======", adapter.name());
|
||||
for (key, value) in headers {
|
||||
log::info!(
|
||||
"[{}] {}: {:?}",
|
||||
adapter.name(),
|
||||
key.as_str(),
|
||||
value.to_str().unwrap_or("<binary>")
|
||||
);
|
||||
}
|
||||
|
||||
// 过滤黑名单 Headers,保护隐私并避免冲突
|
||||
let mut filtered_headers: Vec<String> = Vec::new();
|
||||
let mut passed_headers: Vec<(String, String)> = Vec::new();
|
||||
|
||||
for (key, value) in headers {
|
||||
let key_str = key.as_str().to_lowercase();
|
||||
if allowed_headers.contains(&key_str.as_str()) {
|
||||
request = request.header(key, value);
|
||||
if HEADER_BLACKLIST.contains(&key_str.as_str()) {
|
||||
filtered_headers.push(key_str);
|
||||
continue;
|
||||
}
|
||||
let value_str = value.to_str().unwrap_or("<binary>").to_string();
|
||||
passed_headers.push((key.as_str().to_string(), value_str.clone()));
|
||||
request = request.header(key, value);
|
||||
}
|
||||
|
||||
if !filtered_headers.is_empty() {
|
||||
log::info!(
|
||||
"[{}] ====== 被过滤的 Headers ({}) ======",
|
||||
adapter.name(),
|
||||
filtered_headers.len()
|
||||
);
|
||||
for h in &filtered_headers {
|
||||
log::info!("[{}] - {}", adapter.name(), h);
|
||||
}
|
||||
}
|
||||
|
||||
// 确保 Content-Type 是 json
|
||||
request = request.header("Content-Type", "application/json");
|
||||
// 处理 anthropic-beta Header(透传)
|
||||
// 参考 Claude Code Hub 的实现,直接透传客户端的 beta 标记
|
||||
if let Some(beta) = headers.get("anthropic-beta") {
|
||||
if let Ok(beta_str) = beta.to_str() {
|
||||
request = request.header("anthropic-beta", beta_str);
|
||||
passed_headers.push(("anthropic-beta".to_string(), beta_str.to_string()));
|
||||
log::info!("[{}] 透传 anthropic-beta: {}", adapter.name(), beta_str);
|
||||
}
|
||||
}
|
||||
|
||||
// 客户端 IP 透传(默认开启)
|
||||
if let Some(xff) = headers.get("x-forwarded-for") {
|
||||
if let Ok(xff_str) = xff.to_str() {
|
||||
request = request.header("x-forwarded-for", xff_str);
|
||||
passed_headers.push(("x-forwarded-for".to_string(), xff_str.to_string()));
|
||||
log::debug!("[{}] 透传 x-forwarded-for: {}", adapter.name(), xff_str);
|
||||
}
|
||||
}
|
||||
if let Some(real_ip) = headers.get("x-real-ip") {
|
||||
if let Ok(real_ip_str) = real_ip.to_str() {
|
||||
request = request.header("x-real-ip", real_ip_str);
|
||||
passed_headers.push(("x-real-ip".to_string(), real_ip_str.to_string()));
|
||||
log::debug!("[{}] 透传 x-real-ip: {}", adapter.name(), real_ip_str);
|
||||
}
|
||||
}
|
||||
|
||||
// 禁用压缩,避免 gzip 流式响应解析错误
|
||||
// 参考 CCH: undici 在连接提前关闭时会对不完整的 gzip 流抛出错误
|
||||
request = request.header("accept-encoding", "identity");
|
||||
passed_headers.push(("accept-encoding".to_string(), "identity".to_string()));
|
||||
|
||||
// 使用适配器添加认证头
|
||||
if let Some(auth) = adapter.extract_auth(provider) {
|
||||
@@ -519,6 +595,15 @@ impl RequestForwarder {
|
||||
auth.masked_key()
|
||||
);
|
||||
request = adapter.add_auth_headers(request, &auth);
|
||||
// 记录认证头(脱敏)
|
||||
passed_headers.push((
|
||||
"authorization".to_string(),
|
||||
format!("Bearer {}...", &auth.api_key[..8.min(auth.api_key.len())]),
|
||||
));
|
||||
passed_headers.push((
|
||||
"x-api-key".to_string(),
|
||||
format!("{}...", &auth.api_key[..8.min(auth.api_key.len())]),
|
||||
));
|
||||
} else {
|
||||
log::error!(
|
||||
"[{}] 未找到 API Key!Provider: {}",
|
||||
@@ -527,9 +612,34 @@ impl RequestForwarder {
|
||||
);
|
||||
}
|
||||
|
||||
// anthropic-version 透传:优先使用客户端的版本号
|
||||
// 参考 Claude Code Hub:透传客户端值而非固定版本
|
||||
if let Some(version) = headers.get("anthropic-version") {
|
||||
if let Ok(version_str) = version.to_str() {
|
||||
// 覆盖适配器设置的默认版本
|
||||
request = request.header("anthropic-version", version_str);
|
||||
passed_headers.push(("anthropic-version".to_string(), version_str.to_string()));
|
||||
log::info!(
|
||||
"[{}] 透传 anthropic-version: {}",
|
||||
adapter.name(),
|
||||
version_str
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 最终发送的 Headers 日志 ==========
|
||||
log::info!(
|
||||
"[{}] ====== 最终发送的 Headers ({}) ======",
|
||||
adapter.name(),
|
||||
passed_headers.len()
|
||||
);
|
||||
for (k, v) in &passed_headers {
|
||||
log::info!("[{}] {}: {}", adapter.name(), k, v);
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
log::info!("[{}] 发送请求到: {}", adapter.name(), url);
|
||||
let response = request.json(&request_body).send().await.map_err(|e| {
|
||||
let response = request.json(&filtered_body).send().await.map_err(|e| {
|
||||
log::error!("[{}] 请求失败: {}", adapter.name(), e);
|
||||
if e.is_timeout() {
|
||||
ProxyError::Timeout(format!("请求超时: {e}"))
|
||||
@@ -563,25 +673,6 @@ impl RequestForwarder {
|
||||
}
|
||||
}
|
||||
|
||||
/// 分类ProxyError
|
||||
///
|
||||
/// 决定哪些错误应该触发故障转移到下一个 Provider
|
||||
///
|
||||
/// 设计原则:既然用户配置了多个供应商,就应该让所有供应商都尝试一遍。
|
||||
/// 只有明确是客户端中断的情况才不重试。
|
||||
fn should_retry_same_provider(&self, error: &ProxyError) -> bool {
|
||||
match error {
|
||||
// 网络类错误:短暂抖动时同一 Provider 内重试有意义
|
||||
ProxyError::Timeout(_) => true,
|
||||
ProxyError::ForwardFailed(_) => true,
|
||||
// 上游 HTTP 错误:只对“可能瞬态”的状态码做同 Provider 重试(其余交给 failover)
|
||||
ProxyError::UpstreamError { status, .. } => {
|
||||
*status == 408 || *status == 429 || *status >= 500
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn categorize_proxy_error(&self, error: &ProxyError) -> ErrorCategory {
|
||||
match error {
|
||||
// 网络和上游错误:都应该尝试下一个供应商
|
||||
@@ -597,7 +688,6 @@ impl RequestForwarder {
|
||||
ProxyError::TransformError(_) => ErrorCategory::Retryable,
|
||||
ProxyError::AuthError(_) => ErrorCategory::Retryable,
|
||||
ProxyError::StreamIdleTimeout(_) => ErrorCategory::Retryable,
|
||||
ProxyError::MaxRetriesExceeded => ErrorCategory::Retryable,
|
||||
// 无可用供应商:所有供应商都试过了,无法重试
|
||||
ProxyError::NoAvailableProvider => ErrorCategory::NonRetryable,
|
||||
// 其他错误(数据库/内部错误等):不是换供应商能解决的问题
|
||||
|
||||
@@ -58,10 +58,10 @@ fn openai_model_extractor(events: &[Value], request_model: &str) -> String {
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Codex Responses API 流式响应模型提取(优先使用 usage.model)
|
||||
fn codex_model_extractor(events: &[Value], request_model: &str) -> String {
|
||||
/// Codex 智能流式响应模型提取(自动检测格式)
|
||||
fn codex_auto_model_extractor(events: &[Value], request_model: &str) -> String {
|
||||
// 首先尝试从解析的 usage 中获取模型
|
||||
if let Some(usage) = TokenUsage::from_codex_stream_events(events) {
|
||||
if let Some(usage) = TokenUsage::from_codex_stream_events_auto(events) {
|
||||
if let Some(model) = usage.model {
|
||||
return model;
|
||||
}
|
||||
@@ -76,6 +76,10 @@ fn codex_model_extractor(events: &[Value], request_model: &str) -> String {
|
||||
None
|
||||
}
|
||||
})
|
||||
.or_else(|| {
|
||||
// 再回退:从 OpenAI 格式事件中提取
|
||||
events.iter().find_map(|e| e.get("model")?.as_str())
|
||||
})
|
||||
.unwrap_or(request_model)
|
||||
.to_string()
|
||||
}
|
||||
@@ -111,11 +115,11 @@ pub const OPENAI_PARSER_CONFIG: UsageParserConfig = UsageParserConfig {
|
||||
app_type_str: "codex",
|
||||
};
|
||||
|
||||
/// Codex Responses API 解析配置(用于 /v1/responses)
|
||||
/// Codex 智能解析配置(自动检测 OpenAI 或 Codex 格式)
|
||||
pub const CODEX_PARSER_CONFIG: UsageParserConfig = UsageParserConfig {
|
||||
stream_parser: TokenUsage::from_codex_stream_events,
|
||||
response_parser: TokenUsage::from_codex_response,
|
||||
model_extractor: codex_model_extractor,
|
||||
stream_parser: TokenUsage::from_codex_stream_events_auto,
|
||||
response_parser: TokenUsage::from_codex_response_auto,
|
||||
model_extractor: codex_auto_model_extractor,
|
||||
app_type_str: "codex",
|
||||
};
|
||||
|
||||
|
||||
@@ -5,8 +5,10 @@
|
||||
use crate::app_config::AppType;
|
||||
use crate::provider::Provider;
|
||||
use crate::proxy::{
|
||||
forwarder::RequestForwarder, server::ProxyState, types::AppProxyConfig, ProxyError,
|
||||
extract_session_id, forwarder::RequestForwarder, server::ProxyState, types::AppProxyConfig,
|
||||
ProxyError,
|
||||
};
|
||||
use axum::http::HeaderMap;
|
||||
use std::time::Instant;
|
||||
|
||||
/// 流式超时配置
|
||||
@@ -26,6 +28,7 @@ pub struct StreamingTimeoutConfig {
|
||||
/// - 选中的 Provider 列表(用于故障转移)
|
||||
/// - 请求模型名称
|
||||
/// - 日志标签
|
||||
/// - Session ID(用于日志关联)
|
||||
pub struct RequestContext {
|
||||
/// 请求开始时间
|
||||
pub start_time: Instant,
|
||||
@@ -35,7 +38,7 @@ pub struct RequestContext {
|
||||
pub provider: Provider,
|
||||
/// 完整的 Provider 列表(用于故障转移)
|
||||
providers: Vec<Provider>,
|
||||
/// 请求开始时的“当前供应商”(用于判断是否需要同步 UI/托盘)
|
||||
/// 请求开始时的"当前供应商"(用于判断是否需要同步 UI/托盘)
|
||||
///
|
||||
/// 这里使用本地 settings 的设备级 current provider。
|
||||
/// 代理模式下如果实际使用的 provider 与此不一致,会触发切换以确保 UI 始终准确。
|
||||
@@ -49,6 +52,8 @@ pub struct RequestContext {
|
||||
/// 应用类型(预留,目前通过 app_type_str 使用)
|
||||
#[allow(dead_code)]
|
||||
pub app_type: AppType,
|
||||
/// Session ID(从客户端请求提取或新生成)
|
||||
pub session_id: String,
|
||||
}
|
||||
|
||||
impl RequestContext {
|
||||
@@ -57,6 +62,7 @@ impl RequestContext {
|
||||
/// # Arguments
|
||||
/// * `state` - 代理服务器状态
|
||||
/// * `body` - 请求体 JSON
|
||||
/// * `headers` - 请求头(用于提取 Session ID)
|
||||
/// * `app_type` - 应用类型
|
||||
/// * `tag` - 日志标签
|
||||
/// * `app_type_str` - 应用类型字符串
|
||||
@@ -66,6 +72,7 @@ impl RequestContext {
|
||||
pub async fn new(
|
||||
state: &ProxyState,
|
||||
body: &serde_json::Value,
|
||||
headers: &HeaderMap,
|
||||
app_type: AppType,
|
||||
tag: &'static str,
|
||||
app_type_str: &'static str,
|
||||
@@ -89,13 +96,31 @@ impl RequestContext {
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
|
||||
// 提取 Session ID
|
||||
let session_result = extract_session_id(headers, body, app_type_str);
|
||||
let session_id = session_result.session_id.clone();
|
||||
|
||||
log::debug!(
|
||||
"[{}] Session ID: {} (from {:?}, client_provided: {})",
|
||||
tag,
|
||||
session_id,
|
||||
session_result.source,
|
||||
session_result.client_provided
|
||||
);
|
||||
|
||||
// 使用共享的 ProviderRouter 选择 Provider(熔断器状态跨请求保持)
|
||||
// 注意:只在这里调用一次,结果传递给 forwarder,避免重复消耗 HalfOpen 名额
|
||||
let providers = state
|
||||
.provider_router
|
||||
.select_providers(app_type_str)
|
||||
.await
|
||||
.map_err(|e| ProxyError::DatabaseError(e.to_string()))?;
|
||||
.map_err(|e| match e {
|
||||
crate::error::AppError::AllProvidersCircuitOpen => {
|
||||
ProxyError::AllProvidersCircuitOpen
|
||||
}
|
||||
crate::error::AppError::NoProvidersConfigured => ProxyError::NoProvidersConfigured,
|
||||
_ => ProxyError::DatabaseError(e.to_string()),
|
||||
})?;
|
||||
|
||||
let provider = providers
|
||||
.first()
|
||||
@@ -103,11 +128,12 @@ impl RequestContext {
|
||||
.ok_or(ProxyError::NoAvailableProvider)?;
|
||||
|
||||
log::info!(
|
||||
"[{}] Provider: {}, model: {}, failover chain: {} providers",
|
||||
"[{}] Provider: {}, model: {}, failover chain: {} providers, session: {}",
|
||||
tag,
|
||||
provider.name,
|
||||
request_model,
|
||||
providers.len()
|
||||
providers.len(),
|
||||
session_id
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
@@ -120,6 +146,7 @@ impl RequestContext {
|
||||
tag,
|
||||
app_type_str,
|
||||
app_type,
|
||||
session_id,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -148,18 +175,38 @@ impl RequestContext {
|
||||
/// 创建 RequestForwarder
|
||||
///
|
||||
/// 使用共享的 ProviderRouter,确保熔断器状态跨请求保持
|
||||
///
|
||||
/// 配置生效规则:
|
||||
/// - 故障转移开启:超时配置正常生效(0 表示禁用超时)
|
||||
/// - 故障转移关闭:超时配置不生效(全部传入 0)
|
||||
pub fn create_forwarder(&self, state: &ProxyState) -> RequestForwarder {
|
||||
let (non_streaming_timeout, first_byte_timeout, idle_timeout) =
|
||||
if self.app_config.auto_failover_enabled {
|
||||
// 故障转移开启:使用配置的值(0 = 禁用超时)
|
||||
(
|
||||
self.app_config.non_streaming_timeout as u64,
|
||||
self.app_config.streaming_first_byte_timeout as u64,
|
||||
self.app_config.streaming_idle_timeout as u64,
|
||||
)
|
||||
} else {
|
||||
// 故障转移关闭:不启用超时配置
|
||||
log::info!(
|
||||
"[{}] Failover disabled, timeout configs are bypassed",
|
||||
self.tag
|
||||
);
|
||||
(0, 0, 0)
|
||||
};
|
||||
|
||||
RequestForwarder::new(
|
||||
state.provider_router.clone(),
|
||||
self.app_config.non_streaming_timeout as u64,
|
||||
self.app_config.max_retries as u8,
|
||||
non_streaming_timeout,
|
||||
state.status.clone(),
|
||||
state.current_providers.clone(),
|
||||
state.failover_manager.clone(),
|
||||
state.app_handle.clone(),
|
||||
self.current_provider_id.clone(),
|
||||
self.app_config.streaming_first_byte_timeout as u64,
|
||||
self.app_config.streaming_idle_timeout as u64,
|
||||
first_byte_timeout,
|
||||
idle_timeout,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -177,11 +224,24 @@ impl RequestContext {
|
||||
}
|
||||
|
||||
/// 获取流式超时配置
|
||||
///
|
||||
/// 配置生效规则:
|
||||
/// - 故障转移开启:返回配置的值(0 表示禁用超时检查)
|
||||
/// - 故障转移关闭:返回 0(禁用超时检查)
|
||||
#[inline]
|
||||
pub fn streaming_timeout_config(&self) -> StreamingTimeoutConfig {
|
||||
StreamingTimeoutConfig {
|
||||
first_byte_timeout: self.app_config.streaming_first_byte_timeout as u64,
|
||||
idle_timeout: self.app_config.streaming_idle_timeout as u64,
|
||||
if self.app_config.auto_failover_enabled {
|
||||
// 故障转移开启:使用配置的值(0 = 禁用超时)
|
||||
StreamingTimeoutConfig {
|
||||
first_byte_timeout: self.app_config.streaming_first_byte_timeout as u64,
|
||||
idle_timeout: self.app_config.streaming_idle_timeout as u64,
|
||||
}
|
||||
} else {
|
||||
// 故障转移关闭:禁用流式超时检查
|
||||
StreamingTimeoutConfig {
|
||||
first_byte_timeout: 0,
|
||||
idle_timeout: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,7 +61,8 @@ pub async fn handle_messages(
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(body): Json<Value>,
|
||||
) -> Result<axum::response::Response, ProxyError> {
|
||||
let mut ctx = RequestContext::new(&state, &body, AppType::Claude, "Claude", "claude").await?;
|
||||
let mut ctx =
|
||||
RequestContext::new(&state, &body, &headers, AppType::Claude, "Claude", "claude").await?;
|
||||
|
||||
let is_stream = body
|
||||
.get("stream")
|
||||
@@ -305,7 +306,8 @@ pub async fn handle_chat_completions(
|
||||
) -> Result<axum::response::Response, ProxyError> {
|
||||
log::info!("[Codex] ====== /v1/chat/completions 请求开始 ======");
|
||||
|
||||
let mut ctx = RequestContext::new(&state, &body, AppType::Codex, "Codex", "codex").await?;
|
||||
let mut ctx =
|
||||
RequestContext::new(&state, &body, &headers, AppType::Codex, "Codex", "codex").await?;
|
||||
|
||||
let is_stream = body
|
||||
.get("stream")
|
||||
@@ -353,7 +355,8 @@ pub async fn handle_responses(
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(body): Json<Value>,
|
||||
) -> Result<axum::response::Response, ProxyError> {
|
||||
let mut ctx = RequestContext::new(&state, &body, AppType::Codex, "Codex", "codex").await?;
|
||||
let mut ctx =
|
||||
RequestContext::new(&state, &body, &headers, AppType::Codex, "Codex", "codex").await?;
|
||||
|
||||
let is_stream = body
|
||||
.get("stream")
|
||||
@@ -401,7 +404,7 @@ pub async fn handle_gemini(
|
||||
Json(body): Json<Value>,
|
||||
) -> Result<axum::response::Response, ProxyError> {
|
||||
// Gemini 的模型名称在 URI 中
|
||||
let mut ctx = RequestContext::new(&state, &body, AppType::Gemini, "Gemini", "gemini")
|
||||
let mut ctx = RequestContext::new(&state, &body, &headers, AppType::Gemini, "Gemini", "gemini")
|
||||
.await?
|
||||
.with_model_from_uri(&uri);
|
||||
|
||||
@@ -465,7 +468,7 @@ fn log_forward_error(
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
if let Err(e) = logger.log_error_with_context(
|
||||
request_id.clone(),
|
||||
request_id,
|
||||
ctx.provider.id.clone(),
|
||||
ctx.app_type_str.to_string(),
|
||||
ctx.request_model.clone(),
|
||||
@@ -473,7 +476,7 @@ fn log_forward_error(
|
||||
error_message,
|
||||
ctx.latency_ms(),
|
||||
is_streaming,
|
||||
Some(request_id),
|
||||
Some(ctx.session_id.clone()),
|
||||
None,
|
||||
) {
|
||||
log::warn!("记录失败请求日志失败: {e}");
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
//!
|
||||
//! 提供本地HTTP代理服务,支持多Provider故障转移和请求透传
|
||||
|
||||
pub mod body_filter;
|
||||
pub mod circuit_breaker;
|
||||
pub mod error;
|
||||
pub mod error_mapper;
|
||||
@@ -33,7 +34,9 @@ pub use provider_router::ProviderRouter;
|
||||
#[allow(unused_imports)]
|
||||
pub use response_handler::{NonStreamHandler, ResponseType, StreamHandler};
|
||||
#[allow(unused_imports)]
|
||||
pub use session::{ClientFormat, ProxySession};
|
||||
pub use session::{
|
||||
extract_session_id, ClientFormat, ProxySession, SessionIdResult, SessionIdSource,
|
||||
};
|
||||
#[allow(unused_imports)]
|
||||
pub use types::{ProxyConfig, ProxyServerInfo, ProxyStatus};
|
||||
|
||||
|
||||
@@ -34,6 +34,8 @@ impl ProviderRouter {
|
||||
/// - 故障转移开启时:完全按照故障转移队列顺序返回,忽略当前供应商设置
|
||||
pub async fn select_providers(&self, app_type: &str) -> Result<Vec<Provider>, AppError> {
|
||||
let mut result = Vec::new();
|
||||
let mut total_providers = 0usize;
|
||||
let mut circuit_open_count = 0usize;
|
||||
|
||||
// 检查该应用的自动故障转移开关是否开启(从 proxy_config 表读取)
|
||||
let auto_failover_enabled = match self.db.get_proxy_config_for_app(app_type).await {
|
||||
@@ -53,18 +55,26 @@ impl ProviderRouter {
|
||||
if auto_failover_enabled {
|
||||
// 故障转移开启:使用 in_failover_queue 标记的供应商,按 sort_index 排序
|
||||
let failover_providers = self.db.get_failover_providers(app_type)?;
|
||||
total_providers = failover_providers.len();
|
||||
log::debug!("[{app_type}] Found {total_providers} failover queue provider(s)");
|
||||
log::info!(
|
||||
"[{}] Failover enabled, using queue order ({} items)",
|
||||
app_type,
|
||||
failover_providers.len()
|
||||
"[{app_type}] Failover enabled, using queue order ({total_providers} items)"
|
||||
);
|
||||
|
||||
for provider in failover_providers {
|
||||
// 检查熔断器状态
|
||||
let circuit_key = format!("{}:{}", app_type, provider.id);
|
||||
let breaker = self.get_or_create_circuit_breaker(&circuit_key).await;
|
||||
let state = breaker.get_state().await;
|
||||
|
||||
if breaker.is_available().await {
|
||||
log::debug!(
|
||||
"[{}] Queue provider available: {} ({}) (state: {:?})",
|
||||
app_type,
|
||||
provider.name,
|
||||
provider.id,
|
||||
state
|
||||
);
|
||||
log::info!(
|
||||
"[{}] Queue provider available: {} ({}) at sort_index {:?}",
|
||||
app_type,
|
||||
@@ -74,10 +84,12 @@ impl ProviderRouter {
|
||||
);
|
||||
result.push(provider);
|
||||
} else {
|
||||
circuit_open_count += 1;
|
||||
log::debug!(
|
||||
"[{}] Queue provider {} circuit breaker open, skipping",
|
||||
"[{}] Queue provider {} circuit breaker open (state: {:?}), skipping",
|
||||
app_type,
|
||||
provider.name
|
||||
provider.name,
|
||||
state
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -94,15 +106,27 @@ impl ProviderRouter {
|
||||
current.name,
|
||||
current.id
|
||||
);
|
||||
total_providers = 1;
|
||||
result.push(current);
|
||||
} else {
|
||||
log::debug!(
|
||||
"[{app_type}] Current provider id {current_id} not found in database"
|
||||
);
|
||||
}
|
||||
} else {
|
||||
log::debug!("[{app_type}] No current provider configured");
|
||||
}
|
||||
}
|
||||
|
||||
if result.is_empty() {
|
||||
return Err(AppError::Config(format!(
|
||||
"No available provider for {app_type} (all circuit breakers open or no providers configured)"
|
||||
)));
|
||||
// 区分两种情况:全部熔断 vs 未配置供应商
|
||||
if total_providers > 0 && circuit_open_count == total_providers {
|
||||
log::warn!("[{app_type}] 所有 {total_providers} 个供应商均已熔断,无可用渠道");
|
||||
return Err(AppError::AllProvidersCircuitOpen);
|
||||
} else {
|
||||
log::warn!("[{app_type}] 未配置供应商或故障转移队列为空");
|
||||
return Err(AppError::NoProvidersConfigured);
|
||||
}
|
||||
}
|
||||
|
||||
log::info!(
|
||||
|
||||
@@ -112,6 +112,19 @@ pub async fn handle_non_streaming(
|
||||
|
||||
spawn_log_usage(state, ctx, usage, &model, status.as_u16(), false);
|
||||
} else {
|
||||
let model = json_value
|
||||
.get("model")
|
||||
.and_then(|m| m.as_str())
|
||||
.unwrap_or(&ctx.request_model)
|
||||
.to_string();
|
||||
spawn_log_usage(
|
||||
state,
|
||||
ctx,
|
||||
TokenUsage::default(),
|
||||
&model,
|
||||
status.as_u16(),
|
||||
false,
|
||||
);
|
||||
log::debug!(
|
||||
"[{}] 未能解析 usage 信息,跳过记录",
|
||||
parser_config.app_type_str
|
||||
@@ -123,6 +136,14 @@ pub async fn handle_non_streaming(
|
||||
ctx.tag,
|
||||
body_bytes.len()
|
||||
);
|
||||
spawn_log_usage(
|
||||
state,
|
||||
ctx,
|
||||
TokenUsage::default(),
|
||||
&ctx.request_model,
|
||||
status.as_u16(),
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
log::info!("[{}] ====== 请求结束 ======", ctx.tag);
|
||||
@@ -243,6 +264,7 @@ fn create_usage_collector(
|
||||
let start_time = ctx.start_time;
|
||||
let stream_parser = parser_config.stream_parser;
|
||||
let model_extractor = parser_config.model_extractor;
|
||||
let session_id = ctx.session_id.clone();
|
||||
|
||||
SseUsageCollector::new(start_time, move |events, first_token_ms| {
|
||||
if let Some(usage) = stream_parser(&events) {
|
||||
@@ -251,6 +273,7 @@ fn create_usage_collector(
|
||||
|
||||
let state = state.clone();
|
||||
let provider_id = provider_id.clone();
|
||||
let session_id = session_id.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
log_usage_internal(
|
||||
@@ -263,10 +286,32 @@ fn create_usage_collector(
|
||||
first_token_ms,
|
||||
true, // is_streaming
|
||||
status_code,
|
||||
Some(session_id),
|
||||
)
|
||||
.await;
|
||||
});
|
||||
} else {
|
||||
let model = model_extractor(&events, &request_model);
|
||||
let latency_ms = start_time.elapsed().as_millis() as u64;
|
||||
let state = state.clone();
|
||||
let provider_id = provider_id.clone();
|
||||
let session_id = session_id.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
log_usage_internal(
|
||||
&state,
|
||||
&provider_id,
|
||||
app_type_str,
|
||||
&model,
|
||||
TokenUsage::default(),
|
||||
latency_ms,
|
||||
first_token_ms,
|
||||
true, // is_streaming
|
||||
status_code,
|
||||
Some(session_id),
|
||||
)
|
||||
.await;
|
||||
});
|
||||
log::debug!("[{tag}] 流式响应缺少 usage 统计,跳过消费记录");
|
||||
}
|
||||
})
|
||||
@@ -286,6 +331,7 @@ fn spawn_log_usage(
|
||||
let app_type_str = ctx.app_type_str.to_string();
|
||||
let model = model.to_string();
|
||||
let latency_ms = ctx.latency_ms();
|
||||
let session_id = ctx.session_id.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
log_usage_internal(
|
||||
@@ -298,6 +344,7 @@ fn spawn_log_usage(
|
||||
None,
|
||||
is_streaming,
|
||||
status_code,
|
||||
Some(session_id),
|
||||
)
|
||||
.await;
|
||||
});
|
||||
@@ -315,6 +362,7 @@ async fn log_usage_internal(
|
||||
first_token_ms: Option<u64>,
|
||||
is_streaming: bool,
|
||||
status_code: u16,
|
||||
session_id: Option<String>,
|
||||
) {
|
||||
use super::usage::logger::UsageLogger;
|
||||
|
||||
@@ -338,6 +386,15 @@ async fn log_usage_internal(
|
||||
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
log::debug!(
|
||||
"[{app_type}] 记录请求日志: id={request_id}, provider={provider_id}, model={model}, streaming={is_streaming}, status={status_code}, latency_ms={latency_ms}, first_token_ms={first_token_ms:?}, session={}, input={}, output={}, cache_read={}, cache_creation={}",
|
||||
session_id.as_deref().unwrap_or("none"),
|
||||
usage.input_tokens,
|
||||
usage.output_tokens,
|
||||
usage.cache_read_tokens,
|
||||
usage.cache_creation_tokens
|
||||
);
|
||||
|
||||
if let Err(e) = logger.log_with_calculation(
|
||||
request_id,
|
||||
provider_id.to_string(),
|
||||
@@ -348,7 +405,7 @@ async fn log_usage_internal(
|
||||
latency_ms,
|
||||
first_token_ms,
|
||||
status_code,
|
||||
None,
|
||||
session_id,
|
||||
None, // provider_type
|
||||
is_streaming,
|
||||
) {
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
//! Proxy Session - 请求会话管理
|
||||
//!
|
||||
//! 为每个代理请求创建会话上下文,在整个请求生命周期中跟踪状态和元数据。
|
||||
//!
|
||||
//! ## Session ID 提取
|
||||
//!
|
||||
//! 支持从客户端请求中提取 Session ID,用于关联同一对话的多个请求:
|
||||
//! - Claude: 从 `metadata.user_id` (格式: `user_xxx_session_yyy`) 或 `metadata.session_id` 提取
|
||||
//! - Codex: 从 `previous_response_id` 或 headers 中的 `session_id` 提取
|
||||
//! - 其他: 生成新的 UUID
|
||||
|
||||
use axum::http::HeaderMap;
|
||||
use std::time::Instant;
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -176,6 +184,179 @@ impl ProxySession {
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Session ID 提取器
|
||||
// ============================================================================
|
||||
|
||||
/// Session ID 来源
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum SessionIdSource {
|
||||
/// 从 metadata.user_id 提取 (Claude)
|
||||
MetadataUserId,
|
||||
/// 从 metadata.session_id 提取
|
||||
MetadataSessionId,
|
||||
/// 从 headers 提取 (Codex)
|
||||
Header,
|
||||
/// 从 previous_response_id 提取 (Codex)
|
||||
PreviousResponseId,
|
||||
/// 新生成
|
||||
Generated,
|
||||
}
|
||||
|
||||
/// Session ID 提取结果
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SessionIdResult {
|
||||
/// 提取或生成的 Session ID
|
||||
pub session_id: String,
|
||||
/// Session ID 来源
|
||||
pub source: SessionIdSource,
|
||||
/// 是否为客户端提供的 ID(非新生成)
|
||||
pub client_provided: bool,
|
||||
}
|
||||
|
||||
/// 从请求中提取或生成 Session ID
|
||||
///
|
||||
/// 轻量化实现,仅提取 session_id 用于日志记录,不做复杂的 Session 管理。
|
||||
///
|
||||
/// ## 提取优先级
|
||||
///
|
||||
/// ### Claude 请求
|
||||
/// 1. `metadata.user_id` (格式: `user_xxx_session_yyy`) → 提取 `yyy` 部分
|
||||
/// 2. `metadata.session_id` → 直接使用
|
||||
/// 3. 生成新 UUID
|
||||
///
|
||||
/// ### Codex 请求
|
||||
/// 1. Headers: `session_id` 或 `x-session-id`
|
||||
/// 2. `metadata.session_id`
|
||||
/// 3. `previous_response_id` (对话延续)
|
||||
/// 4. 生成新 UUID
|
||||
///
|
||||
/// ## 示例
|
||||
///
|
||||
/// ```ignore
|
||||
/// let result = extract_session_id(&headers, &body, "claude");
|
||||
/// println!("Session ID: {} (from {:?})", result.session_id, result.source);
|
||||
/// ```
|
||||
pub fn extract_session_id(
|
||||
headers: &HeaderMap,
|
||||
body: &serde_json::Value,
|
||||
client_format: &str,
|
||||
) -> SessionIdResult {
|
||||
// Codex 请求特殊处理
|
||||
if client_format == "codex" || client_format == "openai" {
|
||||
if let Some(result) = extract_codex_session(headers, body) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Claude 请求:从 metadata 提取
|
||||
if let Some(result) = extract_from_metadata(body) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// 兜底:生成新 Session ID
|
||||
generate_new_session_id()
|
||||
}
|
||||
|
||||
/// 提取 Codex Session ID
|
||||
fn extract_codex_session(headers: &HeaderMap, body: &serde_json::Value) -> Option<SessionIdResult> {
|
||||
// 1. 从 headers 提取
|
||||
for header_name in &["session_id", "x-session-id"] {
|
||||
if let Some(value) = headers.get(*header_name) {
|
||||
if let Ok(session_id) = value.to_str() {
|
||||
// Codex Session ID 通常较长(UUID 格式)
|
||||
if session_id.len() > 20 {
|
||||
return Some(SessionIdResult {
|
||||
session_id: format!("codex_{session_id}"),
|
||||
source: SessionIdSource::Header,
|
||||
client_provided: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 从 body.metadata.session_id 提取
|
||||
if let Some(session_id) = body
|
||||
.get("metadata")
|
||||
.and_then(|m| m.get("session_id"))
|
||||
.and_then(|v| v.as_str())
|
||||
{
|
||||
if session_id.len() > 10 {
|
||||
return Some(SessionIdResult {
|
||||
session_id: format!("codex_{session_id}"),
|
||||
source: SessionIdSource::MetadataSessionId,
|
||||
client_provided: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 从 previous_response_id 提取(对话延续)
|
||||
if let Some(prev_id) = body.get("previous_response_id").and_then(|v| v.as_str()) {
|
||||
if prev_id.len() > 10 {
|
||||
return Some(SessionIdResult {
|
||||
session_id: format!("codex_{prev_id}"),
|
||||
source: SessionIdSource::PreviousResponseId,
|
||||
client_provided: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// 从 metadata 提取 Session ID (Claude)
|
||||
fn extract_from_metadata(body: &serde_json::Value) -> Option<SessionIdResult> {
|
||||
let metadata = body.get("metadata")?;
|
||||
|
||||
// 1. 从 metadata.user_id 提取(格式: user_xxx_session_yyy)
|
||||
if let Some(user_id) = metadata.get("user_id").and_then(|v| v.as_str()) {
|
||||
if let Some(session_id) = parse_session_from_user_id(user_id) {
|
||||
return Some(SessionIdResult {
|
||||
session_id,
|
||||
source: SessionIdSource::MetadataUserId,
|
||||
client_provided: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 直接从 metadata.session_id 提取
|
||||
if let Some(session_id) = metadata.get("session_id").and_then(|v| v.as_str()) {
|
||||
if !session_id.is_empty() {
|
||||
return Some(SessionIdResult {
|
||||
session_id: session_id.to_string(),
|
||||
source: SessionIdSource::MetadataSessionId,
|
||||
client_provided: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// 从 user_id 解析 session_id
|
||||
///
|
||||
/// 格式: `user_identifier_session_actual_session_id`
|
||||
fn parse_session_from_user_id(user_id: &str) -> Option<String> {
|
||||
// 查找 "_session_" 分隔符
|
||||
if let Some(pos) = user_id.find("_session_") {
|
||||
let session_id = &user_id[pos + 9..]; // "_session_" 长度为 9
|
||||
if !session_id.is_empty() {
|
||||
return Some(session_id.to_string());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// 生成新的 Session ID
|
||||
fn generate_new_session_id() -> SessionIdResult {
|
||||
SessionIdResult {
|
||||
session_id: Uuid::new_v4().to_string(),
|
||||
source: SessionIdSource::Generated,
|
||||
client_provided: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -295,4 +476,92 @@ mod tests {
|
||||
assert_eq!(ClientFormat::GeminiCli.as_str(), "gemini_cli");
|
||||
assert_eq!(ClientFormat::Unknown.as_str(), "unknown");
|
||||
}
|
||||
|
||||
// ========== Session ID 提取测试 ==========
|
||||
|
||||
#[test]
|
||||
fn test_extract_session_from_claude_metadata_user_id() {
|
||||
let headers = HeaderMap::new();
|
||||
let body = json!({
|
||||
"model": "claude-3-5-sonnet",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"metadata": {
|
||||
"user_id": "user_john_doe_session_abc123def456"
|
||||
}
|
||||
});
|
||||
|
||||
let result = extract_session_id(&headers, &body, "claude");
|
||||
|
||||
assert_eq!(result.session_id, "abc123def456");
|
||||
assert_eq!(result.source, SessionIdSource::MetadataUserId);
|
||||
assert!(result.client_provided);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_session_from_claude_metadata_session_id() {
|
||||
let headers = HeaderMap::new();
|
||||
let body = json!({
|
||||
"model": "claude-3-5-sonnet",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"metadata": {
|
||||
"session_id": "my-session-123"
|
||||
}
|
||||
});
|
||||
|
||||
let result = extract_session_id(&headers, &body, "claude");
|
||||
|
||||
assert_eq!(result.session_id, "my-session-123");
|
||||
assert_eq!(result.source, SessionIdSource::MetadataSessionId);
|
||||
assert!(result.client_provided);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_session_from_codex_previous_response_id() {
|
||||
let headers = HeaderMap::new();
|
||||
let body = json!({
|
||||
"input": "Write a function",
|
||||
"previous_response_id": "resp_abc123def456789"
|
||||
});
|
||||
|
||||
let result = extract_session_id(&headers, &body, "codex");
|
||||
|
||||
assert_eq!(result.session_id, "codex_resp_abc123def456789");
|
||||
assert_eq!(result.source, SessionIdSource::PreviousResponseId);
|
||||
assert!(result.client_provided);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_session_generates_new_when_not_found() {
|
||||
let headers = HeaderMap::new();
|
||||
let body = json!({
|
||||
"model": "claude-3-5-sonnet",
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
});
|
||||
|
||||
let result = extract_session_id(&headers, &body, "claude");
|
||||
|
||||
assert!(!result.session_id.is_empty());
|
||||
assert_eq!(result.source, SessionIdSource::Generated);
|
||||
assert!(!result.client_provided);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_session_from_user_id() {
|
||||
assert_eq!(
|
||||
parse_session_from_user_id("user_john_session_abc123"),
|
||||
Some("abc123".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
parse_session_from_user_id("my_app_session_xyz789"),
|
||||
Some("xyz789".to_string())
|
||||
);
|
||||
// 注意: "_session_" 是分隔符,所以下面的字符串会匹配
|
||||
assert_eq!(
|
||||
parse_session_from_user_id("no_session_marker"),
|
||||
Some("marker".to_string())
|
||||
);
|
||||
// 没有 "_session_" 分隔符的情况
|
||||
assert_eq!(parse_session_from_user_id("user_john_abc123"), None);
|
||||
assert_eq!(parse_session_from_user_id("_session_"), None);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,6 +35,11 @@ impl CostCalculator {
|
||||
/// - `usage`: Token 使用量
|
||||
/// - `pricing`: 模型定价
|
||||
/// - `cost_multiplier`: 成本倍数 (provider 自定义)
|
||||
///
|
||||
/// # 计算逻辑
|
||||
/// - input_cost: (input_tokens - cache_read_tokens) × 输入价格
|
||||
/// - cache_read_cost: cache_read_tokens × 缓存读取价格
|
||||
/// - 这样避免缓存部分被重复计费
|
||||
pub fn calculate(
|
||||
usage: &TokenUsage,
|
||||
pricing: &ModelPricing,
|
||||
@@ -42,7 +47,10 @@ impl CostCalculator {
|
||||
) -> CostBreakdown {
|
||||
let million = Decimal::from(1_000_000);
|
||||
|
||||
let input_cost = Decimal::from(usage.input_tokens) * pricing.input_cost_per_million
|
||||
// 计算实际需要按输入价格计费的 token 数(减去缓存命中部分)
|
||||
let billable_input_tokens = usage.input_tokens.saturating_sub(usage.cache_read_tokens);
|
||||
|
||||
let input_cost = Decimal::from(billable_input_tokens) * pricing.input_cost_per_million
|
||||
/ million
|
||||
* cost_multiplier;
|
||||
let output_cost = Decimal::from(usage.output_tokens) * pricing.output_cost_per_million
|
||||
@@ -113,8 +121,8 @@ mod tests {
|
||||
|
||||
let cost = CostCalculator::calculate(&usage, &pricing, multiplier);
|
||||
|
||||
// input: 1000 * 3.0 / 1M = 0.003
|
||||
assert_eq!(cost.input_cost, Decimal::from_str("0.003").unwrap());
|
||||
// input: (1000 - 200) * 3.0 / 1M = 0.0024 (只计算非缓存部分)
|
||||
assert_eq!(cost.input_cost, Decimal::from_str("0.0024").unwrap());
|
||||
// output: 500 * 15.0 / 1M = 0.0075
|
||||
assert_eq!(cost.output_cost, Decimal::from_str("0.0075").unwrap());
|
||||
// cache_read: 200 * 0.3 / 1M = 0.00006
|
||||
@@ -124,8 +132,8 @@ mod tests {
|
||||
cost.cache_creation_cost,
|
||||
Decimal::from_str("0.000375").unwrap()
|
||||
);
|
||||
// total: 0.003 + 0.0075 + 0.00006 + 0.000375 = 0.010935
|
||||
assert_eq!(cost.total_cost, Decimal::from_str("0.010935").unwrap());
|
||||
// total: 0.0024 + 0.0075 + 0.00006 + 0.000375 = 0.010335
|
||||
assert_eq!(cost.total_cost, Decimal::from_str("0.010335").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -163,13 +163,21 @@ impl TokenUsage {
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let cached_tokens = usage
|
||||
.get("cache_read_input_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.or_else(|| {
|
||||
usage
|
||||
.get("input_tokens_details")
|
||||
.and_then(|d| d.get("cached_tokens"))
|
||||
.and_then(|v| v.as_u64())
|
||||
})
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
Some(Self {
|
||||
input_tokens: input_tokens? as u32,
|
||||
output_tokens: output_tokens? as u32,
|
||||
cache_read_tokens: usage
|
||||
.get("cache_read_input_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0) as u32,
|
||||
cache_read_tokens: cached_tokens,
|
||||
cache_creation_tokens: usage
|
||||
.get("cache_creation_input_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
@@ -188,16 +196,27 @@ impl TokenUsage {
|
||||
let input_tokens = usage.get("input_tokens")?.as_u64()? as u32;
|
||||
let output_tokens = usage.get("output_tokens")?.as_u64()? as u32;
|
||||
|
||||
// 获取 cached_tokens (可能在 input_tokens_details 中)
|
||||
// 获取 cached_tokens (可能在 cache_read_input_tokens 或 input_tokens_details 中)
|
||||
let cached_tokens = usage
|
||||
.get("input_tokens_details")
|
||||
.and_then(|d| d.get("cached_tokens"))
|
||||
.get("cache_read_input_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.or_else(|| {
|
||||
usage
|
||||
.get("input_tokens_details")
|
||||
.and_then(|d| d.get("cached_tokens"))
|
||||
.and_then(|v| v.as_u64())
|
||||
})
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
// 调整 input_tokens: 减去 cached_tokens
|
||||
let adjusted_input = input_tokens.saturating_sub(cached_tokens);
|
||||
|
||||
// 提取响应中的模型名称
|
||||
let model = body
|
||||
.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
Some(Self {
|
||||
input_tokens: adjusted_input,
|
||||
output_tokens,
|
||||
@@ -206,7 +225,7 @@ impl TokenUsage {
|
||||
.get("cache_creation_input_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0) as u32,
|
||||
model: None,
|
||||
model,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -220,7 +239,7 @@ impl TokenUsage {
|
||||
if event_type == "response.completed" {
|
||||
if let Some(response) = event.get("response") {
|
||||
log::debug!("[Codex] 找到 response.completed 事件,解析 usage");
|
||||
return Self::from_codex_response(response);
|
||||
return Self::from_codex_response_adjusted(response);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -229,6 +248,51 @@ impl TokenUsage {
|
||||
None
|
||||
}
|
||||
|
||||
/// 智能 Codex 响应解析 - 自动检测 OpenAI 或 Codex 格式
|
||||
///
|
||||
/// Codex 支持两种 API 格式:
|
||||
/// - `/v1/responses`: 使用 input_tokens/output_tokens
|
||||
/// - `/v1/chat/completions`: 使用 prompt_tokens/completion_tokens (OpenAI 格式)
|
||||
///
|
||||
/// 注意:记录原始 input_tokens,费用计算时再减去 cached_tokens
|
||||
pub fn from_codex_response_auto(body: &Value) -> Option<Self> {
|
||||
let usage = body.get("usage")?;
|
||||
|
||||
// 检测格式:OpenAI 使用 prompt_tokens,Codex 使用 input_tokens
|
||||
if usage.get("prompt_tokens").is_some() {
|
||||
log::debug!("[Codex] 检测到 OpenAI 格式 (prompt_tokens)");
|
||||
Self::from_openai_response(body)
|
||||
} else if usage.get("input_tokens").is_some() {
|
||||
log::debug!("[Codex] 检测到 Codex 格式 (input_tokens)");
|
||||
// 使用非调整版本,记录原始 input_tokens
|
||||
Self::from_codex_response(body)
|
||||
} else {
|
||||
log::debug!("[Codex] 无法识别响应格式,usage: {usage:?}");
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// 智能 Codex 流式响应解析 - 自动检测 OpenAI 或 Codex 格式
|
||||
pub fn from_codex_stream_events_auto(events: &[Value]) -> Option<Self> {
|
||||
log::debug!("[Codex] 智能解析流式事件,共 {} 个事件", events.len());
|
||||
|
||||
// 先尝试 Codex Responses API 格式 (response.completed 事件)
|
||||
for event in events {
|
||||
if let Some(event_type) = event.get("type").and_then(|v| v.as_str()) {
|
||||
if event_type == "response.completed" {
|
||||
if let Some(response) = event.get("response") {
|
||||
log::debug!("[Codex] 找到 response.completed 事件");
|
||||
return Self::from_codex_response_auto(response);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 回退到 OpenAI Chat Completions 格式 (最后一个 chunk 包含 usage)
|
||||
log::debug!("[Codex] 尝试 OpenAI 流式格式");
|
||||
Self::from_openai_stream_events(events)
|
||||
}
|
||||
|
||||
/// 从 OpenAI Chat Completions API 响应解析 (prompt_tokens, completion_tokens)
|
||||
pub fn from_openai_response(body: &Value) -> Option<Self> {
|
||||
let usage = body.get("usage")?;
|
||||
@@ -284,9 +348,16 @@ impl TokenUsage {
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let prompt_tokens = usage.get("promptTokenCount")?.as_u64()? as u32;
|
||||
let total_tokens = usage.get("totalTokenCount")?.as_u64()? as u32;
|
||||
|
||||
// 输出 tokens = 总 tokens - 输入 tokens
|
||||
// 这包含了 candidatesTokenCount + thoughtsTokenCount
|
||||
let output_tokens = total_tokens.saturating_sub(prompt_tokens);
|
||||
|
||||
Some(Self {
|
||||
input_tokens: usage.get("promptTokenCount")?.as_u64()? as u32,
|
||||
output_tokens: usage.get("candidatesTokenCount")?.as_u64()? as u32,
|
||||
input_tokens: prompt_tokens,
|
||||
output_tokens,
|
||||
cache_read_tokens: usage
|
||||
.get("cachedContentTokenCount")
|
||||
.and_then(|v| v.as_u64())
|
||||
@@ -300,20 +371,25 @@ impl TokenUsage {
|
||||
#[allow(dead_code)]
|
||||
pub fn from_gemini_stream_chunks(chunks: &[Value]) -> Option<Self> {
|
||||
let mut total_input = 0u32;
|
||||
let mut total_output = 0u32;
|
||||
let mut total_tokens = 0u32;
|
||||
let mut total_cache_read = 0u32;
|
||||
let mut model: Option<String> = None;
|
||||
|
||||
for chunk in chunks {
|
||||
if let Some(usage) = chunk.get("usageMetadata") {
|
||||
// 输入 tokens (通常在所有 chunk 中保持不变)
|
||||
total_input = usage
|
||||
.get("promptTokenCount")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0) as u32;
|
||||
total_output += usage
|
||||
.get("candidatesTokenCount")
|
||||
|
||||
// 总 tokens (包含输入 + 输出 + 思考)
|
||||
total_tokens = usage
|
||||
.get("totalTokenCount")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
// 缓存读取 tokens
|
||||
total_cache_read = usage
|
||||
.get("cachedContentTokenCount")
|
||||
.and_then(|v| v.as_u64())
|
||||
@@ -328,6 +404,9 @@ impl TokenUsage {
|
||||
}
|
||||
}
|
||||
|
||||
// 输出 tokens = 总 tokens - 输入 tokens
|
||||
let total_output = total_tokens.saturating_sub(total_input);
|
||||
|
||||
if total_input > 0 || total_output > 0 {
|
||||
Some(Self {
|
||||
input_tokens: total_input,
|
||||
@@ -466,15 +545,18 @@ mod tests {
|
||||
let response = json!({
|
||||
"modelVersion": "gemini-3-pro-high",
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 100,
|
||||
"promptTokenCount": 8383,
|
||||
"candidatesTokenCount": 50,
|
||||
"thoughtsTokenCount": 114,
|
||||
"totalTokenCount": 8547,
|
||||
"cachedContentTokenCount": 20
|
||||
}
|
||||
});
|
||||
|
||||
let usage = TokenUsage::from_gemini_response(&response).unwrap();
|
||||
assert_eq!(usage.input_tokens, 100);
|
||||
assert_eq!(usage.output_tokens, 50);
|
||||
assert_eq!(usage.input_tokens, 8383);
|
||||
// output_tokens = totalTokenCount - promptTokenCount = 8547 - 8383 = 164
|
||||
assert_eq!(usage.output_tokens, 164);
|
||||
assert_eq!(usage.cache_read_tokens, 20);
|
||||
assert_eq!(usage.cache_creation_tokens, 0);
|
||||
assert_eq!(usage.model, Some("gemini-3-pro-high".to_string()));
|
||||
@@ -486,19 +568,78 @@ mod tests {
|
||||
let response = json!({
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 100,
|
||||
"candidatesTokenCount": 50,
|
||||
"totalTokenCount": 150,
|
||||
"cachedContentTokenCount": 20
|
||||
}
|
||||
});
|
||||
|
||||
let usage = TokenUsage::from_gemini_response(&response).unwrap();
|
||||
assert_eq!(usage.input_tokens, 100);
|
||||
// output_tokens = totalTokenCount - promptTokenCount = 150 - 100 = 50
|
||||
assert_eq!(usage.output_tokens, 50);
|
||||
assert_eq!(usage.cache_read_tokens, 20);
|
||||
assert_eq!(usage.cache_creation_tokens, 0);
|
||||
assert_eq!(usage.model, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemini_response_with_thoughts() {
|
||||
// 测试包含 thoughtsTokenCount 的实际响应
|
||||
// 这是用户报告的真实场景
|
||||
let response = json!({
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": "",
|
||||
"thoughtSignature": "EvcECvQE..."
|
||||
}
|
||||
],
|
||||
"role": "model"
|
||||
},
|
||||
"finishReason": "STOP"
|
||||
}
|
||||
],
|
||||
"modelVersion": "gemini-3-pro-high",
|
||||
"responseId": "yupTafqLDu-PjMcPhrOx4QQ",
|
||||
"usageMetadata": {
|
||||
"candidatesTokenCount": 50,
|
||||
"promptTokenCount": 8383,
|
||||
"thoughtsTokenCount": 114,
|
||||
"totalTokenCount": 8547
|
||||
}
|
||||
});
|
||||
|
||||
let usage = TokenUsage::from_gemini_response(&response).unwrap();
|
||||
assert_eq!(usage.input_tokens, 8383);
|
||||
// output_tokens = totalTokenCount - promptTokenCount
|
||||
// = 8547 - 8383 = 164 (包含 candidatesTokenCount 50 + thoughtsTokenCount 114)
|
||||
assert_eq!(usage.output_tokens, 164);
|
||||
assert_eq!(usage.cache_read_tokens, 0);
|
||||
assert_eq!(usage.cache_creation_tokens, 0);
|
||||
assert_eq!(usage.model, Some("gemini-3-pro-high".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_codex_response_parsing_cached_tokens_in_details() {
|
||||
let response = json!({
|
||||
"usage": {
|
||||
"input_tokens": 1000,
|
||||
"output_tokens": 500,
|
||||
"input_tokens_details": {
|
||||
"cached_tokens": 300
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let usage = TokenUsage::from_codex_response(&response).unwrap();
|
||||
// 非调整模式:input_tokens 保持原值,但应记录缓存命中
|
||||
assert_eq!(usage.input_tokens, 1000);
|
||||
assert_eq!(usage.output_tokens, 500);
|
||||
assert_eq!(usage.cache_read_tokens, 300);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_codex_response_adjusted() {
|
||||
let response = json!({
|
||||
@@ -534,6 +675,22 @@ mod tests {
|
||||
assert_eq!(usage.cache_read_tokens, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_codex_response_adjusted_cache_read_input_tokens() {
|
||||
let response = json!({
|
||||
"usage": {
|
||||
"input_tokens": 1000,
|
||||
"output_tokens": 500,
|
||||
"cache_read_input_tokens": 200
|
||||
}
|
||||
});
|
||||
|
||||
let usage = TokenUsage::from_codex_response_adjusted(&response).unwrap();
|
||||
assert_eq!(usage.input_tokens, 800);
|
||||
assert_eq!(usage.output_tokens, 500);
|
||||
assert_eq!(usage.cache_read_tokens, 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_codex_response_adjusted_saturating_sub() {
|
||||
// 测试 cached_tokens > input_tokens 的边界情况
|
||||
@@ -615,4 +772,110 @@ mod tests {
|
||||
assert_eq!(usage.cache_read_tokens, 50);
|
||||
assert_eq!(usage.model, Some("claude-sonnet-4-20250514".to_string()));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 智能 Codex 解析测试
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_codex_response_auto_openai_format() {
|
||||
// OpenAI 格式 (prompt_tokens/completion_tokens)
|
||||
let response = json!({
|
||||
"model": "gpt-4o",
|
||||
"usage": {
|
||||
"prompt_tokens": 1000,
|
||||
"completion_tokens": 500,
|
||||
"prompt_tokens_details": {
|
||||
"cached_tokens": 200
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let usage = TokenUsage::from_codex_response_auto(&response).unwrap();
|
||||
assert_eq!(usage.input_tokens, 1000);
|
||||
assert_eq!(usage.output_tokens, 500);
|
||||
assert_eq!(usage.cache_read_tokens, 200);
|
||||
assert_eq!(usage.model, Some("gpt-4o".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_codex_response_auto_codex_format() {
|
||||
// Codex 格式 (input_tokens/output_tokens)
|
||||
let response = json!({
|
||||
"model": "o3",
|
||||
"usage": {
|
||||
"input_tokens": 1000,
|
||||
"output_tokens": 500,
|
||||
"input_tokens_details": {
|
||||
"cached_tokens": 300
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let usage = TokenUsage::from_codex_response_auto(&response).unwrap();
|
||||
// 记录原始 input_tokens,不调整
|
||||
assert_eq!(usage.input_tokens, 1000);
|
||||
assert_eq!(usage.output_tokens, 500);
|
||||
assert_eq!(usage.cache_read_tokens, 300);
|
||||
assert_eq!(usage.model, Some("o3".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_codex_stream_events_auto_codex_format() {
|
||||
// Codex Responses API 流式格式 (response.completed 事件)
|
||||
let events = vec![
|
||||
json!({
|
||||
"type": "response.created",
|
||||
"response": {
|
||||
"id": "resp_123"
|
||||
}
|
||||
}),
|
||||
json!({
|
||||
"type": "response.completed",
|
||||
"response": {
|
||||
"model": "o3",
|
||||
"usage": {
|
||||
"input_tokens": 1000,
|
||||
"output_tokens": 500,
|
||||
"input_tokens_details": {
|
||||
"cached_tokens": 200
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
];
|
||||
|
||||
let usage = TokenUsage::from_codex_stream_events_auto(&events).unwrap();
|
||||
// 记录原始 input_tokens,不调整
|
||||
assert_eq!(usage.input_tokens, 1000);
|
||||
assert_eq!(usage.output_tokens, 500);
|
||||
assert_eq!(usage.cache_read_tokens, 200);
|
||||
assert_eq!(usage.model, Some("o3".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_codex_stream_events_auto_openai_format() {
|
||||
// OpenAI Chat Completions 流式格式 (最后一个 chunk 包含 usage)
|
||||
let events = vec![
|
||||
json!({
|
||||
"id": "chatcmpl-123",
|
||||
"model": "gpt-4o",
|
||||
"choices": [{"delta": {"content": "Hello"}}]
|
||||
}),
|
||||
json!({
|
||||
"id": "chatcmpl-123",
|
||||
"model": "gpt-4o",
|
||||
"choices": [{"delta": {}}],
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50
|
||||
}
|
||||
}),
|
||||
];
|
||||
|
||||
let usage = TokenUsage::from_codex_stream_events_auto(&events).unwrap();
|
||||
assert_eq!(usage.input_tokens, 100);
|
||||
assert_eq!(usage.output_tokens, 50);
|
||||
assert_eq!(usage.model, Some("gpt-4o".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -217,9 +217,12 @@ impl ProviderService {
|
||||
.flatten()
|
||||
.is_some();
|
||||
let is_proxy_running = futures::executor::block_on(state.proxy_service.is_running());
|
||||
let live_taken_over = state
|
||||
.proxy_service
|
||||
.detect_takeover_in_live_config_for_app(&app_type);
|
||||
|
||||
// Hot-switch only when BOTH: this app is taken over AND proxy server is actually running
|
||||
let should_hot_switch = is_app_taken_over && is_proxy_running;
|
||||
let should_hot_switch = (is_app_taken_over || live_taken_over) && is_proxy_running;
|
||||
|
||||
if should_hot_switch {
|
||||
// Proxy takeover mode: hot-switch only, don't write Live config
|
||||
@@ -736,15 +739,15 @@ impl ProviderService {
|
||||
// 删除生成的子供应商
|
||||
if let Some(p) = provider {
|
||||
if p.apps.claude {
|
||||
let claude_id = format!("universal-claude-{}", id);
|
||||
let claude_id = format!("universal-claude-{id}");
|
||||
let _ = state.db.delete_provider("claude", &claude_id);
|
||||
}
|
||||
if p.apps.codex {
|
||||
let codex_id = format!("universal-codex-{}", id);
|
||||
let codex_id = format!("universal-codex-{id}");
|
||||
let _ = state.db.delete_provider("codex", &codex_id);
|
||||
}
|
||||
if p.apps.gemini {
|
||||
let gemini_id = format!("universal-gemini-{}", id);
|
||||
let gemini_id = format!("universal-gemini-{id}");
|
||||
let _ = state.db.delete_provider("gemini", &gemini_id);
|
||||
}
|
||||
}
|
||||
@@ -757,7 +760,7 @@ impl ProviderService {
|
||||
let provider = state
|
||||
.db
|
||||
.get_universal_provider(id)?
|
||||
.ok_or_else(|| AppError::Message(format!("统一供应商 {} 不存在", id)))?;
|
||||
.ok_or_else(|| AppError::Message(format!("统一供应商 {id} 不存在")))?;
|
||||
|
||||
// 同步到 Claude
|
||||
if let Some(mut claude_provider) = provider.to_claude_provider() {
|
||||
@@ -770,7 +773,7 @@ impl ProviderService {
|
||||
state.db.save_provider("claude", &claude_provider)?;
|
||||
} else {
|
||||
// 如果禁用了 Claude,删除对应的子供应商
|
||||
let claude_id = format!("universal-claude-{}", id);
|
||||
let claude_id = format!("universal-claude-{id}");
|
||||
let _ = state.db.delete_provider("claude", &claude_id);
|
||||
}
|
||||
|
||||
@@ -784,7 +787,7 @@ impl ProviderService {
|
||||
}
|
||||
state.db.save_provider("codex", &codex_provider)?;
|
||||
} else {
|
||||
let codex_id = format!("universal-codex-{}", id);
|
||||
let codex_id = format!("universal-codex-{id}");
|
||||
let _ = state.db.delete_provider("codex", &codex_id);
|
||||
}
|
||||
|
||||
@@ -798,7 +801,7 @@ impl ProviderService {
|
||||
}
|
||||
state.db.save_provider("gemini", &gemini_provider)?;
|
||||
} else {
|
||||
let gemini_id = format!("universal-gemini-{}", id);
|
||||
let gemini_id = format!("universal-gemini-{id}");
|
||||
let _ = state.db.delete_provider("gemini", &gemini_id);
|
||||
}
|
||||
|
||||
|
||||
@@ -193,7 +193,7 @@ impl ProxyService {
|
||||
self.start().await?;
|
||||
}
|
||||
|
||||
// 2) 已接管则直接返回(幂等)
|
||||
// 2) 已接管则直接返回(幂等);但如果缺少备份或占位符残留,需要重建接管
|
||||
let current_config = self
|
||||
.db
|
||||
.get_proxy_config_for_app(app_type_str)
|
||||
@@ -201,7 +201,22 @@ impl ProxyService {
|
||||
.map_err(|e| format!("获取 {app_type_str} 配置失败: {e}"))?;
|
||||
|
||||
if current_config.enabled {
|
||||
return Ok(());
|
||||
let has_backup = match self.db.get_live_backup(app_type_str).await {
|
||||
Ok(v) => v.is_some(),
|
||||
Err(e) => {
|
||||
log::warn!("读取 {app_type_str} 备份失败(将继续重建接管): {e}");
|
||||
false
|
||||
}
|
||||
};
|
||||
let live_taken_over = self.detect_takeover_in_live_config_for_app(&app);
|
||||
|
||||
if has_backup || live_taken_over {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
log::warn!(
|
||||
"{app_type_str} 标记为已接管,但缺少备份或占位符,正在重新接管并补齐备份"
|
||||
);
|
||||
}
|
||||
|
||||
// 3) 备份 Live 配置(严格:目标 app 不存在则报错)
|
||||
@@ -1063,7 +1078,7 @@ impl ProxyService {
|
||||
}
|
||||
}
|
||||
|
||||
fn detect_takeover_in_live_config_for_app(&self, app_type: &AppType) -> bool {
|
||||
pub fn detect_takeover_in_live_config_for_app(&self, app_type: &AppType) -> bool {
|
||||
match app_type {
|
||||
AppType::Claude => match self.read_claude_live() {
|
||||
Ok(config) => Self::is_claude_live_taken_over(&config),
|
||||
@@ -1257,10 +1272,8 @@ impl ProxyService {
|
||||
|
||||
/// 检查是否处于 Live 接管模式
|
||||
pub async fn is_takeover_active(&self) -> Result<bool, String> {
|
||||
self.db
|
||||
.is_live_takeover_active()
|
||||
.await
|
||||
.map_err(|e| format!("检查接管状态失败: {e}"))
|
||||
let status = self.get_takeover_status().await?;
|
||||
Ok(status.claude || status.codex || status.gemini)
|
||||
}
|
||||
|
||||
/// 从异常退出中恢复(启动时调用)
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
use crate::database::{lock_conn, Database};
|
||||
use crate::error::AppError;
|
||||
use chrono::{Duration, Local, TimeZone};
|
||||
use chrono::{Local, TimeZone};
|
||||
use rusqlite::{params, Connection, OptionalExtension};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
@@ -181,145 +181,114 @@ impl Database {
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// 获取每日趋势
|
||||
pub fn get_daily_trends(&self, days: u32) -> Result<Vec<DailyStats>, AppError> {
|
||||
/// 获取每日趋势(滑动窗口,<=24h 按小时,>24h 按天,窗口与汇总一致)
|
||||
pub fn get_daily_trends(
|
||||
&self,
|
||||
start_date: Option<i64>,
|
||||
end_date: Option<i64>,
|
||||
) -> Result<Vec<DailyStats>, AppError> {
|
||||
let conn = lock_conn!(self.conn);
|
||||
|
||||
if days <= 1 {
|
||||
let today = Local::now().date_naive();
|
||||
let start_of_today = today.and_hms_opt(0, 0, 0).unwrap();
|
||||
// 使用 earliest() 处理 DST 切换时的歧义时间,fallback 到当前时间减一天
|
||||
let start_ts = Local
|
||||
.from_local_datetime(&start_of_today)
|
||||
.earliest()
|
||||
.unwrap_or_else(|| Local::now() - Duration::days(1))
|
||||
.timestamp();
|
||||
let end_ts = end_date.unwrap_or_else(|| Local::now().timestamp());
|
||||
let mut start_ts = start_date.unwrap_or_else(|| end_ts - 24 * 60 * 60);
|
||||
|
||||
let sql = "SELECT
|
||||
strftime('%Y-%m-%dT%H:00:00', datetime(created_at, 'unixepoch', 'localtime')) as bucket,
|
||||
COUNT(*) as request_count,
|
||||
COALESCE(SUM(CAST(total_cost_usd AS REAL)), 0) as total_cost,
|
||||
COALESCE(SUM(input_tokens + output_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens
|
||||
FROM proxy_request_logs
|
||||
WHERE created_at >= ?
|
||||
GROUP BY bucket
|
||||
ORDER BY bucket ASC";
|
||||
|
||||
let mut stmt = conn.prepare(sql)?;
|
||||
let rows = stmt.query_map([start_ts], |row| {
|
||||
Ok(DailyStats {
|
||||
date: row.get(0)?,
|
||||
request_count: row.get::<_, i64>(1)? as u64,
|
||||
total_cost: format!("{:.6}", row.get::<_, f64>(2)?),
|
||||
total_tokens: row.get::<_, i64>(3)? as u64,
|
||||
total_input_tokens: row.get::<_, i64>(4)? as u64,
|
||||
total_output_tokens: row.get::<_, i64>(5)? as u64,
|
||||
total_cache_creation_tokens: row.get::<_, i64>(6)? as u64,
|
||||
total_cache_read_tokens: row.get::<_, i64>(7)? as u64,
|
||||
})
|
||||
})?;
|
||||
|
||||
let mut buckets: HashMap<String, DailyStats> = HashMap::new();
|
||||
for row in rows {
|
||||
let stat = row?;
|
||||
buckets.insert(stat.date.clone(), stat);
|
||||
}
|
||||
|
||||
let mut stats = Vec::new();
|
||||
for hour in 0..24 {
|
||||
let bucket = today
|
||||
.and_hms_opt(hour, 0, 0)
|
||||
.unwrap()
|
||||
.format("%Y-%m-%dT%H:00:00")
|
||||
.to_string();
|
||||
|
||||
if let Some(stat) = buckets.remove(&bucket) {
|
||||
stats.push(stat);
|
||||
} else {
|
||||
stats.push(DailyStats {
|
||||
date: bucket,
|
||||
request_count: 0,
|
||||
total_cost: "0.000000".to_string(),
|
||||
total_tokens: 0,
|
||||
total_input_tokens: 0,
|
||||
total_output_tokens: 0,
|
||||
total_cache_creation_tokens: 0,
|
||||
total_cache_read_tokens: 0,
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(stats)
|
||||
} else {
|
||||
let today = Local::now().date_naive();
|
||||
let start_day = today - Duration::days((days.saturating_sub(1)) as i64);
|
||||
let start_of_window = start_day.and_hms_opt(0, 0, 0).unwrap();
|
||||
// 使用 earliest() 处理 DST 切换时的歧义时间,fallback 到当前时间减 days 天
|
||||
let start_ts = Local
|
||||
.from_local_datetime(&start_of_window)
|
||||
.earliest()
|
||||
.unwrap_or_else(|| Local::now() - Duration::days(days as i64))
|
||||
.timestamp();
|
||||
|
||||
let sql = "SELECT
|
||||
strftime('%Y-%m-%dT00:00:00', datetime(created_at, 'unixepoch', 'localtime')) as bucket,
|
||||
COUNT(*) as request_count,
|
||||
COALESCE(SUM(CAST(total_cost_usd AS REAL)), 0) as total_cost,
|
||||
COALESCE(SUM(input_tokens + output_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens
|
||||
FROM proxy_request_logs
|
||||
WHERE created_at >= ?
|
||||
GROUP BY bucket
|
||||
ORDER BY bucket ASC";
|
||||
|
||||
let mut stmt = conn.prepare(sql)?;
|
||||
let rows = stmt.query_map([start_ts], |row| {
|
||||
Ok(DailyStats {
|
||||
date: row.get(0)?,
|
||||
request_count: row.get::<_, i64>(1)? as u64,
|
||||
total_cost: format!("{:.6}", row.get::<_, f64>(2)?),
|
||||
total_tokens: row.get::<_, i64>(3)? as u64,
|
||||
total_input_tokens: row.get::<_, i64>(4)? as u64,
|
||||
total_output_tokens: row.get::<_, i64>(5)? as u64,
|
||||
total_cache_creation_tokens: row.get::<_, i64>(6)? as u64,
|
||||
total_cache_read_tokens: row.get::<_, i64>(7)? as u64,
|
||||
})
|
||||
})?;
|
||||
|
||||
let mut map = HashMap::new();
|
||||
for row in rows {
|
||||
let stat = row?;
|
||||
map.insert(stat.date.clone(), stat);
|
||||
}
|
||||
|
||||
let mut stats = Vec::new();
|
||||
|
||||
for i in 0..days {
|
||||
let day = start_day + Duration::days(i as i64);
|
||||
let key = day.format("%Y-%m-%dT00:00:00").to_string();
|
||||
if let Some(stat) = map.remove(&key) {
|
||||
stats.push(stat);
|
||||
} else {
|
||||
stats.push(DailyStats {
|
||||
date: key,
|
||||
request_count: 0,
|
||||
total_cost: "0.000000".to_string(),
|
||||
total_tokens: 0,
|
||||
total_input_tokens: 0,
|
||||
total_output_tokens: 0,
|
||||
total_cache_creation_tokens: 0,
|
||||
total_cache_read_tokens: 0,
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(stats)
|
||||
if start_ts >= end_ts {
|
||||
start_ts = end_ts - 24 * 60 * 60;
|
||||
}
|
||||
|
||||
let duration = end_ts - start_ts;
|
||||
let bucket_seconds: i64 = if duration <= 24 * 60 * 60 {
|
||||
60 * 60
|
||||
} else {
|
||||
24 * 60 * 60
|
||||
};
|
||||
let mut bucket_count: i64 = if duration <= 0 {
|
||||
1
|
||||
} else {
|
||||
((duration as f64) / bucket_seconds as f64).ceil() as i64
|
||||
};
|
||||
|
||||
// 固定 24 小时窗口为 24 个小时桶,避免浮点误差
|
||||
if bucket_seconds == 60 * 60 {
|
||||
bucket_count = 24;
|
||||
}
|
||||
|
||||
if bucket_count < 1 {
|
||||
bucket_count = 1;
|
||||
}
|
||||
|
||||
let sql = "
|
||||
SELECT
|
||||
CAST((created_at - ?1) / ?3 AS INTEGER) as bucket_idx,
|
||||
COUNT(*) as request_count,
|
||||
COALESCE(SUM(CAST(total_cost_usd AS REAL)), 0) as total_cost,
|
||||
COALESCE(SUM(input_tokens + output_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens
|
||||
FROM proxy_request_logs
|
||||
WHERE created_at >= ?1 AND created_at <= ?2
|
||||
GROUP BY bucket_idx
|
||||
ORDER BY bucket_idx ASC";
|
||||
|
||||
let mut stmt = conn.prepare(sql)?;
|
||||
let rows = stmt.query_map(params![start_ts, end_ts, bucket_seconds], |row| {
|
||||
Ok((
|
||||
row.get::<_, i64>(0)?,
|
||||
DailyStats {
|
||||
date: String::new(),
|
||||
request_count: row.get::<_, i64>(1)? as u64,
|
||||
total_cost: format!("{:.6}", row.get::<_, f64>(2)?),
|
||||
total_tokens: row.get::<_, i64>(3)? as u64,
|
||||
total_input_tokens: row.get::<_, i64>(4)? as u64,
|
||||
total_output_tokens: row.get::<_, i64>(5)? as u64,
|
||||
total_cache_creation_tokens: row.get::<_, i64>(6)? as u64,
|
||||
total_cache_read_tokens: row.get::<_, i64>(7)? as u64,
|
||||
},
|
||||
))
|
||||
})?;
|
||||
|
||||
let mut map: HashMap<i64, DailyStats> = HashMap::new();
|
||||
for row in rows {
|
||||
let (mut bucket_idx, stat) = row?;
|
||||
if bucket_idx < 0 {
|
||||
continue;
|
||||
}
|
||||
if bucket_idx >= bucket_count {
|
||||
bucket_idx = bucket_count - 1;
|
||||
}
|
||||
map.insert(bucket_idx, stat);
|
||||
}
|
||||
|
||||
let mut stats = Vec::with_capacity(bucket_count as usize);
|
||||
for i in 0..bucket_count {
|
||||
let bucket_start_ts = start_ts + i * bucket_seconds;
|
||||
let bucket_start = Local
|
||||
.timestamp_opt(bucket_start_ts, 0)
|
||||
.single()
|
||||
.unwrap_or_else(Local::now);
|
||||
|
||||
let date = bucket_start.format("%Y-%m-%dT%H:%M:%S").to_string();
|
||||
|
||||
if let Some(mut stat) = map.remove(&i) {
|
||||
stat.date = date;
|
||||
stats.push(stat);
|
||||
} else {
|
||||
stats.push(DailyStats {
|
||||
date,
|
||||
request_count: 0,
|
||||
total_cost: "0.000000".to_string(),
|
||||
total_tokens: 0,
|
||||
total_input_tokens: 0,
|
||||
total_output_tokens: 0,
|
||||
total_cache_creation_tokens: 0,
|
||||
total_cache_read_tokens: 0,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(stats)
|
||||
}
|
||||
|
||||
/// 获取 Provider 统计
|
||||
@@ -829,89 +798,46 @@ impl Database {
|
||||
}
|
||||
}
|
||||
|
||||
/// 标准化模型名称:去除供应商前缀并将点号替换为短横线
|
||||
/// 例如:anthropic/claude-haiku-4.5 → claude-haiku-4-5
|
||||
fn normalize_model_id(model_id: &str) -> String {
|
||||
// 1. 去除供应商前缀(如 anthropic/、openai/)
|
||||
let stripped = if let Some(pos) = model_id.find('/') {
|
||||
&model_id[pos + 1..]
|
||||
} else {
|
||||
model_id
|
||||
};
|
||||
// 2. 将点号替换为短横线(如 claude-haiku-4.5 → claude-haiku-4-5)
|
||||
stripped.replace('.', "-")
|
||||
}
|
||||
|
||||
pub(crate) fn find_model_pricing_row(
|
||||
conn: &Connection,
|
||||
model_id: &str,
|
||||
) -> Result<Option<(String, String, String, String)>, AppError> {
|
||||
// 0. 标准化模型名称(去除前缀 + 点号转短横线)
|
||||
// 例如:anthropic/claude-haiku-4.5 → claude-haiku-4-5
|
||||
let normalized = normalize_model_id(model_id);
|
||||
// 1) 去除供应商前缀(/ 之前)与冒号后缀(: 之后),例如 moonshotai/kimi-k2-0905:exa → kimi-k2-0905
|
||||
let without_prefix = model_id
|
||||
.rsplit_once('/')
|
||||
.map(|(_, rest)| rest)
|
||||
.unwrap_or(model_id);
|
||||
let cleaned = without_prefix
|
||||
.split(':')
|
||||
.next()
|
||||
.map(str::trim)
|
||||
.unwrap_or(without_prefix);
|
||||
|
||||
// 1. 精确匹配(先尝试原始名称,再尝试标准化后的名称)
|
||||
for id in [model_id, normalized.as_str()] {
|
||||
let exact = conn
|
||||
.query_row(
|
||||
"SELECT input_cost_per_million, output_cost_per_million,
|
||||
cache_read_cost_per_million, cache_creation_cost_per_million
|
||||
FROM model_pricing
|
||||
WHERE model_id = ?1",
|
||||
[id],
|
||||
|row| {
|
||||
Ok((
|
||||
row.get::<_, String>(0)?,
|
||||
row.get::<_, String>(1)?,
|
||||
row.get::<_, String>(2)?,
|
||||
row.get::<_, String>(3)?,
|
||||
))
|
||||
},
|
||||
)
|
||||
.optional()
|
||||
.map_err(|e| AppError::Database(format!("查询模型定价失败: {e}")))?;
|
||||
// 2) 精确匹配清洗后的名称
|
||||
let exact = conn
|
||||
.query_row(
|
||||
"SELECT input_cost_per_million, output_cost_per_million,
|
||||
cache_read_cost_per_million, cache_creation_cost_per_million
|
||||
FROM model_pricing
|
||||
WHERE model_id = ?1",
|
||||
[cleaned],
|
||||
|row| {
|
||||
Ok((
|
||||
row.get::<_, String>(0)?,
|
||||
row.get::<_, String>(1)?,
|
||||
row.get::<_, String>(2)?,
|
||||
row.get::<_, String>(3)?,
|
||||
))
|
||||
},
|
||||
)
|
||||
.optional()
|
||||
.map_err(|e| AppError::Database(format!("查询模型定价失败: {e}")))?;
|
||||
|
||||
if exact.is_some() {
|
||||
if id != model_id {
|
||||
log::info!("模型 {model_id} 标准化后精确匹配到: {id}");
|
||||
}
|
||||
return Ok(exact);
|
||||
}
|
||||
if exact.is_none() {
|
||||
log::warn!("模型 {model_id}(清洗后: {cleaned})未找到定价信息,成本将记录为 0");
|
||||
}
|
||||
|
||||
// 2. 逐步删除后缀匹配(claude-haiku-4-5-20250929 → claude-haiku-4-5 → claude-haiku-4 → claude-haiku)
|
||||
// 使用标准化后的名称进行后缀匹配
|
||||
let mut current = normalized;
|
||||
while let Some(pos) = current.rfind('-') {
|
||||
current = current[..pos].to_string();
|
||||
|
||||
let result = conn
|
||||
.query_row(
|
||||
"SELECT input_cost_per_million, output_cost_per_million,
|
||||
cache_read_cost_per_million, cache_creation_cost_per_million
|
||||
FROM model_pricing
|
||||
WHERE model_id = ?1",
|
||||
[¤t],
|
||||
|row| {
|
||||
Ok((
|
||||
row.get::<_, String>(0)?,
|
||||
row.get::<_, String>(1)?,
|
||||
row.get::<_, String>(2)?,
|
||||
row.get::<_, String>(3)?,
|
||||
))
|
||||
},
|
||||
)
|
||||
.optional()
|
||||
.map_err(|e| AppError::Database(format!("查询模型定价失败: {e}")))?;
|
||||
|
||||
if result.is_some() {
|
||||
log::info!("模型 {model_id} 通过删除后缀匹配到: {current}");
|
||||
return Ok(result);
|
||||
}
|
||||
}
|
||||
|
||||
log::warn!("模型 {model_id} 未找到定价信息,成本将记录为 0");
|
||||
Ok(None)
|
||||
Ok(exact)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -991,54 +917,39 @@ mod tests {
|
||||
let db = Database::memory()?;
|
||||
let conn = lock_conn!(db.conn);
|
||||
|
||||
// 测试精确匹配
|
||||
let result = find_model_pricing_row(&conn, "claude-sonnet-4-5")?;
|
||||
assert!(result.is_some(), "应该能精确匹配 claude-sonnet-4-5");
|
||||
// 准备额外定价数据,覆盖前缀/后缀清洗场景
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO model_pricing (
|
||||
model_id, display_name, input_cost_per_million, output_cost_per_million,
|
||||
cache_read_cost_per_million, cache_creation_cost_per_million
|
||||
) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
params![
|
||||
"claude-haiku-4.5",
|
||||
"Claude Haiku 4.5",
|
||||
"1.0",
|
||||
"2.0",
|
||||
"0.0",
|
||||
"0.0"
|
||||
],
|
||||
)?;
|
||||
|
||||
// 测试带供应商前缀的模型名称(anthropic/claude-haiku-4.5 → claude-haiku-4-5)
|
||||
let result = find_model_pricing_row(&conn, "anthropic/claude-haiku-4.5")?;
|
||||
assert!(
|
||||
result.is_some(),
|
||||
"应该能匹配带前缀的模型 anthropic/claude-haiku-4.5"
|
||||
);
|
||||
|
||||
// 测试带供应商前缀 + 点号的模型名称
|
||||
let result = find_model_pricing_row(&conn, "anthropic/claude-sonnet-4.5")?;
|
||||
assert!(
|
||||
result.is_some(),
|
||||
"应该能匹配带前缀的模型 anthropic/claude-sonnet-4.5"
|
||||
);
|
||||
|
||||
// 测试逐步删除后缀匹配 - 日期后缀
|
||||
let result = find_model_pricing_row(&conn, "claude-sonnet-4-5-20241022")?;
|
||||
assert!(
|
||||
result.is_some(),
|
||||
"应该能通过删除后缀匹配 claude-sonnet-4-5-20241022"
|
||||
);
|
||||
|
||||
// 测试逐步删除后缀匹配 - 多个后缀
|
||||
let result = find_model_pricing_row(&conn, "claude-haiku-4-5-20240229-preview")?;
|
||||
assert!(
|
||||
result.is_some(),
|
||||
"应该能通过删除后缀匹配 claude-haiku-4-5-20240229-preview"
|
||||
);
|
||||
|
||||
// 测试 GPT 模型
|
||||
let result = find_model_pricing_row(&conn, "gpt-5-2024-11-20")?;
|
||||
assert!(result.is_some(), "应该能通过删除后缀匹配 gpt-5-2024-11-20");
|
||||
|
||||
// 测试 Gemini 模型
|
||||
let result = find_model_pricing_row(&conn, "gemini-2.5-flash-exp")?;
|
||||
assert!(
|
||||
result.is_some(),
|
||||
"应该能通过删除后缀匹配 gemini-2.5-flash-exp"
|
||||
);
|
||||
|
||||
// 测试 claude-sonnet-4-5 命名格式
|
||||
// 测试精确匹配(seed_model_pricing 已预置 claude-sonnet-4-5-20250929)
|
||||
let result = find_model_pricing_row(&conn, "claude-sonnet-4-5-20250929")?;
|
||||
assert!(
|
||||
result.is_some(),
|
||||
"应该能通过删除后缀匹配 claude-sonnet-4-5-20250929"
|
||||
"应该能精确匹配 claude-sonnet-4-5-20250929"
|
||||
);
|
||||
|
||||
// 清洗:去除前缀和冒号后缀
|
||||
let result = find_model_pricing_row(&conn, "anthropic/claude-haiku-4.5")?;
|
||||
assert!(
|
||||
result.is_some(),
|
||||
"带前缀的模型 anthropic/claude-haiku-4.5 应能匹配到 claude-haiku-4.5"
|
||||
);
|
||||
let result = find_model_pricing_row(&conn, "moonshotai/kimi-k2-0905:exa")?;
|
||||
assert!(
|
||||
result.is_some(),
|
||||
"带前缀+冒号后缀的模型应清洗后匹配到 kimi-k2-0905"
|
||||
);
|
||||
|
||||
// 测试不存在的模型
|
||||
|
||||
@@ -142,12 +142,13 @@ export function AutoFailoverConfigPanel({
|
||||
min="0"
|
||||
max="10"
|
||||
value={formData.maxRetries}
|
||||
onChange={(e) =>
|
||||
onChange={(e) => {
|
||||
const val = parseInt(e.target.value);
|
||||
setFormData({
|
||||
...formData,
|
||||
maxRetries: parseInt(e.target.value) || 3,
|
||||
})
|
||||
}
|
||||
maxRetries: isNaN(val) ? 0 : val,
|
||||
});
|
||||
}}
|
||||
disabled={isDisabled}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
@@ -168,12 +169,13 @@ export function AutoFailoverConfigPanel({
|
||||
min="1"
|
||||
max="20"
|
||||
value={formData.circuitFailureThreshold}
|
||||
onChange={(e) =>
|
||||
onChange={(e) => {
|
||||
const val = parseInt(e.target.value);
|
||||
setFormData({
|
||||
...formData,
|
||||
circuitFailureThreshold: parseInt(e.target.value) || 5,
|
||||
})
|
||||
}
|
||||
circuitFailureThreshold: isNaN(val) ? 1 : Math.max(1, val),
|
||||
});
|
||||
}}
|
||||
disabled={isDisabled}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
@@ -206,12 +208,13 @@ export function AutoFailoverConfigPanel({
|
||||
min="0"
|
||||
max="180"
|
||||
value={formData.streamingFirstByteTimeout}
|
||||
onChange={(e) =>
|
||||
onChange={(e) => {
|
||||
const val = parseInt(e.target.value);
|
||||
setFormData({
|
||||
...formData,
|
||||
streamingFirstByteTimeout: parseInt(e.target.value) || 30,
|
||||
})
|
||||
}
|
||||
streamingFirstByteTimeout: isNaN(val) ? 0 : val,
|
||||
});
|
||||
}}
|
||||
disabled={isDisabled}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
@@ -232,12 +235,13 @@ export function AutoFailoverConfigPanel({
|
||||
min="0"
|
||||
max="600"
|
||||
value={formData.streamingIdleTimeout}
|
||||
onChange={(e) =>
|
||||
onChange={(e) => {
|
||||
const val = parseInt(e.target.value);
|
||||
setFormData({
|
||||
...formData,
|
||||
streamingIdleTimeout: parseInt(e.target.value) || 60,
|
||||
})
|
||||
}
|
||||
streamingIdleTimeout: isNaN(val) ? 0 : val,
|
||||
});
|
||||
}}
|
||||
disabled={isDisabled}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
@@ -258,12 +262,13 @@ export function AutoFailoverConfigPanel({
|
||||
min="0"
|
||||
max="1800"
|
||||
value={formData.nonStreamingTimeout}
|
||||
onChange={(e) =>
|
||||
onChange={(e) => {
|
||||
const val = parseInt(e.target.value);
|
||||
setFormData({
|
||||
...formData,
|
||||
nonStreamingTimeout: parseInt(e.target.value) || 300,
|
||||
})
|
||||
}
|
||||
nonStreamingTimeout: isNaN(val) ? 0 : val,
|
||||
});
|
||||
}}
|
||||
disabled={isDisabled}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
@@ -293,12 +298,13 @@ export function AutoFailoverConfigPanel({
|
||||
min="1"
|
||||
max="10"
|
||||
value={formData.circuitSuccessThreshold}
|
||||
onChange={(e) =>
|
||||
onChange={(e) => {
|
||||
const val = parseInt(e.target.value);
|
||||
setFormData({
|
||||
...formData,
|
||||
circuitSuccessThreshold: parseInt(e.target.value) || 2,
|
||||
})
|
||||
}
|
||||
circuitSuccessThreshold: isNaN(val) ? 1 : Math.max(1, val),
|
||||
});
|
||||
}}
|
||||
disabled={isDisabled}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
@@ -319,12 +325,13 @@ export function AutoFailoverConfigPanel({
|
||||
min="10"
|
||||
max="300"
|
||||
value={formData.circuitTimeoutSeconds}
|
||||
onChange={(e) =>
|
||||
onChange={(e) => {
|
||||
const val = parseInt(e.target.value);
|
||||
setFormData({
|
||||
...formData,
|
||||
circuitTimeoutSeconds: parseInt(e.target.value) || 60,
|
||||
})
|
||||
}
|
||||
circuitTimeoutSeconds: isNaN(val) ? 10 : Math.max(10, val),
|
||||
});
|
||||
}}
|
||||
disabled={isDisabled}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
@@ -346,13 +353,13 @@ export function AutoFailoverConfigPanel({
|
||||
max="100"
|
||||
step="5"
|
||||
value={Math.round(formData.circuitErrorRateThreshold * 100)}
|
||||
onChange={(e) =>
|
||||
onChange={(e) => {
|
||||
const val = parseInt(e.target.value);
|
||||
setFormData({
|
||||
...formData,
|
||||
circuitErrorRateThreshold:
|
||||
(parseInt(e.target.value) || 50) / 100,
|
||||
})
|
||||
}
|
||||
circuitErrorRateThreshold: isNaN(val) ? 0.5 : val / 100,
|
||||
});
|
||||
}}
|
||||
disabled={isDisabled}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
@@ -373,12 +380,13 @@ export function AutoFailoverConfigPanel({
|
||||
min="5"
|
||||
max="100"
|
||||
value={formData.circuitMinRequests}
|
||||
onChange={(e) =>
|
||||
onChange={(e) => {
|
||||
const val = parseInt(e.target.value);
|
||||
setFormData({
|
||||
...formData,
|
||||
circuitMinRequests: parseInt(e.target.value) || 10,
|
||||
})
|
||||
}
|
||||
circuitMinRequests: isNaN(val) ? 5 : Math.max(5, val),
|
||||
});
|
||||
}}
|
||||
disabled={isDisabled}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
|
||||
@@ -62,6 +62,28 @@ export function RequestLogTable() {
|
||||
});
|
||||
};
|
||||
|
||||
// 将 Unix 时间戳转换为本地时间的 datetime-local 格式
|
||||
const timestampToLocalDatetime = (timestamp: number): string => {
|
||||
const date = new Date(timestamp * 1000);
|
||||
const year = date.getFullYear();
|
||||
const month = String(date.getMonth() + 1).padStart(2, "0");
|
||||
const day = String(date.getDate()).padStart(2, "0");
|
||||
const hours = String(date.getHours()).padStart(2, "0");
|
||||
const minutes = String(date.getMinutes()).padStart(2, "0");
|
||||
return `${year}-${month}-${day}T${hours}:${minutes}`;
|
||||
};
|
||||
|
||||
// 将 datetime-local 格式转换为 Unix 时间戳
|
||||
const localDatetimeToTimestamp = (datetime: string): number | undefined => {
|
||||
if (!datetime) return undefined;
|
||||
// 验证格式是否完整 (YYYY-MM-DDTHH:mm)
|
||||
if (datetime.length < 16) return undefined;
|
||||
const timestamp = new Date(datetime).getTime();
|
||||
// 验证是否为有效日期
|
||||
if (isNaN(timestamp)) return undefined;
|
||||
return Math.floor(timestamp / 1000);
|
||||
};
|
||||
|
||||
const dateLocale =
|
||||
i18n.language === "zh"
|
||||
? "zh-CN"
|
||||
@@ -153,19 +175,16 @@ export function RequestLogTable() {
|
||||
className="h-8 w-[200px] bg-background"
|
||||
value={
|
||||
tempFilters.startDate
|
||||
? new Date(tempFilters.startDate * 1000)
|
||||
.toISOString()
|
||||
.slice(0, 16)
|
||||
? timestampToLocalDatetime(tempFilters.startDate)
|
||||
: ""
|
||||
}
|
||||
onChange={(e) =>
|
||||
onChange={(e) => {
|
||||
const timestamp = localDatetimeToTimestamp(e.target.value);
|
||||
setTempFilters({
|
||||
...tempFilters,
|
||||
startDate: e.target.value
|
||||
? Math.floor(new Date(e.target.value).getTime() / 1000)
|
||||
: undefined,
|
||||
})
|
||||
}
|
||||
startDate: timestamp,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
<span>-</span>
|
||||
<Input
|
||||
@@ -173,19 +192,16 @@ export function RequestLogTable() {
|
||||
className="h-8 w-[200px] bg-background"
|
||||
value={
|
||||
tempFilters.endDate
|
||||
? new Date(tempFilters.endDate * 1000)
|
||||
.toISOString()
|
||||
.slice(0, 16)
|
||||
? timestampToLocalDatetime(tempFilters.endDate)
|
||||
: ""
|
||||
}
|
||||
onChange={(e) =>
|
||||
onChange={(e) => {
|
||||
const timestamp = localDatetimeToTimestamp(e.target.value);
|
||||
setTempFilters({
|
||||
...tempFilters,
|
||||
endDate: e.target.value
|
||||
? Math.floor(new Date(e.target.value).getTime() / 1000)
|
||||
: undefined,
|
||||
})
|
||||
}
|
||||
endDate: timestamp,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -12,13 +12,7 @@ interface UsageSummaryCardsProps {
|
||||
export function UsageSummaryCards({ days }: UsageSummaryCardsProps) {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { startDate, endDate } = useMemo(() => {
|
||||
const end = Math.floor(Date.now() / 1000);
|
||||
const start = end - days * 24 * 60 * 60;
|
||||
return { startDate: start, endDate: end };
|
||||
}, [days]);
|
||||
|
||||
const { data: summary, isLoading } = useUsageSummary(startDate, endDate);
|
||||
const { data: summary, isLoading } = useUsageSummary(days);
|
||||
|
||||
const stats = useMemo(() => {
|
||||
const totalRequests = summary?.totalRequests ?? 0;
|
||||
|
||||
@@ -41,7 +41,12 @@ export function UsageTrendChart({ days }: UsageTrendChartProps) {
|
||||
return {
|
||||
rawDate: stat.date,
|
||||
label: isToday
|
||||
? pointDate.toLocaleTimeString(dateLocale, { hour: "2-digit" })
|
||||
? pointDate.toLocaleString(dateLocale, {
|
||||
month: "2-digit",
|
||||
day: "2-digit",
|
||||
hour: "2-digit",
|
||||
minute: "2-digit",
|
||||
})
|
||||
: pointDate.toLocaleDateString(dateLocale, {
|
||||
month: "2-digit",
|
||||
day: "2-digit",
|
||||
@@ -49,28 +54,13 @@ export function UsageTrendChart({ days }: UsageTrendChartProps) {
|
||||
hour: pointDate.getHours(),
|
||||
inputTokens: stat.totalInputTokens,
|
||||
outputTokens: stat.totalOutputTokens,
|
||||
cacheCreationTokens: stat.totalCacheCreationTokens,
|
||||
cacheReadTokens: stat.totalCacheReadTokens,
|
||||
cost: parseFloat(stat.totalCost),
|
||||
};
|
||||
}) || [];
|
||||
|
||||
const hourlyData = (() => {
|
||||
if (!isToday) return chartData;
|
||||
const map = new Map<number, (typeof chartData)[number]>();
|
||||
chartData.forEach((point) => {
|
||||
map.set(point.hour ?? 0, point);
|
||||
});
|
||||
return Array.from({ length: 24 }, (_, hour) => {
|
||||
const bucket = map.get(hour);
|
||||
return {
|
||||
label: `${hour.toString().padStart(2, "0")}:00`,
|
||||
inputTokens: bucket?.inputTokens ?? 0,
|
||||
outputTokens: bucket?.outputTokens ?? 0,
|
||||
cost: bucket?.cost ?? 0,
|
||||
};
|
||||
});
|
||||
})();
|
||||
|
||||
const displayData = isToday ? hourlyData : chartData;
|
||||
const displayData = chartData;
|
||||
|
||||
const CustomTooltip = ({ active, payload, label }: any) => {
|
||||
if (active && payload && payload.length) {
|
||||
@@ -131,6 +121,20 @@ export function UsageTrendChart({ days }: UsageTrendChartProps) {
|
||||
<stop offset="5%" stopColor="#22c55e" stopOpacity={0.2} />
|
||||
<stop offset="95%" stopColor="#22c55e" stopOpacity={0} />
|
||||
</linearGradient>
|
||||
<linearGradient
|
||||
id="colorCacheCreation"
|
||||
x1="0"
|
||||
y1="0"
|
||||
x2="0"
|
||||
y2="1"
|
||||
>
|
||||
<stop offset="5%" stopColor="#f97316" stopOpacity={0.2} />
|
||||
<stop offset="95%" stopColor="#f97316" stopOpacity={0} />
|
||||
</linearGradient>
|
||||
<linearGradient id="colorCacheRead" x1="0" y1="0" x2="0" y2="1">
|
||||
<stop offset="5%" stopColor="#a855f7" stopOpacity={0.2} />
|
||||
<stop offset="95%" stopColor="#a855f7" stopOpacity={0} />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
<CartesianGrid
|
||||
strokeDasharray="3 3"
|
||||
@@ -182,6 +186,26 @@ export function UsageTrendChart({ days }: UsageTrendChartProps) {
|
||||
fill="url(#colorOutput)"
|
||||
strokeWidth={2}
|
||||
/>
|
||||
<Area
|
||||
yAxisId="tokens"
|
||||
type="monotone"
|
||||
dataKey="cacheCreationTokens"
|
||||
name={t("usage.cacheCreationTokens", "缓存创建")}
|
||||
stroke="#f97316"
|
||||
fillOpacity={1}
|
||||
fill="url(#colorCacheCreation)"
|
||||
strokeWidth={2}
|
||||
/>
|
||||
<Area
|
||||
yAxisId="tokens"
|
||||
type="monotone"
|
||||
dataKey="cacheReadTokens"
|
||||
name={t("usage.cacheReadTokens", "缓存命中")}
|
||||
stroke="#a855f7"
|
||||
fillOpacity={1}
|
||||
fill="url(#colorCacheRead)"
|
||||
strokeWidth={2}
|
||||
/>
|
||||
<Area
|
||||
yAxisId="cost"
|
||||
type="monotone"
|
||||
|
||||
@@ -423,7 +423,7 @@
|
||||
"cost": "Cost",
|
||||
"perMillion": "(per million)",
|
||||
"trends": "Usage Trends",
|
||||
"rangeToday": "Today (hourly)",
|
||||
"rangeToday": "Last 24 hours (hourly)",
|
||||
"rangeLast7Days": "Last 7 days",
|
||||
"rangeLast30Days": "Last 30 days",
|
||||
"totalTokens": "Total Tokens",
|
||||
@@ -436,8 +436,8 @@
|
||||
"billingModel": "Billing Model",
|
||||
"inputTokens": "Input",
|
||||
"outputTokens": "Output",
|
||||
"cacheReadTokens": "Cache Read",
|
||||
"cacheCreationTokens": "Cache Write",
|
||||
"cacheReadTokens": "Cache Hit",
|
||||
"cacheCreationTokens": "Cache Creation",
|
||||
"timingInfo": "Duration/TTFT",
|
||||
"status": "Status",
|
||||
"noData": "No data",
|
||||
@@ -453,8 +453,8 @@
|
||||
"displayName": "Display Name",
|
||||
"inputCost": "Input Cost",
|
||||
"outputCost": "Output Cost",
|
||||
"cacheReadCost": "Cache Read",
|
||||
"cacheWriteCost": "Cache Write",
|
||||
"cacheReadCost": "Cache Hit",
|
||||
"cacheWriteCost": "Cache Creation",
|
||||
"deleteConfirmTitle": "Confirm Delete",
|
||||
"deleteConfirmDesc": "Are you sure you want to delete this model pricing? This action cannot be undone.",
|
||||
"queryFailed": "Query failed",
|
||||
@@ -481,8 +481,8 @@
|
||||
"timeRange": "Time Range",
|
||||
"input": "Input",
|
||||
"output": "Output",
|
||||
"cacheWrite": "Write",
|
||||
"cacheRead": "Read"
|
||||
"cacheWrite": "Creation",
|
||||
"cacheRead": "Hit"
|
||||
},
|
||||
"usageScript": {
|
||||
"title": "Configure Usage Query",
|
||||
|
||||
@@ -423,7 +423,7 @@
|
||||
"cost": "コスト",
|
||||
"perMillion": "(100万あたり)",
|
||||
"trends": "利用トレンド",
|
||||
"rangeToday": "今日 (時間別)",
|
||||
"rangeToday": "直近24時間 (時間別)",
|
||||
"rangeLast7Days": "過去7日間",
|
||||
"rangeLast30Days": "過去30日間",
|
||||
"totalTokens": "総トークン数",
|
||||
@@ -436,8 +436,8 @@
|
||||
"billingModel": "課金モデル",
|
||||
"inputTokens": "入力",
|
||||
"outputTokens": "出力",
|
||||
"cacheReadTokens": "キャッシュ読取",
|
||||
"cacheCreationTokens": "キャッシュ書込",
|
||||
"cacheReadTokens": "キャッシュヒット",
|
||||
"cacheCreationTokens": "キャッシュ作成",
|
||||
"timingInfo": "応答時間/TTFT",
|
||||
"status": "ステータス",
|
||||
"noData": "データなし",
|
||||
@@ -453,8 +453,8 @@
|
||||
"displayName": "表示名",
|
||||
"inputCost": "入力コスト",
|
||||
"outputCost": "出力コスト",
|
||||
"cacheReadCost": "キャッシュ読取",
|
||||
"cacheWriteCost": "キャッシュ書込",
|
||||
"cacheReadCost": "キャッシュヒット",
|
||||
"cacheWriteCost": "キャッシュ作成",
|
||||
"deleteConfirmTitle": "削除の確認",
|
||||
"deleteConfirmDesc": "このモデル料金を削除しますか?この操作は元に戻せません。",
|
||||
"queryFailed": "照会に失敗しました",
|
||||
@@ -481,8 +481,8 @@
|
||||
"timeRange": "期間",
|
||||
"input": "Input",
|
||||
"output": "Output",
|
||||
"cacheWrite": "Write",
|
||||
"cacheRead": "Read"
|
||||
"cacheWrite": "作成",
|
||||
"cacheRead": "ヒット"
|
||||
},
|
||||
"usageScript": {
|
||||
"title": "利用状況を設定",
|
||||
|
||||
@@ -423,7 +423,7 @@
|
||||
"cost": "成本",
|
||||
"perMillion": "(每百万)",
|
||||
"trends": "使用趋势",
|
||||
"rangeToday": "今天 (按小时)",
|
||||
"rangeToday": "过去 24 小时 (按小时)",
|
||||
"rangeLast7Days": "过去 7 天",
|
||||
"rangeLast30Days": "过去 30 天",
|
||||
"totalTokens": "总 Token 数",
|
||||
@@ -436,8 +436,8 @@
|
||||
"billingModel": "计费模型",
|
||||
"inputTokens": "输入",
|
||||
"outputTokens": "输出",
|
||||
"cacheReadTokens": "缓存读取",
|
||||
"cacheCreationTokens": "缓存写入",
|
||||
"cacheReadTokens": "缓存命中",
|
||||
"cacheCreationTokens": "缓存创建",
|
||||
"timingInfo": "用时/首字",
|
||||
"status": "状态",
|
||||
"noData": "暂无数据",
|
||||
@@ -453,8 +453,8 @@
|
||||
"displayName": "显示名称",
|
||||
"inputCost": "输入成本",
|
||||
"outputCost": "输出成本",
|
||||
"cacheReadCost": "缓存读取",
|
||||
"cacheWriteCost": "缓存写入",
|
||||
"cacheReadCost": "缓存命中",
|
||||
"cacheWriteCost": "缓存创建",
|
||||
"deleteConfirmTitle": "确认删除",
|
||||
"deleteConfirmDesc": "确定要删除此模型定价配置吗?此操作无法撤销。",
|
||||
"queryFailed": "查询失败",
|
||||
@@ -481,8 +481,8 @@
|
||||
"timeRange": "时间范围",
|
||||
"input": "Input",
|
||||
"output": "Output",
|
||||
"cacheWrite": "Write",
|
||||
"cacheRead": "Read"
|
||||
"cacheWrite": "创建",
|
||||
"cacheRead": "命中"
|
||||
},
|
||||
"usageScript": {
|
||||
"title": "配置用量查询",
|
||||
|
||||
@@ -49,8 +49,11 @@ export const usageApi = {
|
||||
return invoke("get_usage_summary", { startDate, endDate });
|
||||
},
|
||||
|
||||
getUsageTrends: async (days: number): Promise<DailyStats[]> => {
|
||||
return invoke("get_usage_trends", { days });
|
||||
getUsageTrends: async (
|
||||
startDate?: number,
|
||||
endDate?: number,
|
||||
): Promise<DailyStats[]> => {
|
||||
return invoke("get_usage_trends", { startDate, endDate });
|
||||
},
|
||||
|
||||
getProviderStats: async (): Promise<ProviderStats[]> => {
|
||||
|
||||
+27
-6
@@ -5,8 +5,7 @@ import type { LogFilters } from "@/types/usage";
|
||||
// Query keys
|
||||
export const usageKeys = {
|
||||
all: ["usage"] as const,
|
||||
summary: (startDate?: number, endDate?: number) =>
|
||||
[...usageKeys.all, "summary", startDate, endDate] as const,
|
||||
summary: (days: number) => [...usageKeys.all, "summary", days] as const,
|
||||
trends: (days: number) => [...usageKeys.all, "trends", days] as const,
|
||||
providerStats: () => [...usageKeys.all, "provider-stats"] as const,
|
||||
modelStats: () => [...usageKeys.all, "model-stats"] as const,
|
||||
@@ -19,18 +18,34 @@ export const usageKeys = {
|
||||
[...usageKeys.all, "limits", providerId, appType] as const,
|
||||
};
|
||||
|
||||
const getWindow = (days: number) => {
|
||||
const endDate = Math.floor(Date.now() / 1000);
|
||||
const startDate = endDate - days * 24 * 60 * 60;
|
||||
return { startDate, endDate };
|
||||
};
|
||||
|
||||
// Hooks
|
||||
export function useUsageSummary(startDate?: number, endDate?: number) {
|
||||
export function useUsageSummary(days: number) {
|
||||
return useQuery({
|
||||
queryKey: usageKeys.summary(startDate, endDate),
|
||||
queryFn: () => usageApi.getUsageSummary(startDate, endDate),
|
||||
queryKey: usageKeys.summary(days),
|
||||
queryFn: () => {
|
||||
const { startDate, endDate } = getWindow(days);
|
||||
return usageApi.getUsageSummary(startDate, endDate);
|
||||
},
|
||||
refetchInterval: 30000, // 每30秒自动刷新
|
||||
refetchIntervalInBackground: false, // 后台不刷新
|
||||
});
|
||||
}
|
||||
|
||||
export function useUsageTrends(days: number) {
|
||||
return useQuery({
|
||||
queryKey: usageKeys.trends(days),
|
||||
queryFn: () => usageApi.getUsageTrends(days),
|
||||
queryFn: () => {
|
||||
const { startDate, endDate } = getWindow(days);
|
||||
return usageApi.getUsageTrends(startDate, endDate);
|
||||
},
|
||||
refetchInterval: 30000, // 每30秒自动刷新
|
||||
refetchIntervalInBackground: false,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -38,6 +53,8 @@ export function useProviderStats() {
|
||||
return useQuery({
|
||||
queryKey: usageKeys.providerStats(),
|
||||
queryFn: usageApi.getProviderStats,
|
||||
refetchInterval: 30000, // 每30秒自动刷新
|
||||
refetchIntervalInBackground: false,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -45,6 +62,8 @@ export function useModelStats() {
|
||||
return useQuery({
|
||||
queryKey: usageKeys.modelStats(),
|
||||
queryFn: usageApi.getModelStats,
|
||||
refetchInterval: 30000, // 每30秒自动刷新
|
||||
refetchIntervalInBackground: false,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -56,6 +75,8 @@ export function useRequestLogs(
|
||||
return useQuery({
|
||||
queryKey: usageKeys.logs(filters, page, pageSize),
|
||||
queryFn: () => usageApi.getRequestLogs(filters, page, pageSize),
|
||||
refetchInterval: 30000, // 每30秒自动刷新
|
||||
refetchIntervalInBackground: false,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user