Compare commits

...

22 Commits

Author SHA1 Message Date
YoVinchen 766ca23eca feat(proxy): extract session ID from client requests for logging
- Add SessionIdExtractor to parse session ID from Claude/Codex requests
- Support extraction from metadata.user_id, headers, previous_response_id
- Pass session_id through RequestContext to usage logger
- Enable request correlation by session in proxy_request_logs
2025-12-31 21:31:20 +08:00
YoVinchen 48a8aabfb8 feat(proxy): enable transparent passthrough for headers
- Passthrough anthropic-beta header as-is from client
- Passthrough anthropic-version header from client
- Passthrough client IP headers (x-forwarded-for, x-real-ip) by default
- Filter private params (underscore-prefixed fields) from request body
- No database changes required
2025-12-31 20:59:59 +08:00
YoVinchen cf79b09597 feat(proxy): add body and header filtering for upstream requests 2025-12-31 17:15:00 +08:00
YoVinchen 6004084644 Align usage stats to sliding windows 2025-12-31 17:00:53 +08:00
YoVinchen 52199f39c1 refactor(proxy): distinguish circuit-open from no-provider errors 2025-12-31 14:37:33 +08:00
YoVinchen c69d36d457 refactor(proxy): remove retry logic and add enabled check for failover 2025-12-31 14:06:29 +08:00
YoVinchen a2aa969096 fix(proxy): handle zero value input in failover config fields 2025-12-31 12:54:49 +08:00
YoVinchen 02fd924119 fix(proxy): bypass timeout and retry configs when failover is disabled
When auto_failover_enabled is false, timeout and retry configurations
should not affect normal request flow. This change ensures:

- create_forwarder: passes 0 for all timeout/retry params when failover
  is disabled, effectively bypassing these checks
- streaming_timeout_config: returns 0 for both first_byte_timeout and
  idle_timeout when failover is disabled

This prevents unnecessary timeout errors and retry attempts when users
have explicitly disabled the failover feature.
2025-12-31 12:50:04 +08:00
YoVinchen fa4e7bcd82 refactor(proxy): improve header forwarding with blacklist approach
Change from whitelist to blacklist mode for request header forwarding.
Only skip headers that will be overridden (auth, host, content-length).
This preserves client's original headers and improves compatibility.
2025-12-31 10:21:57 +08:00
YoVinchen 43ccbdad47 feat(database): add Chinese AI model pricing data
Add pricing for domestic AI models (CNY/1M tokens):
- Doubao-Seed-Code (ByteDance)
- DeepSeek V3/V3.1/V3.2
- Kimi K2/K2-Thinking/K2-Turbo (Moonshot)
- MiniMax M2/M2.1/M2.1-Lightning
- GLM-4.6/4.7 (Zhipu)
- Mimo V2 Flash (Xiaomi)

Also fix test case to use correct model ID and remove invalid currency column.
2025-12-31 02:21:53 +08:00
YoVinchen b46fb6e6da refactor(usage): simplify model pricing lookup by removing suffix fallback
Replace complex suffix-stripping fallback with direct prefix/suffix cleanup.
Model IDs are now cleaned by removing vendor prefix (before /) and colon
suffix (after :), then matched exactly against pricing table.
2025-12-31 02:07:05 +08:00
YoVinchen 0efe0594e9 fix(proxy): improve takeover detection with live config check
- Add live config takeover detection for hot-switch decision
- Rebuild takeover when backup is missing or placeholder remains
- Make detect_takeover_in_live_config_for_app public
- Fix is_takeover_active to use actual takeover status
2025-12-31 01:44:02 +08:00
YoVinchen a80b2b98b7 fix(usage): correct cache token billing and add Codex format auto-detection
- Avoid double-billing cache tokens by subtracting from input before calculation
- Add smart Codex parser that auto-detects OpenAI vs Codex API format
- Extract model name from Codex responses for accurate tracking
2025-12-31 00:39:55 +08:00
YoVinchen f4f0590fd6 fix(usage): correct Gemini output token calculation
Fix Gemini API output token parsing to use totalTokenCount - promptTokenCount
instead of candidatesTokenCount alone. This ensures thoughtsTokenCount is
included in output statistics.

- Update from_gemini_response to calculate output from total - input
- Update from_gemini_stream_chunks with same logic for consistency
- Fix from_codex_stream_events to use adjusted token calculation
- Add test case for responses with thoughtsTokenCount
- Update existing tests to match new calculation logic
2025-12-30 23:57:14 +08:00
YoVinchen bfa5c3c526 feat(database): update model pricing data
- Update Claude models to full version format (e.g. claude-opus-4-5-20251101)
- Add GPT-5.2 series model pricing (10 models)
- Add GPT-5.1 series model pricing (10 models)
- Add GPT-5 series model pricing (12 models)
- Add Gemini 3 series model pricing (2 models)
- Update Gemini 2.5 series model ID format (use dot separator)
- Unify display names by removing thinking level suffixes
2025-12-30 23:49:13 +08:00
YoVinchen 9f212115c4 feat(proxy): enhance provider router logging
- Add debug logs for failover queue provider count
- Log circuit breaker state for each provider check
- Add logs for missing current provider scenarios
- Log when no current provider is configured
- Use inline format args for better readability

This improves debugging of provider selection and failover behavior.
2025-12-30 18:58:38 +08:00
YoVinchen da75f22a12 style(rust): use inline format args in format! macros
- Replace format!("...", var) with format!("...{var}")
- Update universal provider ID formatting
- Update error message formatting
- Update config.toml generation in Codex provider

Fixes clippy::uninlined_format_args warnings.
2025-12-30 18:58:15 +08:00
YoVinchen 3a692c84fb fix(proxy): improve usage logging and cache token parsing
- Log requests even when usage parsing fails (with default values)
- Add detailed debug logging for usage metrics
- Support cache_read_input_tokens field in Codex responses
- Fallback to input_tokens_details.cached_tokens if needed
- Add test case for cached_tokens in input_tokens_details
- Ensure all requests are tracked in database for analytics

This fixes missing request logs when API responses lack usage data
and improves cache token detection across different response formats.
2025-12-30 18:58:01 +08:00
YoVinchen fe08f69cac feat(usage): add auto-refresh for usage statistics
- Add 30-second auto-refresh interval for all usage queries
- Disable background refresh to save resources
- Apply to: summary, trends, provider stats, model stats, request logs
- Queries automatically update when tab is active
- Pause refresh when user switches to another tab

This keeps usage data fresh without manual refresh.
2025-12-30 18:56:56 +08:00
YoVinchen 4a1ee98784 fix(usage): fix timezone handling in datetime picker
- Add timestampToLocalDatetime() to convert Unix timestamp to local datetime
- Add localDatetimeToTimestamp() with validation for incomplete input
- Fix issue where typing hours/minutes would jump to previous day
- Validate datetime format completeness before conversion
- Use local timezone instead of UTC for datetime-local input

This resolves the issue where users couldn't fine-tune time selection
and the input would jump unexpectedly when editing hours or minutes.
2025-12-30 18:56:16 +08:00
YoVinchen 2901ead814 feat(usage): add cache metrics to trend chart
- Add cache creation tokens visualization (orange line)
- Add cache hit tokens visualization (purple line)
- Add gradient definitions for new cache metrics
- Include cache data in hourly aggregation
- Display cache metrics alongside input/output tokens

This provides better visibility into cache usage patterns over time.
2025-12-30 18:55:27 +08:00
YoVinchen e99d599ad8 i18n: update cache terminology across all languages
- Change 'Cache Read' to 'Cache Hit' in all languages
- Change 'Cache Write' to 'Cache Creation' in all languages
- Update zh: 缓存读取 → 缓存命中, 缓存写入 → 缓存创建
- Update en: Cache Read → Cache Hit, Cache Write → Cache Creation
- Update ja: キャッシュ読取 → キャッシュヒット, キャッシュ書込 → キャッシュ作成

