From c5d3732b9fc70dae4e71cfe976b0e389d2e37cb8 Mon Sep 17 00:00:00 2001 From: YoVinchen Date: Sun, 11 Jan 2026 16:45:05 +0800 Subject: [PATCH] fix(proxy): harden error handling and input validation - Handle RwLock poisoning in settings.rs with unwrap_or_else - Add fallback for dirs::home_dir() in config modules - Normalize localhost to 127.0.0.1 in ProxyPanel - Format IPv6 addresses with brackets for valid URLs - Strict port validation with pure digit regex - Treat NaN as validation failure in config panels - Log warning on cost_multiplier parse failure - Align timeoutSeconds range to [0, 300] across all panels --- src-tauri/src/codex_config.rs | 10 ++++- src-tauri/src/config.rs | 16 ++++--- src-tauri/src/gemini_config.rs | 12 ++++-- src-tauri/src/proxy/handlers.rs | 7 +++- src-tauri/src/proxy/response_processor.rs | 7 +++- src-tauri/src/settings.rs | 18 ++++++-- .../proxy/AutoFailoverConfigPanel.tsx | 42 +++++++++---------- .../proxy/CircuitBreakerConfigPanel.tsx | 28 +++++++------ src/components/proxy/ProxyPanel.tsx | 37 ++++++++++++---- 9 files changed, 118 insertions(+), 59 deletions(-) diff --git a/src-tauri/src/codex_config.rs b/src-tauri/src/codex_config.rs index 1d74b433..040e5f0c 100644 --- a/src-tauri/src/codex_config.rs +++ b/src-tauri/src/codex_config.rs @@ -9,13 +9,21 @@ use serde_json::Value; use std::fs; use std::path::Path; +/// 获取用户主目录,带回退和日志 +fn get_home_dir() -> PathBuf { + dirs::home_dir().unwrap_or_else(|| { + log::warn!("无法获取用户主目录,回退到当前目录"); + PathBuf::from(".") + }) +} + /// 获取 Codex 配置目录路径 pub fn get_codex_config_dir() -> PathBuf { if let Some(custom) = crate::settings::get_codex_override_dir() { return custom; } - dirs::home_dir().expect("无法获取用户主目录").join(".codex") + get_home_dir().join(".codex") } /// 获取 Codex auth.json 路径 diff --git a/src-tauri/src/config.rs b/src-tauri/src/config.rs index 0b3cc9af..6eae3a2c 100644 --- a/src-tauri/src/config.rs +++ b/src-tauri/src/config.rs @@ -5,22 +5,26 @@ use std::path::{Path, PathBuf}; use crate::error::AppError; +/// 获取用户主目录,带回退和日志 +fn get_home_dir() -> PathBuf { + dirs::home_dir().unwrap_or_else(|| { + log::warn!("无法获取用户主目录,回退到当前目录"); + PathBuf::from(".") + }) +} + /// 获取 Claude Code 配置目录路径 pub fn get_claude_config_dir() -> PathBuf { if let Some(custom) = crate::settings::get_claude_override_dir() { return custom; } - dirs::home_dir() - .expect("无法获取用户主目录") - .join(".claude") + get_home_dir().join(".claude") } /// 默认 Claude MCP 配置文件路径 (~/.claude.json) pub fn get_default_claude_mcp_path() -> PathBuf { - dirs::home_dir() - .expect("无法获取用户主目录") - .join(".claude.json") + get_home_dir().join(".claude.json") } fn derive_mcp_path_from_override(dir: &Path) -> Option { diff --git a/src-tauri/src/gemini_config.rs b/src-tauri/src/gemini_config.rs index 8e453bab..6ea47666 100644 --- a/src-tauri/src/gemini_config.rs +++ b/src-tauri/src/gemini_config.rs @@ -5,15 +5,21 @@ use std::collections::HashMap; use std::fs; use std::path::PathBuf; +/// 获取用户主目录,带回退和日志 +fn get_home_dir() -> PathBuf { + dirs::home_dir().unwrap_or_else(|| { + log::warn!("无法获取用户主目录,回退到当前目录"); + PathBuf::from(".") + }) +} + /// 获取 Gemini 配置目录路径(支持设置覆盖) pub fn get_gemini_dir() -> PathBuf { if let Some(custom) = crate::settings::get_gemini_override_dir() { return custom; } - dirs::home_dir() - .expect("无法获取用户主目录") - .join(".gemini") + get_home_dir().join(".gemini") } /// 获取 Gemini .env 文件路径 diff --git a/src-tauri/src/proxy/handlers.rs b/src-tauri/src/proxy/handlers.rs index 286113dc..664b849b 100644 --- a/src-tauri/src/proxy/handlers.rs +++ b/src-tauri/src/proxy/handlers.rs @@ -456,7 +456,12 @@ async fn log_usage( Ok(Some(p)) => { if let Some(meta) = p.meta { if let Some(cm) = meta.cost_multiplier { - Decimal::from_str(&cm).unwrap_or(Decimal::from(1)) + Decimal::from_str(&cm).unwrap_or_else(|e| { + log::warn!( + "cost_multiplier 解析失败 (provider_id={provider_id}): {cm} - {e}" + ); + Decimal::from(1) + }) } else { Decimal::from(1) } diff --git a/src-tauri/src/proxy/response_processor.rs b/src-tauri/src/proxy/response_processor.rs index b393211f..f6a2eeb8 100644 --- a/src-tauri/src/proxy/response_processor.rs +++ b/src-tauri/src/proxy/response_processor.rs @@ -372,7 +372,12 @@ async fn log_usage_internal( Ok(Some(p)) => { if let Some(meta) = p.meta { if let Some(cm) = meta.cost_multiplier { - Decimal::from_str(&cm).unwrap_or(Decimal::from(1)) + Decimal::from_str(&cm).unwrap_or_else(|e| { + log::warn!( + "cost_multiplier 解析失败 (provider_id={provider_id}): {cm} - {e}" + ); + Decimal::from(1) + }) } else { Decimal::from(1) } diff --git a/src-tauri/src/settings.rs b/src-tauri/src/settings.rs index 3eb49547..21e15d52 100644 --- a/src-tauri/src/settings.rs +++ b/src-tauri/src/settings.rs @@ -194,14 +194,23 @@ fn resolve_override_path(raw: &str) -> PathBuf { } pub fn get_settings() -> AppSettings { - settings_store().read().expect("读取设置锁失败").clone() + settings_store() + .read() + .unwrap_or_else(|e| { + log::warn!("设置锁已毒化,使用恢复值: {e}"); + e.into_inner() + }) + .clone() } pub fn update_settings(mut new_settings: AppSettings) -> Result<(), AppError> { new_settings.normalize_paths(); save_settings_file(&new_settings)?; - let mut guard = settings_store().write().expect("写入设置锁失败"); + let mut guard = settings_store().write().unwrap_or_else(|e| { + log::warn!("设置锁已毒化,使用恢复值: {e}"); + e.into_inner() + }); *guard = new_settings; Ok(()) } @@ -210,7 +219,10 @@ pub fn update_settings(mut new_settings: AppSettings) -> Result<(), AppError> { /// 用于导入配置等场景,确保内存缓存与文件同步 pub fn reload_settings() -> Result<(), AppError> { let fresh_settings = AppSettings::load_from_file(); - let mut guard = settings_store().write().expect("写入设置锁失败"); + let mut guard = settings_store().write().unwrap_or_else(|e| { + log::warn!("设置锁已毒化,使用恢复值: {e}"); + e.into_inner() + }); *guard = fresh_settings; Ok(()) } diff --git a/src/components/proxy/AutoFailoverConfigPanel.tsx b/src/components/proxy/AutoFailoverConfigPanel.tsx index 2436c6ee..4186e358 100644 --- a/src/components/proxy/AutoFailoverConfigPanel.tsx +++ b/src/components/proxy/AutoFailoverConfigPanel.tsx @@ -56,10 +56,12 @@ export function AutoFailoverConfigPanel({ const handleSave = async () => { if (!config) return; - // 解析数字,空值使用默认值,0 是有效值 - const parseNum = (val: string, defaultVal: number) => { - const n = parseInt(val); - return isNaN(n) ? defaultVal : n; + // 解析数字,返回 NaN 表示无效输入 + const parseNum = (val: string) => { + const trimmed = val.trim(); + // 必须是纯数字 + if (!/^-?\d+$/.test(trimmed)) return NaN; + return parseInt(trimmed); }; // 定义各字段的有效范围 @@ -70,38 +72,32 @@ export function AutoFailoverConfigPanel({ nonStreamingTimeout: { min: 0, max: 1800 }, circuitFailureThreshold: { min: 1, max: 20 }, circuitSuccessThreshold: { min: 1, max: 10 }, - circuitTimeoutSeconds: { min: 10, max: 300 }, + circuitTimeoutSeconds: { min: 0, max: 300 }, circuitErrorRateThreshold: { min: 0, max: 100 }, circuitMinRequests: { min: 5, max: 100 }, }; // 解析原始值 const raw = { - maxRetries: parseNum(formData.maxRetries, 3), - streamingFirstByteTimeout: parseNum( - formData.streamingFirstByteTimeout, - 30, - ), - streamingIdleTimeout: parseNum(formData.streamingIdleTimeout, 60), - nonStreamingTimeout: parseNum(formData.nonStreamingTimeout, 300), - circuitFailureThreshold: parseNum(formData.circuitFailureThreshold, 5), - circuitSuccessThreshold: parseNum(formData.circuitSuccessThreshold, 2), - circuitTimeoutSeconds: parseNum(formData.circuitTimeoutSeconds, 60), - circuitErrorRateThreshold: parseNum( - formData.circuitErrorRateThreshold, - 50, - ), - circuitMinRequests: parseNum(formData.circuitMinRequests, 10), + maxRetries: parseNum(formData.maxRetries), + streamingFirstByteTimeout: parseNum(formData.streamingFirstByteTimeout), + streamingIdleTimeout: parseNum(formData.streamingIdleTimeout), + nonStreamingTimeout: parseNum(formData.nonStreamingTimeout), + circuitFailureThreshold: parseNum(formData.circuitFailureThreshold), + circuitSuccessThreshold: parseNum(formData.circuitSuccessThreshold), + circuitTimeoutSeconds: parseNum(formData.circuitTimeoutSeconds), + circuitErrorRateThreshold: parseNum(formData.circuitErrorRateThreshold), + circuitMinRequests: parseNum(formData.circuitMinRequests), }; - // 校验是否超出范围 + // 校验是否超出范围(NaN 也视为无效) const errors: string[] = []; const checkRange = ( value: number, range: { min: number; max: number }, label: string, ) => { - if (value < range.min || value > range.max) { + if (isNaN(value) || value < range.min || value > range.max) { errors.push(`${label}: ${range.min}-${range.max}`); } }; @@ -424,7 +420,7 @@ export function AutoFailoverConfigPanel({ diff --git a/src/components/proxy/CircuitBreakerConfigPanel.tsx b/src/components/proxy/CircuitBreakerConfigPanel.tsx index 9a0a55cc..a5776799 100644 --- a/src/components/proxy/CircuitBreakerConfigPanel.tsx +++ b/src/components/proxy/CircuitBreakerConfigPanel.tsx @@ -41,38 +41,40 @@ export function CircuitBreakerConfigPanel() { }, [config]); const handleSave = async () => { - // 解析数字,空值使用默认值,0 是有效值 - const parseNum = (val: string, defaultVal: number) => { - const n = parseInt(val); - return isNaN(n) ? defaultVal : n; + // 解析数字,返回 NaN 表示无效输入 + const parseNum = (val: string) => { + const trimmed = val.trim(); + // 必须是纯数字 + if (!/^-?\d+$/.test(trimmed)) return NaN; + return parseInt(trimmed); }; // 定义各字段的有效范围 const ranges = { failureThreshold: { min: 1, max: 20 }, successThreshold: { min: 1, max: 10 }, - timeoutSeconds: { min: 10, max: 300 }, + timeoutSeconds: { min: 0, max: 300 }, errorRateThreshold: { min: 0, max: 100 }, minRequests: { min: 5, max: 100 }, }; // 解析原始值 const raw = { - failureThreshold: parseNum(formData.failureThreshold, 5), - successThreshold: parseNum(formData.successThreshold, 2), - timeoutSeconds: parseNum(formData.timeoutSeconds, 60), - errorRateThreshold: parseNum(formData.errorRateThreshold, 50), - minRequests: parseNum(formData.minRequests, 10), + failureThreshold: parseNum(formData.failureThreshold), + successThreshold: parseNum(formData.successThreshold), + timeoutSeconds: parseNum(formData.timeoutSeconds), + errorRateThreshold: parseNum(formData.errorRateThreshold), + minRequests: parseNum(formData.minRequests), }; - // 校验是否超出范围 + // 校验是否超出范围(NaN 也视为无效) const errors: string[] = []; const checkRange = ( value: number, range: { min: number; max: number }, label: string, ) => { - if (value < range.min || value > range.max) { + if (isNaN(value) || value < range.min || value > range.max) { errors.push(`${label}: ${range.min}-${range.max}`); } }; @@ -183,7 +185,7 @@ export function CircuitBreakerConfigPanel() { diff --git a/src/components/proxy/ProxyPanel.tsx b/src/components/proxy/ProxyPanel.tsx index 702bd3cb..5bee8590 100644 --- a/src/components/proxy/ProxyPanel.tsx +++ b/src/components/proxy/ProxyPanel.tsx @@ -106,11 +106,13 @@ export function ProxyPanel() { // 校验地址格式(简单的 IP 地址或 localhost 校验) const addressTrimmed = listenAddress.trim(); const ipv4Regex = /^(\d{1,3}\.){3}\d{1,3}$/; + // 规范化 localhost 为 127.0.0.1 + const normalizedAddress = + addressTrimmed === "localhost" ? "127.0.0.1" : addressTrimmed; const isValidAddress = - addressTrimmed === "localhost" || - addressTrimmed === "0.0.0.0" || - (ipv4Regex.test(addressTrimmed) && - addressTrimmed.split(".").every((n) => { + normalizedAddress === "0.0.0.0" || + (ipv4Regex.test(normalizedAddress) && + normalizedAddress.split(".").every((n) => { const num = parseInt(n); return num >= 0 && num <= 255; })); @@ -124,7 +126,17 @@ export function ProxyPanel() { return; } - const port = parseInt(listenPort); + // 严格校验端口:必须是纯数字 + const portTrimmed = listenPort.trim(); + if (!/^\d+$/.test(portTrimmed)) { + toast.error( + t("proxy.settings.invalidPort", { + defaultValue: "端口无效,请输入 1024-65535 之间的数字", + }), + ); + return; + } + const port = parseInt(portTrimmed); if (isNaN(port) || port < 1024 || port > 65535) { toast.error( t("proxy.settings.invalidPort", { @@ -136,9 +148,11 @@ export function ProxyPanel() { try { await updateGlobalConfig.mutateAsync({ ...globalConfig, - listenAddress: addressTrimmed, + listenAddress: normalizedAddress, listenPort: port, }); + // 同步更新本地状态为规范化后的值 + setListenAddress(normalizedAddress); toast.success( t("proxy.settings.configSaved", { defaultValue: "代理配置已保存" }), { closeButton: true }, @@ -164,6 +178,13 @@ export function ProxyPanel() { } }; + // 格式化地址用于 URL(IPv6 需要方括号) + const formatAddressForUrl = (address: string, port: number): string => { + const isIPv6 = address.includes(":"); + const host = isIPv6 ? `[${address}]` : address; + return `http://${host}:${port}`; + }; + return ( <>
@@ -178,14 +199,14 @@ export function ProxyPanel() {

- http://{status.address}:{status.port} + {formatAddressForUrl(status.address, status.port)}