mirror of
https://github.com/farion1231/cc-switch.git
synced 2026-04-24 10:12:46 +08:00
fix(proxy): harden error handling and input validation
- Handle RwLock poisoning in settings.rs with unwrap_or_else - Add fallback for dirs::home_dir() in config modules - Normalize localhost to 127.0.0.1 in ProxyPanel - Format IPv6 addresses with brackets for valid URLs - Strict port validation with pure digit regex - Treat NaN as validation failure in config panels - Log warning on cost_multiplier parse failure - Align timeoutSeconds range to [0, 300] across all panels
This commit is contained in:
@@ -9,13 +9,21 @@ use serde_json::Value;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
/// 获取用户主目录,带回退和日志
|
||||
fn get_home_dir() -> PathBuf {
|
||||
dirs::home_dir().unwrap_or_else(|| {
|
||||
log::warn!("无法获取用户主目录,回退到当前目录");
|
||||
PathBuf::from(".")
|
||||
})
|
||||
}
|
||||
|
||||
/// 获取 Codex 配置目录路径
|
||||
pub fn get_codex_config_dir() -> PathBuf {
|
||||
if let Some(custom) = crate::settings::get_codex_override_dir() {
|
||||
return custom;
|
||||
}
|
||||
|
||||
dirs::home_dir().expect("无法获取用户主目录").join(".codex")
|
||||
get_home_dir().join(".codex")
|
||||
}
|
||||
|
||||
/// 获取 Codex auth.json 路径
|
||||
|
||||
@@ -5,22 +5,26 @@ use std::path::{Path, PathBuf};
|
||||
|
||||
use crate::error::AppError;
|
||||
|
||||
/// 获取用户主目录,带回退和日志
|
||||
fn get_home_dir() -> PathBuf {
|
||||
dirs::home_dir().unwrap_or_else(|| {
|
||||
log::warn!("无法获取用户主目录,回退到当前目录");
|
||||
PathBuf::from(".")
|
||||
})
|
||||
}
|
||||
|
||||
/// 获取 Claude Code 配置目录路径
|
||||
pub fn get_claude_config_dir() -> PathBuf {
|
||||
if let Some(custom) = crate::settings::get_claude_override_dir() {
|
||||
return custom;
|
||||
}
|
||||
|
||||
dirs::home_dir()
|
||||
.expect("无法获取用户主目录")
|
||||
.join(".claude")
|
||||
get_home_dir().join(".claude")
|
||||
}
|
||||
|
||||
/// 默认 Claude MCP 配置文件路径 (~/.claude.json)
|
||||
pub fn get_default_claude_mcp_path() -> PathBuf {
|
||||
dirs::home_dir()
|
||||
.expect("无法获取用户主目录")
|
||||
.join(".claude.json")
|
||||
get_home_dir().join(".claude.json")
|
||||
}
|
||||
|
||||
fn derive_mcp_path_from_override(dir: &Path) -> Option<PathBuf> {
|
||||
|
||||
@@ -5,15 +5,21 @@ use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// 获取用户主目录,带回退和日志
|
||||
fn get_home_dir() -> PathBuf {
|
||||
dirs::home_dir().unwrap_or_else(|| {
|
||||
log::warn!("无法获取用户主目录,回退到当前目录");
|
||||
PathBuf::from(".")
|
||||
})
|
||||
}
|
||||
|
||||
/// 获取 Gemini 配置目录路径(支持设置覆盖)
|
||||
pub fn get_gemini_dir() -> PathBuf {
|
||||
if let Some(custom) = crate::settings::get_gemini_override_dir() {
|
||||
return custom;
|
||||
}
|
||||
|
||||
dirs::home_dir()
|
||||
.expect("无法获取用户主目录")
|
||||
.join(".gemini")
|
||||
get_home_dir().join(".gemini")
|
||||
}
|
||||
|
||||
/// 获取 Gemini .env 文件路径
|
||||
|
||||
@@ -456,7 +456,12 @@ async fn log_usage(
|
||||
Ok(Some(p)) => {
|
||||
if let Some(meta) = p.meta {
|
||||
if let Some(cm) = meta.cost_multiplier {
|
||||
Decimal::from_str(&cm).unwrap_or(Decimal::from(1))
|
||||
Decimal::from_str(&cm).unwrap_or_else(|e| {
|
||||
log::warn!(
|
||||
"cost_multiplier 解析失败 (provider_id={provider_id}): {cm} - {e}"
|
||||
);
|
||||
Decimal::from(1)
|
||||
})
|
||||
} else {
|
||||
Decimal::from(1)
|
||||
}
|
||||
|
||||
@@ -372,7 +372,12 @@ async fn log_usage_internal(
|
||||
Ok(Some(p)) => {
|
||||
if let Some(meta) = p.meta {
|
||||
if let Some(cm) = meta.cost_multiplier {
|
||||
Decimal::from_str(&cm).unwrap_or(Decimal::from(1))
|
||||
Decimal::from_str(&cm).unwrap_or_else(|e| {
|
||||
log::warn!(
|
||||
"cost_multiplier 解析失败 (provider_id={provider_id}): {cm} - {e}"
|
||||
);
|
||||
Decimal::from(1)
|
||||
})
|
||||
} else {
|
||||
Decimal::from(1)
|
||||
}
|
||||
|
||||
@@ -194,14 +194,23 @@ fn resolve_override_path(raw: &str) -> PathBuf {
|
||||
}
|
||||
|
||||
pub fn get_settings() -> AppSettings {
|
||||
settings_store().read().expect("读取设置锁失败").clone()
|
||||
settings_store()
|
||||
.read()
|
||||
.unwrap_or_else(|e| {
|
||||
log::warn!("设置锁已毒化,使用恢复值: {e}");
|
||||
e.into_inner()
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
|
||||
pub fn update_settings(mut new_settings: AppSettings) -> Result<(), AppError> {
|
||||
new_settings.normalize_paths();
|
||||
save_settings_file(&new_settings)?;
|
||||
|
||||
let mut guard = settings_store().write().expect("写入设置锁失败");
|
||||
let mut guard = settings_store().write().unwrap_or_else(|e| {
|
||||
log::warn!("设置锁已毒化,使用恢复值: {e}");
|
||||
e.into_inner()
|
||||
});
|
||||
*guard = new_settings;
|
||||
Ok(())
|
||||
}
|
||||
@@ -210,7 +219,10 @@ pub fn update_settings(mut new_settings: AppSettings) -> Result<(), AppError> {
|
||||
/// 用于导入配置等场景,确保内存缓存与文件同步
|
||||
pub fn reload_settings() -> Result<(), AppError> {
|
||||
let fresh_settings = AppSettings::load_from_file();
|
||||
let mut guard = settings_store().write().expect("写入设置锁失败");
|
||||
let mut guard = settings_store().write().unwrap_or_else(|e| {
|
||||
log::warn!("设置锁已毒化,使用恢复值: {e}");
|
||||
e.into_inner()
|
||||
});
|
||||
*guard = fresh_settings;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -56,10 +56,12 @@ export function AutoFailoverConfigPanel({
|
||||
|
||||
const handleSave = async () => {
|
||||
if (!config) return;
|
||||
// 解析数字,空值使用默认值,0 是有效值
|
||||
const parseNum = (val: string, defaultVal: number) => {
|
||||
const n = parseInt(val);
|
||||
return isNaN(n) ? defaultVal : n;
|
||||
// 解析数字,返回 NaN 表示无效输入
|
||||
const parseNum = (val: string) => {
|
||||
const trimmed = val.trim();
|
||||
// 必须是纯数字
|
||||
if (!/^-?\d+$/.test(trimmed)) return NaN;
|
||||
return parseInt(trimmed);
|
||||
};
|
||||
|
||||
// 定义各字段的有效范围
|
||||
@@ -70,38 +72,32 @@ export function AutoFailoverConfigPanel({
|
||||
nonStreamingTimeout: { min: 0, max: 1800 },
|
||||
circuitFailureThreshold: { min: 1, max: 20 },
|
||||
circuitSuccessThreshold: { min: 1, max: 10 },
|
||||
circuitTimeoutSeconds: { min: 10, max: 300 },
|
||||
circuitTimeoutSeconds: { min: 0, max: 300 },
|
||||
circuitErrorRateThreshold: { min: 0, max: 100 },
|
||||
circuitMinRequests: { min: 5, max: 100 },
|
||||
};
|
||||
|
||||
// 解析原始值
|
||||
const raw = {
|
||||
maxRetries: parseNum(formData.maxRetries, 3),
|
||||
streamingFirstByteTimeout: parseNum(
|
||||
formData.streamingFirstByteTimeout,
|
||||
30,
|
||||
),
|
||||
streamingIdleTimeout: parseNum(formData.streamingIdleTimeout, 60),
|
||||
nonStreamingTimeout: parseNum(formData.nonStreamingTimeout, 300),
|
||||
circuitFailureThreshold: parseNum(formData.circuitFailureThreshold, 5),
|
||||
circuitSuccessThreshold: parseNum(formData.circuitSuccessThreshold, 2),
|
||||
circuitTimeoutSeconds: parseNum(formData.circuitTimeoutSeconds, 60),
|
||||
circuitErrorRateThreshold: parseNum(
|
||||
formData.circuitErrorRateThreshold,
|
||||
50,
|
||||
),
|
||||
circuitMinRequests: parseNum(formData.circuitMinRequests, 10),
|
||||
maxRetries: parseNum(formData.maxRetries),
|
||||
streamingFirstByteTimeout: parseNum(formData.streamingFirstByteTimeout),
|
||||
streamingIdleTimeout: parseNum(formData.streamingIdleTimeout),
|
||||
nonStreamingTimeout: parseNum(formData.nonStreamingTimeout),
|
||||
circuitFailureThreshold: parseNum(formData.circuitFailureThreshold),
|
||||
circuitSuccessThreshold: parseNum(formData.circuitSuccessThreshold),
|
||||
circuitTimeoutSeconds: parseNum(formData.circuitTimeoutSeconds),
|
||||
circuitErrorRateThreshold: parseNum(formData.circuitErrorRateThreshold),
|
||||
circuitMinRequests: parseNum(formData.circuitMinRequests),
|
||||
};
|
||||
|
||||
// 校验是否超出范围
|
||||
// 校验是否超出范围(NaN 也视为无效)
|
||||
const errors: string[] = [];
|
||||
const checkRange = (
|
||||
value: number,
|
||||
range: { min: number; max: number },
|
||||
label: string,
|
||||
) => {
|
||||
if (value < range.min || value > range.max) {
|
||||
if (isNaN(value) || value < range.min || value > range.max) {
|
||||
errors.push(`${label}: ${range.min}-${range.max}`);
|
||||
}
|
||||
};
|
||||
@@ -424,7 +420,7 @@ export function AutoFailoverConfigPanel({
|
||||
<Input
|
||||
id={`timeoutSeconds-${appType}`}
|
||||
type="number"
|
||||
min="10"
|
||||
min="0"
|
||||
max="300"
|
||||
value={formData.circuitTimeoutSeconds}
|
||||
onChange={(e) =>
|
||||
|
||||
@@ -41,38 +41,40 @@ export function CircuitBreakerConfigPanel() {
|
||||
}, [config]);
|
||||
|
||||
const handleSave = async () => {
|
||||
// 解析数字,空值使用默认值,0 是有效值
|
||||
const parseNum = (val: string, defaultVal: number) => {
|
||||
const n = parseInt(val);
|
||||
return isNaN(n) ? defaultVal : n;
|
||||
// 解析数字,返回 NaN 表示无效输入
|
||||
const parseNum = (val: string) => {
|
||||
const trimmed = val.trim();
|
||||
// 必须是纯数字
|
||||
if (!/^-?\d+$/.test(trimmed)) return NaN;
|
||||
return parseInt(trimmed);
|
||||
};
|
||||
|
||||
// 定义各字段的有效范围
|
||||
const ranges = {
|
||||
failureThreshold: { min: 1, max: 20 },
|
||||
successThreshold: { min: 1, max: 10 },
|
||||
timeoutSeconds: { min: 10, max: 300 },
|
||||
timeoutSeconds: { min: 0, max: 300 },
|
||||
errorRateThreshold: { min: 0, max: 100 },
|
||||
minRequests: { min: 5, max: 100 },
|
||||
};
|
||||
|
||||
// 解析原始值
|
||||
const raw = {
|
||||
failureThreshold: parseNum(formData.failureThreshold, 5),
|
||||
successThreshold: parseNum(formData.successThreshold, 2),
|
||||
timeoutSeconds: parseNum(formData.timeoutSeconds, 60),
|
||||
errorRateThreshold: parseNum(formData.errorRateThreshold, 50),
|
||||
minRequests: parseNum(formData.minRequests, 10),
|
||||
failureThreshold: parseNum(formData.failureThreshold),
|
||||
successThreshold: parseNum(formData.successThreshold),
|
||||
timeoutSeconds: parseNum(formData.timeoutSeconds),
|
||||
errorRateThreshold: parseNum(formData.errorRateThreshold),
|
||||
minRequests: parseNum(formData.minRequests),
|
||||
};
|
||||
|
||||
// 校验是否超出范围
|
||||
// 校验是否超出范围(NaN 也视为无效)
|
||||
const errors: string[] = [];
|
||||
const checkRange = (
|
||||
value: number,
|
||||
range: { min: number; max: number },
|
||||
label: string,
|
||||
) => {
|
||||
if (value < range.min || value > range.max) {
|
||||
if (isNaN(value) || value < range.min || value > range.max) {
|
||||
errors.push(`${label}: ${range.min}-${range.max}`);
|
||||
}
|
||||
};
|
||||
@@ -183,7 +185,7 @@ export function CircuitBreakerConfigPanel() {
|
||||
<Input
|
||||
id="timeoutSeconds"
|
||||
type="number"
|
||||
min="10"
|
||||
min="0"
|
||||
max="300"
|
||||
value={formData.timeoutSeconds}
|
||||
onChange={(e) =>
|
||||
|
||||
@@ -106,11 +106,13 @@ export function ProxyPanel() {
|
||||
// 校验地址格式(简单的 IP 地址或 localhost 校验)
|
||||
const addressTrimmed = listenAddress.trim();
|
||||
const ipv4Regex = /^(\d{1,3}\.){3}\d{1,3}$/;
|
||||
// 规范化 localhost 为 127.0.0.1
|
||||
const normalizedAddress =
|
||||
addressTrimmed === "localhost" ? "127.0.0.1" : addressTrimmed;
|
||||
const isValidAddress =
|
||||
addressTrimmed === "localhost" ||
|
||||
addressTrimmed === "0.0.0.0" ||
|
||||
(ipv4Regex.test(addressTrimmed) &&
|
||||
addressTrimmed.split(".").every((n) => {
|
||||
normalizedAddress === "0.0.0.0" ||
|
||||
(ipv4Regex.test(normalizedAddress) &&
|
||||
normalizedAddress.split(".").every((n) => {
|
||||
const num = parseInt(n);
|
||||
return num >= 0 && num <= 255;
|
||||
}));
|
||||
@@ -124,7 +126,17 @@ export function ProxyPanel() {
|
||||
return;
|
||||
}
|
||||
|
||||
const port = parseInt(listenPort);
|
||||
// 严格校验端口:必须是纯数字
|
||||
const portTrimmed = listenPort.trim();
|
||||
if (!/^\d+$/.test(portTrimmed)) {
|
||||
toast.error(
|
||||
t("proxy.settings.invalidPort", {
|
||||
defaultValue: "端口无效,请输入 1024-65535 之间的数字",
|
||||
}),
|
||||
);
|
||||
return;
|
||||
}
|
||||
const port = parseInt(portTrimmed);
|
||||
if (isNaN(port) || port < 1024 || port > 65535) {
|
||||
toast.error(
|
||||
t("proxy.settings.invalidPort", {
|
||||
@@ -136,9 +148,11 @@ export function ProxyPanel() {
|
||||
try {
|
||||
await updateGlobalConfig.mutateAsync({
|
||||
...globalConfig,
|
||||
listenAddress: addressTrimmed,
|
||||
listenAddress: normalizedAddress,
|
||||
listenPort: port,
|
||||
});
|
||||
// 同步更新本地状态为规范化后的值
|
||||
setListenAddress(normalizedAddress);
|
||||
toast.success(
|
||||
t("proxy.settings.configSaved", { defaultValue: "代理配置已保存" }),
|
||||
{ closeButton: true },
|
||||
@@ -164,6 +178,13 @@ export function ProxyPanel() {
|
||||
}
|
||||
};
|
||||
|
||||
// 格式化地址用于 URL(IPv6 需要方括号)
|
||||
const formatAddressForUrl = (address: string, port: number): string => {
|
||||
const isIPv6 = address.includes(":");
|
||||
const host = isIPv6 ? `[${address}]` : address;
|
||||
return `http://${host}:${port}`;
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<section className="space-y-6">
|
||||
@@ -178,14 +199,14 @@ export function ProxyPanel() {
|
||||
</p>
|
||||
<div className="flex flex-col gap-2 sm:flex-row sm:items-center">
|
||||
<code className="flex-1 text-sm bg-background px-3 py-2 rounded border border-border/60">
|
||||
http://{status.address}:{status.port}
|
||||
{formatAddressForUrl(status.address, status.port)}
|
||||
</code>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={() => {
|
||||
navigator.clipboard.writeText(
|
||||
`http://${status.address}:${status.port}`,
|
||||
formatAddressForUrl(status.address, status.port),
|
||||
);
|
||||
toast.success(
|
||||
t("proxy.panel.addressCopied", {
|
||||
|
||||
Reference in New Issue
Block a user