Affected keys: cacheReadTokens, cacheCreationTokens, cacheReadCost,
cacheWriteCost, cacheRead, cacheWrite
2025-12-30 18:55:02 +08:00
31 changed files with 1888 additions and 583 deletions
+6 -7
View File
@@ -184,17 +184,16 @@ pub async fn reset_circuit_breaker(
.await?;
// 3. 检查是否应该切回优先级更高的供应商(从 proxy_config 表读取)
let auto_failover_enabled = match db.get_proxy_config_for_app(&app_type).await {
Ok(config) => config.auto_failover_enabled,
// 只有当该应用已被代理接管(enabled=true)且开启了自动故障转移时才执行
let (app_enabled, auto_failover_enabled) = match db.get_proxy_config_for_app(&app_type).await {
Ok(config) => (config.enabled, config.auto_failover_enabled),
Err(e) => {
log::error!(
"[{app_type}] Failed to read proxy_config for auto_failover_enabled: {e}, defaulting to disabled"
);
false
log::error!("[{app_type}] Failed to read proxy_config: {e}, defaulting to disabled");
(false, false)
}
};
if auto_failover_enabled && state.proxy_service.is_running().await {
if app_enabled && auto_failover_enabled && state.proxy_service.is_running().await {
// 获取当前供应商 ID
let current_id = db
.get_current_provider(&app_type)
+3 -2
View File
@@ -19,9 +19,10 @@ pub fn get_usage_summary(
#[tauri::command]
pub fn get_usage_trends(
state: State<'_, AppState>,
days: u32,
start_date: Option<i64>,
end_date: Option<i64>,
) -> Result<Vec<DailyStats>, AppError> {
state.db.get_daily_trends(days)
state.db.get_daily_trends(start_date, end_date)
}
/// 获取 Provider 统计
+224 -34
View File
@@ -765,9 +765,9 @@ impl Database {
/// 注意: model_id 使用短横线格式(如 claude-haiku-4-5),与 API 返回的模型名称标准化后一致
fn seed_model_pricing(conn: &Connection) -> Result<(), AppError> {
let pricing_data = [
// Claude 4.5 系列
// Claude 4.5 系列 (Latest Models)
(
"claude-opus-4-5",
"claude-opus-4-5-20251101",
"Claude Opus 4.5",
"5",
"25",
@@ -775,7 +775,7 @@ impl Database {
"6.25",
),
(
"claude-sonnet-4-5",
"claude-sonnet-4-5-20250929",
"Claude Sonnet 4.5",
"3",
"15",
@@ -783,16 +783,24 @@ impl Database {
"3.75",
),
(
"claude-haiku-4-5",
"claude-haiku-4-5-20251001",
"Claude Haiku 4.5",
"1",
"5",
"0.10",
"1.25",
),
// Claude 4.1 系列
// Claude 4 系列 (Legacy Models)
(
"claude-opus-4-1",
"claude-opus-4-20250514",
"Claude Opus 4",
"15",
"75",
"1.50",
"18.75",
),
(
"claude-opus-4-1-20250805",
"Claude Opus 4.1",
"15",
"75",
@@ -800,17 +808,8 @@ impl Database {
"18.75",
),
(
"claude-sonnet-4-1",
"Claude Sonnet 4.1",
"3",
"15",
"0.30",
"3.75",
),
// Claude 3.7 系列
(
"claude-sonnet-3-7",
"Claude Sonnet 3.7",
"claude-sonnet-4-20250514",
"Claude Sonnet 4",
"3",
"15",
"0.30",
@@ -818,38 +817,167 @@ impl Database {
),
// Claude 3.5 系列
(
"claude-sonnet-3-5",
"Claude Sonnet 3.5",
"3",
"15",
"0.30",
"3.75",
),
(
"claude-haiku-3-5",
"Claude Haiku 3.5",
"claude-3-5-haiku-20241022",
"Claude 3.5 Haiku",
"0.80",
"4",
"0.08",
"1",
),
// GPT-5 系列(model_id 使用短横线格式)
(
"claude-3-5-sonnet-20241022",
"Claude 3.5 Sonnet",
"3",
"15",
"0.30",
"3.75",
),
// GPT-5.2 系列
("gpt-5.2", "GPT-5.2", "1.75", "14", "0.175", "0"),
("gpt-5.2-low", "GPT-5.2", "1.75", "14", "0.175", "0"),
("gpt-5.2-medium", "GPT-5.2", "1.75", "14", "0.175", "0"),
("gpt-5.2-high", "GPT-5.2", "1.75", "14", "0.175", "0"),
("gpt-5.2-xhigh", "GPT-5.2", "1.75", "14", "0.175", "0"),
("gpt-5.2-codex", "GPT-5.2 Codex", "1.75", "14", "0.175", "0"),
(
"gpt-5.2-codex-low",
"GPT-5.2 Codex",
"1.75",
"14",
"0.175",
"0",
),
(
"gpt-5.2-codex-medium",
"GPT-5.2 Codex",
"1.75",
"14",
"0.175",
"0",
),
(
"gpt-5.2-codex-high",
"GPT-5.2 Codex",
"1.75",
"14",
"0.175",
"0",
),
(
"gpt-5.2-codex-xhigh",
"GPT-5.2 Codex",
"1.75",
"14",
"0.175",
"0",
),
// GPT-5.1 系列
("gpt-5.1", "GPT-5.1", "1.25", "10", "0.125", "0"),
("gpt-5.1-low", "GPT-5.1", "1.25", "10", "0.125", "0"),
("gpt-5.1-medium", "GPT-5.1", "1.25", "10", "0.125", "0"),
("gpt-5.1-high", "GPT-5.1", "1.25", "10", "0.125", "0"),
("gpt-5.1-minimal", "GPT-5.1", "1.25", "10", "0.125", "0"),
("gpt-5.1-codex", "GPT-5.1 Codex", "1.25", "10", "0.125", "0"),
(
"gpt-5.1-codex-mini",
"GPT-5.1 Codex",
"1.25",
"10",
"0.125",
"0",
),
(
"gpt-5.1-codex-max",
"GPT-5.1 Codex",
"1.25",
"10",
"0.125",
"0",
),
(
"gpt-5.1-codex-max-high",
"GPT-5.1 Codex",
"1.25",
"10",
"0.125",
"0",
),
(
"gpt-5.1-codex-max-xhigh",
"GPT-5.1 Codex",
"1.25",
"10",
"0.125",
"0",
),
// GPT-5 系列
("gpt-5", "GPT-5", "1.25", "10", "0.125", "0"),
("gpt-5-1", "GPT-5.1", "1.25", "10", "0.125", "0"),
("gpt-5-low", "GPT-5", "1.25", "10", "0.125", "0"),
("gpt-5-medium", "GPT-5", "1.25", "10", "0.125", "0"),
("gpt-5-high", "GPT-5", "1.25", "10", "0.125", "0"),
("gpt-5-minimal", "GPT-5", "1.25", "10", "0.125", "0"),
("gpt-5-codex", "GPT-5 Codex", "1.25", "10", "0.125", "0"),
("gpt-5-1-codex", "GPT-5.1 Codex", "1.25", "10", "0.125", "0"),
("gpt-5-codex-low", "GPT-5 Codex", "1.25", "10", "0.125", "0"),
(
"gpt-5-codex-medium",
"GPT-5 Codex",
"1.25",
"10",
"0.125",
"0",
),
(
"gpt-5-codex-high",
"GPT-5 Codex",
"1.25",
"10",
"0.125",
"0",
),
(
"gpt-5-codex-mini",
"GPT-5 Codex",
"1.25",
"10",
"0.125",
"0",
),
(
"gpt-5-codex-mini-medium",
"GPT-5 Codex",
"1.25",
"10",
"0.125",
"0",
),
(
"gpt-5-codex-mini-high",
"GPT-5 Codex",
"1.25",
"10",
"0.125",
"0",
),
// Gemini 3 系列
(
"gemini-3-pro-preview",
"Gemini 3 Pro Preview",
"2",
"12",
"0",
"0.2",
"0",
),
// Gemini 2.5 系列(model_id 使用短横线格式)
(
"gemini-2-5-pro",
"gemini-3-flash-preview",
"Gemini 3 Flash Preview",
"0.5",
"3",
"0.05",
"0",
),
// Gemini 2.5 系列
(
"gemini-2.5-pro",
"Gemini 2.5 Pro",
"1.25",
"10",
@@ -857,13 +985,75 @@ impl Database {
"0",
),
(
"gemini-2-5-flash",
"gemini-2.5-flash",
"Gemini 2.5 Flash",
"0.3",
"2.5",
"0.03",
"0",
),
// ====== 国产模型 (CNY/1M tokens) ======
// Doubao (字节跳动)
(
"doubao-seed-code",
"Doubao Seed Code",
"1.20",
"8.00",
"0.24",
"0",
),
// DeepSeek 系列
(
"deepseek-v3.2",
"DeepSeek V3.2",
"2.00",
"3.00",
"0.40",
"0",
),
(
"deepseek-v3.1",
"DeepSeek V3.1",
"4.00",
"12.00",
"0.80",
"0",
),
("deepseek-v3", "DeepSeek V3", "2.00", "8.00", "0.40", "0"),
// Kimi (月之暗面)
(
"kimi-k2-thinking",
"Kimi K2 Thinking",
"4.00",
"16.00",
"1.00",
"0",
),
("kimi-k2-0905", "Kimi K2", "4.00", "16.00", "1.00", "0"),
(
"kimi-k2-turbo",
"Kimi K2 Turbo",
"8.00",
"58.00",
"1.00",
"0",
),
// MiniMax 系列
("minimax-m2.1", "MiniMax M2.1", "2.10", "8.40", "0.21", "0"),
(
"minimax-m2.1-lightning",
"MiniMax M2.1 Lightning",
"2.10",
"16.80",
"0.21",
"0",
),
("minimax-m2", "MiniMax M2", "2.10", "8.40", "0.21", "0"),
// GLM (智谱)
("glm-4.7", "GLM-4.7", "2.00", "8.00", "0.40", "0"),
("glm-4.6", "GLM-4.6", "2.00", "8.00", "0.40", "0"),
// Mimo (小米)
("mimo-v2-flash", "Mimo V2 Flash", "0", "0", "0", "0"),
];
for (model_id, display_name, input, output, cache_read, cache_creation) in pricing_data {
+4
View File
@@ -52,6 +52,10 @@ pub enum AppError {
},
#[error("数据库错误: {0}")]
Database(String),
#[error("所有供应商已熔断,无可用渠道")]
AllProvidersCircuitOpen,
#[error("未配置供应商")]
NoProvidersConfigured,
}
impl AppError {
+4 -5
View File
@@ -386,16 +386,15 @@ impl UniversalProvider {
// 生成 Codex 的 config.toml 内容
let config_toml = format!(
r#"model_provider = "newapi"
model = "{}"
model_reasoning_effort = "{}"
model = "{model}"
model_reasoning_effort = "{reasoning_effort}"
disable_response_storage = true
[model_providers.newapi]
name = "NewAPI"
base_url = "{}"
base_url = "{codex_base_url}"
wire_api = "responses"
requires_openai_auth = true"#,
model, reasoning_effort, codex_base_url
requires_openai_auth = true"#
);
let settings_config = serde_json::json!({
+303
View File
@@ -0,0 +1,303 @@
//! 请求体过滤模块
//!
//! 过滤不应透传到上游的私有参数,防止内部信息泄露。
//!
//! ## 过滤规则
//! - 以 `_` 开头的字段被视为私有参数,会被递归过滤
//! - 支持白名单机制,允许透传特定的 `_` 前缀字段
//! - 支持嵌套对象和数组的深度过滤
//!
//! ## 使用场景
//! - `_internal_id`: 内部追踪 ID
//! - `_debug_mode`: 调试标记
//! - `_session_token`: 会话令牌
//! - `_client_version`: 客户端版本
use serde_json::Value;
use std::collections::HashSet;
/// 过滤私有参数(以 `_` 开头的字段)
///
/// 递归遍历 JSON 结构,移除所有以下划线开头的字段。
///
/// # Arguments
/// * `body` - 原始请求体
///
/// # Returns
/// 过滤后的请求体
///
/// # Example
/// ```ignore
/// let input = json!({
/// "model": "claude-3",
/// "_internal_id": "abc123",
/// "messages": [{"role": "user", "content": "hello", "_token": "secret"}]
/// });
/// let output = filter_private_params(input);
/// // output 中不包含 _internal_id 和 _token
/// ```
#[cfg(test)]
pub fn filter_private_params(body: Value) -> Value {
filter_private_params_with_whitelist(body, &[])
}
/// 过滤私有参数(支持白名单)
///
/// 递归遍历 JSON 结构,移除所有以下划线开头的字段,
/// 但保留白名单中指定的字段。
///
/// # Arguments
/// * `body` - 原始请求体
/// * `whitelist` - 白名单字段列表(不过滤这些字段)
///
/// # Returns
/// 过滤后的请求体
///
/// # Example
/// ```ignore
/// let input = json!({
/// "model": "claude-3",
/// "_metadata": {"key": "value"}, // 白名单中,保留
/// "_internal_id": "abc123" // 不在白名单中,过滤
/// });
/// let output = filter_private_params_with_whitelist(input, &["_metadata"]);
/// // output 包含 _metadata,不包含 _internal_id
/// ```
pub fn filter_private_params_with_whitelist(body: Value, whitelist: &[String]) -> Value {
let whitelist_set: HashSet<&str> = whitelist.iter().map(|s| s.as_str()).collect();
filter_recursive_with_whitelist(body, &mut Vec::new(), &whitelist_set)
}
/// 递归过滤实现
#[cfg(test)]
fn filter_recursive(value: Value, removed_keys: &mut Vec<String>) -> Value {
filter_recursive_with_whitelist(value, removed_keys, &HashSet::new())
}
/// 递归过滤实现(支持白名单)
fn filter_recursive_with_whitelist(
value: Value,
removed_keys: &mut Vec<String>,
whitelist: &HashSet<&str>,
) -> Value {
match value {
Value::Object(map) => {
let filtered: serde_json::Map<String, Value> = map
.into_iter()
.filter_map(|(key, val)| {
// 以 _ 开头且不在白名单中的字段被过滤
if key.starts_with('_') && !whitelist.contains(key.as_str()) {
removed_keys.push(key);
None
} else {
Some((
key,
filter_recursive_with_whitelist(val, removed_keys, whitelist),
))
}
})
.collect();
// 仅在有过滤时记录日志(避免每次请求都打印)
if !removed_keys.is_empty() {
log::debug!("[BodyFilter] 过滤私有参数: {removed_keys:?}");
removed_keys.clear();
}
Value::Object(filtered)
}
Value::Array(arr) => Value::Array(
arr.into_iter()
.map(|v| filter_recursive_with_whitelist(v, removed_keys, whitelist))
.collect(),
),
other => other,
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_filter_top_level_private_params() {
let input = json!({
"model": "claude-3",
"_internal_id": "abc123",
"_debug": true,
"max_tokens": 1024
});
let output = filter_private_params(input);
assert!(output.get("model").is_some());
assert!(output.get("max_tokens").is_some());
assert!(output.get("_internal_id").is_none());
assert!(output.get("_debug").is_none());
}
#[test]
fn test_filter_nested_private_params() {
let input = json!({
"model": "claude-3",
"messages": [
{
"role": "user",
"content": "hello",
"_session_token": "secret"
}
],
"metadata": {
"user_id": "user-1",
"_tracking_id": "track-1"
}
});
let output = filter_private_params(input);
// 顶级字段保留
assert!(output.get("model").is_some());
assert!(output.get("messages").is_some());
assert!(output.get("metadata").is_some());
// messages 数组中的私有参数被过滤
let messages = output.get("messages").unwrap().as_array().unwrap();
assert!(messages[0].get("role").is_some());
assert!(messages[0].get("content").is_some());
assert!(messages[0].get("_session_token").is_none());
// metadata 对象中的私有参数被过滤
let metadata = output.get("metadata").unwrap();
assert!(metadata.get("user_id").is_some());
assert!(metadata.get("_tracking_id").is_none());
}
#[test]
fn test_filter_deeply_nested() {
let input = json!({
"level1": {
"level2": {
"level3": {
"keep": "value",
"_remove": "secret"
}
}
}
});
let output = filter_private_params(input);
let level3 = output
.get("level1")
.unwrap()
.get("level2")
.unwrap()
.get("level3")
.unwrap();
assert!(level3.get("keep").is_some());
assert!(level3.get("_remove").is_none());
}
#[test]
fn test_filter_array_of_objects() {
let input = json!({
"items": [
{"id": 1, "_secret": "a"},
{"id": 2, "_secret": "b"},
{"id": 3, "_secret": "c"}
]
});
let output = filter_private_params(input);
let items = output.get("items").unwrap().as_array().unwrap();
for item in items {
assert!(item.get("id").is_some());
assert!(item.get("_secret").is_none());
}
}
#[test]
fn test_no_private_params() {
let input = json!({
"model": "claude-3",
"messages": [{"role": "user", "content": "hello"}]
});
let output = filter_private_params(input.clone());
// 无私有参数时,输出应与输入相同
assert_eq!(input, output);
}
#[test]
fn test_empty_object() {
let input = json!({});
let output = filter_private_params(input);
assert_eq!(output, json!({}));
}
#[test]
fn test_primitive_values() {
// 原始值不应被修改
assert_eq!(filter_private_params(json!(42)), json!(42));
assert_eq!(filter_private_params(json!("string")), json!("string"));
assert_eq!(filter_private_params(json!(true)), json!(true));
assert_eq!(filter_private_params(json!(null)), json!(null));
}
#[test]
fn test_whitelist_preserves_private_params() {
let input = json!({
"model": "claude-3",
"_metadata": {"key": "value"},
"_internal_id": "abc123",
"_stream_options": {"include_usage": true}
});
let whitelist = vec!["_metadata".to_string(), "_stream_options".to_string()];
let output = filter_private_params_with_whitelist(input, &whitelist);
// 白名单中的字段保留
assert!(output.get("_metadata").is_some());
assert!(output.get("_stream_options").is_some());
// 不在白名单中的私有字段被过滤
assert!(output.get("_internal_id").is_none());
// 普通字段保留
assert!(output.get("model").is_some());
}
#[test]
fn test_whitelist_nested() {
let input = json!({
"data": {
"_allowed": "keep",
"_forbidden": "remove",
"normal": "value"
}
});
let whitelist = vec!["_allowed".to_string()];
let output = filter_private_params_with_whitelist(input, &whitelist);
let data = output.get("data").unwrap();
assert!(data.get("_allowed").is_some());
assert!(data.get("_forbidden").is_none());
assert!(data.get("normal").is_some());
}
#[test]
fn test_empty_whitelist_same_as_default() {
let input = json!({
"model": "claude-3",
"_internal_id": "abc123"
});
let output1 = filter_private_params(input.clone());
let output2 = filter_private_params_with_whitelist(input, &[]);
assert_eq!(output1, output2);
}
}
+12
View File
@@ -23,6 +23,12 @@ pub enum ProxyError {
#[error("无可用的Provider")]
NoAvailableProvider,
#[error("所有供应商已熔断,无可用渠道")]
AllProvidersCircuitOpen,
#[error("未配置供应商")]
NoProvidersConfigured,
#[allow(dead_code)]
#[error("Provider不健康: {0}")]
ProviderUnhealthy(String),
@@ -111,6 +117,12 @@ impl IntoResponse for ProxyError {
ProxyError::NoAvailableProvider => {
(StatusCode::SERVICE_UNAVAILABLE, self.to_string())
}
ProxyError::AllProvidersCircuitOpen => {
(StatusCode::SERVICE_UNAVAILABLE, self.to_string())
}
ProxyError::NoProvidersConfigured => {
(StatusCode::SERVICE_UNAVAILABLE, self.to_string())
}
ProxyError::ProviderUnhealthy(_) => {
(StatusCode::SERVICE_UNAVAILABLE, self.to_string())
}
+8
View File
@@ -27,6 +27,12 @@ pub fn map_proxy_error_to_status(error: &ProxyError) -> u16 {
// 无可用 Provider503 Service Unavailable
ProxyError::NoAvailableProvider => 503,
// 所有供应商已熔断:503 Service Unavailable
ProxyError::AllProvidersCircuitOpen => 503,
// 未配置供应商:503 Service Unavailable
ProxyError::NoProvidersConfigured => 503,
// 重试耗尽:503 Service Unavailable
ProxyError::MaxRetriesExceeded => 503,
@@ -57,6 +63,8 @@ pub fn get_error_message(error: &ProxyError) -> String {
ProxyError::Timeout(msg) => format!("请求超时: {msg}"),
ProxyError::ForwardFailed(msg) => format!("转发失败: {msg}"),
ProxyError::NoAvailableProvider => "无可用 Provider".to_string(),
ProxyError::AllProvidersCircuitOpen => "所有供应商已熔断,无可用渠道".to_string(),
ProxyError::NoProvidersConfigured => "未配置供应商".to_string(),
ProxyError::MaxRetriesExceeded => "所有 Provider 都失败,重试耗尽".to_string(),
ProxyError::ProviderUnhealthy(msg) => format!("Provider 不健康: {msg}"),
ProxyError::DatabaseError(msg) => format!("数据库错误: {msg}"),
+15
View File
@@ -81,6 +81,21 @@ impl FailoverSwitchManager {
provider_id: &str,
provider_name: &str,
) -> Result<bool, AppError> {
// 检查该应用是否已被代理接管(enabled=true
// 只有被接管的应用才允许执行故障转移切换
let app_enabled = match self.db.get_proxy_config_for_app(app_type).await {
Ok(config) => config.enabled,
Err(e) => {
log::warn!("[Failover] 无法读取 {app_type} 配置: {e},跳过切换");
return Ok(false);
}
};
if !app_enabled {
log::info!("[Failover] {app_type} 未被代理接管(enabled=false),跳过切换");
return Ok(false);
}
log::info!("[Failover] 开始切换供应商: {app_type} -> {provider_name} ({provider_id})");
// 1. 更新数据库 is_current
+187 -97
View File
@@ -1,8 +1,9 @@
//! 请求转发器
//!
//! 负责将请求转发到上游Provider,支持重试和故障转移
//! 负责将请求转发到上游Provider,支持故障转移
use super::{
body_filter::filter_private_params_with_whitelist,
error::*,
failover_switch::FailoverSwitchManager,
provider_router::ProviderRouter,
@@ -17,6 +18,71 @@ use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
/// Headers 黑名单 - 不透传到上游的 Headers
///
/// 参考 Claude Code Hub 设计,过滤以下类别:
/// 1. 认证类(会被覆盖)
/// 2. 连接类(由 HTTP 客户端管理)
/// 3. 代理转发类
/// 4. CDN/云服务商特定头
/// 5. 请求追踪类
/// 6. 浏览器特定头(可能被上游检测)
///
/// 注意:客户端 IP 类(x-forwarded-for, x-real-ip)默认透传
const HEADER_BLACKLIST: &[&str] = &[
// 认证类(会被覆盖)
"authorization",
"x-api-key",
// 连接类
"host",
"content-length",
"connection",
"transfer-encoding",
// 编码类(会被覆盖为 identity)
"accept-encoding",
// 代理转发类(保留 x-forwarded-for 和 x-real-ip
"x-forwarded-host",
"x-forwarded-port",
"x-forwarded-proto",
"forwarded",
// CDN/云服务商特定头
"cf-connecting-ip",
"cf-ipcountry",
"cf-ray",
"cf-visitor",
"true-client-ip",
"fastly-client-ip",
"x-azure-clientip",
"x-azure-fdid",
"x-azure-ref",
"akamai-origin-hop",
"x-akamai-config-log-detail",
// 请求追踪类
"x-request-id",
"x-correlation-id",
"x-trace-id",
"x-amzn-trace-id",
"x-b3-traceid",
"x-b3-spanid",
"x-b3-parentspanid",
"x-b3-sampled",
"traceparent",
"tracestate",
// 浏览器特定头(可能被上游检测为非 CLI 请求)
"sec-fetch-mode",
"sec-fetch-site",
"sec-fetch-dest",
"sec-ch-ua",
"sec-ch-ua-mobile",
"sec-ch-ua-platform",
"accept-language",
// anthropic-beta 单独处理,避免重复
"anthropic-beta",
// 客户端 IP 单独处理(默认透传)
"x-forwarded-for",
"x-real-ip",
];
pub struct ForwardResult {
pub response: Response,
pub provider: Provider,
@@ -31,8 +97,6 @@ pub struct RequestForwarder {
client: Client,
/// 共享的 ProviderRouter(持有熔断器状态)
router: Arc<ProviderRouter>,
/// 单个 Provider 内的最大重试次数
max_retries: u8,
status: Arc<RwLock<ProxyStatus>>,
current_providers: Arc<RwLock<std::collections::HashMap<String, (String, String)>>>,
/// 故障转移切换管理器
@@ -48,7 +112,6 @@ impl RequestForwarder {
pub fn new(
router: Arc<ProviderRouter>,
non_streaming_timeout: u64,
max_retries: u8,
status: Arc<RwLock<ProxyStatus>>,
current_providers: Arc<RwLock<std::collections::HashMap<String, (String, String)>>>,
failover_manager: Arc<FailoverSwitchManager>,
@@ -77,7 +140,6 @@ impl RequestForwarder {
Self {
client,
router,
max_retries,
status,
current_providers,
failover_manager,
@@ -86,59 +148,6 @@ impl RequestForwarder {
}
}
/// 对单个 Provider 执行请求(带重试)
///
/// 在同一个 Provider 上最多重试 max_retries 次,使用指数退避
async fn forward_with_provider_retry(
&self,
provider: &Provider,
endpoint: &str,
body: &Value,
headers: &axum::http::HeaderMap,
adapter: &dyn ProviderAdapter,
) -> Result<Response, ProxyError> {
let mut last_error = None;
for attempt in 0..=self.max_retries {
if attempt > 0 {
// 指数退避:100ms, 200ms, 400ms, ...
let delay_ms = 100 * 2u64.pow(attempt as u32 - 1);
log::info!(
"[{}] 重试第 {}/{} 次(等待 {}ms",
adapter.name(),
attempt,
self.max_retries,
delay_ms
);
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
match self
.forward(provider, endpoint, body, headers, adapter)
.await
{
Ok(response) => return Ok(response),
Err(e) => {
// 只有“同一 Provider 内可重试”的错误才继续重试
if !self.should_retry_same_provider(&e) {
return Err(e);
}
log::debug!(
"[{}] Provider {} 第 {} 次请求失败: {}",
adapter.name(),
provider.name,
attempt + 1,
e
);
last_error = Some(e);
}
}
}
Err(last_error.unwrap_or(ProxyError::MaxRetriesExceeded))
}
/// 转发请求(带故障转移)
///
/// # Arguments
@@ -224,9 +233,9 @@ impl RequestForwarder {
let start = Instant::now();
// 转发请求(带单 Provider 内重试
// 转发请求(每个 Provider 只尝试一次,重试由客户端控制
match self
.forward_with_provider_retry(provider, endpoint, &body, &headers, adapter.as_ref())
.forward(provider, endpoint, &body, &headers, adapter.as_ref())
.await
{
Ok(response) => {
@@ -477,6 +486,28 @@ impl RequestForwarder {
mapped_body
};
// 过滤私有参数(以 `_` 开头的字段),防止内部信息泄露到上游
// 默认使用空白名单,过滤所有 _ 前缀字段
let filtered_body = filter_private_params_with_whitelist(request_body, &[]);
// ========== 请求体日志(截断显示) ==========
let body_str = serde_json::to_string_pretty(&filtered_body)
.unwrap_or_else(|_| filtered_body.to_string());
let body_preview = if body_str.len() > 2000 {
format!(
"{}...\n[截断,总长度: {} 字符]",
&body_str[..2000],
body_str.len()
)
} else {
body_str
};
log::info!(
"[{}] ====== 最终请求体 ======\n{}",
adapter.name(),
body_preview
);
log::info!(
"[{}] 转发请求: {} -> {}",
adapter.name(),
@@ -487,28 +518,73 @@ impl RequestForwarder {
// 构建请求
let mut request = self.client.post(&url);
// 只透传必要的 Headers(白名单模式)
let allowed_headers = [
"accept",
"user-agent",
"x-request-id",
"x-stainless-arch",
"x-stainless-lang",
"x-stainless-os",
"x-stainless-package-version",
"x-stainless-runtime",
"x-stainless-runtime-version",
];
// ========== 详细 Headers 日志 ==========
log::info!("[{}] ====== 客户端原始 Headers ======", adapter.name());
for (key, value) in headers {
log::info!(
"[{}] {}: {:?}",
adapter.name(),
key.as_str(),
value.to_str().unwrap_or("<binary>")
);
}
// 过滤黑名单 Headers,保护隐私并避免冲突
let mut filtered_headers: Vec<String> = Vec::new();
let mut passed_headers: Vec<(String, String)> = Vec::new();
for (key, value) in headers {
let key_str = key.as_str().to_lowercase();
if allowed_headers.contains(&key_str.as_str()) {
request = request.header(key, value);
if HEADER_BLACKLIST.contains(&key_str.as_str()) {
filtered_headers.push(key_str);
continue;
}
let value_str = value.to_str().unwrap_or("<binary>").to_string();
passed_headers.push((key.as_str().to_string(), value_str.clone()));
request = request.header(key, value);
}
if !filtered_headers.is_empty() {
log::info!(
"[{}] ====== 被过滤的 Headers ({}) ======",
adapter.name(),
filtered_headers.len()
);
for h in &filtered_headers {
log::info!("[{}] - {}", adapter.name(), h);
}
}
// 确保 Content-Type 是 json
request = request.header("Content-Type", "application/json");
// 处理 anthropic-beta Header(透传)
// 参考 Claude Code Hub 的实现,直接透传客户端的 beta 标记
if let Some(beta) = headers.get("anthropic-beta") {
if let Ok(beta_str) = beta.to_str() {
request = request.header("anthropic-beta", beta_str);
passed_headers.push(("anthropic-beta".to_string(), beta_str.to_string()));
log::info!("[{}] 透传 anthropic-beta: {}", adapter.name(), beta_str);
}
}
// 客户端 IP 透传(默认开启)
if let Some(xff) = headers.get("x-forwarded-for") {
if let Ok(xff_str) = xff.to_str() {
request = request.header("x-forwarded-for", xff_str);
passed_headers.push(("x-forwarded-for".to_string(), xff_str.to_string()));
log::debug!("[{}] 透传 x-forwarded-for: {}", adapter.name(), xff_str);
}
}
if let Some(real_ip) = headers.get("x-real-ip") {
if let Ok(real_ip_str) = real_ip.to_str() {
request = request.header("x-real-ip", real_ip_str);
passed_headers.push(("x-real-ip".to_string(), real_ip_str.to_string()));
log::debug!("[{}] 透传 x-real-ip: {}", adapter.name(), real_ip_str);
}
}
// 禁用压缩,避免 gzip 流式响应解析错误
// 参考 CCH: undici 在连接提前关闭时会对不完整的 gzip 流抛出错误
request = request.header("accept-encoding", "identity");
passed_headers.push(("accept-encoding".to_string(), "identity".to_string()));
// 使用适配器添加认证头
if let Some(auth) = adapter.extract_auth(provider) {
@@ -519,6 +595,15 @@ impl RequestForwarder {
auth.masked_key()
);
request = adapter.add_auth_headers(request, &auth);
// 记录认证头(脱敏)
passed_headers.push((
"authorization".to_string(),
format!("Bearer {}...", &auth.api_key[..8.min(auth.api_key.len())]),
));
passed_headers.push((
"x-api-key".to_string(),
format!("{}...", &auth.api_key[..8.min(auth.api_key.len())]),
));
} else {
log::error!(
"[{}] 未找到 API KeyProvider: {}",
@@ -527,9 +612,34 @@ impl RequestForwarder {
);
}
// anthropic-version 透传:优先使用客户端的版本号
// 参考 Claude Code Hub:透传客户端值而非固定版本
if let Some(version) = headers.get("anthropic-version") {
if let Ok(version_str) = version.to_str() {
// 覆盖适配器设置的默认版本
request = request.header("anthropic-version", version_str);
passed_headers.push(("anthropic-version".to_string(), version_str.to_string()));
log::info!(
"[{}] 透传 anthropic-version: {}",
adapter.name(),
version_str
);
}
}
// ========== 最终发送的 Headers 日志 ==========
log::info!(
"[{}] ====== 最终发送的 Headers ({}) ======",
adapter.name(),
passed_headers.len()
);
for (k, v) in &passed_headers {
log::info!("[{}] {}: {}", adapter.name(), k, v);
}
// 发送请求
log::info!("[{}] 发送请求到: {}", adapter.name(), url);
let response = request.json(&request_body).send().await.map_err(|e| {
let response = request.json(&filtered_body).send().await.map_err(|e| {
log::error!("[{}] 请求失败: {}", adapter.name(), e);
if e.is_timeout() {
ProxyError::Timeout(format!("请求超时: {e}"))
@@ -563,25 +673,6 @@ impl RequestForwarder {
}
}
/// 分类ProxyError
///
/// 决定哪些错误应该触发故障转移到下一个 Provider
///
/// 设计原则:既然用户配置了多个供应商,就应该让所有供应商都尝试一遍。
/// 只有明确是客户端中断的情况才不重试。
fn should_retry_same_provider(&self, error: &ProxyError) -> bool {
match error {
// 网络类错误:短暂抖动时同一 Provider 内重试有意义
ProxyError::Timeout(_) => true,
ProxyError::ForwardFailed(_) => true,
// 上游 HTTP 错误:只对“可能瞬态”的状态码做同 Provider 重试(其余交给 failover
ProxyError::UpstreamError { status, .. } => {
*status == 408 || *status == 429 || *status >= 500
}
_ => false,
}
}
fn categorize_proxy_error(&self, error: &ProxyError) -> ErrorCategory {
match error {
// 网络和上游错误:都应该尝试下一个供应商
@@ -597,7 +688,6 @@ impl RequestForwarder {
ProxyError::TransformError(_) => ErrorCategory::Retryable,
ProxyError::AuthError(_) => ErrorCategory::Retryable,
ProxyError::StreamIdleTimeout(_) => ErrorCategory::Retryable,
ProxyError::MaxRetriesExceeded => ErrorCategory::Retryable,
// 无可用供应商:所有供应商都试过了,无法重试
ProxyError::NoAvailableProvider => ErrorCategory::NonRetryable,
// 其他错误(数据库/内部错误等):不是换供应商能解决的问题
+11 -7
View File
@@ -58,10 +58,10 @@ fn openai_model_extractor(events: &[Value], request_model: &str) -> String {
.to_string()
}
/// Codex Responses API 流式响应模型提取(优先使用 usage.model
fn codex_model_extractor(events: &[Value], request_model: &str) -> String {
/// Codex 智能流式响应模型提取(自动检测格式
fn codex_auto_model_extractor(events: &[Value], request_model: &str) -> String {
// 首先尝试从解析的 usage 中获取模型
if let Some(usage) = TokenUsage::from_codex_stream_events(events) {
if let Some(usage) = TokenUsage::from_codex_stream_events_auto(events) {
if let Some(model) = usage.model {
return model;
}
@@ -76,6 +76,10 @@ fn codex_model_extractor(events: &[Value], request_model: &str) -> String {
None
}
})
.or_else(|| {
// 再回退:从 OpenAI 格式事件中提取
events.iter().find_map(|e| e.get("model")?.as_str())
})
.unwrap_or(request_model)
.to_string()
}
@@ -111,11 +115,11 @@ pub const OPENAI_PARSER_CONFIG: UsageParserConfig = UsageParserConfig {
app_type_str: "codex",
};
/// Codex Responses API 解析配置(用于 /v1/responses
/// Codex 智能解析配置(自动检测 OpenAI 或 Codex 格式
pub const CODEX_PARSER_CONFIG: UsageParserConfig = UsageParserConfig {
stream_parser: TokenUsage::from_codex_stream_events,
response_parser: TokenUsage::from_codex_response,
model_extractor: codex_model_extractor,
stream_parser: TokenUsage::from_codex_stream_events_auto,
response_parser: TokenUsage::from_codex_response_auto,
model_extractor: codex_auto_model_extractor,
app_type_str: "codex",
};
+72 -12
View File
@@ -5,8 +5,10 @@
use crate::app_config::AppType;
use crate::provider::Provider;
use crate::proxy::{
forwarder::RequestForwarder, server::ProxyState, types::AppProxyConfig, ProxyError,
extract_session_id, forwarder::RequestForwarder, server::ProxyState, types::AppProxyConfig,
ProxyError,
};
use axum::http::HeaderMap;
use std::time::Instant;
/// 流式超时配置
@@ -26,6 +28,7 @@ pub struct StreamingTimeoutConfig {
/// - 选中的 Provider 列表(用于故障转移)
/// - 请求模型名称
/// - 日志标签
/// - Session ID(用于日志关联)
pub struct RequestContext {
/// 请求开始时间
pub start_time: Instant,
@@ -35,7 +38,7 @@ pub struct RequestContext {
pub provider: Provider,
/// 完整的 Provider 列表(用于故障转移)
providers: Vec<Provider>,
/// 请求开始时的当前供应商(用于判断是否需要同步 UI/托盘)
/// 请求开始时的"当前供应商"(用于判断是否需要同步 UI/托盘)
///
/// 这里使用本地 settings 的设备级 current provider。
/// 代理模式下如果实际使用的 provider 与此不一致,会触发切换以确保 UI 始终准确。
@@ -49,6 +52,8 @@ pub struct RequestContext {
/// 应用类型(预留,目前通过 app_type_str 使用)
#[allow(dead_code)]
pub app_type: AppType,
/// Session ID(从客户端请求提取或新生成)
pub session_id: String,
}
impl RequestContext {
@@ -57,6 +62,7 @@ impl RequestContext {
/// # Arguments
/// * `state` - 代理服务器状态
/// * `body` - 请求体 JSON
/// * `headers` - 请求头(用于提取 Session ID
/// * `app_type` - 应用类型
/// * `tag` - 日志标签
/// * `app_type_str` - 应用类型字符串
@@ -66,6 +72,7 @@ impl RequestContext {
pub async fn new(
state: &ProxyState,
body: &serde_json::Value,
headers: &HeaderMap,
app_type: AppType,
tag: &'static str,
app_type_str: &'static str,
@@ -89,13 +96,31 @@ impl RequestContext {
.unwrap_or("unknown")
.to_string();
// 提取 Session ID
let session_result = extract_session_id(headers, body, app_type_str);
let session_id = session_result.session_id.clone();
log::debug!(
"[{}] Session ID: {} (from {:?}, client_provided: {})",
tag,
session_id,
session_result.source,
session_result.client_provided
);
// 使用共享的 ProviderRouter 选择 Provider(熔断器状态跨请求保持)
// 注意:只在这里调用一次,结果传递给 forwarder,避免重复消耗 HalfOpen 名额
let providers = state
.provider_router
.select_providers(app_type_str)
.await
.map_err(|e| ProxyError::DatabaseError(e.to_string()))?;
.map_err(|e| match e {
crate::error::AppError::AllProvidersCircuitOpen => {
ProxyError::AllProvidersCircuitOpen
}
crate::error::AppError::NoProvidersConfigured => ProxyError::NoProvidersConfigured,
_ => ProxyError::DatabaseError(e.to_string()),
})?;
let provider = providers
.first()
@@ -103,11 +128,12 @@ impl RequestContext {
.ok_or(ProxyError::NoAvailableProvider)?;
log::info!(
"[{}] Provider: {}, model: {}, failover chain: {} providers",
"[{}] Provider: {}, model: {}, failover chain: {} providers, session: {}",
tag,
provider.name,
request_model,
providers.len()
providers.len(),
session_id
);
Ok(Self {
@@ -120,6 +146,7 @@ impl RequestContext {
tag,
app_type_str,
app_type,
session_id,
})
}
@@ -148,18 +175,38 @@ impl RequestContext {
/// 创建 RequestForwarder
///
/// 使用共享的 ProviderRouter,确保熔断器状态跨请求保持
///
/// 配置生效规则:
/// - 故障转移开启:超时配置正常生效(0 表示禁用超时)
/// - 故障转移关闭:超时配置不生效(全部传入 0)
pub fn create_forwarder(&self, state: &ProxyState) -> RequestForwarder {
let (non_streaming_timeout, first_byte_timeout, idle_timeout) =
if self.app_config.auto_failover_enabled {
// 故障转移开启:使用配置的值(0 = 禁用超时)
(
self.app_config.non_streaming_timeout as u64,
self.app_config.streaming_first_byte_timeout as u64,
self.app_config.streaming_idle_timeout as u64,
)
} else {
// 故障转移关闭:不启用超时配置
log::info!(
"[{}] Failover disabled, timeout configs are bypassed",
self.tag
);
(0, 0, 0)
};
RequestForwarder::new(
state.provider_router.clone(),
self.app_config.non_streaming_timeout as u64,
self.app_config.max_retries as u8,
non_streaming_timeout,
state.status.clone(),
state.current_providers.clone(),
state.failover_manager.clone(),
state.app_handle.clone(),
self.current_provider_id.clone(),
self.app_config.streaming_first_byte_timeout as u64,
self.app_config.streaming_idle_timeout as u64,
first_byte_timeout,
idle_timeout,
)
}
@@ -177,11 +224,24 @@ impl RequestContext {
}
/// 获取流式超时配置
///
/// 配置生效规则:
/// - 故障转移开启:返回配置的值(0 表示禁用超时检查)
/// - 故障转移关闭:返回 0(禁用超时检查)
#[inline]
pub fn streaming_timeout_config(&self) -> StreamingTimeoutConfig {
StreamingTimeoutConfig {
first_byte_timeout: self.app_config.streaming_first_byte_timeout as u64,
idle_timeout: self.app_config.streaming_idle_timeout as u64,
if self.app_config.auto_failover_enabled {
// 故障转移开启:使用配置的值(0 = 禁用超时)
StreamingTimeoutConfig {
first_byte_timeout: self.app_config.streaming_first_byte_timeout as u64,
idle_timeout: self.app_config.streaming_idle_timeout as u64,
}
} else {
// 故障转移关闭:禁用流式超时检查
StreamingTimeoutConfig {
first_byte_timeout: 0,
idle_timeout: 0,
}
}
}
}
+9 -6
View File
@@ -61,7 +61,8 @@ pub async fn handle_messages(
headers: axum::http::HeaderMap,
Json(body): Json<Value>,
) -> Result<axum::response::Response, ProxyError> {
let mut ctx = RequestContext::new(&state, &body, AppType::Claude, "Claude", "claude").await?;
let mut ctx =
RequestContext::new(&state, &body, &headers, AppType::Claude, "Claude", "claude").await?;
let is_stream = body
.get("stream")
@@ -305,7 +306,8 @@ pub async fn handle_chat_completions(
) -> Result<axum::response::Response, ProxyError> {
log::info!("[Codex] ====== /v1/chat/completions 请求开始 ======");
let mut ctx = RequestContext::new(&state, &body, AppType::Codex, "Codex", "codex").await?;
let mut ctx =
RequestContext::new(&state, &body, &headers, AppType::Codex, "Codex", "codex").await?;
let is_stream = body
.get("stream")
@@ -353,7 +355,8 @@ pub async fn handle_responses(
headers: axum::http::HeaderMap,
Json(body): Json<Value>,
) -> Result<axum::response::Response, ProxyError> {
let mut ctx = RequestContext::new(&state, &body, AppType::Codex, "Codex", "codex").await?;
let mut ctx =
RequestContext::new(&state, &body, &headers, AppType::Codex, "Codex", "codex").await?;
let is_stream = body
.get("stream")
@@ -401,7 +404,7 @@ pub async fn handle_gemini(
Json(body): Json<Value>,
) -> Result<axum::response::Response, ProxyError> {
// Gemini 的模型名称在 URI 中
let mut ctx = RequestContext::new(&state, &body, AppType::Gemini, "Gemini", "gemini")
let mut ctx = RequestContext::new(&state, &body, &headers, AppType::Gemini, "Gemini", "gemini")
.await?
.with_model_from_uri(&uri);
@@ -465,7 +468,7 @@ fn log_forward_error(
let request_id = uuid::Uuid::new_v4().to_string();
if let Err(e) = logger.log_error_with_context(
request_id.clone(),
request_id,
ctx.provider.id.clone(),
ctx.app_type_str.to_string(),
ctx.request_model.clone(),
@@ -473,7 +476,7 @@ fn log_forward_error(
error_message,
ctx.latency_ms(),
is_streaming,
Some(request_id),
Some(ctx.session_id.clone()),
None,
) {
log::warn!("记录失败请求日志失败: {e}");
+4 -1
View File
@@ -2,6 +2,7 @@
//!
//! 提供本地HTTP代理服务,支持多Provider故障转移和请求透传
pub mod body_filter;
pub mod circuit_breaker;
pub mod error;
pub mod error_mapper;
@@ -33,7 +34,9 @@ pub use provider_router::ProviderRouter;
#[allow(unused_imports)]
pub use response_handler::{NonStreamHandler, ResponseType, StreamHandler};
#[allow(unused_imports)]
pub use session::{ClientFormat, ProxySession};
pub use session::{
extract_session_id, ClientFormat, ProxySession, SessionIdResult, SessionIdSource,
};
#[allow(unused_imports)]
pub use types::{ProxyConfig, ProxyServerInfo, ProxyStatus};
+32 -8
View File
@@ -34,6 +34,8 @@ impl ProviderRouter {
/// - 故障转移开启时:完全按照故障转移队列顺序返回,忽略当前供应商设置
pub async fn select_providers(&self, app_type: &str) -> Result<Vec<Provider>, AppError> {
let mut result = Vec::new();
let mut total_providers = 0usize;
let mut circuit_open_count = 0usize;
// 检查该应用的自动故障转移开关是否开启(从 proxy_config 表读取)
let auto_failover_enabled = match self.db.get_proxy_config_for_app(app_type).await {
@@ -53,18 +55,26 @@ impl ProviderRouter {
if auto_failover_enabled {
// 故障转移开启:使用 in_failover_queue 标记的供应商,按 sort_index 排序
let failover_providers = self.db.get_failover_providers(app_type)?;
total_providers = failover_providers.len();
log::debug!("[{app_type}] Found {total_providers} failover queue provider(s)");
log::info!(
"[{}] Failover enabled, using queue order ({} items)",
app_type,
failover_providers.len()
"[{app_type}] Failover enabled, using queue order ({total_providers} items)"
);
for provider in failover_providers {
// 检查熔断器状态
let circuit_key = format!("{}:{}", app_type, provider.id);
let breaker = self.get_or_create_circuit_breaker(&circuit_key).await;
let state = breaker.get_state().await;
if breaker.is_available().await {
log::debug!(
"[{}] Queue provider available: {} ({}) (state: {:?})",
app_type,
provider.name,
provider.id,
state
);
log::info!(
"[{}] Queue provider available: {} ({}) at sort_index {:?}",
app_type,
@@ -74,10 +84,12 @@ impl ProviderRouter {
);
result.push(provider);
} else {
circuit_open_count += 1;
log::debug!(
"[{}] Queue provider {} circuit breaker open, skipping",
"[{}] Queue provider {} circuit breaker open (state: {:?}), skipping",
app_type,
provider.name
provider.name,
state
);
}
}
@@ -94,15 +106,27 @@ impl ProviderRouter {
current.name,
current.id
);
total_providers = 1;
result.push(current);
} else {
log::debug!(
"[{app_type}] Current provider id {current_id} not found in database"
);
}
} else {
log::debug!("[{app_type}] No current provider configured");
}
}
if result.is_empty() {
return Err(AppError::Config(format!(
"No available provider for {app_type} (all circuit breakers open or no providers configured)"
)));
// 区分两种情况:全部熔断 vs 未配置供应商
if total_providers > 0 && circuit_open_count == total_providers {
log::warn!("[{app_type}] 所有 {total_providers} 个供应商均已熔断,无可用渠道");
return Err(AppError::AllProvidersCircuitOpen);
} else {
log::warn!("[{app_type}] 未配置供应商或故障转移队列为空");
return Err(AppError::NoProvidersConfigured);
}
}
log::info!(
+58 -1
View File
@@ -112,6 +112,19 @@ pub async fn handle_non_streaming(
spawn_log_usage(state, ctx, usage, &model, status.as_u16(), false);
} else {
let model = json_value
.get("model")
.and_then(|m| m.as_str())
.unwrap_or(&ctx.request_model)
.to_string();
spawn_log_usage(
state,
ctx,
TokenUsage::default(),
&model,
status.as_u16(),
false,
);
log::debug!(
"[{}] 未能解析 usage 信息,跳过记录",
parser_config.app_type_str
@@ -123,6 +136,14 @@ pub async fn handle_non_streaming(
ctx.tag,
body_bytes.len()
);
spawn_log_usage(
state,
ctx,
TokenUsage::default(),
&ctx.request_model,
status.as_u16(),
false,
);
}
log::info!("[{}] ====== 请求结束 ======", ctx.tag);
@@ -243,6 +264,7 @@ fn create_usage_collector(
let start_time = ctx.start_time;
let stream_parser = parser_config.stream_parser;
let model_extractor = parser_config.model_extractor;
let session_id = ctx.session_id.clone();
SseUsageCollector::new(start_time, move |events, first_token_ms| {
if let Some(usage) = stream_parser(&events) {
@@ -251,6 +273,7 @@ fn create_usage_collector(
let state = state.clone();
let provider_id = provider_id.clone();
let session_id = session_id.clone();
tokio::spawn(async move {
log_usage_internal(
@@ -263,10 +286,32 @@ fn create_usage_collector(
first_token_ms,
true, // is_streaming
status_code,
Some(session_id),
)
.await;
});
} else {
let model = model_extractor(&events, &request_model);
let latency_ms = start_time.elapsed().as_millis() as u64;
let state = state.clone();
let provider_id = provider_id.clone();
let session_id = session_id.clone();
tokio::spawn(async move {
log_usage_internal(
&state,
&provider_id,
app_type_str,
&model,
TokenUsage::default(),
latency_ms,
first_token_ms,
true, // is_streaming
status_code,
Some(session_id),
)
.await;
});
log::debug!("[{tag}] 流式响应缺少 usage 统计,跳过消费记录");
}
})
@@ -286,6 +331,7 @@ fn spawn_log_usage(
let app_type_str = ctx.app_type_str.to_string();
let model = model.to_string();
let latency_ms = ctx.latency_ms();
let session_id = ctx.session_id.clone();
tokio::spawn(async move {
log_usage_internal(
@@ -298,6 +344,7 @@ fn spawn_log_usage(
None,
is_streaming,
status_code,
Some(session_id),
)
.await;
});
@@ -315,6 +362,7 @@ async fn log_usage_internal(
first_token_ms: Option<u64>,
is_streaming: bool,
status_code: u16,
session_id: Option<String>,
) {
use super::usage::logger::UsageLogger;
@@ -338,6 +386,15 @@ async fn log_usage_internal(
let request_id = uuid::Uuid::new_v4().to_string();
log::debug!(
"[{app_type}] 记录请求日志: id={request_id}, provider={provider_id}, model={model}, streaming={is_streaming}, status={status_code}, latency_ms={latency_ms}, first_token_ms={first_token_ms:?}, session={}, input={}, output={}, cache_read={}, cache_creation={}",
session_id.as_deref().unwrap_or("none"),
usage.input_tokens,
usage.output_tokens,
usage.cache_read_tokens,
usage.cache_creation_tokens
);
if let Err(e) = logger.log_with_calculation(
request_id,
provider_id.to_string(),
@@ -348,7 +405,7 @@ async fn log_usage_internal(
latency_ms,
first_token_ms,
status_code,
None,
session_id,
None, // provider_type
is_streaming,
) {
+269
View File
@@ -1,7 +1,15 @@
//! Proxy Session - 请求会话管理
//!
//! 为每个代理请求创建会话上下文,在整个请求生命周期中跟踪状态和元数据。
//!
//! ## Session ID 提取
//!
//! 支持从客户端请求中提取 Session ID,用于关联同一对话的多个请求:
//! - Claude: 从 `metadata.user_id` (格式: `user_xxx_session_yyy`) 或 `metadata.session_id` 提取
//! - Codex: 从 `previous_response_id` 或 headers 中的 `session_id` 提取
//! - 其他: 生成新的 UUID
use axum::http::HeaderMap;
use std::time::Instant;
use uuid::Uuid;
@@ -176,6 +184,179 @@ impl ProxySession {
}
}
// ============================================================================
// Session ID 提取器
// ============================================================================
/// Session ID 来源
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SessionIdSource {
/// 从 metadata.user_id 提取 (Claude)
MetadataUserId,
/// 从 metadata.session_id 提取
MetadataSessionId,
/// 从 headers 提取 (Codex)
Header,
/// 从 previous_response_id 提取 (Codex)
PreviousResponseId,
/// 新生成
Generated,
}
/// Session ID 提取结果
#[derive(Debug, Clone)]
pub struct SessionIdResult {
/// 提取或生成的 Session ID
pub session_id: String,
/// Session ID 来源
pub source: SessionIdSource,
/// 是否为客户端提供的 ID(非新生成)
pub client_provided: bool,
}
/// 从请求中提取或生成 Session ID
///
/// 轻量化实现,仅提取 session_id 用于日志记录,不做复杂的 Session 管理。
///
/// ## 提取优先级
///
/// ### Claude 请求
/// 1. `metadata.user_id` (格式: `user_xxx_session_yyy`) → 提取 `yyy` 部分
/// 2. `metadata.session_id` → 直接使用
/// 3. 生成新 UUID
///
/// ### Codex 请求
/// 1. Headers: `session_id` 或 `x-session-id`
/// 2. `metadata.session_id`
/// 3. `previous_response_id` (对话延续)
/// 4. 生成新 UUID
///
/// ## 示例
///
/// ```ignore
/// let result = extract_session_id(&headers, &body, "claude");
/// println!("Session ID: {} (from {:?})", result.session_id, result.source);
/// ```
pub fn extract_session_id(
headers: &HeaderMap,
body: &serde_json::Value,
client_format: &str,
) -> SessionIdResult {
// Codex 请求特殊处理
if client_format == "codex" || client_format == "openai" {
if let Some(result) = extract_codex_session(headers, body) {
return result;
}
}
// Claude 请求:从 metadata 提取
if let Some(result) = extract_from_metadata(body) {
return result;
}
// 兜底:生成新 Session ID
generate_new_session_id()
}
/// 提取 Codex Session ID
fn extract_codex_session(headers: &HeaderMap, body: &serde_json::Value) -> Option<SessionIdResult> {
// 1. 从 headers 提取
for header_name in &["session_id", "x-session-id"] {
if let Some(value) = headers.get(*header_name) {
if let Ok(session_id) = value.to_str() {
// Codex Session ID 通常较长(UUID 格式)
if session_id.len() > 20 {
return Some(SessionIdResult {
session_id: format!("codex_{session_id}"),
source: SessionIdSource::Header,
client_provided: true,
});
}
}
}
}
// 2. 从 body.metadata.session_id 提取
if let Some(session_id) = body
.get("metadata")
.and_then(|m| m.get("session_id"))
.and_then(|v| v.as_str())
{
if session_id.len() > 10 {
return Some(SessionIdResult {
session_id: format!("codex_{session_id}"),
source: SessionIdSource::MetadataSessionId,
client_provided: true,
});
}
}
// 3. 从 previous_response_id 提取(对话延续)
if let Some(prev_id) = body.get("previous_response_id").and_then(|v| v.as_str()) {
if prev_id.len() > 10 {
return Some(SessionIdResult {
session_id: format!("codex_{prev_id}"),
source: SessionIdSource::PreviousResponseId,
client_provided: true,
});
}
}
None
}
/// 从 metadata 提取 Session ID (Claude)
fn extract_from_metadata(body: &serde_json::Value) -> Option<SessionIdResult> {
let metadata = body.get("metadata")?;
// 1. 从 metadata.user_id 提取(格式: user_xxx_session_yyy
if let Some(user_id) = metadata.get("user_id").and_then(|v| v.as_str()) {
if let Some(session_id) = parse_session_from_user_id(user_id) {
return Some(SessionIdResult {
session_id,
source: SessionIdSource::MetadataUserId,
client_provided: true,
});
}
}
// 2. 直接从 metadata.session_id 提取
if let Some(session_id) = metadata.get("session_id").and_then(|v| v.as_str()) {
if !session_id.is_empty() {
return Some(SessionIdResult {
session_id: session_id.to_string(),
source: SessionIdSource::MetadataSessionId,
client_provided: true,
});
}
}
None
}
/// 从 user_id 解析 session_id
///
/// 格式: `user_identifier_session_actual_session_id`
fn parse_session_from_user_id(user_id: &str) -> Option<String> {
// 查找 "_session_" 分隔符
if let Some(pos) = user_id.find("_session_") {
let session_id = &user_id[pos + 9..]; // "_session_" 长度为 9
if !session_id.is_empty() {
return Some(session_id.to_string());
}
}
None
}
/// 生成新的 Session ID
fn generate_new_session_id() -> SessionIdResult {
SessionIdResult {
session_id: Uuid::new_v4().to_string(),
source: SessionIdSource::Generated,
client_provided: false,
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -295,4 +476,92 @@ mod tests {
assert_eq!(ClientFormat::GeminiCli.as_str(), "gemini_cli");
assert_eq!(ClientFormat::Unknown.as_str(), "unknown");
}
// ========== Session ID 提取测试 ==========
#[test]
fn test_extract_session_from_claude_metadata_user_id() {
let headers = HeaderMap::new();
let body = json!({
"model": "claude-3-5-sonnet",
"messages": [{"role": "user", "content": "Hello"}],
"metadata": {
"user_id": "user_john_doe_session_abc123def456"
}
});
let result = extract_session_id(&headers, &body, "claude");
assert_eq!(result.session_id, "abc123def456");
assert_eq!(result.source, SessionIdSource::MetadataUserId);
assert!(result.client_provided);
}
#[test]
fn test_extract_session_from_claude_metadata_session_id() {
let headers = HeaderMap::new();
let body = json!({
"model": "claude-3-5-sonnet",
"messages": [{"role": "user", "content": "Hello"}],
"metadata": {
"session_id": "my-session-123"
}
});
let result = extract_session_id(&headers, &body, "claude");
assert_eq!(result.session_id, "my-session-123");
assert_eq!(result.source, SessionIdSource::MetadataSessionId);
assert!(result.client_provided);
}
#[test]
fn test_extract_session_from_codex_previous_response_id() {
let headers = HeaderMap::new();
let body = json!({
"input": "Write a function",
"previous_response_id": "resp_abc123def456789"
});
let result = extract_session_id(&headers, &body, "codex");
assert_eq!(result.session_id, "codex_resp_abc123def456789");
assert_eq!(result.source, SessionIdSource::PreviousResponseId);
assert!(result.client_provided);
}
#[test]
fn test_extract_session_generates_new_when_not_found() {
let headers = HeaderMap::new();
let body = json!({
"model": "claude-3-5-sonnet",
"messages": [{"role": "user", "content": "Hello"}]
});
let result = extract_session_id(&headers, &body, "claude");
assert!(!result.session_id.is_empty());
assert_eq!(result.source, SessionIdSource::Generated);
assert!(!result.client_provided);
}
#[test]
fn test_parse_session_from_user_id() {
assert_eq!(
parse_session_from_user_id("user_john_session_abc123"),
Some("abc123".to_string())
);
assert_eq!(
parse_session_from_user_id("my_app_session_xyz789"),
Some("xyz789".to_string())
);
// 注意: "_session_" 是分隔符,所以下面的字符串会匹配
assert_eq!(
parse_session_from_user_id("no_session_marker"),
Some("marker".to_string())
);
// 没有 "_session_" 分隔符的情况
assert_eq!(parse_session_from_user_id("user_john_abc123"), None);
assert_eq!(parse_session_from_user_id("_session_"), None);
}
}
+13 -5
View File
@@ -35,6 +35,11 @@ impl CostCalculator {
/// - `usage`: Token 使用量
/// - `pricing`: 模型定价
/// - `cost_multiplier`: 成本倍数 (provider 自定义)
///
/// # 计算逻辑
/// - input_cost: (input_tokens - cache_read_tokens) × 输入价格
/// - cache_read_cost: cache_read_tokens × 缓存读取价格
/// - 这样避免缓存部分被重复计费
pub fn calculate(
usage: &TokenUsage,
pricing: &ModelPricing,
@@ -42,7 +47,10 @@ impl CostCalculator {
) -> CostBreakdown {
let million = Decimal::from(1_000_000);
let input_cost = Decimal::from(usage.input_tokens) * pricing.input_cost_per_million
// 计算实际需要按输入价格计费的 token 数(减去缓存命中部分)
let billable_input_tokens = usage.input_tokens.saturating_sub(usage.cache_read_tokens);
let input_cost = Decimal::from(billable_input_tokens) * pricing.input_cost_per_million
/ million
* cost_multiplier;
let output_cost = Decimal::from(usage.output_tokens) * pricing.output_cost_per_million
@@ -113,8 +121,8 @@ mod tests {
let cost = CostCalculator::calculate(&usage, &pricing, multiplier);
// input: 1000 * 3.0 / 1M = 0.003
assert_eq!(cost.input_cost, Decimal::from_str("0.003").unwrap());
// input: (1000 - 200) * 3.0 / 1M = 0.0024 (只计算非缓存部分)
assert_eq!(cost.input_cost, Decimal::from_str("0.0024").unwrap());
// output: 500 * 15.0 / 1M = 0.0075
assert_eq!(cost.output_cost, Decimal::from_str("0.0075").unwrap());
// cache_read: 200 * 0.3 / 1M = 0.00006
@@ -124,8 +132,8 @@ mod tests {
cost.cache_creation_cost,
Decimal::from_str("0.000375").unwrap()
);
// total: 0.003 + 0.0075 + 0.00006 + 0.000375 = 0.010935
assert_eq!(cost.total_cost, Decimal::from_str("0.010935").unwrap());
// total: 0.0024 + 0.0075 + 0.00006 + 0.000375 = 0.010335
assert_eq!(cost.total_cost, Decimal::from_str("0.010335").unwrap());
}
#[test]
+281 -18
View File
@@ -163,13 +163,21 @@ impl TokenUsage {
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let cached_tokens = usage
.get("cache_read_input_tokens")
.and_then(|v| v.as_u64())
.or_else(|| {
usage
.get("input_tokens_details")
.and_then(|d| d.get("cached_tokens"))
.and_then(|v| v.as_u64())
})
.unwrap_or(0) as u32;
Some(Self {
input_tokens: input_tokens? as u32,
output_tokens: output_tokens? as u32,
cache_read_tokens: usage
.get("cache_read_input_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32,
cache_read_tokens: cached_tokens,
cache_creation_tokens: usage
.get("cache_creation_input_tokens")
.and_then(|v| v.as_u64())
@@ -188,16 +196,27 @@ impl TokenUsage {
let input_tokens = usage.get("input_tokens")?.as_u64()? as u32;
let output_tokens = usage.get("output_tokens")?.as_u64()? as u32;
// 获取 cached_tokens (可能在 input_tokens_details 中)
// 获取 cached_tokens (可能在 cache_read_input_tokens 或 input_tokens_details 中)
let cached_tokens = usage
.get("input_tokens_details")
.and_then(|d| d.get("cached_tokens"))
.get("cache_read_input_tokens")
.and_then(|v| v.as_u64())
.or_else(|| {
usage
.get("input_tokens_details")
.and_then(|d| d.get("cached_tokens"))
.and_then(|v| v.as_u64())
})
.unwrap_or(0) as u32;
// 调整 input_tokens: 减去 cached_tokens
let adjusted_input = input_tokens.saturating_sub(cached_tokens);
// 提取响应中的模型名称
let model = body
.get("model")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
Some(Self {
input_tokens: adjusted_input,
output_tokens,
@@ -206,7 +225,7 @@ impl TokenUsage {
.get("cache_creation_input_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32,
model: None,
model,
})
}
@@ -220,7 +239,7 @@ impl TokenUsage {
if event_type == "response.completed" {
if let Some(response) = event.get("response") {
log::debug!("[Codex] 找到 response.completed 事件,解析 usage");
return Self::from_codex_response(response);
return Self::from_codex_response_adjusted(response);
}
}
}
@@ -229,6 +248,51 @@ impl TokenUsage {
None
}
/// 智能 Codex 响应解析 - 自动检测 OpenAI 或 Codex 格式
///
/// Codex 支持两种 API 格式:
/// - `/v1/responses`: 使用 input_tokens/output_tokens
/// - `/v1/chat/completions`: 使用 prompt_tokens/completion_tokens (OpenAI 格式)
///
/// 注意:记录原始 input_tokens,费用计算时再减去 cached_tokens
pub fn from_codex_response_auto(body: &Value) -> Option<Self> {
let usage = body.get("usage")?;
// 检测格式:OpenAI 使用 prompt_tokensCodex 使用 input_tokens
if usage.get("prompt_tokens").is_some() {
log::debug!("[Codex] 检测到 OpenAI 格式 (prompt_tokens)");
Self::from_openai_response(body)
} else if usage.get("input_tokens").is_some() {
log::debug!("[Codex] 检测到 Codex 格式 (input_tokens)");
// 使用非调整版本,记录原始 input_tokens
Self::from_codex_response(body)
} else {
log::debug!("[Codex] 无法识别响应格式,usage: {usage:?}");
None
}
}
/// 智能 Codex 流式响应解析 - 自动检测 OpenAI 或 Codex 格式
pub fn from_codex_stream_events_auto(events: &[Value]) -> Option<Self> {
log::debug!("[Codex] 智能解析流式事件,共 {} 个事件", events.len());
// 先尝试 Codex Responses API 格式 (response.completed 事件)
for event in events {
if let Some(event_type) = event.get("type").and_then(|v| v.as_str()) {
if event_type == "response.completed" {
if let Some(response) = event.get("response") {
log::debug!("[Codex] 找到 response.completed 事件");
return Self::from_codex_response_auto(response);
}
}
}
}
// 回退到 OpenAI Chat Completions 格式 (最后一个 chunk 包含 usage)
log::debug!("[Codex] 尝试 OpenAI 流式格式");
Self::from_openai_stream_events(events)
}
/// 从 OpenAI Chat Completions API 响应解析 (prompt_tokens, completion_tokens)
pub fn from_openai_response(body: &Value) -> Option<Self> {
let usage = body.get("usage")?;
@@ -284,9 +348,16 @@ impl TokenUsage {
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let prompt_tokens = usage.get("promptTokenCount")?.as_u64()? as u32;
let total_tokens = usage.get("totalTokenCount")?.as_u64()? as u32;
// 输出 tokens = 总 tokens - 输入 tokens
// 这包含了 candidatesTokenCount + thoughtsTokenCount
let output_tokens = total_tokens.saturating_sub(prompt_tokens);
Some(Self {
input_tokens: usage.get("promptTokenCount")?.as_u64()? as u32,
output_tokens: usage.get("candidatesTokenCount")?.as_u64()? as u32,
input_tokens: prompt_tokens,
output_tokens,
cache_read_tokens: usage
.get("cachedContentTokenCount")
.and_then(|v| v.as_u64())
@@ -300,20 +371,25 @@ impl TokenUsage {
#[allow(dead_code)]
pub fn from_gemini_stream_chunks(chunks: &[Value]) -> Option<Self> {
let mut total_input = 0u32;
let mut total_output = 0u32;
let mut total_tokens = 0u32;
let mut total_cache_read = 0u32;
let mut model: Option<String> = None;
for chunk in chunks {
if let Some(usage) = chunk.get("usageMetadata") {
// 输入 tokens (通常在所有 chunk 中保持不变)
total_input = usage
.get("promptTokenCount")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32;
total_output += usage
.get("candidatesTokenCount")
// 总 tokens (包含输入 + 输出 + 思考)
total_tokens = usage
.get("totalTokenCount")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32;
// 缓存读取 tokens
total_cache_read = usage
.get("cachedContentTokenCount")
.and_then(|v| v.as_u64())
@@ -328,6 +404,9 @@ impl TokenUsage {
}
}
// 输出 tokens = 总 tokens - 输入 tokens
let total_output = total_tokens.saturating_sub(total_input);
if total_input > 0 || total_output > 0 {
Some(Self {
input_tokens: total_input,
@@ -466,15 +545,18 @@ mod tests {
let response = json!({
"modelVersion": "gemini-3-pro-high",
"usageMetadata": {
"promptTokenCount": 100,
"promptTokenCount": 8383,
"candidatesTokenCount": 50,
"thoughtsTokenCount": 114,
"totalTokenCount": 8547,
"cachedContentTokenCount": 20
}
});
let usage = TokenUsage::from_gemini_response(&response).unwrap();
assert_eq!(usage.input_tokens, 100);
assert_eq!(usage.output_tokens, 50);
assert_eq!(usage.input_tokens, 8383);
// output_tokens = totalTokenCount - promptTokenCount = 8547 - 8383 = 164
assert_eq!(usage.output_tokens, 164);
assert_eq!(usage.cache_read_tokens, 20);
assert_eq!(usage.cache_creation_tokens, 0);
assert_eq!(usage.model, Some("gemini-3-pro-high".to_string()));
@@ -486,19 +568,78 @@ mod tests {
let response = json!({
"usageMetadata": {
"promptTokenCount": 100,
"candidatesTokenCount": 50,
"totalTokenCount": 150,
"cachedContentTokenCount": 20
}
});
let usage = TokenUsage::from_gemini_response(&response).unwrap();
assert_eq!(usage.input_tokens, 100);
// output_tokens = totalTokenCount - promptTokenCount = 150 - 100 = 50
assert_eq!(usage.output_tokens, 50);
assert_eq!(usage.cache_read_tokens, 20);
assert_eq!(usage.cache_creation_tokens, 0);
assert_eq!(usage.model, None);
}
#[test]
fn test_gemini_response_with_thoughts() {
// 测试包含 thoughtsTokenCount 的实际响应
// 这是用户报告的真实场景
let response = json!({
"candidates": [
{
"content": {
"parts": [
{
"text": "",
"thoughtSignature": "EvcECvQE..."
}
],
"role": "model"
},
"finishReason": "STOP"
}
],
"modelVersion": "gemini-3-pro-high",
"responseId": "yupTafqLDu-PjMcPhrOx4QQ",
"usageMetadata": {
"candidatesTokenCount": 50,
"promptTokenCount": 8383,
"thoughtsTokenCount": 114,
"totalTokenCount": 8547
}
});
let usage = TokenUsage::from_gemini_response(&response).unwrap();
assert_eq!(usage.input_tokens, 8383);
// output_tokens = totalTokenCount - promptTokenCount
// = 8547 - 8383 = 164 (包含 candidatesTokenCount 50 + thoughtsTokenCount 114)
assert_eq!(usage.output_tokens, 164);
assert_eq!(usage.cache_read_tokens, 0);
assert_eq!(usage.cache_creation_tokens, 0);
assert_eq!(usage.model, Some("gemini-3-pro-high".to_string()));
}
#[test]
fn test_codex_response_parsing_cached_tokens_in_details() {
let response = json!({
"usage": {
"input_tokens": 1000,
"output_tokens": 500,
"input_tokens_details": {
"cached_tokens": 300
}
}
});
let usage = TokenUsage::from_codex_response(&response).unwrap();
// 非调整模式:input_tokens 保持原值,但应记录缓存命中
assert_eq!(usage.input_tokens, 1000);
assert_eq!(usage.output_tokens, 500);
assert_eq!(usage.cache_read_tokens, 300);
}
#[test]
fn test_codex_response_adjusted() {
let response = json!({
@@ -534,6 +675,22 @@ mod tests {
assert_eq!(usage.cache_read_tokens, 0);
}
#[test]
fn test_codex_response_adjusted_cache_read_input_tokens() {
let response = json!({
"usage": {
"input_tokens": 1000,
"output_tokens": 500,
"cache_read_input_tokens": 200
}
});
let usage = TokenUsage::from_codex_response_adjusted(&response).unwrap();
assert_eq!(usage.input_tokens, 800);
assert_eq!(usage.output_tokens, 500);
assert_eq!(usage.cache_read_tokens, 200);
}
#[test]
fn test_codex_response_adjusted_saturating_sub() {
// 测试 cached_tokens > input_tokens 的边界情况
@@ -615,4 +772,110 @@ mod tests {
assert_eq!(usage.cache_read_tokens, 50);
assert_eq!(usage.model, Some("claude-sonnet-4-20250514".to_string()));
}
// ============================================================================
// 智能 Codex 解析测试
// ============================================================================
#[test]
fn test_codex_response_auto_openai_format() {
// OpenAI 格式 (prompt_tokens/completion_tokens)
let response = json!({
"model": "gpt-4o",
"usage": {
"prompt_tokens": 1000,
"completion_tokens": 500,
"prompt_tokens_details": {
"cached_tokens": 200
}
}
});
let usage = TokenUsage::from_codex_response_auto(&response).unwrap();
assert_eq!(usage.input_tokens, 1000);
assert_eq!(usage.output_tokens, 500);
assert_eq!(usage.cache_read_tokens, 200);
assert_eq!(usage.model, Some("gpt-4o".to_string()));
}
#[test]
fn test_codex_response_auto_codex_format() {
// Codex 格式 (input_tokens/output_tokens)
let response = json!({
"model": "o3",
"usage": {
"input_tokens": 1000,
"output_tokens": 500,
"input_tokens_details": {
"cached_tokens": 300
}
}
});
let usage = TokenUsage::from_codex_response_auto(&response).unwrap();
// 记录原始 input_tokens,不调整
assert_eq!(usage.input_tokens, 1000);
assert_eq!(usage.output_tokens, 500);
assert_eq!(usage.cache_read_tokens, 300);
assert_eq!(usage.model, Some("o3".to_string()));
}
#[test]
fn test_codex_stream_events_auto_codex_format() {
// Codex Responses API 流式格式 (response.completed 事件)
let events = vec![
json!({
"type": "response.created",
"response": {
"id": "resp_123"
}
}),
json!({
"type": "response.completed",
"response": {
"model": "o3",
"usage": {
"input_tokens": 1000,
"output_tokens": 500,
"input_tokens_details": {
"cached_tokens": 200
}
}
}
}),
];
let usage = TokenUsage::from_codex_stream_events_auto(&events).unwrap();
// 记录原始 input_tokens,不调整
assert_eq!(usage.input_tokens, 1000);
assert_eq!(usage.output_tokens, 500);
assert_eq!(usage.cache_read_tokens, 200);
assert_eq!(usage.model, Some("o3".to_string()));
}
#[test]
fn test_codex_stream_events_auto_openai_format() {
// OpenAI Chat Completions 流式格式 (最后一个 chunk 包含 usage)
let events = vec![
json!({
"id": "chatcmpl-123",
"model": "gpt-4o",
"choices": [{"delta": {"content": "Hello"}}]
}),
json!({
"id": "chatcmpl-123",
"model": "gpt-4o",
"choices": [{"delta": {}}],
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50
}
}),
];
let usage = TokenUsage::from_codex_stream_events_auto(&events).unwrap();
assert_eq!(usage.input_tokens, 100);
assert_eq!(usage.output_tokens, 50);
assert_eq!(usage.model, Some("gpt-4o".to_string()));
}
}
+11 -8
View File
@@ -217,9 +217,12 @@ impl ProviderService {
.flatten()
.is_some();
let is_proxy_running = futures::executor::block_on(state.proxy_service.is_running());
let live_taken_over = state
.proxy_service
.detect_takeover_in_live_config_for_app(&app_type);
// Hot-switch only when BOTH: this app is taken over AND proxy server is actually running
let should_hot_switch = is_app_taken_over && is_proxy_running;
let should_hot_switch = (is_app_taken_over || live_taken_over) && is_proxy_running;
if should_hot_switch {
// Proxy takeover mode: hot-switch only, don't write Live config
@@ -736,15 +739,15 @@ impl ProviderService {
// 删除生成的子供应商
if let Some(p) = provider {
if p.apps.claude {
let claude_id = format!("universal-claude-{}", id);
let claude_id = format!("universal-claude-{id}");
let _ = state.db.delete_provider("claude", &claude_id);
}
if p.apps.codex {
let codex_id = format!("universal-codex-{}", id);
let codex_id = format!("universal-codex-{id}");
let _ = state.db.delete_provider("codex", &codex_id);
}
if p.apps.gemini {
let gemini_id = format!("universal-gemini-{}", id);
let gemini_id = format!("universal-gemini-{id}");
let _ = state.db.delete_provider("gemini", &gemini_id);
}
}
@@ -757,7 +760,7 @@ impl ProviderService {
let provider = state
.db
.get_universal_provider(id)?
.ok_or_else(|| AppError::Message(format!("统一供应商 {} 不存在", id)))?;
.ok_or_else(|| AppError::Message(format!("统一供应商 {id} 不存在")))?;
// 同步到 Claude
if let Some(mut claude_provider) = provider.to_claude_provider() {
@@ -770,7 +773,7 @@ impl ProviderService {
state.db.save_provider("claude", &claude_provider)?;
} else {
// 如果禁用了 Claude,删除对应的子供应商
let claude_id = format!("universal-claude-{}", id);
let claude_id = format!("universal-claude-{id}");
let _ = state.db.delete_provider("claude", &claude_id);
}
@@ -784,7 +787,7 @@ impl ProviderService {
}
state.db.save_provider("codex", &codex_provider)?;
} else {
let codex_id = format!("universal-codex-{}", id);
let codex_id = format!("universal-codex-{id}");
let _ = state.db.delete_provider("codex", &codex_id);
}
@@ -798,7 +801,7 @@ impl ProviderService {
}
state.db.save_provider("gemini", &gemini_provider)?;
} else {
let gemini_id = format!("universal-gemini-{}", id);
let gemini_id = format!("universal-gemini-{id}");
let _ = state.db.delete_provider("gemini", &gemini_id);
}
+20 -7
View File
@@ -193,7 +193,7 @@ impl ProxyService {
self.start().await?;
}
// 2) 已接管则直接返回(幂等)
// 2) 已接管则直接返回(幂等);但如果缺少备份或占位符残留,需要重建接管
let current_config = self
.db
.get_proxy_config_for_app(app_type_str)
@@ -201,7 +201,22 @@ impl ProxyService {
.map_err(|e| format!("获取 {app_type_str} 配置失败: {e}"))?;
if current_config.enabled {
return Ok(());
let has_backup = match self.db.get_live_backup(app_type_str).await {
Ok(v) => v.is_some(),
Err(e) => {
log::warn!("读取 {app_type_str} 备份失败(将继续重建接管): {e}");
false
}
};
let live_taken_over = self.detect_takeover_in_live_config_for_app(&app);
if has_backup || live_taken_over {
return Ok(());
}
log::warn!(
"{app_type_str} 标记为已接管,但缺少备份或占位符,正在重新接管并补齐备份"
);
}
// 3) 备份 Live 配置(严格:目标 app 不存在则报错)
@@ -1063,7 +1078,7 @@ impl ProxyService {
}
}
fn detect_takeover_in_live_config_for_app(&self, app_type: &AppType) -> bool {
pub fn detect_takeover_in_live_config_for_app(&self, app_type: &AppType) -> bool {
match app_type {
AppType::Claude => match self.read_claude_live() {
Ok(config) => Self::is_claude_live_taken_over(&config),
@@ -1257,10 +1272,8 @@ impl ProxyService {
/// 检查是否处于 Live 接管模式
pub async fn is_takeover_active(&self) -> Result<bool, String> {
self.db
.is_live_takeover_active()
.await
.map_err(|e| format!("检查接管状态失败: {e}"))
let status = self.get_takeover_status().await?;
Ok(status.claude || status.codex || status.gemini)
}
/// 从异常退出中恢复(启动时调用)
+166 -255
View File
@@ -4,7 +4,7 @@
use crate::database::{lock_conn, Database};
use crate::error::AppError;
use chrono::{Duration, Local, TimeZone};
use chrono::{Local, TimeZone};
use rusqlite::{params, Connection, OptionalExtension};
use serde::{Deserialize, Serialize};
use serde_json::Value;
@@ -181,145 +181,114 @@ impl Database {
Ok(result)
}
/// 获取每日趋势
pub fn get_daily_trends(&self, days: u32) -> Result<Vec<DailyStats>, AppError> {
/// 获取每日趋势(滑动窗口,<=24h 按小时,>24h 按天,窗口与汇总一致)
pub fn get_daily_trends(
&self,
start_date: Option<i64>,
end_date: Option<i64>,
) -> Result<Vec<DailyStats>, AppError> {
let conn = lock_conn!(self.conn);
if days <= 1 {
let today = Local::now().date_naive();
let start_of_today = today.and_hms_opt(0, 0, 0).unwrap();
// 使用 earliest() 处理 DST 切换时的歧义时间,fallback 到当前时间减一天
let start_ts = Local
.from_local_datetime(&start_of_today)
.earliest()
.unwrap_or_else(|| Local::now() - Duration::days(1))
.timestamp();
let end_ts = end_date.unwrap_or_else(|| Local::now().timestamp());
let mut start_ts = start_date.unwrap_or_else(|| end_ts - 24 * 60 * 60);
let sql = "SELECT
strftime('%Y-%m-%dT%H:00:00', datetime(created_at, 'unixepoch', 'localtime')) as bucket,
COUNT(*) as request_count,
COALESCE(SUM(CAST(total_cost_usd AS REAL)), 0) as total_cost,
COALESCE(SUM(input_tokens + output_tokens), 0) as total_tokens,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens
FROM proxy_request_logs
WHERE created_at >= ?
GROUP BY bucket
ORDER BY bucket ASC";
let mut stmt = conn.prepare(sql)?;
let rows = stmt.query_map([start_ts], |row| {
Ok(DailyStats {
date: row.get(0)?,
request_count: row.get::<_, i64>(1)? as u64,
total_cost: format!("{:.6}", row.get::<_, f64>(2)?),
total_tokens: row.get::<_, i64>(3)? as u64,
total_input_tokens: row.get::<_, i64>(4)? as u64,
total_output_tokens: row.get::<_, i64>(5)? as u64,
total_cache_creation_tokens: row.get::<_, i64>(6)? as u64,
total_cache_read_tokens: row.get::<_, i64>(7)? as u64,
})
})?;
let mut buckets: HashMap<String, DailyStats> = HashMap::new();
for row in rows {
let stat = row?;
buckets.insert(stat.date.clone(), stat);
}
let mut stats = Vec::new();
for hour in 0..24 {
let bucket = today
.and_hms_opt(hour, 0, 0)
.unwrap()
.format("%Y-%m-%dT%H:00:00")
.to_string();
if let Some(stat) = buckets.remove(&bucket) {
stats.push(stat);
} else {
stats.push(DailyStats {
date: bucket,
request_count: 0,
total_cost: "0.000000".to_string(),
total_tokens: 0,
total_input_tokens: 0,
total_output_tokens: 0,
total_cache_creation_tokens: 0,
total_cache_read_tokens: 0,
});
}
}
Ok(stats)
} else {
let today = Local::now().date_naive();
let start_day = today - Duration::days((days.saturating_sub(1)) as i64);
let start_of_window = start_day.and_hms_opt(0, 0, 0).unwrap();
// 使用 earliest() 处理 DST 切换时的歧义时间,fallback 到当前时间减 days 天
let start_ts = Local
.from_local_datetime(&start_of_window)
.earliest()
.unwrap_or_else(|| Local::now() - Duration::days(days as i64))
.timestamp();
let sql = "SELECT
strftime('%Y-%m-%dT00:00:00', datetime(created_at, 'unixepoch', 'localtime')) as bucket,
COUNT(*) as request_count,
COALESCE(SUM(CAST(total_cost_usd AS REAL)), 0) as total_cost,
COALESCE(SUM(input_tokens + output_tokens), 0) as total_tokens,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens
FROM proxy_request_logs
WHERE created_at >= ?
GROUP BY bucket
ORDER BY bucket ASC";
let mut stmt = conn.prepare(sql)?;
let rows = stmt.query_map([start_ts], |row| {
Ok(DailyStats {
date: row.get(0)?,
request_count: row.get::<_, i64>(1)? as u64,
total_cost: format!("{:.6}", row.get::<_, f64>(2)?),
total_tokens: row.get::<_, i64>(3)? as u64,
total_input_tokens: row.get::<_, i64>(4)? as u64,
total_output_tokens: row.get::<_, i64>(5)? as u64,
total_cache_creation_tokens: row.get::<_, i64>(6)? as u64,
total_cache_read_tokens: row.get::<_, i64>(7)? as u64,
})
})?;
let mut map = HashMap::new();
for row in rows {
let stat = row?;
map.insert(stat.date.clone(), stat);
}
let mut stats = Vec::new();
for i in 0..days {
let day = start_day + Duration::days(i as i64);
let key = day.format("%Y-%m-%dT00:00:00").to_string();
if let Some(stat) = map.remove(&key) {
stats.push(stat);
} else {
stats.push(DailyStats {
date: key,
request_count: 0,
total_cost: "0.000000".to_string(),
total_tokens: 0,
total_input_tokens: 0,
total_output_tokens: 0,
total_cache_creation_tokens: 0,
total_cache_read_tokens: 0,
});
}
}
Ok(stats)
if start_ts >= end_ts {
start_ts = end_ts - 24 * 60 * 60;
}
let duration = end_ts - start_ts;
let bucket_seconds: i64 = if duration <= 24 * 60 * 60 {
60 * 60
} else {
24 * 60 * 60
};
let mut bucket_count: i64 = if duration <= 0 {
1
} else {
((duration as f64) / bucket_seconds as f64).ceil() as i64
};
// 固定 24 小时窗口为 24 个小时桶,避免浮点误差
if bucket_seconds == 60 * 60 {
bucket_count = 24;
}
if bucket_count < 1 {
bucket_count = 1;
}
let sql = "
SELECT
CAST((created_at - ?1) / ?3 AS INTEGER) as bucket_idx,
COUNT(*) as request_count,
COALESCE(SUM(CAST(total_cost_usd AS REAL)), 0) as total_cost,
COALESCE(SUM(input_tokens + output_tokens), 0) as total_tokens,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens
FROM proxy_request_logs
WHERE created_at >= ?1 AND created_at <= ?2
GROUP BY bucket_idx
ORDER BY bucket_idx ASC";
let mut stmt = conn.prepare(sql)?;
let rows = stmt.query_map(params![start_ts, end_ts, bucket_seconds], |row| {
Ok((
row.get::<_, i64>(0)?,
DailyStats {
date: String::new(),
request_count: row.get::<_, i64>(1)? as u64,
total_cost: format!("{:.6}", row.get::<_, f64>(2)?),
total_tokens: row.get::<_, i64>(3)? as u64,
total_input_tokens: row.get::<_, i64>(4)? as u64,
total_output_tokens: row.get::<_, i64>(5)? as u64,
total_cache_creation_tokens: row.get::<_, i64>(6)? as u64,
total_cache_read_tokens: row.get::<_, i64>(7)? as u64,
},
))
})?;
let mut map: HashMap<i64, DailyStats> = HashMap::new();
for row in rows {
let (mut bucket_idx, stat) = row?;
if bucket_idx < 0 {
continue;
}
if bucket_idx >= bucket_count {
bucket_idx = bucket_count - 1;
}
map.insert(bucket_idx, stat);
}
let mut stats = Vec::with_capacity(bucket_count as usize);
for i in 0..bucket_count {
let bucket_start_ts = start_ts + i * bucket_seconds;
let bucket_start = Local
.timestamp_opt(bucket_start_ts, 0)
.single()
.unwrap_or_else(Local::now);
let date = bucket_start.format("%Y-%m-%dT%H:%M:%S").to_string();
if let Some(mut stat) = map.remove(&i) {
stat.date = date;
stats.push(stat);
} else {
stats.push(DailyStats {
date,
request_count: 0,
total_cost: "0.000000".to_string(),
total_tokens: 0,
total_input_tokens: 0,
total_output_tokens: 0,
total_cache_creation_tokens: 0,
total_cache_read_tokens: 0,
});
}
}
Ok(stats)
}
/// 获取 Provider 统计
@@ -829,89 +798,46 @@ impl Database {
}
}
/// 标准化模型名称:去除供应商前缀并将点号替换为短横线
/// 例如:anthropic/claude-haiku-4.5 → claude-haiku-4-5
fn normalize_model_id(model_id: &str) -> String {
// 1. 去除供应商前缀(如 anthropic/、openai/
let stripped = if let Some(pos) = model_id.find('/') {
&model_id[pos + 1..]
} else {
model_id
};
// 2. 将点号替换为短横线(如 claude-haiku-4.5 → claude-haiku-4-5
stripped.replace('.', "-")
}
pub(crate) fn find_model_pricing_row(
conn: &Connection,
model_id: &str,
) -> Result<Option<(String, String, String, String)>, AppError> {
// 0. 标准化模型名称(去除前缀 + 点号转短横线)
// 例如:anthropic/claude-haiku-4.5 → claude-haiku-4-5
let normalized = normalize_model_id(model_id);
// 1) 去除供应商前缀(/ 之前)与冒号后缀(: 之后),例如 moonshotai/kimi-k2-0905:exa → kimi-k2-0905
let without_prefix = model_id
.rsplit_once('/')
.map(|(_, rest)| rest)
.unwrap_or(model_id);
let cleaned = without_prefix
.split(':')
.next()
.map(str::trim)
.unwrap_or(without_prefix);
// 1. 精确匹配(先尝试原始名称,再尝试标准化后的名称
for id in [model_id, normalized.as_str()] {
let exact = conn
.query_row(
"SELECT input_cost_per_million, output_cost_per_million,
cache_read_cost_per_million, cache_creation_cost_per_million
FROM model_pricing
WHERE model_id = ?1",
[id],
|row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
row.get::<_, String>(3)?,
))
},
)
.optional()
.map_err(|e| AppError::Database(format!("查询模型定价失败: {e}")))?;
// 2) 精确匹配清洗后的名称
let exact = conn
.query_row(
"SELECT input_cost_per_million, output_cost_per_million,
cache_read_cost_per_million, cache_creation_cost_per_million
FROM model_pricing
WHERE model_id = ?1",
[cleaned],
|row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
row.get::<_, String>(3)?,
))
},
)
.optional()
.map_err(|e| AppError::Database(format!("查询模型定价失败: {e}")))?;
if exact.is_some() {
if id != model_id {
log::info!("模型 {model_id} 标准化后精确匹配到: {id}");
}
return Ok(exact);
}
if exact.is_none() {
log::warn!("模型 {model_id}(清洗后: {cleaned})未找到定价信息,成本将记录为 0");
}
// 2. 逐步删除后缀匹配(claude-haiku-4-5-20250929 → claude-haiku-4-5 → claude-haiku-4 → claude-haiku
// 使用标准化后的名称进行后缀匹配
let mut current = normalized;
while let Some(pos) = current.rfind('-') {
current = current[..pos].to_string();
let result = conn
.query_row(
"SELECT input_cost_per_million, output_cost_per_million,
cache_read_cost_per_million, cache_creation_cost_per_million
FROM model_pricing
WHERE model_id = ?1",
[&current],
|row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
row.get::<_, String>(3)?,
))
},
)
.optional()
.map_err(|e| AppError::Database(format!("查询模型定价失败: {e}")))?;
if result.is_some() {
log::info!("模型 {model_id} 通过删除后缀匹配到: {current}");
return Ok(result);
}
}
log::warn!("模型 {model_id} 未找到定价信息,成本将记录为 0");
Ok(None)
Ok(exact)
}
#[cfg(test)]
@@ -991,54 +917,39 @@ mod tests {
let db = Database::memory()?;
let conn = lock_conn!(db.conn);
// 测试精确匹配
let result = find_model_pricing_row(&conn, "claude-sonnet-4-5")?;
assert!(result.is_some(), "应该能精确匹配 claude-sonnet-4-5");
// 准备额外定价数据,覆盖前缀/后缀清洗场景
conn.execute(
"INSERT OR REPLACE INTO model_pricing (
model_id, display_name, input_cost_per_million, output_cost_per_million,
cache_read_cost_per_million, cache_creation_cost_per_million
) VALUES (?, ?, ?, ?, ?, ?)",
params![
"claude-haiku-4.5",
"Claude Haiku 4.5",
"1.0",
"2.0",
"0.0",
"0.0"
],
)?;
// 测试带供应商前缀的模型名称(anthropic/claude-haiku-4.5 → claude-haiku-4-5
let result = find_model_pricing_row(&conn, "anthropic/claude-haiku-4.5")?;
assert!(
result.is_some(),
"应该能匹配带前缀的模型 anthropic/claude-haiku-4.5"
);
// 测试带供应商前缀 + 点号的模型名称
let result = find_model_pricing_row(&conn, "anthropic/claude-sonnet-4.5")?;
assert!(
result.is_some(),
"应该能匹配带前缀的模型 anthropic/claude-sonnet-4.5"
);
// 测试逐步删除后缀匹配 - 日期后缀
let result = find_model_pricing_row(&conn, "claude-sonnet-4-5-20241022")?;
assert!(
result.is_some(),
"应该能通过删除后缀匹配 claude-sonnet-4-5-20241022"
);
// 测试逐步删除后缀匹配 - 多个后缀
let result = find_model_pricing_row(&conn, "claude-haiku-4-5-20240229-preview")?;
assert!(
result.is_some(),
"应该能通过删除后缀匹配 claude-haiku-4-5-20240229-preview"
);
// 测试 GPT 模型
let result = find_model_pricing_row(&conn, "gpt-5-2024-11-20")?;
assert!(result.is_some(), "应该能通过删除后缀匹配 gpt-5-2024-11-20");
// 测试 Gemini 模型
let result = find_model_pricing_row(&conn, "gemini-2.5-flash-exp")?;
assert!(
result.is_some(),
"应该能通过删除后缀匹配 gemini-2.5-flash-exp"
);
// 测试 claude-sonnet-4-5 命名格式
// 测试精确匹配(seed_model_pricing 已预置 claude-sonnet-4-5-20250929
let result = find_model_pricing_row(&conn, "claude-sonnet-4-5-20250929")?;
assert!(
result.is_some(),
"应该能通过删除后缀匹配 claude-sonnet-4-5-20250929"
"应该能精确匹配 claude-sonnet-4-5-20250929"
);
// 清洗:去除前缀和冒号后缀
let result = find_model_pricing_row(&conn, "anthropic/claude-haiku-4.5")?;
assert!(
result.is_some(),
"带前缀的模型 anthropic/claude-haiku-4.5 应能匹配到 claude-haiku-4.5"
);
let result = find_model_pricing_row(&conn, "moonshotai/kimi-k2-0905:exa")?;
assert!(
result.is_some(),
"带前缀+冒号后缀的模型应清洗后匹配到 kimi-k2-0905"
);
// 测试不存在的模型
@@ -142,12 +142,13 @@ export function AutoFailoverConfigPanel({
min="0"
max="10"
value={formData.maxRetries}
onChange={(e) =>
onChange={(e) => {
const val = parseInt(e.target.value);
setFormData({
...formData,
maxRetries: parseInt(e.target.value) || 3,
})
}
maxRetries: isNaN(val) ? 0 : val,
});
}}
disabled={isDisabled}
/>
<p className="text-xs text-muted-foreground">
@@ -168,12 +169,13 @@ export function AutoFailoverConfigPanel({
min="1"
max="20"
value={formData.circuitFailureThreshold}
onChange={(e) =>
onChange={(e) => {
const val = parseInt(e.target.value);
setFormData({
...formData,
circuitFailureThreshold: parseInt(e.target.value) || 5,
})
}
circuitFailureThreshold: isNaN(val) ? 1 : Math.max(1, val),
});
}}
disabled={isDisabled}
/>
<p className="text-xs text-muted-foreground">
@@ -206,12 +208,13 @@ export function AutoFailoverConfigPanel({
min="0"
max="180"
value={formData.streamingFirstByteTimeout}
onChange={(e) =>
onChange={(e) => {
const val = parseInt(e.target.value);
setFormData({
...formData,
streamingFirstByteTimeout: parseInt(e.target.value) || 30,
})
}
streamingFirstByteTimeout: isNaN(val) ? 0 : val,
});
}}
disabled={isDisabled}
/>
<p className="text-xs text-muted-foreground">
@@ -232,12 +235,13 @@ export function AutoFailoverConfigPanel({
min="0"
max="600"
value={formData.streamingIdleTimeout}
onChange={(e) =>
onChange={(e) => {
const val = parseInt(e.target.value);
setFormData({
...formData,
streamingIdleTimeout: parseInt(e.target.value) || 60,
})
}
streamingIdleTimeout: isNaN(val) ? 0 : val,
});
}}
disabled={isDisabled}
/>
<p className="text-xs text-muted-foreground">
@@ -258,12 +262,13 @@ export function AutoFailoverConfigPanel({
min="0"
max="1800"
value={formData.nonStreamingTimeout}
onChange={(e) =>
onChange={(e) => {
const val = parseInt(e.target.value);
setFormData({
...formData,
nonStreamingTimeout: parseInt(e.target.value) || 300,
})
}
nonStreamingTimeout: isNaN(val) ? 0 : val,
});
}}
disabled={isDisabled}
/>
<p className="text-xs text-muted-foreground">
@@ -293,12 +298,13 @@ export function AutoFailoverConfigPanel({
min="1"
max="10"
value={formData.circuitSuccessThreshold}
onChange={(e) =>
onChange={(e) => {
const val = parseInt(e.target.value);
setFormData({
...formData,
circuitSuccessThreshold: parseInt(e.target.value) || 2,
})
}
circuitSuccessThreshold: isNaN(val) ? 1 : Math.max(1, val),
});
}}
disabled={isDisabled}
/>
<p className="text-xs text-muted-foreground">
@@ -319,12 +325,13 @@ export function AutoFailoverConfigPanel({
min="10"
max="300"
value={formData.circuitTimeoutSeconds}
onChange={(e) =>
onChange={(e) => {
const val = parseInt(e.target.value);
setFormData({
...formData,
circuitTimeoutSeconds: parseInt(e.target.value) || 60,
})
}
circuitTimeoutSeconds: isNaN(val) ? 10 : Math.max(10, val),
});
}}
disabled={isDisabled}
/>
<p className="text-xs text-muted-foreground">
@@ -346,13 +353,13 @@ export function AutoFailoverConfigPanel({
max="100"
step="5"
value={Math.round(formData.circuitErrorRateThreshold * 100)}
onChange={(e) =>
onChange={(e) => {
const val = parseInt(e.target.value);
setFormData({
...formData,
circuitErrorRateThreshold:
(parseInt(e.target.value) || 50) / 100,
})
}
circuitErrorRateThreshold: isNaN(val) ? 0.5 : val / 100,
});
}}
disabled={isDisabled}
/>
<p className="text-xs text-muted-foreground">
@@ -373,12 +380,13 @@ export function AutoFailoverConfigPanel({
min="5"
max="100"
value={formData.circuitMinRequests}
onChange={(e) =>
onChange={(e) => {
const val = parseInt(e.target.value);
setFormData({
...formData,
circuitMinRequests: parseInt(e.target.value) || 10,
})
}
circuitMinRequests: isNaN(val) ? 5 : Math.max(5, val),
});
}}
disabled={isDisabled}
/>
<p className="text-xs text-muted-foreground">
+34 -18
View File
@@ -62,6 +62,28 @@ export function RequestLogTable() {
});
};
// 将 Unix 时间戳转换为本地时间的 datetime-local 格式
const timestampToLocalDatetime = (timestamp: number): string => {
const date = new Date(timestamp * 1000);
const year = date.getFullYear();
const month = String(date.getMonth() + 1).padStart(2, "0");
const day = String(date.getDate()).padStart(2, "0");
const hours = String(date.getHours()).padStart(2, "0");
const minutes = String(date.getMinutes()).padStart(2, "0");
return `${year}-${month}-${day}T${hours}:${minutes}`;
};
// 将 datetime-local 格式转换为 Unix 时间戳
const localDatetimeToTimestamp = (datetime: string): number | undefined => {
if (!datetime) return undefined;
// 验证格式是否完整 (YYYY-MM-DDTHH:mm)
if (datetime.length < 16) return undefined;
const timestamp = new Date(datetime).getTime();
// 验证是否为有效日期
if (isNaN(timestamp)) return undefined;
return Math.floor(timestamp / 1000);
};
const dateLocale =
i18n.language === "zh"
? "zh-CN"
@@ -153,19 +175,16 @@ export function RequestLogTable() {
className="h-8 w-[200px] bg-background"
value={
tempFilters.startDate
? new Date(tempFilters.startDate * 1000)
.toISOString()
.slice(0, 16)
? timestampToLocalDatetime(tempFilters.startDate)
: ""
}
onChange={(e) =>
onChange={(e) => {
const timestamp = localDatetimeToTimestamp(e.target.value);
setTempFilters({
...tempFilters,
startDate: e.target.value
? Math.floor(new Date(e.target.value).getTime() / 1000)
: undefined,
})
}
startDate: timestamp,
});
}}
/>
<span>-</span>
<Input
@@ -173,19 +192,16 @@ export function RequestLogTable() {
className="h-8 w-[200px] bg-background"
value={
tempFilters.endDate
? new Date(tempFilters.endDate * 1000)
.toISOString()
.slice(0, 16)
? timestampToLocalDatetime(tempFilters.endDate)
: ""
}
onChange={(e) =>
onChange={(e) => {
const timestamp = localDatetimeToTimestamp(e.target.value);
setTempFilters({
...tempFilters,
endDate: e.target.value
? Math.floor(new Date(e.target.value).getTime() / 1000)
: undefined,
})
}
endDate: timestamp,
});
}}
/>
</div>
+1 -7
View File
@@ -12,13 +12,7 @@ interface UsageSummaryCardsProps {
export function UsageSummaryCards({ days }: UsageSummaryCardsProps) {
const { t } = useTranslation();
const { startDate, endDate } = useMemo(() => {
const end = Math.floor(Date.now() / 1000);
const start = end - days * 24 * 60 * 60;
return { startDate: start, endDate: end };
}, [days]);
const { data: summary, isLoading } = useUsageSummary(startDate, endDate);
const { data: summary, isLoading } = useUsageSummary(days);
const stats = useMemo(() => {
const totalRequests = summary?.totalRequests ?? 0;
+43 -19
View File
@@ -41,7 +41,12 @@ export function UsageTrendChart({ days }: UsageTrendChartProps) {
return {
rawDate: stat.date,
label: isToday
? pointDate.toLocaleTimeString(dateLocale, { hour: "2-digit" })
? pointDate.toLocaleString(dateLocale, {
month: "2-digit",
day: "2-digit",
hour: "2-digit",
minute: "2-digit",
})
: pointDate.toLocaleDateString(dateLocale, {
month: "2-digit",
day: "2-digit",
@@ -49,28 +54,13 @@ export function UsageTrendChart({ days }: UsageTrendChartProps) {
hour: pointDate.getHours(),
inputTokens: stat.totalInputTokens,
outputTokens: stat.totalOutputTokens,
cacheCreationTokens: stat.totalCacheCreationTokens,
cacheReadTokens: stat.totalCacheReadTokens,
cost: parseFloat(stat.totalCost),
};
}) || [];
const hourlyData = (() => {
if (!isToday) return chartData;
const map = new Map<number, (typeof chartData)[number]>();
chartData.forEach((point) => {
map.set(point.hour ?? 0, point);
});
return Array.from({ length: 24 }, (_, hour) => {
const bucket = map.get(hour);
return {
label: `${hour.toString().padStart(2, "0")}:00`,
inputTokens: bucket?.inputTokens ?? 0,
outputTokens: bucket?.outputTokens ?? 0,
cost: bucket?.cost ?? 0,
};
});
})();
const displayData = isToday ? hourlyData : chartData;
const displayData = chartData;
const CustomTooltip = ({ active, payload, label }: any) => {
if (active && payload && payload.length) {
@@ -131,6 +121,20 @@ export function UsageTrendChart({ days }: UsageTrendChartProps) {
<stop offset="5%" stopColor="#22c55e" stopOpacity={0.2} />
<stop offset="95%" stopColor="#22c55e" stopOpacity={0} />
</linearGradient>
<linearGradient
id="colorCacheCreation"
x1="0"
y1="0"
x2="0"
y2="1"
>
<stop offset="5%" stopColor="#f97316" stopOpacity={0.2} />
<stop offset="95%" stopColor="#f97316" stopOpacity={0} />
</linearGradient>
<linearGradient id="colorCacheRead" x1="0" y1="0" x2="0" y2="1">
<stop offset="5%" stopColor="#a855f7" stopOpacity={0.2} />
<stop offset="95%" stopColor="#a855f7" stopOpacity={0} />
</linearGradient>
</defs>
<CartesianGrid
strokeDasharray="3 3"
@@ -182,6 +186,26 @@ export function UsageTrendChart({ days }: UsageTrendChartProps) {
fill="url(#colorOutput)"
strokeWidth={2}
/>
<Area
yAxisId="tokens"
type="monotone"
dataKey="cacheCreationTokens"
name={t("usage.cacheCreationTokens", "缓存创建")}
stroke="#f97316"
fillOpacity={1}
fill="url(#colorCacheCreation)"
strokeWidth={2}
/>
<Area
yAxisId="tokens"
type="monotone"
dataKey="cacheReadTokens"
name={t("usage.cacheReadTokens", "缓存命中")}
stroke="#a855f7"
fillOpacity={1}
fill="url(#colorCacheRead)"
strokeWidth={2}
/>
<Area
yAxisId="cost"
type="monotone"
+7 -7
View File
@@ -423,7 +423,7 @@
"cost": "Cost",
"perMillion": "(per million)",
"trends": "Usage Trends",
"rangeToday": "Today (hourly)",
"rangeToday": "Last 24 hours (hourly)",
"rangeLast7Days": "Last 7 days",
"rangeLast30Days": "Last 30 days",
"totalTokens": "Total Tokens",
@@ -436,8 +436,8 @@
"billingModel": "Billing Model",
"inputTokens": "Input",
"outputTokens": "Output",
"cacheReadTokens": "Cache Read",
"cacheCreationTokens": "Cache Write",
"cacheReadTokens": "Cache Hit",
"cacheCreationTokens": "Cache Creation",
"timingInfo": "Duration/TTFT",
"status": "Status",
"noData": "No data",
@@ -453,8 +453,8 @@
"displayName": "Display Name",
"inputCost": "Input Cost",
"outputCost": "Output Cost",
"cacheReadCost": "Cache Read",
"cacheWriteCost": "Cache Write",
"cacheReadCost": "Cache Hit",
"cacheWriteCost": "Cache Creation",
"deleteConfirmTitle": "Confirm Delete",
"deleteConfirmDesc": "Are you sure you want to delete this model pricing? This action cannot be undone.",
"queryFailed": "Query failed",
@@ -481,8 +481,8 @@
"timeRange": "Time Range",
"input": "Input",
"output": "Output",
"cacheWrite": "Write",
"cacheRead": "Read"
"cacheWrite": "Creation",
"cacheRead": "Hit"
},
"usageScript": {
"title": "Configure Usage Query",
+7 -7
View File
@@ -423,7 +423,7 @@
"cost": "コスト",
"perMillion": "(100万あたり)",
"trends": "利用トレンド",
"rangeToday": "今日 (時間別)",
"rangeToday": "直近24時間 (時間別)",
"rangeLast7Days": "過去7日間",
"rangeLast30Days": "過去30日間",
"totalTokens": "総トークン数",
@@ -436,8 +436,8 @@
"billingModel": "課金モデル",
"inputTokens": "入力",
"outputTokens": "出力",
"cacheReadTokens": "キャッシュ読取",
"cacheCreationTokens": "キャッシュ書込",
"cacheReadTokens": "キャッシュヒット",
"cacheCreationTokens": "キャッシュ作成",
"timingInfo": "応答時間/TTFT",
"status": "ステータス",
"noData": "データなし",
@@ -453,8 +453,8 @@
"displayName": "表示名",
"inputCost": "入力コスト",
"outputCost": "出力コスト",
"cacheReadCost": "キャッシュ読取",
"cacheWriteCost": "キャッシュ書込",
"cacheReadCost": "キャッシュヒット",
"cacheWriteCost": "キャッシュ作成",
"deleteConfirmTitle": "削除の確認",
"deleteConfirmDesc": "このモデル料金を削除しますか?この操作は元に戻せません。",
"queryFailed": "照会に失敗しました",
@@ -481,8 +481,8 @@
"timeRange": "期間",
"input": "Input",
"output": "Output",
"cacheWrite": "Write",
"cacheRead": "Read"
"cacheWrite": "作成",
"cacheRead": "ヒット"
},
"usageScript": {
"title": "利用状況を設定",
+7 -7
View File
@@ -423,7 +423,7 @@
"cost": "成本",
"perMillion": "(每百万)",
"trends": "使用趋势",
"rangeToday": "今天 (按小时)",
"rangeToday": "过去 24 小时 (按小时)",
"rangeLast7Days": "过去 7 天",
"rangeLast30Days": "过去 30 天",
"totalTokens": "总 Token 数",
@@ -436,8 +436,8 @@
"billingModel": "计费模型",
"inputTokens": "输入",
"outputTokens": "输出",
"cacheReadTokens": "缓存读取",
"cacheCreationTokens": "缓存写入",
"cacheReadTokens": "缓存命中",
"cacheCreationTokens": "缓存创建",
"timingInfo": "用时/首字",
"status": "状态",
"noData": "暂无数据",
@@ -453,8 +453,8 @@
"displayName": "显示名称",
"inputCost": "输入成本",
"outputCost": "输出成本",
"cacheReadCost": "缓存读取",
"cacheWriteCost": "缓存写入",
"cacheReadCost": "缓存命中",
"cacheWriteCost": "缓存创建",
"deleteConfirmTitle": "确认删除",
"deleteConfirmDesc": "确定要删除此模型定价配置吗?此操作无法撤销。",
"queryFailed": "查询失败",
@@ -481,8 +481,8 @@
"timeRange": "时间范围",
"input": "Input",
"output": "Output",
"cacheWrite": "Write",
"cacheRead": "Read"
"cacheWrite": "创建",
"cacheRead": "命中"
},
"usageScript": {
"title": "配置用量查询",
+5 -2
View File
@@ -49,8 +49,11 @@ export const usageApi = {
return invoke("get_usage_summary", { startDate, endDate });
},
getUsageTrends: async (days: number): Promise<DailyStats[]> => {
return invoke("get_usage_trends", { days });
getUsageTrends: async (
startDate?: number,
endDate?: number,
): Promise<DailyStats[]> => {
return invoke("get_usage_trends", { startDate, endDate });
},
getProviderStats: async (): Promise<ProviderStats[]> => {
+27 -6
View File
@@ -5,8 +5,7 @@ import type { LogFilters } from "@/types/usage";
// Query keys
export const usageKeys = {
all: ["usage"] as const,
summary: (startDate?: number, endDate?: number) =>
[...usageKeys.all, "summary", startDate, endDate] as const,
summary: (days: number) => [...usageKeys.all, "summary", days] as const,
trends: (days: number) => [...usageKeys.all, "trends", days] as const,
providerStats: () => [...usageKeys.all, "provider-stats"] as const,
modelStats: () => [...usageKeys.all, "model-stats"] as const,
@@ -19,18 +18,34 @@ export const usageKeys = {
[...usageKeys.all, "limits", providerId, appType] as const,
};
const getWindow = (days: number) => {
const endDate = Math.floor(Date.now() / 1000);
const startDate = endDate - days * 24 * 60 * 60;
return { startDate, endDate };
};
// Hooks
export function useUsageSummary(startDate?: number, endDate?: number) {
export function useUsageSummary(days: number) {
return useQuery({
queryKey: usageKeys.summary(startDate, endDate),
queryFn: () => usageApi.getUsageSummary(startDate, endDate),
queryKey: usageKeys.summary(days),
queryFn: () => {
const { startDate, endDate } = getWindow(days);
return usageApi.getUsageSummary(startDate, endDate);
},
refetchInterval: 30000, // 每30秒自动刷新
refetchIntervalInBackground: false, // 后台不刷新
});
}
export function useUsageTrends(days: number) {
return useQuery({
queryKey: usageKeys.trends(days),
queryFn: () => usageApi.getUsageTrends(days),
queryFn: () => {
const { startDate, endDate } = getWindow(days);
return usageApi.getUsageTrends(startDate, endDate);
},
refetchInterval: 30000, // 每30秒自动刷新
refetchIntervalInBackground: false,
});
}
@@ -38,6 +53,8 @@ export function useProviderStats() {
return useQuery({
queryKey: usageKeys.providerStats(),
queryFn: usageApi.getProviderStats,
refetchInterval: 30000, // 每30秒自动刷新
refetchIntervalInBackground: false,
});
}
@@ -45,6 +62,8 @@ export function useModelStats() {
return useQuery({
queryKey: usageKeys.modelStats(),
queryFn: usageApi.getModelStats,
refetchInterval: 30000, // 每30秒自动刷新
refetchIntervalInBackground: false,
});
}
@@ -56,6 +75,8 @@ export function useRequestLogs(
return useQuery({
queryKey: usageKeys.logs(filters, page, pageSize),
queryFn: () => usageApi.getRequestLogs(filters, page, pageSize),
refetchInterval: 30000, // 每30秒自动刷新
refetchIntervalInBackground: false,
});
}