From 5376ea042bae882457e27d370a7ffaedef068b57 Mon Sep 17 00:00:00 2001 From: Dex Miller Date: Wed, 31 Dec 2025 22:57:00 +0800 Subject: [PATCH] Feat/usage improvements (#508) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * i18n: update cache terminology across all languages - Change 'Cache Read' to 'Cache Hit' in all languages - Change 'Cache Write' to 'Cache Creation' in all languages - Update zh: 缓存读取 → 缓存命中, 缓存写入 → 缓存创建 - Update en: Cache Read → Cache Hit, Cache Write → Cache Creation - Update ja: キャッシュ読取 → キャッシュヒット, キャッシュ書込 → キャッシュ作成 Affected keys: cacheReadTokens, cacheCreationTokens, cacheReadCost, cacheWriteCost, cacheRead, cacheWrite * feat(usage): add cache metrics to trend chart - Add cache creation tokens visualization (orange line) - Add cache hit tokens visualization (purple line) - Add gradient definitions for new cache metrics - Include cache data in hourly aggregation - Display cache metrics alongside input/output tokens This provides better visibility into cache usage patterns over time. * fix(usage): fix timezone handling in datetime picker - Add timestampToLocalDatetime() to convert Unix timestamp to local datetime - Add localDatetimeToTimestamp() with validation for incomplete input - Fix issue where typing hours/minutes would jump to previous day - Validate datetime format completeness before conversion - Use local timezone instead of UTC for datetime-local input This resolves the issue where users couldn't fine-tune time selection and the input would jump unexpectedly when editing hours or minutes. * feat(usage): add auto-refresh for usage statistics - Add 30-second auto-refresh interval for all usage queries - Disable background refresh to save resources - Apply to: summary, trends, provider stats, model stats, request logs - Queries automatically update when tab is active - Pause refresh when user switches to another tab This keeps usage data fresh without manual refresh. * fix(proxy): improve usage logging and cache token parsing - Log requests even when usage parsing fails (with default values) - Add detailed debug logging for usage metrics - Support cache_read_input_tokens field in Codex responses - Fallback to input_tokens_details.cached_tokens if needed - Add test case for cached_tokens in input_tokens_details - Ensure all requests are tracked in database for analytics This fixes missing request logs when API responses lack usage data and improves cache token detection across different response formats. * style(rust): use inline format args in format! macros - Replace format!("...", var) with format!("...{var}") - Update universal provider ID formatting - Update error message formatting - Update config.toml generation in Codex provider Fixes clippy::uninlined_format_args warnings. * feat(proxy): enhance provider router logging - Add debug logs for failover queue provider count - Log circuit breaker state for each provider check - Add logs for missing current provider scenarios - Log when no current provider is configured - Use inline format args for better readability This improves debugging of provider selection and failover behavior. * feat(database): update model pricing data - Update Claude models to full version format (e.g. claude-opus-4-5-20251101) - Add GPT-5.2 series model pricing (10 models) - Add GPT-5.1 series model pricing (10 models) - Add GPT-5 series model pricing (12 models) - Add Gemini 3 series model pricing (2 models) - Update Gemini 2.5 series model ID format (use dot separator) - Unify display names by removing thinking level suffixes * fix(usage): correct Gemini output token calculation Fix Gemini API output token parsing to use totalTokenCount - promptTokenCount instead of candidatesTokenCount alone. This ensures thoughtsTokenCount is included in output statistics. - Update from_gemini_response to calculate output from total - input - Update from_gemini_stream_chunks with same logic for consistency - Fix from_codex_stream_events to use adjusted token calculation - Add test case for responses with thoughtsTokenCount - Update existing tests to match new calculation logic * fix(usage): correct cache token billing and add Codex format auto-detection - Avoid double-billing cache tokens by subtracting from input before calculation - Add smart Codex parser that auto-detects OpenAI vs Codex API format - Extract model name from Codex responses for accurate tracking * fix(proxy): improve takeover detection with live config check - Add live config takeover detection for hot-switch decision - Rebuild takeover when backup is missing or placeholder remains - Make detect_takeover_in_live_config_for_app public - Fix is_takeover_active to use actual takeover status * refactor(usage): simplify model pricing lookup by removing suffix fallback Replace complex suffix-stripping fallback with direct prefix/suffix cleanup. Model IDs are now cleaned by removing vendor prefix (before /) and colon suffix (after :), then matched exactly against pricing table. * feat(database): add Chinese AI model pricing data Add pricing for domestic AI models (CNY/1M tokens): - Doubao-Seed-Code (ByteDance) - DeepSeek V3/V3.1/V3.2 - Kimi K2/K2-Thinking/K2-Turbo (Moonshot) - MiniMax M2/M2.1/M2.1-Lightning - GLM-4.6/4.7 (Zhipu) - Mimo V2 Flash (Xiaomi) Also fix test case to use correct model ID and remove invalid currency column. * refactor(proxy): improve header forwarding with blacklist approach Change from whitelist to blacklist mode for request header forwarding. Only skip headers that will be overridden (auth, host, content-length). This preserves client's original headers and improves compatibility. * fix(proxy): bypass timeout and retry configs when failover is disabled When auto_failover_enabled is false, timeout and retry configurations should not affect normal request flow. This change ensures: - create_forwarder: passes 0 for all timeout/retry params when failover is disabled, effectively bypassing these checks - streaming_timeout_config: returns 0 for both first_byte_timeout and idle_timeout when failover is disabled This prevents unnecessary timeout errors and retry attempts when users have explicitly disabled the failover feature. * fix(proxy): handle zero value input in failover config fields * refactor(proxy): remove retry logic and add enabled check for failover * refactor(proxy): distinguish circuit-open from no-provider errors * Align usage stats to sliding windows * feat(proxy): add body and header filtering for upstream requests * feat(proxy): enable transparent passthrough for headers - Passthrough anthropic-beta header as-is from client - Passthrough anthropic-version header from client - Passthrough client IP headers (x-forwarded-for, x-real-ip) by default - Filter private params (underscore-prefixed fields) from request body - No database changes required * feat(proxy): extract session ID from client requests for logging - Add SessionIdExtractor to parse session ID from Claude/Codex requests - Support extraction from metadata.user_id, headers, previous_response_id - Pass session_id through RequestContext to usage logger - Enable request correlation by session in proxy_request_logs --- src-tauri/src/commands/proxy.rs | 13 +- src-tauri/src/commands/usage.rs | 5 +- src-tauri/src/database/schema.rs | 258 +++++++++-- src-tauri/src/error.rs | 4 + src-tauri/src/provider.rs | 9 +- src-tauri/src/proxy/body_filter.rs | 303 +++++++++++++ src-tauri/src/proxy/error.rs | 12 + src-tauri/src/proxy/error_mapper.rs | 8 + src-tauri/src/proxy/failover_switch.rs | 15 + src-tauri/src/proxy/forwarder.rs | 284 ++++++++---- src-tauri/src/proxy/handler_config.rs | 18 +- src-tauri/src/proxy/handler_context.rs | 84 +++- src-tauri/src/proxy/handlers.rs | 15 +- src-tauri/src/proxy/mod.rs | 5 +- src-tauri/src/proxy/provider_router.rs | 40 +- src-tauri/src/proxy/response_processor.rs | 59 ++- src-tauri/src/proxy/session.rs | 269 +++++++++++ src-tauri/src/proxy/usage/calculator.rs | 18 +- src-tauri/src/proxy/usage/parser.rs | 299 ++++++++++++- src-tauri/src/services/provider/mod.rs | 19 +- src-tauri/src/services/proxy.rs | 27 +- src-tauri/src/services/usage_stats.rs | 421 +++++++----------- .../proxy/AutoFailoverConfigPanel.tsx | 82 ++-- src/components/usage/RequestLogTable.tsx | 52 ++- src/components/usage/UsageSummaryCards.tsx | 8 +- src/components/usage/UsageTrendChart.tsx | 62 ++- src/i18n/locales/en.json | 14 +- src/i18n/locales/ja.json | 14 +- src/i18n/locales/zh.json | 14 +- src/lib/api/usage.ts | 7 +- src/lib/query/usage.ts | 33 +- 31 files changed, 1888 insertions(+), 583 deletions(-) create mode 100644 src-tauri/src/proxy/body_filter.rs diff --git a/src-tauri/src/commands/proxy.rs b/src-tauri/src/commands/proxy.rs index 37be402cf..56f587842 100644 --- a/src-tauri/src/commands/proxy.rs +++ b/src-tauri/src/commands/proxy.rs @@ -184,17 +184,16 @@ pub async fn reset_circuit_breaker( .await?; // 3. 检查是否应该切回优先级更高的供应商(从 proxy_config 表读取) - let auto_failover_enabled = match db.get_proxy_config_for_app(&app_type).await { - Ok(config) => config.auto_failover_enabled, + // 只有当该应用已被代理接管(enabled=true)且开启了自动故障转移时才执行 + let (app_enabled, auto_failover_enabled) = match db.get_proxy_config_for_app(&app_type).await { + Ok(config) => (config.enabled, config.auto_failover_enabled), Err(e) => { - log::error!( - "[{app_type}] Failed to read proxy_config for auto_failover_enabled: {e}, defaulting to disabled" - ); - false + log::error!("[{app_type}] Failed to read proxy_config: {e}, defaulting to disabled"); + (false, false) } }; - if auto_failover_enabled && state.proxy_service.is_running().await { + if app_enabled && auto_failover_enabled && state.proxy_service.is_running().await { // 获取当前供应商 ID let current_id = db .get_current_provider(&app_type) diff --git a/src-tauri/src/commands/usage.rs b/src-tauri/src/commands/usage.rs index e872feb5d..4e527ebbe 100644 --- a/src-tauri/src/commands/usage.rs +++ b/src-tauri/src/commands/usage.rs @@ -19,9 +19,10 @@ pub fn get_usage_summary( #[tauri::command] pub fn get_usage_trends( state: State<'_, AppState>, - days: u32, + start_date: Option, + end_date: Option, ) -> Result, AppError> { - state.db.get_daily_trends(days) + state.db.get_daily_trends(start_date, end_date) } /// 获取 Provider 统计 diff --git a/src-tauri/src/database/schema.rs b/src-tauri/src/database/schema.rs index 52187ae47..7071b26c0 100644 --- a/src-tauri/src/database/schema.rs +++ b/src-tauri/src/database/schema.rs @@ -765,9 +765,9 @@ impl Database { /// 注意: model_id 使用短横线格式(如 claude-haiku-4-5),与 API 返回的模型名称标准化后一致 fn seed_model_pricing(conn: &Connection) -> Result<(), AppError> { let pricing_data = [ - // Claude 4.5 系列 + // Claude 4.5 系列 (Latest Models) ( - "claude-opus-4-5", + "claude-opus-4-5-20251101", "Claude Opus 4.5", "5", "25", @@ -775,7 +775,7 @@ impl Database { "6.25", ), ( - "claude-sonnet-4-5", + "claude-sonnet-4-5-20250929", "Claude Sonnet 4.5", "3", "15", @@ -783,16 +783,24 @@ impl Database { "3.75", ), ( - "claude-haiku-4-5", + "claude-haiku-4-5-20251001", "Claude Haiku 4.5", "1", "5", "0.10", "1.25", ), - // Claude 4.1 系列 + // Claude 4 系列 (Legacy Models) ( - "claude-opus-4-1", + "claude-opus-4-20250514", + "Claude Opus 4", + "15", + "75", + "1.50", + "18.75", + ), + ( + "claude-opus-4-1-20250805", "Claude Opus 4.1", "15", "75", @@ -800,17 +808,8 @@ impl Database { "18.75", ), ( - "claude-sonnet-4-1", - "Claude Sonnet 4.1", - "3", - "15", - "0.30", - "3.75", - ), - // Claude 3.7 系列 - ( - "claude-sonnet-3-7", - "Claude Sonnet 3.7", + "claude-sonnet-4-20250514", + "Claude Sonnet 4", "3", "15", "0.30", @@ -818,38 +817,167 @@ impl Database { ), // Claude 3.5 系列 ( - "claude-sonnet-3-5", - "Claude Sonnet 3.5", - "3", - "15", - "0.30", - "3.75", - ), - ( - "claude-haiku-3-5", - "Claude Haiku 3.5", + "claude-3-5-haiku-20241022", + "Claude 3.5 Haiku", "0.80", "4", "0.08", "1", ), - // GPT-5 系列(model_id 使用短横线格式) + ( + "claude-3-5-sonnet-20241022", + "Claude 3.5 Sonnet", + "3", + "15", + "0.30", + "3.75", + ), + // GPT-5.2 系列 + ("gpt-5.2", "GPT-5.2", "1.75", "14", "0.175", "0"), + ("gpt-5.2-low", "GPT-5.2", "1.75", "14", "0.175", "0"), + ("gpt-5.2-medium", "GPT-5.2", "1.75", "14", "0.175", "0"), + ("gpt-5.2-high", "GPT-5.2", "1.75", "14", "0.175", "0"), + ("gpt-5.2-xhigh", "GPT-5.2", "1.75", "14", "0.175", "0"), + ("gpt-5.2-codex", "GPT-5.2 Codex", "1.75", "14", "0.175", "0"), + ( + "gpt-5.2-codex-low", + "GPT-5.2 Codex", + "1.75", + "14", + "0.175", + "0", + ), + ( + "gpt-5.2-codex-medium", + "GPT-5.2 Codex", + "1.75", + "14", + "0.175", + "0", + ), + ( + "gpt-5.2-codex-high", + "GPT-5.2 Codex", + "1.75", + "14", + "0.175", + "0", + ), + ( + "gpt-5.2-codex-xhigh", + "GPT-5.2 Codex", + "1.75", + "14", + "0.175", + "0", + ), + // GPT-5.1 系列 + ("gpt-5.1", "GPT-5.1", "1.25", "10", "0.125", "0"), + ("gpt-5.1-low", "GPT-5.1", "1.25", "10", "0.125", "0"), + ("gpt-5.1-medium", "GPT-5.1", "1.25", "10", "0.125", "0"), + ("gpt-5.1-high", "GPT-5.1", "1.25", "10", "0.125", "0"), + ("gpt-5.1-minimal", "GPT-5.1", "1.25", "10", "0.125", "0"), + ("gpt-5.1-codex", "GPT-5.1 Codex", "1.25", "10", "0.125", "0"), + ( + "gpt-5.1-codex-mini", + "GPT-5.1 Codex", + "1.25", + "10", + "0.125", + "0", + ), + ( + "gpt-5.1-codex-max", + "GPT-5.1 Codex", + "1.25", + "10", + "0.125", + "0", + ), + ( + "gpt-5.1-codex-max-high", + "GPT-5.1 Codex", + "1.25", + "10", + "0.125", + "0", + ), + ( + "gpt-5.1-codex-max-xhigh", + "GPT-5.1 Codex", + "1.25", + "10", + "0.125", + "0", + ), + // GPT-5 系列 ("gpt-5", "GPT-5", "1.25", "10", "0.125", "0"), - ("gpt-5-1", "GPT-5.1", "1.25", "10", "0.125", "0"), + ("gpt-5-low", "GPT-5", "1.25", "10", "0.125", "0"), + ("gpt-5-medium", "GPT-5", "1.25", "10", "0.125", "0"), + ("gpt-5-high", "GPT-5", "1.25", "10", "0.125", "0"), + ("gpt-5-minimal", "GPT-5", "1.25", "10", "0.125", "0"), ("gpt-5-codex", "GPT-5 Codex", "1.25", "10", "0.125", "0"), - ("gpt-5-1-codex", "GPT-5.1 Codex", "1.25", "10", "0.125", "0"), + ("gpt-5-codex-low", "GPT-5 Codex", "1.25", "10", "0.125", "0"), + ( + "gpt-5-codex-medium", + "GPT-5 Codex", + "1.25", + "10", + "0.125", + "0", + ), + ( + "gpt-5-codex-high", + "GPT-5 Codex", + "1.25", + "10", + "0.125", + "0", + ), + ( + "gpt-5-codex-mini", + "GPT-5 Codex", + "1.25", + "10", + "0.125", + "0", + ), + ( + "gpt-5-codex-mini-medium", + "GPT-5 Codex", + "1.25", + "10", + "0.125", + "0", + ), + ( + "gpt-5-codex-mini-high", + "GPT-5 Codex", + "1.25", + "10", + "0.125", + "0", + ), // Gemini 3 系列 ( "gemini-3-pro-preview", "Gemini 3 Pro Preview", "2", "12", - "0", + "0.2", "0", ), - // Gemini 2.5 系列(model_id 使用短横线格式) ( - "gemini-2-5-pro", + "gemini-3-flash-preview", + "Gemini 3 Flash Preview", + "0.5", + "3", + "0.05", + "0", + ), + // Gemini 2.5 系列 + ( + "gemini-2.5-pro", "Gemini 2.5 Pro", "1.25", "10", @@ -857,13 +985,75 @@ impl Database { "0", ), ( - "gemini-2-5-flash", + "gemini-2.5-flash", "Gemini 2.5 Flash", "0.3", "2.5", "0.03", "0", ), + // ====== 国产模型 (CNY/1M tokens) ====== + // Doubao (字节跳动) + ( + "doubao-seed-code", + "Doubao Seed Code", + "1.20", + "8.00", + "0.24", + "0", + ), + // DeepSeek 系列 + ( + "deepseek-v3.2", + "DeepSeek V3.2", + "2.00", + "3.00", + "0.40", + "0", + ), + ( + "deepseek-v3.1", + "DeepSeek V3.1", + "4.00", + "12.00", + "0.80", + "0", + ), + ("deepseek-v3", "DeepSeek V3", "2.00", "8.00", "0.40", "0"), + // Kimi (月之暗面) + ( + "kimi-k2-thinking", + "Kimi K2 Thinking", + "4.00", + "16.00", + "1.00", + "0", + ), + ("kimi-k2-0905", "Kimi K2", "4.00", "16.00", "1.00", "0"), + ( + "kimi-k2-turbo", + "Kimi K2 Turbo", + "8.00", + "58.00", + "1.00", + "0", + ), + // MiniMax 系列 + ("minimax-m2.1", "MiniMax M2.1", "2.10", "8.40", "0.21", "0"), + ( + "minimax-m2.1-lightning", + "MiniMax M2.1 Lightning", + "2.10", + "16.80", + "0.21", + "0", + ), + ("minimax-m2", "MiniMax M2", "2.10", "8.40", "0.21", "0"), + // GLM (智谱) + ("glm-4.7", "GLM-4.7", "2.00", "8.00", "0.40", "0"), + ("glm-4.6", "GLM-4.6", "2.00", "8.00", "0.40", "0"), + // Mimo (小米) + ("mimo-v2-flash", "Mimo V2 Flash", "0", "0", "0", "0"), ]; for (model_id, display_name, input, output, cache_read, cache_creation) in pricing_data { diff --git a/src-tauri/src/error.rs b/src-tauri/src/error.rs index e9eafd35b..9d8c622f1 100644 --- a/src-tauri/src/error.rs +++ b/src-tauri/src/error.rs @@ -52,6 +52,10 @@ pub enum AppError { }, #[error("数据库错误: {0}")] Database(String), + #[error("所有供应商已熔断,无可用渠道")] + AllProvidersCircuitOpen, + #[error("未配置供应商")] + NoProvidersConfigured, } impl AppError { diff --git a/src-tauri/src/provider.rs b/src-tauri/src/provider.rs index 3a4d1fa1f..13eebce51 100644 --- a/src-tauri/src/provider.rs +++ b/src-tauri/src/provider.rs @@ -386,16 +386,15 @@ impl UniversalProvider { // 生成 Codex 的 config.toml 内容 let config_toml = format!( r#"model_provider = "newapi" -model = "{}" -model_reasoning_effort = "{}" +model = "{model}" +model_reasoning_effort = "{reasoning_effort}" disable_response_storage = true [model_providers.newapi] name = "NewAPI" -base_url = "{}" +base_url = "{codex_base_url}" wire_api = "responses" -requires_openai_auth = true"#, - model, reasoning_effort, codex_base_url +requires_openai_auth = true"# ); let settings_config = serde_json::json!({ diff --git a/src-tauri/src/proxy/body_filter.rs b/src-tauri/src/proxy/body_filter.rs new file mode 100644 index 000000000..fc12ef64c --- /dev/null +++ b/src-tauri/src/proxy/body_filter.rs @@ -0,0 +1,303 @@ +//! 请求体过滤模块 +//! +//! 过滤不应透传到上游的私有参数,防止内部信息泄露。 +//! +//! ## 过滤规则 +//! - 以 `_` 开头的字段被视为私有参数,会被递归过滤 +//! - 支持白名单机制,允许透传特定的 `_` 前缀字段 +//! - 支持嵌套对象和数组的深度过滤 +//! +//! ## 使用场景 +//! - `_internal_id`: 内部追踪 ID +//! - `_debug_mode`: 调试标记 +//! - `_session_token`: 会话令牌 +//! - `_client_version`: 客户端版本 + +use serde_json::Value; +use std::collections::HashSet; + +/// 过滤私有参数(以 `_` 开头的字段) +/// +/// 递归遍历 JSON 结构,移除所有以下划线开头的字段。 +/// +/// # Arguments +/// * `body` - 原始请求体 +/// +/// # Returns +/// 过滤后的请求体 +/// +/// # Example +/// ```ignore +/// let input = json!({ +/// "model": "claude-3", +/// "_internal_id": "abc123", +/// "messages": [{"role": "user", "content": "hello", "_token": "secret"}] +/// }); +/// let output = filter_private_params(input); +/// // output 中不包含 _internal_id 和 _token +/// ``` +#[cfg(test)] +pub fn filter_private_params(body: Value) -> Value { + filter_private_params_with_whitelist(body, &[]) +} + +/// 过滤私有参数(支持白名单) +/// +/// 递归遍历 JSON 结构,移除所有以下划线开头的字段, +/// 但保留白名单中指定的字段。 +/// +/// # Arguments +/// * `body` - 原始请求体 +/// * `whitelist` - 白名单字段列表(不过滤这些字段) +/// +/// # Returns +/// 过滤后的请求体 +/// +/// # Example +/// ```ignore +/// let input = json!({ +/// "model": "claude-3", +/// "_metadata": {"key": "value"}, // 白名单中,保留 +/// "_internal_id": "abc123" // 不在白名单中,过滤 +/// }); +/// let output = filter_private_params_with_whitelist(input, &["_metadata"]); +/// // output 包含 _metadata,不包含 _internal_id +/// ``` +pub fn filter_private_params_with_whitelist(body: Value, whitelist: &[String]) -> Value { + let whitelist_set: HashSet<&str> = whitelist.iter().map(|s| s.as_str()).collect(); + filter_recursive_with_whitelist(body, &mut Vec::new(), &whitelist_set) +} + +/// 递归过滤实现 +#[cfg(test)] +fn filter_recursive(value: Value, removed_keys: &mut Vec) -> Value { + filter_recursive_with_whitelist(value, removed_keys, &HashSet::new()) +} + +/// 递归过滤实现(支持白名单) +fn filter_recursive_with_whitelist( + value: Value, + removed_keys: &mut Vec, + whitelist: &HashSet<&str>, +) -> Value { + match value { + Value::Object(map) => { + let filtered: serde_json::Map = map + .into_iter() + .filter_map(|(key, val)| { + // 以 _ 开头且不在白名单中的字段被过滤 + if key.starts_with('_') && !whitelist.contains(key.as_str()) { + removed_keys.push(key); + None + } else { + Some(( + key, + filter_recursive_with_whitelist(val, removed_keys, whitelist), + )) + } + }) + .collect(); + + // 仅在有过滤时记录日志(避免每次请求都打印) + if !removed_keys.is_empty() { + log::debug!("[BodyFilter] 过滤私有参数: {removed_keys:?}"); + removed_keys.clear(); + } + + Value::Object(filtered) + } + Value::Array(arr) => Value::Array( + arr.into_iter() + .map(|v| filter_recursive_with_whitelist(v, removed_keys, whitelist)) + .collect(), + ), + other => other, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_filter_top_level_private_params() { + let input = json!({ + "model": "claude-3", + "_internal_id": "abc123", + "_debug": true, + "max_tokens": 1024 + }); + + let output = filter_private_params(input); + + assert!(output.get("model").is_some()); + assert!(output.get("max_tokens").is_some()); + assert!(output.get("_internal_id").is_none()); + assert!(output.get("_debug").is_none()); + } + + #[test] + fn test_filter_nested_private_params() { + let input = json!({ + "model": "claude-3", + "messages": [ + { + "role": "user", + "content": "hello", + "_session_token": "secret" + } + ], + "metadata": { + "user_id": "user-1", + "_tracking_id": "track-1" + } + }); + + let output = filter_private_params(input); + + // 顶级字段保留 + assert!(output.get("model").is_some()); + assert!(output.get("messages").is_some()); + assert!(output.get("metadata").is_some()); + + // messages 数组中的私有参数被过滤 + let messages = output.get("messages").unwrap().as_array().unwrap(); + assert!(messages[0].get("role").is_some()); + assert!(messages[0].get("content").is_some()); + assert!(messages[0].get("_session_token").is_none()); + + // metadata 对象中的私有参数被过滤 + let metadata = output.get("metadata").unwrap(); + assert!(metadata.get("user_id").is_some()); + assert!(metadata.get("_tracking_id").is_none()); + } + + #[test] + fn test_filter_deeply_nested() { + let input = json!({ + "level1": { + "level2": { + "level3": { + "keep": "value", + "_remove": "secret" + } + } + } + }); + + let output = filter_private_params(input); + + let level3 = output + .get("level1") + .unwrap() + .get("level2") + .unwrap() + .get("level3") + .unwrap(); + + assert!(level3.get("keep").is_some()); + assert!(level3.get("_remove").is_none()); + } + + #[test] + fn test_filter_array_of_objects() { + let input = json!({ + "items": [ + {"id": 1, "_secret": "a"}, + {"id": 2, "_secret": "b"}, + {"id": 3, "_secret": "c"} + ] + }); + + let output = filter_private_params(input); + let items = output.get("items").unwrap().as_array().unwrap(); + + for item in items { + assert!(item.get("id").is_some()); + assert!(item.get("_secret").is_none()); + } + } + + #[test] + fn test_no_private_params() { + let input = json!({ + "model": "claude-3", + "messages": [{"role": "user", "content": "hello"}] + }); + + let output = filter_private_params(input.clone()); + + // 无私有参数时,输出应与输入相同 + assert_eq!(input, output); + } + + #[test] + fn test_empty_object() { + let input = json!({}); + let output = filter_private_params(input); + assert_eq!(output, json!({})); + } + + #[test] + fn test_primitive_values() { + // 原始值不应被修改 + assert_eq!(filter_private_params(json!(42)), json!(42)); + assert_eq!(filter_private_params(json!("string")), json!("string")); + assert_eq!(filter_private_params(json!(true)), json!(true)); + assert_eq!(filter_private_params(json!(null)), json!(null)); + } + + #[test] + fn test_whitelist_preserves_private_params() { + let input = json!({ + "model": "claude-3", + "_metadata": {"key": "value"}, + "_internal_id": "abc123", + "_stream_options": {"include_usage": true} + }); + + let whitelist = vec!["_metadata".to_string(), "_stream_options".to_string()]; + let output = filter_private_params_with_whitelist(input, &whitelist); + + // 白名单中的字段保留 + assert!(output.get("_metadata").is_some()); + assert!(output.get("_stream_options").is_some()); + // 不在白名单中的私有字段被过滤 + assert!(output.get("_internal_id").is_none()); + // 普通字段保留 + assert!(output.get("model").is_some()); + } + + #[test] + fn test_whitelist_nested() { + let input = json!({ + "data": { + "_allowed": "keep", + "_forbidden": "remove", + "normal": "value" + } + }); + + let whitelist = vec!["_allowed".to_string()]; + let output = filter_private_params_with_whitelist(input, &whitelist); + + let data = output.get("data").unwrap(); + assert!(data.get("_allowed").is_some()); + assert!(data.get("_forbidden").is_none()); + assert!(data.get("normal").is_some()); + } + + #[test] + fn test_empty_whitelist_same_as_default() { + let input = json!({ + "model": "claude-3", + "_internal_id": "abc123" + }); + + let output1 = filter_private_params(input.clone()); + let output2 = filter_private_params_with_whitelist(input, &[]); + + assert_eq!(output1, output2); + } +} diff --git a/src-tauri/src/proxy/error.rs b/src-tauri/src/proxy/error.rs index 316ffdc3f..1b0531a3a 100644 --- a/src-tauri/src/proxy/error.rs +++ b/src-tauri/src/proxy/error.rs @@ -23,6 +23,12 @@ pub enum ProxyError { #[error("无可用的Provider")] NoAvailableProvider, + #[error("所有供应商已熔断,无可用渠道")] + AllProvidersCircuitOpen, + + #[error("未配置供应商")] + NoProvidersConfigured, + #[allow(dead_code)] #[error("Provider不健康: {0}")] ProviderUnhealthy(String), @@ -111,6 +117,12 @@ impl IntoResponse for ProxyError { ProxyError::NoAvailableProvider => { (StatusCode::SERVICE_UNAVAILABLE, self.to_string()) } + ProxyError::AllProvidersCircuitOpen => { + (StatusCode::SERVICE_UNAVAILABLE, self.to_string()) + } + ProxyError::NoProvidersConfigured => { + (StatusCode::SERVICE_UNAVAILABLE, self.to_string()) + } ProxyError::ProviderUnhealthy(_) => { (StatusCode::SERVICE_UNAVAILABLE, self.to_string()) } diff --git a/src-tauri/src/proxy/error_mapper.rs b/src-tauri/src/proxy/error_mapper.rs index b8b44e9a3..f98c9608a 100644 --- a/src-tauri/src/proxy/error_mapper.rs +++ b/src-tauri/src/proxy/error_mapper.rs @@ -27,6 +27,12 @@ pub fn map_proxy_error_to_status(error: &ProxyError) -> u16 { // 无可用 Provider:503 Service Unavailable ProxyError::NoAvailableProvider => 503, + // 所有供应商已熔断:503 Service Unavailable + ProxyError::AllProvidersCircuitOpen => 503, + + // 未配置供应商:503 Service Unavailable + ProxyError::NoProvidersConfigured => 503, + // 重试耗尽:503 Service Unavailable ProxyError::MaxRetriesExceeded => 503, @@ -57,6 +63,8 @@ pub fn get_error_message(error: &ProxyError) -> String { ProxyError::Timeout(msg) => format!("请求超时: {msg}"), ProxyError::ForwardFailed(msg) => format!("转发失败: {msg}"), ProxyError::NoAvailableProvider => "无可用 Provider".to_string(), + ProxyError::AllProvidersCircuitOpen => "所有供应商已熔断,无可用渠道".to_string(), + ProxyError::NoProvidersConfigured => "未配置供应商".to_string(), ProxyError::MaxRetriesExceeded => "所有 Provider 都失败,重试耗尽".to_string(), ProxyError::ProviderUnhealthy(msg) => format!("Provider 不健康: {msg}"), ProxyError::DatabaseError(msg) => format!("数据库错误: {msg}"), diff --git a/src-tauri/src/proxy/failover_switch.rs b/src-tauri/src/proxy/failover_switch.rs index b1ac5b746..e0bb8a00d 100644 --- a/src-tauri/src/proxy/failover_switch.rs +++ b/src-tauri/src/proxy/failover_switch.rs @@ -81,6 +81,21 @@ impl FailoverSwitchManager { provider_id: &str, provider_name: &str, ) -> Result { + // 检查该应用是否已被代理接管(enabled=true) + // 只有被接管的应用才允许执行故障转移切换 + let app_enabled = match self.db.get_proxy_config_for_app(app_type).await { + Ok(config) => config.enabled, + Err(e) => { + log::warn!("[Failover] 无法读取 {app_type} 配置: {e},跳过切换"); + return Ok(false); + } + }; + + if !app_enabled { + log::info!("[Failover] {app_type} 未被代理接管(enabled=false),跳过切换"); + return Ok(false); + } + log::info!("[Failover] 开始切换供应商: {app_type} -> {provider_name} ({provider_id})"); // 1. 更新数据库 is_current diff --git a/src-tauri/src/proxy/forwarder.rs b/src-tauri/src/proxy/forwarder.rs index 334182f92..742ebdb42 100644 --- a/src-tauri/src/proxy/forwarder.rs +++ b/src-tauri/src/proxy/forwarder.rs @@ -1,8 +1,9 @@ //! 请求转发器 //! -//! 负责将请求转发到上游Provider,支持重试和故障转移 +//! 负责将请求转发到上游Provider,支持故障转移 use super::{ + body_filter::filter_private_params_with_whitelist, error::*, failover_switch::FailoverSwitchManager, provider_router::ProviderRouter, @@ -17,6 +18,71 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::RwLock; +/// Headers 黑名单 - 不透传到上游的 Headers +/// +/// 参考 Claude Code Hub 设计,过滤以下类别: +/// 1. 认证类(会被覆盖) +/// 2. 连接类(由 HTTP 客户端管理) +/// 3. 代理转发类 +/// 4. CDN/云服务商特定头 +/// 5. 请求追踪类 +/// 6. 浏览器特定头(可能被上游检测) +/// +/// 注意:客户端 IP 类(x-forwarded-for, x-real-ip)默认透传 +const HEADER_BLACKLIST: &[&str] = &[ + // 认证类(会被覆盖) + "authorization", + "x-api-key", + // 连接类 + "host", + "content-length", + "connection", + "transfer-encoding", + // 编码类(会被覆盖为 identity) + "accept-encoding", + // 代理转发类(保留 x-forwarded-for 和 x-real-ip) + "x-forwarded-host", + "x-forwarded-port", + "x-forwarded-proto", + "forwarded", + // CDN/云服务商特定头 + "cf-connecting-ip", + "cf-ipcountry", + "cf-ray", + "cf-visitor", + "true-client-ip", + "fastly-client-ip", + "x-azure-clientip", + "x-azure-fdid", + "x-azure-ref", + "akamai-origin-hop", + "x-akamai-config-log-detail", + // 请求追踪类 + "x-request-id", + "x-correlation-id", + "x-trace-id", + "x-amzn-trace-id", + "x-b3-traceid", + "x-b3-spanid", + "x-b3-parentspanid", + "x-b3-sampled", + "traceparent", + "tracestate", + // 浏览器特定头(可能被上游检测为非 CLI 请求) + "sec-fetch-mode", + "sec-fetch-site", + "sec-fetch-dest", + "sec-ch-ua", + "sec-ch-ua-mobile", + "sec-ch-ua-platform", + "accept-language", + // anthropic-beta 单独处理,避免重复 + "anthropic-beta", + // 客户端 IP 单独处理(默认透传) + "x-forwarded-for", + "x-real-ip", +]; + pub struct ForwardResult { pub response: Response, pub provider: Provider, @@ -31,8 +97,6 @@ pub struct RequestForwarder { client: Client, /// 共享的 ProviderRouter(持有熔断器状态) router: Arc, - /// 单个 Provider 内的最大重试次数 - max_retries: u8, status: Arc>, current_providers: Arc>>, /// 故障转移切换管理器 @@ -48,7 +112,6 @@ impl RequestForwarder { pub fn new( router: Arc, non_streaming_timeout: u64, - max_retries: u8, status: Arc>, current_providers: Arc>>, failover_manager: Arc, @@ -77,7 +140,6 @@ impl RequestForwarder { Self { client, router, - max_retries, status, current_providers, failover_manager, @@ -86,59 +148,6 @@ impl RequestForwarder { } } - /// 对单个 Provider 执行请求(带重试) - /// - /// 在同一个 Provider 上最多重试 max_retries 次,使用指数退避 - async fn forward_with_provider_retry( - &self, - provider: &Provider, - endpoint: &str, - body: &Value, - headers: &axum::http::HeaderMap, - adapter: &dyn ProviderAdapter, - ) -> Result { - let mut last_error = None; - - for attempt in 0..=self.max_retries { - if attempt > 0 { - // 指数退避:100ms, 200ms, 400ms, ... - let delay_ms = 100 * 2u64.pow(attempt as u32 - 1); - log::info!( - "[{}] 重试第 {}/{} 次(等待 {}ms)", - adapter.name(), - attempt, - self.max_retries, - delay_ms - ); - tokio::time::sleep(Duration::from_millis(delay_ms)).await; - } - - match self - .forward(provider, endpoint, body, headers, adapter) - .await - { - Ok(response) => return Ok(response), - Err(e) => { - // 只有“同一 Provider 内可重试”的错误才继续重试 - if !self.should_retry_same_provider(&e) { - return Err(e); - } - - log::debug!( - "[{}] Provider {} 第 {} 次请求失败: {}", - adapter.name(), - provider.name, - attempt + 1, - e - ); - last_error = Some(e); - } - } - } - - Err(last_error.unwrap_or(ProxyError::MaxRetriesExceeded)) - } - /// 转发请求(带故障转移) /// /// # Arguments @@ -224,9 +233,9 @@ impl RequestForwarder { let start = Instant::now(); - // 转发请求(带单 Provider 内重试) + // 转发请求(每个 Provider 只尝试一次,重试由客户端控制) match self - .forward_with_provider_retry(provider, endpoint, &body, &headers, adapter.as_ref()) + .forward(provider, endpoint, &body, &headers, adapter.as_ref()) .await { Ok(response) => { @@ -477,6 +486,28 @@ impl RequestForwarder { mapped_body }; + // 过滤私有参数(以 `_` 开头的字段),防止内部信息泄露到上游 + // 默认使用空白名单,过滤所有 _ 前缀字段 + let filtered_body = filter_private_params_with_whitelist(request_body, &[]); + + // ========== 请求体日志(截断显示) ========== + let body_str = serde_json::to_string_pretty(&filtered_body) + .unwrap_or_else(|_| filtered_body.to_string()); + let body_preview = if body_str.len() > 2000 { + format!( + "{}...\n[截断,总长度: {} 字符]", + &body_str[..2000], + body_str.len() + ) + } else { + body_str + }; + log::info!( + "[{}] ====== 最终请求体 ======\n{}", + adapter.name(), + body_preview + ); + log::info!( "[{}] 转发请求: {} -> {}", adapter.name(), @@ -487,28 +518,73 @@ impl RequestForwarder { // 构建请求 let mut request = self.client.post(&url); - // 只透传必要的 Headers(白名单模式) - let allowed_headers = [ - "accept", - "user-agent", - "x-request-id", - "x-stainless-arch", - "x-stainless-lang", - "x-stainless-os", - "x-stainless-package-version", - "x-stainless-runtime", - "x-stainless-runtime-version", - ]; + // ========== 详细 Headers 日志 ========== + log::info!("[{}] ====== 客户端原始 Headers ======", adapter.name()); + for (key, value) in headers { + log::info!( + "[{}] {}: {:?}", + adapter.name(), + key.as_str(), + value.to_str().unwrap_or("") + ); + } + + // 过滤黑名单 Headers,保护隐私并避免冲突 + let mut filtered_headers: Vec = Vec::new(); + let mut passed_headers: Vec<(String, String)> = Vec::new(); for (key, value) in headers { let key_str = key.as_str().to_lowercase(); - if allowed_headers.contains(&key_str.as_str()) { - request = request.header(key, value); + if HEADER_BLACKLIST.contains(&key_str.as_str()) { + filtered_headers.push(key_str); + continue; + } + let value_str = value.to_str().unwrap_or("").to_string(); + passed_headers.push((key.as_str().to_string(), value_str.clone())); + request = request.header(key, value); + } + + if !filtered_headers.is_empty() { + log::info!( + "[{}] ====== 被过滤的 Headers ({}) ======", + adapter.name(), + filtered_headers.len() + ); + for h in &filtered_headers { + log::info!("[{}] - {}", adapter.name(), h); } } - // 确保 Content-Type 是 json - request = request.header("Content-Type", "application/json"); + // 处理 anthropic-beta Header(透传) + // 参考 Claude Code Hub 的实现,直接透传客户端的 beta 标记 + if let Some(beta) = headers.get("anthropic-beta") { + if let Ok(beta_str) = beta.to_str() { + request = request.header("anthropic-beta", beta_str); + passed_headers.push(("anthropic-beta".to_string(), beta_str.to_string())); + log::info!("[{}] 透传 anthropic-beta: {}", adapter.name(), beta_str); + } + } + + // 客户端 IP 透传(默认开启) + if let Some(xff) = headers.get("x-forwarded-for") { + if let Ok(xff_str) = xff.to_str() { + request = request.header("x-forwarded-for", xff_str); + passed_headers.push(("x-forwarded-for".to_string(), xff_str.to_string())); + log::debug!("[{}] 透传 x-forwarded-for: {}", adapter.name(), xff_str); + } + } + if let Some(real_ip) = headers.get("x-real-ip") { + if let Ok(real_ip_str) = real_ip.to_str() { + request = request.header("x-real-ip", real_ip_str); + passed_headers.push(("x-real-ip".to_string(), real_ip_str.to_string())); + log::debug!("[{}] 透传 x-real-ip: {}", adapter.name(), real_ip_str); + } + } + + // 禁用压缩,避免 gzip 流式响应解析错误 + // 参考 CCH: undici 在连接提前关闭时会对不完整的 gzip 流抛出错误 + request = request.header("accept-encoding", "identity"); + passed_headers.push(("accept-encoding".to_string(), "identity".to_string())); // 使用适配器添加认证头 if let Some(auth) = adapter.extract_auth(provider) { @@ -519,6 +595,15 @@ impl RequestForwarder { auth.masked_key() ); request = adapter.add_auth_headers(request, &auth); + // 记录认证头(脱敏) + passed_headers.push(( + "authorization".to_string(), + format!("Bearer {}...", &auth.api_key[..8.min(auth.api_key.len())]), + )); + passed_headers.push(( + "x-api-key".to_string(), + format!("{}...", &auth.api_key[..8.min(auth.api_key.len())]), + )); } else { log::error!( "[{}] 未找到 API Key!Provider: {}", @@ -527,9 +612,34 @@ impl RequestForwarder { ); } + // anthropic-version 透传:优先使用客户端的版本号 + // 参考 Claude Code Hub:透传客户端值而非固定版本 + if let Some(version) = headers.get("anthropic-version") { + if let Ok(version_str) = version.to_str() { + // 覆盖适配器设置的默认版本 + request = request.header("anthropic-version", version_str); + passed_headers.push(("anthropic-version".to_string(), version_str.to_string())); + log::info!( + "[{}] 透传 anthropic-version: {}", + adapter.name(), + version_str + ); + } + } + + // ========== 最终发送的 Headers 日志 ========== + log::info!( + "[{}] ====== 最终发送的 Headers ({}) ======", + adapter.name(), + passed_headers.len() + ); + for (k, v) in &passed_headers { + log::info!("[{}] {}: {}", adapter.name(), k, v); + } + // 发送请求 log::info!("[{}] 发送请求到: {}", adapter.name(), url); - let response = request.json(&request_body).send().await.map_err(|e| { + let response = request.json(&filtered_body).send().await.map_err(|e| { log::error!("[{}] 请求失败: {}", adapter.name(), e); if e.is_timeout() { ProxyError::Timeout(format!("请求超时: {e}")) @@ -563,25 +673,6 @@ impl RequestForwarder { } } - /// 分类ProxyError - /// - /// 决定哪些错误应该触发故障转移到下一个 Provider - /// - /// 设计原则:既然用户配置了多个供应商,就应该让所有供应商都尝试一遍。 - /// 只有明确是客户端中断的情况才不重试。 - fn should_retry_same_provider(&self, error: &ProxyError) -> bool { - match error { - // 网络类错误:短暂抖动时同一 Provider 内重试有意义 - ProxyError::Timeout(_) => true, - ProxyError::ForwardFailed(_) => true, - // 上游 HTTP 错误:只对“可能瞬态”的状态码做同 Provider 重试(其余交给 failover) - ProxyError::UpstreamError { status, .. } => { - *status == 408 || *status == 429 || *status >= 500 - } - _ => false, - } - } - fn categorize_proxy_error(&self, error: &ProxyError) -> ErrorCategory { match error { // 网络和上游错误:都应该尝试下一个供应商 @@ -597,7 +688,6 @@ impl RequestForwarder { ProxyError::TransformError(_) => ErrorCategory::Retryable, ProxyError::AuthError(_) => ErrorCategory::Retryable, ProxyError::StreamIdleTimeout(_) => ErrorCategory::Retryable, - ProxyError::MaxRetriesExceeded => ErrorCategory::Retryable, // 无可用供应商:所有供应商都试过了,无法重试 ProxyError::NoAvailableProvider => ErrorCategory::NonRetryable, // 其他错误(数据库/内部错误等):不是换供应商能解决的问题 diff --git a/src-tauri/src/proxy/handler_config.rs b/src-tauri/src/proxy/handler_config.rs index ca2df0787..fcbeb1fa2 100644 --- a/src-tauri/src/proxy/handler_config.rs +++ b/src-tauri/src/proxy/handler_config.rs @@ -58,10 +58,10 @@ fn openai_model_extractor(events: &[Value], request_model: &str) -> String { .to_string() } -/// Codex Responses API 流式响应模型提取(优先使用 usage.model) -fn codex_model_extractor(events: &[Value], request_model: &str) -> String { +/// Codex 智能流式响应模型提取(自动检测格式) +fn codex_auto_model_extractor(events: &[Value], request_model: &str) -> String { // 首先尝试从解析的 usage 中获取模型 - if let Some(usage) = TokenUsage::from_codex_stream_events(events) { + if let Some(usage) = TokenUsage::from_codex_stream_events_auto(events) { if let Some(model) = usage.model { return model; } @@ -76,6 +76,10 @@ fn codex_model_extractor(events: &[Value], request_model: &str) -> String { None } }) + .or_else(|| { + // 再回退:从 OpenAI 格式事件中提取 + events.iter().find_map(|e| e.get("model")?.as_str()) + }) .unwrap_or(request_model) .to_string() } @@ -111,11 +115,11 @@ pub const OPENAI_PARSER_CONFIG: UsageParserConfig = UsageParserConfig { app_type_str: "codex", }; -/// Codex Responses API 解析配置(用于 /v1/responses) +/// Codex 智能解析配置(自动检测 OpenAI 或 Codex 格式) pub const CODEX_PARSER_CONFIG: UsageParserConfig = UsageParserConfig { - stream_parser: TokenUsage::from_codex_stream_events, - response_parser: TokenUsage::from_codex_response, - model_extractor: codex_model_extractor, + stream_parser: TokenUsage::from_codex_stream_events_auto, + response_parser: TokenUsage::from_codex_response_auto, + model_extractor: codex_auto_model_extractor, app_type_str: "codex", }; diff --git a/src-tauri/src/proxy/handler_context.rs b/src-tauri/src/proxy/handler_context.rs index 86bce211d..773515f3d 100644 --- a/src-tauri/src/proxy/handler_context.rs +++ b/src-tauri/src/proxy/handler_context.rs @@ -5,8 +5,10 @@ use crate::app_config::AppType; use crate::provider::Provider; use crate::proxy::{ - forwarder::RequestForwarder, server::ProxyState, types::AppProxyConfig, ProxyError, + extract_session_id, forwarder::RequestForwarder, server::ProxyState, types::AppProxyConfig, + ProxyError, }; +use axum::http::HeaderMap; use std::time::Instant; /// 流式超时配置 @@ -26,6 +28,7 @@ pub struct StreamingTimeoutConfig { /// - 选中的 Provider 列表(用于故障转移) /// - 请求模型名称 /// - 日志标签 +/// - Session ID(用于日志关联) pub struct RequestContext { /// 请求开始时间 pub start_time: Instant, @@ -35,7 +38,7 @@ pub struct RequestContext { pub provider: Provider, /// 完整的 Provider 列表(用于故障转移) providers: Vec, - /// 请求开始时的“当前供应商”(用于判断是否需要同步 UI/托盘) + /// 请求开始时的"当前供应商"(用于判断是否需要同步 UI/托盘) /// /// 这里使用本地 settings 的设备级 current provider。 /// 代理模式下如果实际使用的 provider 与此不一致,会触发切换以确保 UI 始终准确。 @@ -49,6 +52,8 @@ pub struct RequestContext { /// 应用类型(预留,目前通过 app_type_str 使用) #[allow(dead_code)] pub app_type: AppType, + /// Session ID(从客户端请求提取或新生成) + pub session_id: String, } impl RequestContext { @@ -57,6 +62,7 @@ impl RequestContext { /// # Arguments /// * `state` - 代理服务器状态 /// * `body` - 请求体 JSON + /// * `headers` - 请求头(用于提取 Session ID) /// * `app_type` - 应用类型 /// * `tag` - 日志标签 /// * `app_type_str` - 应用类型字符串 @@ -66,6 +72,7 @@ impl RequestContext { pub async fn new( state: &ProxyState, body: &serde_json::Value, + headers: &HeaderMap, app_type: AppType, tag: &'static str, app_type_str: &'static str, @@ -89,13 +96,31 @@ impl RequestContext { .unwrap_or("unknown") .to_string(); + // 提取 Session ID + let session_result = extract_session_id(headers, body, app_type_str); + let session_id = session_result.session_id.clone(); + + log::debug!( + "[{}] Session ID: {} (from {:?}, client_provided: {})", + tag, + session_id, + session_result.source, + session_result.client_provided + ); + // 使用共享的 ProviderRouter 选择 Provider(熔断器状态跨请求保持) // 注意:只在这里调用一次,结果传递给 forwarder,避免重复消耗 HalfOpen 名额 let providers = state .provider_router .select_providers(app_type_str) .await - .map_err(|e| ProxyError::DatabaseError(e.to_string()))?; + .map_err(|e| match e { + crate::error::AppError::AllProvidersCircuitOpen => { + ProxyError::AllProvidersCircuitOpen + } + crate::error::AppError::NoProvidersConfigured => ProxyError::NoProvidersConfigured, + _ => ProxyError::DatabaseError(e.to_string()), + })?; let provider = providers .first() @@ -103,11 +128,12 @@ impl RequestContext { .ok_or(ProxyError::NoAvailableProvider)?; log::info!( - "[{}] Provider: {}, model: {}, failover chain: {} providers", + "[{}] Provider: {}, model: {}, failover chain: {} providers, session: {}", tag, provider.name, request_model, - providers.len() + providers.len(), + session_id ); Ok(Self { @@ -120,6 +146,7 @@ impl RequestContext { tag, app_type_str, app_type, + session_id, }) } @@ -148,18 +175,38 @@ impl RequestContext { /// 创建 RequestForwarder /// /// 使用共享的 ProviderRouter,确保熔断器状态跨请求保持 + /// + /// 配置生效规则: + /// - 故障转移开启:超时配置正常生效(0 表示禁用超时) + /// - 故障转移关闭:超时配置不生效(全部传入 0) pub fn create_forwarder(&self, state: &ProxyState) -> RequestForwarder { + let (non_streaming_timeout, first_byte_timeout, idle_timeout) = + if self.app_config.auto_failover_enabled { + // 故障转移开启:使用配置的值(0 = 禁用超时) + ( + self.app_config.non_streaming_timeout as u64, + self.app_config.streaming_first_byte_timeout as u64, + self.app_config.streaming_idle_timeout as u64, + ) + } else { + // 故障转移关闭:不启用超时配置 + log::info!( + "[{}] Failover disabled, timeout configs are bypassed", + self.tag + ); + (0, 0, 0) + }; + RequestForwarder::new( state.provider_router.clone(), - self.app_config.non_streaming_timeout as u64, - self.app_config.max_retries as u8, + non_streaming_timeout, state.status.clone(), state.current_providers.clone(), state.failover_manager.clone(), state.app_handle.clone(), self.current_provider_id.clone(), - self.app_config.streaming_first_byte_timeout as u64, - self.app_config.streaming_idle_timeout as u64, + first_byte_timeout, + idle_timeout, ) } @@ -177,11 +224,24 @@ impl RequestContext { } /// 获取流式超时配置 + /// + /// 配置生效规则: + /// - 故障转移开启:返回配置的值(0 表示禁用超时检查) + /// - 故障转移关闭:返回 0(禁用超时检查) #[inline] pub fn streaming_timeout_config(&self) -> StreamingTimeoutConfig { - StreamingTimeoutConfig { - first_byte_timeout: self.app_config.streaming_first_byte_timeout as u64, - idle_timeout: self.app_config.streaming_idle_timeout as u64, + if self.app_config.auto_failover_enabled { + // 故障转移开启:使用配置的值(0 = 禁用超时) + StreamingTimeoutConfig { + first_byte_timeout: self.app_config.streaming_first_byte_timeout as u64, + idle_timeout: self.app_config.streaming_idle_timeout as u64, + } + } else { + // 故障转移关闭:禁用流式超时检查 + StreamingTimeoutConfig { + first_byte_timeout: 0, + idle_timeout: 0, + } } } } diff --git a/src-tauri/src/proxy/handlers.rs b/src-tauri/src/proxy/handlers.rs index 20512dc70..0e29aff74 100644 --- a/src-tauri/src/proxy/handlers.rs +++ b/src-tauri/src/proxy/handlers.rs @@ -61,7 +61,8 @@ pub async fn handle_messages( headers: axum::http::HeaderMap, Json(body): Json, ) -> Result { - let mut ctx = RequestContext::new(&state, &body, AppType::Claude, "Claude", "claude").await?; + let mut ctx = + RequestContext::new(&state, &body, &headers, AppType::Claude, "Claude", "claude").await?; let is_stream = body .get("stream") @@ -305,7 +306,8 @@ pub async fn handle_chat_completions( ) -> Result { log::info!("[Codex] ====== /v1/chat/completions 请求开始 ======"); - let mut ctx = RequestContext::new(&state, &body, AppType::Codex, "Codex", "codex").await?; + let mut ctx = + RequestContext::new(&state, &body, &headers, AppType::Codex, "Codex", "codex").await?; let is_stream = body .get("stream") @@ -353,7 +355,8 @@ pub async fn handle_responses( headers: axum::http::HeaderMap, Json(body): Json, ) -> Result { - let mut ctx = RequestContext::new(&state, &body, AppType::Codex, "Codex", "codex").await?; + let mut ctx = + RequestContext::new(&state, &body, &headers, AppType::Codex, "Codex", "codex").await?; let is_stream = body .get("stream") @@ -401,7 +404,7 @@ pub async fn handle_gemini( Json(body): Json, ) -> Result { // Gemini 的模型名称在 URI 中 - let mut ctx = RequestContext::new(&state, &body, AppType::Gemini, "Gemini", "gemini") + let mut ctx = RequestContext::new(&state, &body, &headers, AppType::Gemini, "Gemini", "gemini") .await? .with_model_from_uri(&uri); @@ -465,7 +468,7 @@ fn log_forward_error( let request_id = uuid::Uuid::new_v4().to_string(); if let Err(e) = logger.log_error_with_context( - request_id.clone(), + request_id, ctx.provider.id.clone(), ctx.app_type_str.to_string(), ctx.request_model.clone(), @@ -473,7 +476,7 @@ fn log_forward_error( error_message, ctx.latency_ms(), is_streaming, - Some(request_id), + Some(ctx.session_id.clone()), None, ) { log::warn!("记录失败请求日志失败: {e}"); diff --git a/src-tauri/src/proxy/mod.rs b/src-tauri/src/proxy/mod.rs index 69063dc08..5b0ac0216 100644 --- a/src-tauri/src/proxy/mod.rs +++ b/src-tauri/src/proxy/mod.rs @@ -2,6 +2,7 @@ //! //! 提供本地HTTP代理服务,支持多Provider故障转移和请求透传 +pub mod body_filter; pub mod circuit_breaker; pub mod error; pub mod error_mapper; @@ -33,7 +34,9 @@ pub use provider_router::ProviderRouter; #[allow(unused_imports)] pub use response_handler::{NonStreamHandler, ResponseType, StreamHandler}; #[allow(unused_imports)] -pub use session::{ClientFormat, ProxySession}; +pub use session::{ + extract_session_id, ClientFormat, ProxySession, SessionIdResult, SessionIdSource, +}; #[allow(unused_imports)] pub use types::{ProxyConfig, ProxyServerInfo, ProxyStatus}; diff --git a/src-tauri/src/proxy/provider_router.rs b/src-tauri/src/proxy/provider_router.rs index 1cbae1a65..12373587e 100644 --- a/src-tauri/src/proxy/provider_router.rs +++ b/src-tauri/src/proxy/provider_router.rs @@ -34,6 +34,8 @@ impl ProviderRouter { /// - 故障转移开启时:完全按照故障转移队列顺序返回,忽略当前供应商设置 pub async fn select_providers(&self, app_type: &str) -> Result, AppError> { let mut result = Vec::new(); + let mut total_providers = 0usize; + let mut circuit_open_count = 0usize; // 检查该应用的自动故障转移开关是否开启(从 proxy_config 表读取) let auto_failover_enabled = match self.db.get_proxy_config_for_app(app_type).await { @@ -53,18 +55,26 @@ impl ProviderRouter { if auto_failover_enabled { // 故障转移开启:使用 in_failover_queue 标记的供应商,按 sort_index 排序 let failover_providers = self.db.get_failover_providers(app_type)?; + total_providers = failover_providers.len(); + log::debug!("[{app_type}] Found {total_providers} failover queue provider(s)"); log::info!( - "[{}] Failover enabled, using queue order ({} items)", - app_type, - failover_providers.len() + "[{app_type}] Failover enabled, using queue order ({total_providers} items)" ); for provider in failover_providers { // 检查熔断器状态 let circuit_key = format!("{}:{}", app_type, provider.id); let breaker = self.get_or_create_circuit_breaker(&circuit_key).await; + let state = breaker.get_state().await; if breaker.is_available().await { + log::debug!( + "[{}] Queue provider available: {} ({}) (state: {:?})", + app_type, + provider.name, + provider.id, + state + ); log::info!( "[{}] Queue provider available: {} ({}) at sort_index {:?}", app_type, @@ -74,10 +84,12 @@ impl ProviderRouter { ); result.push(provider); } else { + circuit_open_count += 1; log::debug!( - "[{}] Queue provider {} circuit breaker open, skipping", + "[{}] Queue provider {} circuit breaker open (state: {:?}), skipping", app_type, - provider.name + provider.name, + state ); } } @@ -94,15 +106,27 @@ impl ProviderRouter { current.name, current.id ); + total_providers = 1; result.push(current); + } else { + log::debug!( + "[{app_type}] Current provider id {current_id} not found in database" + ); } + } else { + log::debug!("[{app_type}] No current provider configured"); } } if result.is_empty() { - return Err(AppError::Config(format!( - "No available provider for {app_type} (all circuit breakers open or no providers configured)" - ))); + // 区分两种情况:全部熔断 vs 未配置供应商 + if total_providers > 0 && circuit_open_count == total_providers { + log::warn!("[{app_type}] 所有 {total_providers} 个供应商均已熔断,无可用渠道"); + return Err(AppError::AllProvidersCircuitOpen); + } else { + log::warn!("[{app_type}] 未配置供应商或故障转移队列为空"); + return Err(AppError::NoProvidersConfigured); + } } log::info!( diff --git a/src-tauri/src/proxy/response_processor.rs b/src-tauri/src/proxy/response_processor.rs index 714302b3c..0df50e039 100644 --- a/src-tauri/src/proxy/response_processor.rs +++ b/src-tauri/src/proxy/response_processor.rs @@ -112,6 +112,19 @@ pub async fn handle_non_streaming( spawn_log_usage(state, ctx, usage, &model, status.as_u16(), false); } else { + let model = json_value + .get("model") + .and_then(|m| m.as_str()) + .unwrap_or(&ctx.request_model) + .to_string(); + spawn_log_usage( + state, + ctx, + TokenUsage::default(), + &model, + status.as_u16(), + false, + ); log::debug!( "[{}] 未能解析 usage 信息,跳过记录", parser_config.app_type_str @@ -123,6 +136,14 @@ pub async fn handle_non_streaming( ctx.tag, body_bytes.len() ); + spawn_log_usage( + state, + ctx, + TokenUsage::default(), + &ctx.request_model, + status.as_u16(), + false, + ); } log::info!("[{}] ====== 请求结束 ======", ctx.tag); @@ -243,6 +264,7 @@ fn create_usage_collector( let start_time = ctx.start_time; let stream_parser = parser_config.stream_parser; let model_extractor = parser_config.model_extractor; + let session_id = ctx.session_id.clone(); SseUsageCollector::new(start_time, move |events, first_token_ms| { if let Some(usage) = stream_parser(&events) { @@ -251,6 +273,7 @@ fn create_usage_collector( let state = state.clone(); let provider_id = provider_id.clone(); + let session_id = session_id.clone(); tokio::spawn(async move { log_usage_internal( @@ -263,10 +286,32 @@ fn create_usage_collector( first_token_ms, true, // is_streaming status_code, + Some(session_id), ) .await; }); } else { + let model = model_extractor(&events, &request_model); + let latency_ms = start_time.elapsed().as_millis() as u64; + let state = state.clone(); + let provider_id = provider_id.clone(); + let session_id = session_id.clone(); + + tokio::spawn(async move { + log_usage_internal( + &state, + &provider_id, + app_type_str, + &model, + TokenUsage::default(), + latency_ms, + first_token_ms, + true, // is_streaming + status_code, + Some(session_id), + ) + .await; + }); log::debug!("[{tag}] 流式响应缺少 usage 统计,跳过消费记录"); } }) @@ -286,6 +331,7 @@ fn spawn_log_usage( let app_type_str = ctx.app_type_str.to_string(); let model = model.to_string(); let latency_ms = ctx.latency_ms(); + let session_id = ctx.session_id.clone(); tokio::spawn(async move { log_usage_internal( @@ -298,6 +344,7 @@ fn spawn_log_usage( None, is_streaming, status_code, + Some(session_id), ) .await; }); @@ -315,6 +362,7 @@ async fn log_usage_internal( first_token_ms: Option, is_streaming: bool, status_code: u16, + session_id: Option, ) { use super::usage::logger::UsageLogger; @@ -338,6 +386,15 @@ async fn log_usage_internal( let request_id = uuid::Uuid::new_v4().to_string(); + log::debug!( + "[{app_type}] 记录请求日志: id={request_id}, provider={provider_id}, model={model}, streaming={is_streaming}, status={status_code}, latency_ms={latency_ms}, first_token_ms={first_token_ms:?}, session={}, input={}, output={}, cache_read={}, cache_creation={}", + session_id.as_deref().unwrap_or("none"), + usage.input_tokens, + usage.output_tokens, + usage.cache_read_tokens, + usage.cache_creation_tokens + ); + if let Err(e) = logger.log_with_calculation( request_id, provider_id.to_string(), @@ -348,7 +405,7 @@ async fn log_usage_internal( latency_ms, first_token_ms, status_code, - None, + session_id, None, // provider_type is_streaming, ) { diff --git a/src-tauri/src/proxy/session.rs b/src-tauri/src/proxy/session.rs index 3e1d14e56..cf3f964fe 100644 --- a/src-tauri/src/proxy/session.rs +++ b/src-tauri/src/proxy/session.rs @@ -1,7 +1,15 @@ //! Proxy Session - 请求会话管理 //! //! 为每个代理请求创建会话上下文,在整个请求生命周期中跟踪状态和元数据。 +//! +//! ## Session ID 提取 +//! +//! 支持从客户端请求中提取 Session ID,用于关联同一对话的多个请求: +//! - Claude: 从 `metadata.user_id` (格式: `user_xxx_session_yyy`) 或 `metadata.session_id` 提取 +//! - Codex: 从 `previous_response_id` 或 headers 中的 `session_id` 提取 +//! - 其他: 生成新的 UUID +use axum::http::HeaderMap; use std::time::Instant; use uuid::Uuid; @@ -176,6 +184,179 @@ impl ProxySession { } } +// ============================================================================ +// Session ID 提取器 +// ============================================================================ + +/// Session ID 来源 +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SessionIdSource { + /// 从 metadata.user_id 提取 (Claude) + MetadataUserId, + /// 从 metadata.session_id 提取 + MetadataSessionId, + /// 从 headers 提取 (Codex) + Header, + /// 从 previous_response_id 提取 (Codex) + PreviousResponseId, + /// 新生成 + Generated, +} + +/// Session ID 提取结果 +#[derive(Debug, Clone)] +pub struct SessionIdResult { + /// 提取或生成的 Session ID + pub session_id: String, + /// Session ID 来源 + pub source: SessionIdSource, + /// 是否为客户端提供的 ID(非新生成) + pub client_provided: bool, +} + +/// 从请求中提取或生成 Session ID +/// +/// 轻量化实现,仅提取 session_id 用于日志记录,不做复杂的 Session 管理。 +/// +/// ## 提取优先级 +/// +/// ### Claude 请求 +/// 1. `metadata.user_id` (格式: `user_xxx_session_yyy`) → 提取 `yyy` 部分 +/// 2. `metadata.session_id` → 直接使用 +/// 3. 生成新 UUID +/// +/// ### Codex 请求 +/// 1. Headers: `session_id` 或 `x-session-id` +/// 2. `metadata.session_id` +/// 3. `previous_response_id` (对话延续) +/// 4. 生成新 UUID +/// +/// ## 示例 +/// +/// ```ignore +/// let result = extract_session_id(&headers, &body, "claude"); +/// println!("Session ID: {} (from {:?})", result.session_id, result.source); +/// ``` +pub fn extract_session_id( + headers: &HeaderMap, + body: &serde_json::Value, + client_format: &str, +) -> SessionIdResult { + // Codex 请求特殊处理 + if client_format == "codex" || client_format == "openai" { + if let Some(result) = extract_codex_session(headers, body) { + return result; + } + } + + // Claude 请求:从 metadata 提取 + if let Some(result) = extract_from_metadata(body) { + return result; + } + + // 兜底:生成新 Session ID + generate_new_session_id() +} + +/// 提取 Codex Session ID +fn extract_codex_session(headers: &HeaderMap, body: &serde_json::Value) -> Option { + // 1. 从 headers 提取 + for header_name in &["session_id", "x-session-id"] { + if let Some(value) = headers.get(*header_name) { + if let Ok(session_id) = value.to_str() { + // Codex Session ID 通常较长(UUID 格式) + if session_id.len() > 20 { + return Some(SessionIdResult { + session_id: format!("codex_{session_id}"), + source: SessionIdSource::Header, + client_provided: true, + }); + } + } + } + } + + // 2. 从 body.metadata.session_id 提取 + if let Some(session_id) = body + .get("metadata") + .and_then(|m| m.get("session_id")) + .and_then(|v| v.as_str()) + { + if session_id.len() > 10 { + return Some(SessionIdResult { + session_id: format!("codex_{session_id}"), + source: SessionIdSource::MetadataSessionId, + client_provided: true, + }); + } + } + + // 3. 从 previous_response_id 提取(对话延续) + if let Some(prev_id) = body.get("previous_response_id").and_then(|v| v.as_str()) { + if prev_id.len() > 10 { + return Some(SessionIdResult { + session_id: format!("codex_{prev_id}"), + source: SessionIdSource::PreviousResponseId, + client_provided: true, + }); + } + } + + None +} + +/// 从 metadata 提取 Session ID (Claude) +fn extract_from_metadata(body: &serde_json::Value) -> Option { + let metadata = body.get("metadata")?; + + // 1. 从 metadata.user_id 提取(格式: user_xxx_session_yyy) + if let Some(user_id) = metadata.get("user_id").and_then(|v| v.as_str()) { + if let Some(session_id) = parse_session_from_user_id(user_id) { + return Some(SessionIdResult { + session_id, + source: SessionIdSource::MetadataUserId, + client_provided: true, + }); + } + } + + // 2. 直接从 metadata.session_id 提取 + if let Some(session_id) = metadata.get("session_id").and_then(|v| v.as_str()) { + if !session_id.is_empty() { + return Some(SessionIdResult { + session_id: session_id.to_string(), + source: SessionIdSource::MetadataSessionId, + client_provided: true, + }); + } + } + + None +} + +/// 从 user_id 解析 session_id +/// +/// 格式: `user_identifier_session_actual_session_id` +fn parse_session_from_user_id(user_id: &str) -> Option { + // 查找 "_session_" 分隔符 + if let Some(pos) = user_id.find("_session_") { + let session_id = &user_id[pos + 9..]; // "_session_" 长度为 9 + if !session_id.is_empty() { + return Some(session_id.to_string()); + } + } + None +} + +/// 生成新的 Session ID +fn generate_new_session_id() -> SessionIdResult { + SessionIdResult { + session_id: Uuid::new_v4().to_string(), + source: SessionIdSource::Generated, + client_provided: false, + } +} + #[cfg(test)] mod tests { use super::*; @@ -295,4 +476,92 @@ mod tests { assert_eq!(ClientFormat::GeminiCli.as_str(), "gemini_cli"); assert_eq!(ClientFormat::Unknown.as_str(), "unknown"); } + + // ========== Session ID 提取测试 ========== + + #[test] + fn test_extract_session_from_claude_metadata_user_id() { + let headers = HeaderMap::new(); + let body = json!({ + "model": "claude-3-5-sonnet", + "messages": [{"role": "user", "content": "Hello"}], + "metadata": { + "user_id": "user_john_doe_session_abc123def456" + } + }); + + let result = extract_session_id(&headers, &body, "claude"); + + assert_eq!(result.session_id, "abc123def456"); + assert_eq!(result.source, SessionIdSource::MetadataUserId); + assert!(result.client_provided); + } + + #[test] + fn test_extract_session_from_claude_metadata_session_id() { + let headers = HeaderMap::new(); + let body = json!({ + "model": "claude-3-5-sonnet", + "messages": [{"role": "user", "content": "Hello"}], + "metadata": { + "session_id": "my-session-123" + } + }); + + let result = extract_session_id(&headers, &body, "claude"); + + assert_eq!(result.session_id, "my-session-123"); + assert_eq!(result.source, SessionIdSource::MetadataSessionId); + assert!(result.client_provided); + } + + #[test] + fn test_extract_session_from_codex_previous_response_id() { + let headers = HeaderMap::new(); + let body = json!({ + "input": "Write a function", + "previous_response_id": "resp_abc123def456789" + }); + + let result = extract_session_id(&headers, &body, "codex"); + + assert_eq!(result.session_id, "codex_resp_abc123def456789"); + assert_eq!(result.source, SessionIdSource::PreviousResponseId); + assert!(result.client_provided); + } + + #[test] + fn test_extract_session_generates_new_when_not_found() { + let headers = HeaderMap::new(); + let body = json!({ + "model": "claude-3-5-sonnet", + "messages": [{"role": "user", "content": "Hello"}] + }); + + let result = extract_session_id(&headers, &body, "claude"); + + assert!(!result.session_id.is_empty()); + assert_eq!(result.source, SessionIdSource::Generated); + assert!(!result.client_provided); + } + + #[test] + fn test_parse_session_from_user_id() { + assert_eq!( + parse_session_from_user_id("user_john_session_abc123"), + Some("abc123".to_string()) + ); + assert_eq!( + parse_session_from_user_id("my_app_session_xyz789"), + Some("xyz789".to_string()) + ); + // 注意: "_session_" 是分隔符,所以下面的字符串会匹配 + assert_eq!( + parse_session_from_user_id("no_session_marker"), + Some("marker".to_string()) + ); + // 没有 "_session_" 分隔符的情况 + assert_eq!(parse_session_from_user_id("user_john_abc123"), None); + assert_eq!(parse_session_from_user_id("_session_"), None); + } } diff --git a/src-tauri/src/proxy/usage/calculator.rs b/src-tauri/src/proxy/usage/calculator.rs index 7295e13ae..80fd2c673 100644 --- a/src-tauri/src/proxy/usage/calculator.rs +++ b/src-tauri/src/proxy/usage/calculator.rs @@ -35,6 +35,11 @@ impl CostCalculator { /// - `usage`: Token 使用量 /// - `pricing`: 模型定价 /// - `cost_multiplier`: 成本倍数 (provider 自定义) + /// + /// # 计算逻辑 + /// - input_cost: (input_tokens - cache_read_tokens) × 输入价格 + /// - cache_read_cost: cache_read_tokens × 缓存读取价格 + /// - 这样避免缓存部分被重复计费 pub fn calculate( usage: &TokenUsage, pricing: &ModelPricing, @@ -42,7 +47,10 @@ impl CostCalculator { ) -> CostBreakdown { let million = Decimal::from(1_000_000); - let input_cost = Decimal::from(usage.input_tokens) * pricing.input_cost_per_million + // 计算实际需要按输入价格计费的 token 数(减去缓存命中部分) + let billable_input_tokens = usage.input_tokens.saturating_sub(usage.cache_read_tokens); + + let input_cost = Decimal::from(billable_input_tokens) * pricing.input_cost_per_million / million * cost_multiplier; let output_cost = Decimal::from(usage.output_tokens) * pricing.output_cost_per_million @@ -113,8 +121,8 @@ mod tests { let cost = CostCalculator::calculate(&usage, &pricing, multiplier); - // input: 1000 * 3.0 / 1M = 0.003 - assert_eq!(cost.input_cost, Decimal::from_str("0.003").unwrap()); + // input: (1000 - 200) * 3.0 / 1M = 0.0024 (只计算非缓存部分) + assert_eq!(cost.input_cost, Decimal::from_str("0.0024").unwrap()); // output: 500 * 15.0 / 1M = 0.0075 assert_eq!(cost.output_cost, Decimal::from_str("0.0075").unwrap()); // cache_read: 200 * 0.3 / 1M = 0.00006 @@ -124,8 +132,8 @@ mod tests { cost.cache_creation_cost, Decimal::from_str("0.000375").unwrap() ); - // total: 0.003 + 0.0075 + 0.00006 + 0.000375 = 0.010935 - assert_eq!(cost.total_cost, Decimal::from_str("0.010935").unwrap()); + // total: 0.0024 + 0.0075 + 0.00006 + 0.000375 = 0.010335 + assert_eq!(cost.total_cost, Decimal::from_str("0.010335").unwrap()); } #[test] diff --git a/src-tauri/src/proxy/usage/parser.rs b/src-tauri/src/proxy/usage/parser.rs index 33c1aaf16..0a33d0187 100644 --- a/src-tauri/src/proxy/usage/parser.rs +++ b/src-tauri/src/proxy/usage/parser.rs @@ -163,13 +163,21 @@ impl TokenUsage { .and_then(|v| v.as_str()) .map(|s| s.to_string()); + let cached_tokens = usage + .get("cache_read_input_tokens") + .and_then(|v| v.as_u64()) + .or_else(|| { + usage + .get("input_tokens_details") + .and_then(|d| d.get("cached_tokens")) + .and_then(|v| v.as_u64()) + }) + .unwrap_or(0) as u32; + Some(Self { input_tokens: input_tokens? as u32, output_tokens: output_tokens? as u32, - cache_read_tokens: usage - .get("cache_read_input_tokens") - .and_then(|v| v.as_u64()) - .unwrap_or(0) as u32, + cache_read_tokens: cached_tokens, cache_creation_tokens: usage .get("cache_creation_input_tokens") .and_then(|v| v.as_u64()) @@ -188,16 +196,27 @@ impl TokenUsage { let input_tokens = usage.get("input_tokens")?.as_u64()? as u32; let output_tokens = usage.get("output_tokens")?.as_u64()? as u32; - // 获取 cached_tokens (可能在 input_tokens_details 中) + // 获取 cached_tokens (可能在 cache_read_input_tokens 或 input_tokens_details 中) let cached_tokens = usage - .get("input_tokens_details") - .and_then(|d| d.get("cached_tokens")) + .get("cache_read_input_tokens") .and_then(|v| v.as_u64()) + .or_else(|| { + usage + .get("input_tokens_details") + .and_then(|d| d.get("cached_tokens")) + .and_then(|v| v.as_u64()) + }) .unwrap_or(0) as u32; // 调整 input_tokens: 减去 cached_tokens let adjusted_input = input_tokens.saturating_sub(cached_tokens); + // 提取响应中的模型名称 + let model = body + .get("model") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + Some(Self { input_tokens: adjusted_input, output_tokens, @@ -206,7 +225,7 @@ impl TokenUsage { .get("cache_creation_input_tokens") .and_then(|v| v.as_u64()) .unwrap_or(0) as u32, - model: None, + model, }) } @@ -220,7 +239,7 @@ impl TokenUsage { if event_type == "response.completed" { if let Some(response) = event.get("response") { log::debug!("[Codex] 找到 response.completed 事件,解析 usage"); - return Self::from_codex_response(response); + return Self::from_codex_response_adjusted(response); } } } @@ -229,6 +248,51 @@ impl TokenUsage { None } + /// 智能 Codex 响应解析 - 自动检测 OpenAI 或 Codex 格式 + /// + /// Codex 支持两种 API 格式: + /// - `/v1/responses`: 使用 input_tokens/output_tokens + /// - `/v1/chat/completions`: 使用 prompt_tokens/completion_tokens (OpenAI 格式) + /// + /// 注意:记录原始 input_tokens,费用计算时再减去 cached_tokens + pub fn from_codex_response_auto(body: &Value) -> Option { + let usage = body.get("usage")?; + + // 检测格式:OpenAI 使用 prompt_tokens,Codex 使用 input_tokens + if usage.get("prompt_tokens").is_some() { + log::debug!("[Codex] 检测到 OpenAI 格式 (prompt_tokens)"); + Self::from_openai_response(body) + } else if usage.get("input_tokens").is_some() { + log::debug!("[Codex] 检测到 Codex 格式 (input_tokens)"); + // 使用非调整版本,记录原始 input_tokens + Self::from_codex_response(body) + } else { + log::debug!("[Codex] 无法识别响应格式,usage: {usage:?}"); + None + } + } + + /// 智能 Codex 流式响应解析 - 自动检测 OpenAI 或 Codex 格式 + pub fn from_codex_stream_events_auto(events: &[Value]) -> Option { + log::debug!("[Codex] 智能解析流式事件,共 {} 个事件", events.len()); + + // 先尝试 Codex Responses API 格式 (response.completed 事件) + for event in events { + if let Some(event_type) = event.get("type").and_then(|v| v.as_str()) { + if event_type == "response.completed" { + if let Some(response) = event.get("response") { + log::debug!("[Codex] 找到 response.completed 事件"); + return Self::from_codex_response_auto(response); + } + } + } + } + + // 回退到 OpenAI Chat Completions 格式 (最后一个 chunk 包含 usage) + log::debug!("[Codex] 尝试 OpenAI 流式格式"); + Self::from_openai_stream_events(events) + } + /// 从 OpenAI Chat Completions API 响应解析 (prompt_tokens, completion_tokens) pub fn from_openai_response(body: &Value) -> Option { let usage = body.get("usage")?; @@ -284,9 +348,16 @@ impl TokenUsage { .and_then(|v| v.as_str()) .map(|s| s.to_string()); + let prompt_tokens = usage.get("promptTokenCount")?.as_u64()? as u32; + let total_tokens = usage.get("totalTokenCount")?.as_u64()? as u32; + + // 输出 tokens = 总 tokens - 输入 tokens + // 这包含了 candidatesTokenCount + thoughtsTokenCount + let output_tokens = total_tokens.saturating_sub(prompt_tokens); + Some(Self { - input_tokens: usage.get("promptTokenCount")?.as_u64()? as u32, - output_tokens: usage.get("candidatesTokenCount")?.as_u64()? as u32, + input_tokens: prompt_tokens, + output_tokens, cache_read_tokens: usage .get("cachedContentTokenCount") .and_then(|v| v.as_u64()) @@ -300,20 +371,25 @@ impl TokenUsage { #[allow(dead_code)] pub fn from_gemini_stream_chunks(chunks: &[Value]) -> Option { let mut total_input = 0u32; - let mut total_output = 0u32; + let mut total_tokens = 0u32; let mut total_cache_read = 0u32; let mut model: Option = None; for chunk in chunks { if let Some(usage) = chunk.get("usageMetadata") { + // 输入 tokens (通常在所有 chunk 中保持不变) total_input = usage .get("promptTokenCount") .and_then(|v| v.as_u64()) .unwrap_or(0) as u32; - total_output += usage - .get("candidatesTokenCount") + + // 总 tokens (包含输入 + 输出 + 思考) + total_tokens = usage + .get("totalTokenCount") .and_then(|v| v.as_u64()) .unwrap_or(0) as u32; + + // 缓存读取 tokens total_cache_read = usage .get("cachedContentTokenCount") .and_then(|v| v.as_u64()) @@ -328,6 +404,9 @@ impl TokenUsage { } } + // 输出 tokens = 总 tokens - 输入 tokens + let total_output = total_tokens.saturating_sub(total_input); + if total_input > 0 || total_output > 0 { Some(Self { input_tokens: total_input, @@ -466,15 +545,18 @@ mod tests { let response = json!({ "modelVersion": "gemini-3-pro-high", "usageMetadata": { - "promptTokenCount": 100, + "promptTokenCount": 8383, "candidatesTokenCount": 50, + "thoughtsTokenCount": 114, + "totalTokenCount": 8547, "cachedContentTokenCount": 20 } }); let usage = TokenUsage::from_gemini_response(&response).unwrap(); - assert_eq!(usage.input_tokens, 100); - assert_eq!(usage.output_tokens, 50); + assert_eq!(usage.input_tokens, 8383); + // output_tokens = totalTokenCount - promptTokenCount = 8547 - 8383 = 164 + assert_eq!(usage.output_tokens, 164); assert_eq!(usage.cache_read_tokens, 20); assert_eq!(usage.cache_creation_tokens, 0); assert_eq!(usage.model, Some("gemini-3-pro-high".to_string())); @@ -486,19 +568,78 @@ mod tests { let response = json!({ "usageMetadata": { "promptTokenCount": 100, - "candidatesTokenCount": 50, + "totalTokenCount": 150, "cachedContentTokenCount": 20 } }); let usage = TokenUsage::from_gemini_response(&response).unwrap(); assert_eq!(usage.input_tokens, 100); + // output_tokens = totalTokenCount - promptTokenCount = 150 - 100 = 50 assert_eq!(usage.output_tokens, 50); assert_eq!(usage.cache_read_tokens, 20); assert_eq!(usage.cache_creation_tokens, 0); assert_eq!(usage.model, None); } + #[test] + fn test_gemini_response_with_thoughts() { + // 测试包含 thoughtsTokenCount 的实际响应 + // 这是用户报告的真实场景 + let response = json!({ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "", + "thoughtSignature": "EvcECvQE..." + } + ], + "role": "model" + }, + "finishReason": "STOP" + } + ], + "modelVersion": "gemini-3-pro-high", + "responseId": "yupTafqLDu-PjMcPhrOx4QQ", + "usageMetadata": { + "candidatesTokenCount": 50, + "promptTokenCount": 8383, + "thoughtsTokenCount": 114, + "totalTokenCount": 8547 + } + }); + + let usage = TokenUsage::from_gemini_response(&response).unwrap(); + assert_eq!(usage.input_tokens, 8383); + // output_tokens = totalTokenCount - promptTokenCount + // = 8547 - 8383 = 164 (包含 candidatesTokenCount 50 + thoughtsTokenCount 114) + assert_eq!(usage.output_tokens, 164); + assert_eq!(usage.cache_read_tokens, 0); + assert_eq!(usage.cache_creation_tokens, 0); + assert_eq!(usage.model, Some("gemini-3-pro-high".to_string())); + } + + #[test] + fn test_codex_response_parsing_cached_tokens_in_details() { + let response = json!({ + "usage": { + "input_tokens": 1000, + "output_tokens": 500, + "input_tokens_details": { + "cached_tokens": 300 + } + } + }); + + let usage = TokenUsage::from_codex_response(&response).unwrap(); + // 非调整模式:input_tokens 保持原值,但应记录缓存命中 + assert_eq!(usage.input_tokens, 1000); + assert_eq!(usage.output_tokens, 500); + assert_eq!(usage.cache_read_tokens, 300); + } + #[test] fn test_codex_response_adjusted() { let response = json!({ @@ -534,6 +675,22 @@ mod tests { assert_eq!(usage.cache_read_tokens, 0); } + #[test] + fn test_codex_response_adjusted_cache_read_input_tokens() { + let response = json!({ + "usage": { + "input_tokens": 1000, + "output_tokens": 500, + "cache_read_input_tokens": 200 + } + }); + + let usage = TokenUsage::from_codex_response_adjusted(&response).unwrap(); + assert_eq!(usage.input_tokens, 800); + assert_eq!(usage.output_tokens, 500); + assert_eq!(usage.cache_read_tokens, 200); + } + #[test] fn test_codex_response_adjusted_saturating_sub() { // 测试 cached_tokens > input_tokens 的边界情况 @@ -615,4 +772,110 @@ mod tests { assert_eq!(usage.cache_read_tokens, 50); assert_eq!(usage.model, Some("claude-sonnet-4-20250514".to_string())); } + + // ============================================================================ + // 智能 Codex 解析测试 + // ============================================================================ + + #[test] + fn test_codex_response_auto_openai_format() { + // OpenAI 格式 (prompt_tokens/completion_tokens) + let response = json!({ + "model": "gpt-4o", + "usage": { + "prompt_tokens": 1000, + "completion_tokens": 500, + "prompt_tokens_details": { + "cached_tokens": 200 + } + } + }); + + let usage = TokenUsage::from_codex_response_auto(&response).unwrap(); + assert_eq!(usage.input_tokens, 1000); + assert_eq!(usage.output_tokens, 500); + assert_eq!(usage.cache_read_tokens, 200); + assert_eq!(usage.model, Some("gpt-4o".to_string())); + } + + #[test] + fn test_codex_response_auto_codex_format() { + // Codex 格式 (input_tokens/output_tokens) + let response = json!({ + "model": "o3", + "usage": { + "input_tokens": 1000, + "output_tokens": 500, + "input_tokens_details": { + "cached_tokens": 300 + } + } + }); + + let usage = TokenUsage::from_codex_response_auto(&response).unwrap(); + // 记录原始 input_tokens,不调整 + assert_eq!(usage.input_tokens, 1000); + assert_eq!(usage.output_tokens, 500); + assert_eq!(usage.cache_read_tokens, 300); + assert_eq!(usage.model, Some("o3".to_string())); + } + + #[test] + fn test_codex_stream_events_auto_codex_format() { + // Codex Responses API 流式格式 (response.completed 事件) + let events = vec![ + json!({ + "type": "response.created", + "response": { + "id": "resp_123" + } + }), + json!({ + "type": "response.completed", + "response": { + "model": "o3", + "usage": { + "input_tokens": 1000, + "output_tokens": 500, + "input_tokens_details": { + "cached_tokens": 200 + } + } + } + }), + ]; + + let usage = TokenUsage::from_codex_stream_events_auto(&events).unwrap(); + // 记录原始 input_tokens,不调整 + assert_eq!(usage.input_tokens, 1000); + assert_eq!(usage.output_tokens, 500); + assert_eq!(usage.cache_read_tokens, 200); + assert_eq!(usage.model, Some("o3".to_string())); + } + + #[test] + fn test_codex_stream_events_auto_openai_format() { + // OpenAI Chat Completions 流式格式 (最后一个 chunk 包含 usage) + let events = vec![ + json!({ + "id": "chatcmpl-123", + "model": "gpt-4o", + "choices": [{"delta": {"content": "Hello"}}] + }), + json!({ + "id": "chatcmpl-123", + "model": "gpt-4o", + "choices": [{"delta": {}}], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 50 + } + }), + ]; + + let usage = TokenUsage::from_codex_stream_events_auto(&events).unwrap(); + assert_eq!(usage.input_tokens, 100); + assert_eq!(usage.output_tokens, 50); + assert_eq!(usage.model, Some("gpt-4o".to_string())); + } } diff --git a/src-tauri/src/services/provider/mod.rs b/src-tauri/src/services/provider/mod.rs index 599574414..91c3ab8e1 100644 --- a/src-tauri/src/services/provider/mod.rs +++ b/src-tauri/src/services/provider/mod.rs @@ -217,9 +217,12 @@ impl ProviderService { .flatten() .is_some(); let is_proxy_running = futures::executor::block_on(state.proxy_service.is_running()); + let live_taken_over = state + .proxy_service + .detect_takeover_in_live_config_for_app(&app_type); // Hot-switch only when BOTH: this app is taken over AND proxy server is actually running - let should_hot_switch = is_app_taken_over && is_proxy_running; + let should_hot_switch = (is_app_taken_over || live_taken_over) && is_proxy_running; if should_hot_switch { // Proxy takeover mode: hot-switch only, don't write Live config @@ -736,15 +739,15 @@ impl ProviderService { // 删除生成的子供应商 if let Some(p) = provider { if p.apps.claude { - let claude_id = format!("universal-claude-{}", id); + let claude_id = format!("universal-claude-{id}"); let _ = state.db.delete_provider("claude", &claude_id); } if p.apps.codex { - let codex_id = format!("universal-codex-{}", id); + let codex_id = format!("universal-codex-{id}"); let _ = state.db.delete_provider("codex", &codex_id); } if p.apps.gemini { - let gemini_id = format!("universal-gemini-{}", id); + let gemini_id = format!("universal-gemini-{id}"); let _ = state.db.delete_provider("gemini", &gemini_id); } } @@ -757,7 +760,7 @@ impl ProviderService { let provider = state .db .get_universal_provider(id)? - .ok_or_else(|| AppError::Message(format!("统一供应商 {} 不存在", id)))?; + .ok_or_else(|| AppError::Message(format!("统一供应商 {id} 不存在")))?; // 同步到 Claude if let Some(mut claude_provider) = provider.to_claude_provider() { @@ -770,7 +773,7 @@ impl ProviderService { state.db.save_provider("claude", &claude_provider)?; } else { // 如果禁用了 Claude,删除对应的子供应商 - let claude_id = format!("universal-claude-{}", id); + let claude_id = format!("universal-claude-{id}"); let _ = state.db.delete_provider("claude", &claude_id); } @@ -784,7 +787,7 @@ impl ProviderService { } state.db.save_provider("codex", &codex_provider)?; } else { - let codex_id = format!("universal-codex-{}", id); + let codex_id = format!("universal-codex-{id}"); let _ = state.db.delete_provider("codex", &codex_id); } @@ -798,7 +801,7 @@ impl ProviderService { } state.db.save_provider("gemini", &gemini_provider)?; } else { - let gemini_id = format!("universal-gemini-{}", id); + let gemini_id = format!("universal-gemini-{id}"); let _ = state.db.delete_provider("gemini", &gemini_id); } diff --git a/src-tauri/src/services/proxy.rs b/src-tauri/src/services/proxy.rs index 29b5b3157..4865f7ede 100644 --- a/src-tauri/src/services/proxy.rs +++ b/src-tauri/src/services/proxy.rs @@ -193,7 +193,7 @@ impl ProxyService { self.start().await?; } - // 2) 已接管则直接返回(幂等) + // 2) 已接管则直接返回(幂等);但如果缺少备份或占位符残留,需要重建接管 let current_config = self .db .get_proxy_config_for_app(app_type_str) @@ -201,7 +201,22 @@ impl ProxyService { .map_err(|e| format!("获取 {app_type_str} 配置失败: {e}"))?; if current_config.enabled { - return Ok(()); + let has_backup = match self.db.get_live_backup(app_type_str).await { + Ok(v) => v.is_some(), + Err(e) => { + log::warn!("读取 {app_type_str} 备份失败(将继续重建接管): {e}"); + false + } + }; + let live_taken_over = self.detect_takeover_in_live_config_for_app(&app); + + if has_backup || live_taken_over { + return Ok(()); + } + + log::warn!( + "{app_type_str} 标记为已接管,但缺少备份或占位符,正在重新接管并补齐备份" + ); } // 3) 备份 Live 配置(严格:目标 app 不存在则报错) @@ -1063,7 +1078,7 @@ impl ProxyService { } } - fn detect_takeover_in_live_config_for_app(&self, app_type: &AppType) -> bool { + pub fn detect_takeover_in_live_config_for_app(&self, app_type: &AppType) -> bool { match app_type { AppType::Claude => match self.read_claude_live() { Ok(config) => Self::is_claude_live_taken_over(&config), @@ -1257,10 +1272,8 @@ impl ProxyService { /// 检查是否处于 Live 接管模式 pub async fn is_takeover_active(&self) -> Result { - self.db - .is_live_takeover_active() - .await - .map_err(|e| format!("检查接管状态失败: {e}")) + let status = self.get_takeover_status().await?; + Ok(status.claude || status.codex || status.gemini) } /// 从异常退出中恢复(启动时调用) diff --git a/src-tauri/src/services/usage_stats.rs b/src-tauri/src/services/usage_stats.rs index 81cc30257..741648751 100644 --- a/src-tauri/src/services/usage_stats.rs +++ b/src-tauri/src/services/usage_stats.rs @@ -4,7 +4,7 @@ use crate::database::{lock_conn, Database}; use crate::error::AppError; -use chrono::{Duration, Local, TimeZone}; +use chrono::{Local, TimeZone}; use rusqlite::{params, Connection, OptionalExtension}; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -181,145 +181,114 @@ impl Database { Ok(result) } - /// 获取每日趋势 - pub fn get_daily_trends(&self, days: u32) -> Result, AppError> { + /// 获取每日趋势(滑动窗口,<=24h 按小时,>24h 按天,窗口与汇总一致) + pub fn get_daily_trends( + &self, + start_date: Option, + end_date: Option, + ) -> Result, AppError> { let conn = lock_conn!(self.conn); - if days <= 1 { - let today = Local::now().date_naive(); - let start_of_today = today.and_hms_opt(0, 0, 0).unwrap(); - // 使用 earliest() 处理 DST 切换时的歧义时间,fallback 到当前时间减一天 - let start_ts = Local - .from_local_datetime(&start_of_today) - .earliest() - .unwrap_or_else(|| Local::now() - Duration::days(1)) - .timestamp(); + let end_ts = end_date.unwrap_or_else(|| Local::now().timestamp()); + let mut start_ts = start_date.unwrap_or_else(|| end_ts - 24 * 60 * 60); - let sql = "SELECT - strftime('%Y-%m-%dT%H:00:00', datetime(created_at, 'unixepoch', 'localtime')) as bucket, - COUNT(*) as request_count, - COALESCE(SUM(CAST(total_cost_usd AS REAL)), 0) as total_cost, - COALESCE(SUM(input_tokens + output_tokens), 0) as total_tokens, - COALESCE(SUM(input_tokens), 0) as total_input_tokens, - COALESCE(SUM(output_tokens), 0) as total_output_tokens, - COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens, - COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens - FROM proxy_request_logs - WHERE created_at >= ? - GROUP BY bucket - ORDER BY bucket ASC"; - - let mut stmt = conn.prepare(sql)?; - let rows = stmt.query_map([start_ts], |row| { - Ok(DailyStats { - date: row.get(0)?, - request_count: row.get::<_, i64>(1)? as u64, - total_cost: format!("{:.6}", row.get::<_, f64>(2)?), - total_tokens: row.get::<_, i64>(3)? as u64, - total_input_tokens: row.get::<_, i64>(4)? as u64, - total_output_tokens: row.get::<_, i64>(5)? as u64, - total_cache_creation_tokens: row.get::<_, i64>(6)? as u64, - total_cache_read_tokens: row.get::<_, i64>(7)? as u64, - }) - })?; - - let mut buckets: HashMap = HashMap::new(); - for row in rows { - let stat = row?; - buckets.insert(stat.date.clone(), stat); - } - - let mut stats = Vec::new(); - for hour in 0..24 { - let bucket = today - .and_hms_opt(hour, 0, 0) - .unwrap() - .format("%Y-%m-%dT%H:00:00") - .to_string(); - - if let Some(stat) = buckets.remove(&bucket) { - stats.push(stat); - } else { - stats.push(DailyStats { - date: bucket, - request_count: 0, - total_cost: "0.000000".to_string(), - total_tokens: 0, - total_input_tokens: 0, - total_output_tokens: 0, - total_cache_creation_tokens: 0, - total_cache_read_tokens: 0, - }); - } - } - Ok(stats) - } else { - let today = Local::now().date_naive(); - let start_day = today - Duration::days((days.saturating_sub(1)) as i64); - let start_of_window = start_day.and_hms_opt(0, 0, 0).unwrap(); - // 使用 earliest() 处理 DST 切换时的歧义时间,fallback 到当前时间减 days 天 - let start_ts = Local - .from_local_datetime(&start_of_window) - .earliest() - .unwrap_or_else(|| Local::now() - Duration::days(days as i64)) - .timestamp(); - - let sql = "SELECT - strftime('%Y-%m-%dT00:00:00', datetime(created_at, 'unixepoch', 'localtime')) as bucket, - COUNT(*) as request_count, - COALESCE(SUM(CAST(total_cost_usd AS REAL)), 0) as total_cost, - COALESCE(SUM(input_tokens + output_tokens), 0) as total_tokens, - COALESCE(SUM(input_tokens), 0) as total_input_tokens, - COALESCE(SUM(output_tokens), 0) as total_output_tokens, - COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens, - COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens - FROM proxy_request_logs - WHERE created_at >= ? - GROUP BY bucket - ORDER BY bucket ASC"; - - let mut stmt = conn.prepare(sql)?; - let rows = stmt.query_map([start_ts], |row| { - Ok(DailyStats { - date: row.get(0)?, - request_count: row.get::<_, i64>(1)? as u64, - total_cost: format!("{:.6}", row.get::<_, f64>(2)?), - total_tokens: row.get::<_, i64>(3)? as u64, - total_input_tokens: row.get::<_, i64>(4)? as u64, - total_output_tokens: row.get::<_, i64>(5)? as u64, - total_cache_creation_tokens: row.get::<_, i64>(6)? as u64, - total_cache_read_tokens: row.get::<_, i64>(7)? as u64, - }) - })?; - - let mut map = HashMap::new(); - for row in rows { - let stat = row?; - map.insert(stat.date.clone(), stat); - } - - let mut stats = Vec::new(); - - for i in 0..days { - let day = start_day + Duration::days(i as i64); - let key = day.format("%Y-%m-%dT00:00:00").to_string(); - if let Some(stat) = map.remove(&key) { - stats.push(stat); - } else { - stats.push(DailyStats { - date: key, - request_count: 0, - total_cost: "0.000000".to_string(), - total_tokens: 0, - total_input_tokens: 0, - total_output_tokens: 0, - total_cache_creation_tokens: 0, - total_cache_read_tokens: 0, - }); - } - } - Ok(stats) + if start_ts >= end_ts { + start_ts = end_ts - 24 * 60 * 60; } + + let duration = end_ts - start_ts; + let bucket_seconds: i64 = if duration <= 24 * 60 * 60 { + 60 * 60 + } else { + 24 * 60 * 60 + }; + let mut bucket_count: i64 = if duration <= 0 { + 1 + } else { + ((duration as f64) / bucket_seconds as f64).ceil() as i64 + }; + + // 固定 24 小时窗口为 24 个小时桶,避免浮点误差 + if bucket_seconds == 60 * 60 { + bucket_count = 24; + } + + if bucket_count < 1 { + bucket_count = 1; + } + + let sql = " + SELECT + CAST((created_at - ?1) / ?3 AS INTEGER) as bucket_idx, + COUNT(*) as request_count, + COALESCE(SUM(CAST(total_cost_usd AS REAL)), 0) as total_cost, + COALESCE(SUM(input_tokens + output_tokens), 0) as total_tokens, + COALESCE(SUM(input_tokens), 0) as total_input_tokens, + COALESCE(SUM(output_tokens), 0) as total_output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens + FROM proxy_request_logs + WHERE created_at >= ?1 AND created_at <= ?2 + GROUP BY bucket_idx + ORDER BY bucket_idx ASC"; + + let mut stmt = conn.prepare(sql)?; + let rows = stmt.query_map(params![start_ts, end_ts, bucket_seconds], |row| { + Ok(( + row.get::<_, i64>(0)?, + DailyStats { + date: String::new(), + request_count: row.get::<_, i64>(1)? as u64, + total_cost: format!("{:.6}", row.get::<_, f64>(2)?), + total_tokens: row.get::<_, i64>(3)? as u64, + total_input_tokens: row.get::<_, i64>(4)? as u64, + total_output_tokens: row.get::<_, i64>(5)? as u64, + total_cache_creation_tokens: row.get::<_, i64>(6)? as u64, + total_cache_read_tokens: row.get::<_, i64>(7)? as u64, + }, + )) + })?; + + let mut map: HashMap = HashMap::new(); + for row in rows { + let (mut bucket_idx, stat) = row?; + if bucket_idx < 0 { + continue; + } + if bucket_idx >= bucket_count { + bucket_idx = bucket_count - 1; + } + map.insert(bucket_idx, stat); + } + + let mut stats = Vec::with_capacity(bucket_count as usize); + for i in 0..bucket_count { + let bucket_start_ts = start_ts + i * bucket_seconds; + let bucket_start = Local + .timestamp_opt(bucket_start_ts, 0) + .single() + .unwrap_or_else(Local::now); + + let date = bucket_start.format("%Y-%m-%dT%H:%M:%S").to_string(); + + if let Some(mut stat) = map.remove(&i) { + stat.date = date; + stats.push(stat); + } else { + stats.push(DailyStats { + date, + request_count: 0, + total_cost: "0.000000".to_string(), + total_tokens: 0, + total_input_tokens: 0, + total_output_tokens: 0, + total_cache_creation_tokens: 0, + total_cache_read_tokens: 0, + }); + } + } + + Ok(stats) } /// 获取 Provider 统计 @@ -829,89 +798,46 @@ impl Database { } } -/// 标准化模型名称:去除供应商前缀并将点号替换为短横线 -/// 例如:anthropic/claude-haiku-4.5 → claude-haiku-4-5 -fn normalize_model_id(model_id: &str) -> String { - // 1. 去除供应商前缀(如 anthropic/、openai/) - let stripped = if let Some(pos) = model_id.find('/') { - &model_id[pos + 1..] - } else { - model_id - }; - // 2. 将点号替换为短横线(如 claude-haiku-4.5 → claude-haiku-4-5) - stripped.replace('.', "-") -} - pub(crate) fn find_model_pricing_row( conn: &Connection, model_id: &str, ) -> Result, AppError> { - // 0. 标准化模型名称(去除前缀 + 点号转短横线) - // 例如:anthropic/claude-haiku-4.5 → claude-haiku-4-5 - let normalized = normalize_model_id(model_id); + // 1) 去除供应商前缀(/ 之前)与冒号后缀(: 之后),例如 moonshotai/kimi-k2-0905:exa → kimi-k2-0905 + let without_prefix = model_id + .rsplit_once('/') + .map(|(_, rest)| rest) + .unwrap_or(model_id); + let cleaned = without_prefix + .split(':') + .next() + .map(str::trim) + .unwrap_or(without_prefix); - // 1. 精确匹配(先尝试原始名称,再尝试标准化后的名称) - for id in [model_id, normalized.as_str()] { - let exact = conn - .query_row( - "SELECT input_cost_per_million, output_cost_per_million, - cache_read_cost_per_million, cache_creation_cost_per_million - FROM model_pricing - WHERE model_id = ?1", - [id], - |row| { - Ok(( - row.get::<_, String>(0)?, - row.get::<_, String>(1)?, - row.get::<_, String>(2)?, - row.get::<_, String>(3)?, - )) - }, - ) - .optional() - .map_err(|e| AppError::Database(format!("查询模型定价失败: {e}")))?; + // 2) 精确匹配清洗后的名称 + let exact = conn + .query_row( + "SELECT input_cost_per_million, output_cost_per_million, + cache_read_cost_per_million, cache_creation_cost_per_million + FROM model_pricing + WHERE model_id = ?1", + [cleaned], + |row| { + Ok(( + row.get::<_, String>(0)?, + row.get::<_, String>(1)?, + row.get::<_, String>(2)?, + row.get::<_, String>(3)?, + )) + }, + ) + .optional() + .map_err(|e| AppError::Database(format!("查询模型定价失败: {e}")))?; - if exact.is_some() { - if id != model_id { - log::info!("模型 {model_id} 标准化后精确匹配到: {id}"); - } - return Ok(exact); - } + if exact.is_none() { + log::warn!("模型 {model_id}(清洗后: {cleaned})未找到定价信息,成本将记录为 0"); } - // 2. 逐步删除后缀匹配(claude-haiku-4-5-20250929 → claude-haiku-4-5 → claude-haiku-4 → claude-haiku) - // 使用标准化后的名称进行后缀匹配 - let mut current = normalized; - while let Some(pos) = current.rfind('-') { - current = current[..pos].to_string(); - - let result = conn - .query_row( - "SELECT input_cost_per_million, output_cost_per_million, - cache_read_cost_per_million, cache_creation_cost_per_million - FROM model_pricing - WHERE model_id = ?1", - [¤t], - |row| { - Ok(( - row.get::<_, String>(0)?, - row.get::<_, String>(1)?, - row.get::<_, String>(2)?, - row.get::<_, String>(3)?, - )) - }, - ) - .optional() - .map_err(|e| AppError::Database(format!("查询模型定价失败: {e}")))?; - - if result.is_some() { - log::info!("模型 {model_id} 通过删除后缀匹配到: {current}"); - return Ok(result); - } - } - - log::warn!("模型 {model_id} 未找到定价信息,成本将记录为 0"); - Ok(None) + Ok(exact) } #[cfg(test)] @@ -991,54 +917,39 @@ mod tests { let db = Database::memory()?; let conn = lock_conn!(db.conn); - // 测试精确匹配 - let result = find_model_pricing_row(&conn, "claude-sonnet-4-5")?; - assert!(result.is_some(), "应该能精确匹配 claude-sonnet-4-5"); + // 准备额外定价数据,覆盖前缀/后缀清洗场景 + conn.execute( + "INSERT OR REPLACE INTO model_pricing ( + model_id, display_name, input_cost_per_million, output_cost_per_million, + cache_read_cost_per_million, cache_creation_cost_per_million + ) VALUES (?, ?, ?, ?, ?, ?)", + params![ + "claude-haiku-4.5", + "Claude Haiku 4.5", + "1.0", + "2.0", + "0.0", + "0.0" + ], + )?; - // 测试带供应商前缀的模型名称(anthropic/claude-haiku-4.5 → claude-haiku-4-5) - let result = find_model_pricing_row(&conn, "anthropic/claude-haiku-4.5")?; - assert!( - result.is_some(), - "应该能匹配带前缀的模型 anthropic/claude-haiku-4.5" - ); - - // 测试带供应商前缀 + 点号的模型名称 - let result = find_model_pricing_row(&conn, "anthropic/claude-sonnet-4.5")?; - assert!( - result.is_some(), - "应该能匹配带前缀的模型 anthropic/claude-sonnet-4.5" - ); - - // 测试逐步删除后缀匹配 - 日期后缀 - let result = find_model_pricing_row(&conn, "claude-sonnet-4-5-20241022")?; - assert!( - result.is_some(), - "应该能通过删除后缀匹配 claude-sonnet-4-5-20241022" - ); - - // 测试逐步删除后缀匹配 - 多个后缀 - let result = find_model_pricing_row(&conn, "claude-haiku-4-5-20240229-preview")?; - assert!( - result.is_some(), - "应该能通过删除后缀匹配 claude-haiku-4-5-20240229-preview" - ); - - // 测试 GPT 模型 - let result = find_model_pricing_row(&conn, "gpt-5-2024-11-20")?; - assert!(result.is_some(), "应该能通过删除后缀匹配 gpt-5-2024-11-20"); - - // 测试 Gemini 模型 - let result = find_model_pricing_row(&conn, "gemini-2.5-flash-exp")?; - assert!( - result.is_some(), - "应该能通过删除后缀匹配 gemini-2.5-flash-exp" - ); - - // 测试 claude-sonnet-4-5 命名格式 + // 测试精确匹配(seed_model_pricing 已预置 claude-sonnet-4-5-20250929) let result = find_model_pricing_row(&conn, "claude-sonnet-4-5-20250929")?; assert!( result.is_some(), - "应该能通过删除后缀匹配 claude-sonnet-4-5-20250929" + "应该能精确匹配 claude-sonnet-4-5-20250929" + ); + + // 清洗:去除前缀和冒号后缀 + let result = find_model_pricing_row(&conn, "anthropic/claude-haiku-4.5")?; + assert!( + result.is_some(), + "带前缀的模型 anthropic/claude-haiku-4.5 应能匹配到 claude-haiku-4.5" + ); + let result = find_model_pricing_row(&conn, "moonshotai/kimi-k2-0905:exa")?; + assert!( + result.is_some(), + "带前缀+冒号后缀的模型应清洗后匹配到 kimi-k2-0905" ); // 测试不存在的模型 diff --git a/src/components/proxy/AutoFailoverConfigPanel.tsx b/src/components/proxy/AutoFailoverConfigPanel.tsx index 9daeaed0d..30f55e136 100644 --- a/src/components/proxy/AutoFailoverConfigPanel.tsx +++ b/src/components/proxy/AutoFailoverConfigPanel.tsx @@ -142,12 +142,13 @@ export function AutoFailoverConfigPanel({ min="0" max="10" value={formData.maxRetries} - onChange={(e) => + onChange={(e) => { + const val = parseInt(e.target.value); setFormData({ ...formData, - maxRetries: parseInt(e.target.value) || 3, - }) - } + maxRetries: isNaN(val) ? 0 : val, + }); + }} disabled={isDisabled} />

@@ -168,12 +169,13 @@ export function AutoFailoverConfigPanel({ min="1" max="20" value={formData.circuitFailureThreshold} - onChange={(e) => + onChange={(e) => { + const val = parseInt(e.target.value); setFormData({ ...formData, - circuitFailureThreshold: parseInt(e.target.value) || 5, - }) - } + circuitFailureThreshold: isNaN(val) ? 1 : Math.max(1, val), + }); + }} disabled={isDisabled} />

@@ -206,12 +208,13 @@ export function AutoFailoverConfigPanel({ min="0" max="180" value={formData.streamingFirstByteTimeout} - onChange={(e) => + onChange={(e) => { + const val = parseInt(e.target.value); setFormData({ ...formData, - streamingFirstByteTimeout: parseInt(e.target.value) || 30, - }) - } + streamingFirstByteTimeout: isNaN(val) ? 0 : val, + }); + }} disabled={isDisabled} />

@@ -232,12 +235,13 @@ export function AutoFailoverConfigPanel({ min="0" max="600" value={formData.streamingIdleTimeout} - onChange={(e) => + onChange={(e) => { + const val = parseInt(e.target.value); setFormData({ ...formData, - streamingIdleTimeout: parseInt(e.target.value) || 60, - }) - } + streamingIdleTimeout: isNaN(val) ? 0 : val, + }); + }} disabled={isDisabled} />

@@ -258,12 +262,13 @@ export function AutoFailoverConfigPanel({ min="0" max="1800" value={formData.nonStreamingTimeout} - onChange={(e) => + onChange={(e) => { + const val = parseInt(e.target.value); setFormData({ ...formData, - nonStreamingTimeout: parseInt(e.target.value) || 300, - }) - } + nonStreamingTimeout: isNaN(val) ? 0 : val, + }); + }} disabled={isDisabled} />

@@ -293,12 +298,13 @@ export function AutoFailoverConfigPanel({ min="1" max="10" value={formData.circuitSuccessThreshold} - onChange={(e) => + onChange={(e) => { + const val = parseInt(e.target.value); setFormData({ ...formData, - circuitSuccessThreshold: parseInt(e.target.value) || 2, - }) - } + circuitSuccessThreshold: isNaN(val) ? 1 : Math.max(1, val), + }); + }} disabled={isDisabled} />

@@ -319,12 +325,13 @@ export function AutoFailoverConfigPanel({ min="10" max="300" value={formData.circuitTimeoutSeconds} - onChange={(e) => + onChange={(e) => { + const val = parseInt(e.target.value); setFormData({ ...formData, - circuitTimeoutSeconds: parseInt(e.target.value) || 60, - }) - } + circuitTimeoutSeconds: isNaN(val) ? 10 : Math.max(10, val), + }); + }} disabled={isDisabled} />

@@ -346,13 +353,13 @@ export function AutoFailoverConfigPanel({ max="100" step="5" value={Math.round(formData.circuitErrorRateThreshold * 100)} - onChange={(e) => + onChange={(e) => { + const val = parseInt(e.target.value); setFormData({ ...formData, - circuitErrorRateThreshold: - (parseInt(e.target.value) || 50) / 100, - }) - } + circuitErrorRateThreshold: isNaN(val) ? 0.5 : val / 100, + }); + }} disabled={isDisabled} />

@@ -373,12 +380,13 @@ export function AutoFailoverConfigPanel({ min="5" max="100" value={formData.circuitMinRequests} - onChange={(e) => + onChange={(e) => { + const val = parseInt(e.target.value); setFormData({ ...formData, - circuitMinRequests: parseInt(e.target.value) || 10, - }) - } + circuitMinRequests: isNaN(val) ? 5 : Math.max(5, val), + }); + }} disabled={isDisabled} />

diff --git a/src/components/usage/RequestLogTable.tsx b/src/components/usage/RequestLogTable.tsx index 6f6436833..92b7097c1 100644 --- a/src/components/usage/RequestLogTable.tsx +++ b/src/components/usage/RequestLogTable.tsx @@ -62,6 +62,28 @@ export function RequestLogTable() { }); }; + // 将 Unix 时间戳转换为本地时间的 datetime-local 格式 + const timestampToLocalDatetime = (timestamp: number): string => { + const date = new Date(timestamp * 1000); + const year = date.getFullYear(); + const month = String(date.getMonth() + 1).padStart(2, "0"); + const day = String(date.getDate()).padStart(2, "0"); + const hours = String(date.getHours()).padStart(2, "0"); + const minutes = String(date.getMinutes()).padStart(2, "0"); + return `${year}-${month}-${day}T${hours}:${minutes}`; + }; + + // 将 datetime-local 格式转换为 Unix 时间戳 + const localDatetimeToTimestamp = (datetime: string): number | undefined => { + if (!datetime) return undefined; + // 验证格式是否完整 (YYYY-MM-DDTHH:mm) + if (datetime.length < 16) return undefined; + const timestamp = new Date(datetime).getTime(); + // 验证是否为有效日期 + if (isNaN(timestamp)) return undefined; + return Math.floor(timestamp / 1000); + }; + const dateLocale = i18n.language === "zh" ? "zh-CN" @@ -153,19 +175,16 @@ export function RequestLogTable() { className="h-8 w-[200px] bg-background" value={ tempFilters.startDate - ? new Date(tempFilters.startDate * 1000) - .toISOString() - .slice(0, 16) + ? timestampToLocalDatetime(tempFilters.startDate) : "" } - onChange={(e) => + onChange={(e) => { + const timestamp = localDatetimeToTimestamp(e.target.value); setTempFilters({ ...tempFilters, - startDate: e.target.value - ? Math.floor(new Date(e.target.value).getTime() / 1000) - : undefined, - }) - } + startDate: timestamp, + }); + }} /> - + onChange={(e) => { + const timestamp = localDatetimeToTimestamp(e.target.value); setTempFilters({ ...tempFilters, - endDate: e.target.value - ? Math.floor(new Date(e.target.value).getTime() / 1000) - : undefined, - }) - } + endDate: timestamp, + }); + }} /> diff --git a/src/components/usage/UsageSummaryCards.tsx b/src/components/usage/UsageSummaryCards.tsx index 5ccae8651..f3a6176ba 100644 --- a/src/components/usage/UsageSummaryCards.tsx +++ b/src/components/usage/UsageSummaryCards.tsx @@ -12,13 +12,7 @@ interface UsageSummaryCardsProps { export function UsageSummaryCards({ days }: UsageSummaryCardsProps) { const { t } = useTranslation(); - const { startDate, endDate } = useMemo(() => { - const end = Math.floor(Date.now() / 1000); - const start = end - days * 24 * 60 * 60; - return { startDate: start, endDate: end }; - }, [days]); - - const { data: summary, isLoading } = useUsageSummary(startDate, endDate); + const { data: summary, isLoading } = useUsageSummary(days); const stats = useMemo(() => { const totalRequests = summary?.totalRequests ?? 0; diff --git a/src/components/usage/UsageTrendChart.tsx b/src/components/usage/UsageTrendChart.tsx index 3fe75d9ff..376014e02 100644 --- a/src/components/usage/UsageTrendChart.tsx +++ b/src/components/usage/UsageTrendChart.tsx @@ -41,7 +41,12 @@ export function UsageTrendChart({ days }: UsageTrendChartProps) { return { rawDate: stat.date, label: isToday - ? pointDate.toLocaleTimeString(dateLocale, { hour: "2-digit" }) + ? pointDate.toLocaleString(dateLocale, { + month: "2-digit", + day: "2-digit", + hour: "2-digit", + minute: "2-digit", + }) : pointDate.toLocaleDateString(dateLocale, { month: "2-digit", day: "2-digit", @@ -49,28 +54,13 @@ export function UsageTrendChart({ days }: UsageTrendChartProps) { hour: pointDate.getHours(), inputTokens: stat.totalInputTokens, outputTokens: stat.totalOutputTokens, + cacheCreationTokens: stat.totalCacheCreationTokens, + cacheReadTokens: stat.totalCacheReadTokens, cost: parseFloat(stat.totalCost), }; }) || []; - const hourlyData = (() => { - if (!isToday) return chartData; - const map = new Map(); - chartData.forEach((point) => { - map.set(point.hour ?? 0, point); - }); - return Array.from({ length: 24 }, (_, hour) => { - const bucket = map.get(hour); - return { - label: `${hour.toString().padStart(2, "0")}:00`, - inputTokens: bucket?.inputTokens ?? 0, - outputTokens: bucket?.outputTokens ?? 0, - cost: bucket?.cost ?? 0, - }; - }); - })(); - - const displayData = isToday ? hourlyData : chartData; + const displayData = chartData; const CustomTooltip = ({ active, payload, label }: any) => { if (active && payload && payload.length) { @@ -131,6 +121,20 @@ export function UsageTrendChart({ days }: UsageTrendChartProps) { + + + + + + + + + + => { - return invoke("get_usage_trends", { days }); + getUsageTrends: async ( + startDate?: number, + endDate?: number, + ): Promise => { + return invoke("get_usage_trends", { startDate, endDate }); }, getProviderStats: async (): Promise => { diff --git a/src/lib/query/usage.ts b/src/lib/query/usage.ts index 176fc3769..bf220c549 100644 --- a/src/lib/query/usage.ts +++ b/src/lib/query/usage.ts @@ -5,8 +5,7 @@ import type { LogFilters } from "@/types/usage"; // Query keys export const usageKeys = { all: ["usage"] as const, - summary: (startDate?: number, endDate?: number) => - [...usageKeys.all, "summary", startDate, endDate] as const, + summary: (days: number) => [...usageKeys.all, "summary", days] as const, trends: (days: number) => [...usageKeys.all, "trends", days] as const, providerStats: () => [...usageKeys.all, "provider-stats"] as const, modelStats: () => [...usageKeys.all, "model-stats"] as const, @@ -19,18 +18,34 @@ export const usageKeys = { [...usageKeys.all, "limits", providerId, appType] as const, }; +const getWindow = (days: number) => { + const endDate = Math.floor(Date.now() / 1000); + const startDate = endDate - days * 24 * 60 * 60; + return { startDate, endDate }; +}; + // Hooks -export function useUsageSummary(startDate?: number, endDate?: number) { +export function useUsageSummary(days: number) { return useQuery({ - queryKey: usageKeys.summary(startDate, endDate), - queryFn: () => usageApi.getUsageSummary(startDate, endDate), + queryKey: usageKeys.summary(days), + queryFn: () => { + const { startDate, endDate } = getWindow(days); + return usageApi.getUsageSummary(startDate, endDate); + }, + refetchInterval: 30000, // 每30秒自动刷新 + refetchIntervalInBackground: false, // 后台不刷新 }); } export function useUsageTrends(days: number) { return useQuery({ queryKey: usageKeys.trends(days), - queryFn: () => usageApi.getUsageTrends(days), + queryFn: () => { + const { startDate, endDate } = getWindow(days); + return usageApi.getUsageTrends(startDate, endDate); + }, + refetchInterval: 30000, // 每30秒自动刷新 + refetchIntervalInBackground: false, }); } @@ -38,6 +53,8 @@ export function useProviderStats() { return useQuery({ queryKey: usageKeys.providerStats(), queryFn: usageApi.getProviderStats, + refetchInterval: 30000, // 每30秒自动刷新 + refetchIntervalInBackground: false, }); } @@ -45,6 +62,8 @@ export function useModelStats() { return useQuery({ queryKey: usageKeys.modelStats(), queryFn: usageApi.getModelStats, + refetchInterval: 30000, // 每30秒自动刷新 + refetchIntervalInBackground: false, }); } @@ -56,6 +75,8 @@ export function useRequestLogs( return useQuery({ queryKey: usageKeys.logs(filters, page, pageSize), queryFn: () => usageApi.getRequestLogs(filters, page, pageSize), + refetchInterval: 30000, // 每30秒自动刷新 + refetchIntervalInBackground: false, }); }