diff --git a/src-tauri/src/commands/provider.rs b/src-tauri/src/commands/provider.rs index 03157d0d..64b2e5be 100644 --- a/src-tauri/src/commands/provider.rs +++ b/src-tauri/src/commands/provider.rs @@ -133,6 +133,7 @@ pub async fn testUsageScript( #[allow(non_snake_case)] baseUrl: Option, #[allow(non_snake_case)] accessToken: Option, #[allow(non_snake_case)] userId: Option, + #[allow(non_snake_case)] templateType: Option, ) -> Result { let app_type = AppType::from_str(&app).map_err(|e| e.to_string())?; ProviderService::test_usage_script( @@ -145,6 +146,7 @@ pub async fn testUsageScript( baseUrl.as_deref(), accessToken.as_deref(), userId.as_deref(), + templateType.as_deref(), ) .await .map_err(|e| e.to_string()) diff --git a/src-tauri/src/deeplink/provider.rs b/src-tauri/src/deeplink/provider.rs index e69c7c35..88f5709e 100644 --- a/src-tauri/src/deeplink/provider.rs +++ b/src-tauri/src/deeplink/provider.rs @@ -225,6 +225,7 @@ fn build_provider_meta(request: &DeepLinkImportRequest) -> Result, + /// 模板类型(用于后端判断验证规则) + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "templateType")] + pub template_type: Option, /// 自动查询间隔(单位:分钟,0 表示禁用自动查询) #[serde(skip_serializing_if = "Option::is_none")] #[serde(rename = "autoQueryInterval")] diff --git a/src-tauri/src/services/provider/mod.rs b/src-tauri/src/services/provider/mod.rs index a7df4073..7b45b4e6 100644 --- a/src-tauri/src/services/provider/mod.rs +++ b/src-tauri/src/services/provider/mod.rs @@ -615,6 +615,7 @@ impl ProviderService { base_url: Option<&str>, access_token: Option<&str>, user_id: Option<&str>, + template_type: Option<&str>, ) -> Result { usage::test_usage_script( state, @@ -626,6 +627,7 @@ impl ProviderService { base_url, access_token, user_id, + template_type, ) .await } diff --git a/src-tauri/src/services/provider/usage.rs b/src-tauri/src/services/provider/usage.rs index 6ff94ca9..5d567dc7 100644 --- a/src-tauri/src/services/provider/usage.rs +++ b/src-tauri/src/services/provider/usage.rs @@ -17,6 +17,7 @@ pub(crate) async fn execute_and_format_usage_result( timeout: u64, access_token: Option<&str>, user_id: Option<&str>, + template_type: Option<&str>, ) -> Result { match usage_script::execute_usage_script( script_code, @@ -25,6 +26,7 @@ pub(crate) async fn execute_and_format_usage_result( timeout, access_token, user_id, + template_type, ) .await { @@ -113,7 +115,7 @@ pub async fn query_usage( app_type: AppType, provider_id: &str, ) -> Result { - let (script_code, timeout, api_key, base_url, access_token, user_id) = { + let (script_code, timeout, api_key, base_url, access_token, user_id, template_type) = { let providers = state.db.get_all_providers(app_type.as_str())?; let provider = providers.get(provider_id).ok_or_else(|| { AppError::localized( @@ -164,6 +166,7 @@ pub async fn query_usage( base_url, usage_script.access_token.clone(), usage_script.user_id.clone(), + usage_script.template_type.clone(), ) }; @@ -174,6 +177,7 @@ pub async fn query_usage( timeout, access_token.as_deref(), user_id.as_deref(), + template_type.as_deref(), ) .await } @@ -190,6 +194,7 @@ pub async fn test_usage_script( base_url: Option<&str>, access_token: Option<&str>, user_id: Option<&str>, + template_type: Option<&str>, ) -> Result { // Use provided credential parameters directly for testing execute_and_format_usage_result( @@ -199,6 +204,7 @@ pub async fn test_usage_script( timeout, access_token, user_id, + template_type, ) .await } diff --git a/src-tauri/src/usage_script.rs b/src-tauri/src/usage_script.rs index 3da7ac8e..f9afdb22 100644 --- a/src-tauri/src/usage_script.rs +++ b/src-tauri/src/usage_script.rs @@ -13,13 +13,21 @@ pub async fn execute_usage_script( timeout_secs: u64, access_token: Option<&str>, user_id: Option<&str>, + template_type: Option<&str>, ) -> Result { + // 检测是否为自定义模板模式 + // 优先使用前端传递的 template_type + let is_custom_template = template_type.map(|t| t == "custom").unwrap_or(false); + // 1. 替换模板变量,避免泄露敏感信息 let script_with_vars = build_script_with_vars(script_code, api_key, base_url, access_token, user_id); - // 2. 验证 base_url 的安全性 - validate_base_url(base_url)?; + // 2. 验证 base_url 的安全性(仅当提供了 base_url 时) + // 自定义模板模式下,用户可能不使用模板变量,而是直接在脚本中写完整 URL + if !base_url.is_empty() { + validate_base_url(base_url)?; + } // 3. 在独立作用域中提取 request 配置(确保 Runtime/Context 在 await 前释放) let request_config = { @@ -97,7 +105,8 @@ pub async fn execute_usage_script( })?; // 5. 验证请求 URL 是否安全(防止 SSRF) - validate_request_url(&request.url, base_url)?; + // 如果提供了 base_url,则验证同源;否则只做基本安全检查 + validate_request_url(&request.url, base_url, is_custom_template)?; // 6. 发送 HTTP 请求 let response_data = send_http_request(&request, timeout_secs).await?; @@ -472,7 +481,11 @@ fn validate_base_url(base_url: &str) -> Result<(), AppError> { } /// 验证请求 URL 是否安全(防止 SSRF) -fn validate_request_url(request_url: &str, base_url: &str) -> Result<(), AppError> { +fn validate_request_url( + request_url: &str, + base_url: &str, + is_custom_template: bool, +) -> Result<(), AppError> { // 解析请求 URL let parsed_request = Url::parse(request_url).map_err(|e| { AppError::localized( @@ -482,19 +495,11 @@ fn validate_request_url(request_url: &str, base_url: &str) -> Result<(), AppErro ) })?; - // 解析 base URL - let parsed_base = Url::parse(base_url).map_err(|e| { - AppError::localized( - "usage_script.base_url_invalid", - format!("无效的 base_url: {e}"), - format!("Invalid base_url: {e}"), - ) - })?; - let is_request_loopback = is_loopback_host(&parsed_request); // 必须使用 HTTPS(允许 localhost 用于开发) - if parsed_request.scheme() != "https" && !is_request_loopback { + // 自定义模板模式下,允许用户自行决定是否使用 HTTP(用户需自行承担安全风险) + if !is_custom_template && parsed_request.scheme() != "https" && !is_request_loopback { return Err(AppError::localized( "usage_script.request_https_required", "请求 URL 必须使用 HTTPS 协议(localhost 除外)", @@ -502,60 +507,85 @@ fn validate_request_url(request_url: &str, base_url: &str) -> Result<(), AppErro )); } - // 核心安全检查:必须与 base_url 同源(相同域名和端口) - if parsed_request.host_str() != parsed_base.host_str() { - return Err(AppError::localized( - "usage_script.request_host_mismatch", - format!( - "请求域名 {} 与 base_url 域名 {} 不匹配(必须是同源请求)", - parsed_request.host_str().unwrap_or("unknown"), - parsed_base.host_str().unwrap_or("unknown") - ), - format!( - "Request host {} must match base_url host {} (same-origin required)", - parsed_request.host_str().unwrap_or("unknown"), - parsed_base.host_str().unwrap_or("unknown") - ), - )); - } + // 如果提供了 base_url(非空),则进行同源检查 + // 🔧 自定义模板模式下,用户可以自由访问任意 HTTPS 域名,跳过同源检查 + if !base_url.is_empty() && !is_custom_template { + // 解析 base URL + let parsed_base = Url::parse(base_url).map_err(|e| { + AppError::localized( + "usage_script.base_url_invalid", + format!("无效的 base_url: {e}"), + format!("Invalid base_url: {e}"), + ) + })?; - // 检查端口是否匹配(考虑默认端口) - // 使用 port_or_known_default() 会自动处理默认端口(http->80, https->443) - match ( - parsed_request.port_or_known_default(), - parsed_base.port_or_known_default(), - ) { - (Some(request_port), Some(base_port)) if request_port == base_port => { - // 端口匹配,继续执行 - } - (Some(request_port), Some(base_port)) => { + // 核心安全检查:必须与 base_url 同源(相同域名和端口) + if parsed_request.host_str() != parsed_base.host_str() { return Err(AppError::localized( - "usage_script.request_port_mismatch", - format!("请求端口 {request_port} 必须与 base_url 端口 {base_port} 匹配"), - format!("Request port {request_port} must match base_url port {base_port}"), + "usage_script.request_host_mismatch", + format!( + "请求域名 {} 与 base_url 域名 {} 不匹配(必须是同源请求)", + parsed_request.host_str().unwrap_or("unknown"), + parsed_base.host_str().unwrap_or("unknown") + ), + format!( + "Request host {} must match base_url host {} (same-origin required)", + parsed_request.host_str().unwrap_or("unknown"), + parsed_base.host_str().unwrap_or("unknown") + ), )); } - _ => { - // 理论上不会发生,因为 port_or_known_default() 应该总是返回 Some - return Err(AppError::localized( - "usage_script.request_port_unknown", - "无法确定端口号", - "Unable to determine port number", - )); + + // 检查端口是否匹配(考虑默认端口) + // 使用 port_or_known_default() 会自动处理默认端口(http->80, https->443) + match ( + parsed_request.port_or_known_default(), + parsed_base.port_or_known_default(), + ) { + (Some(request_port), Some(base_port)) if request_port == base_port => { + // 端口匹配,继续执行 + } + (Some(request_port), Some(base_port)) => { + return Err(AppError::localized( + "usage_script.request_port_mismatch", + format!("请求端口 {request_port} 必须与 base_url 端口 {base_port} 匹配"), + format!("Request port {request_port} must match base_url port {base_port}"), + )); + } + _ => { + // 理论上不会发生,因为 port_or_known_default() 应该总是返回 Some + return Err(AppError::localized( + "usage_script.request_port_unknown", + "无法确定端口号", + "Unable to determine port number", + )); + } } - } - // 禁止私有 IP 地址访问(除非 base_url 本身就是私有地址,用于开发环境) - if let Some(host) = parsed_request.host_str() { - let base_host = parsed_base.host_str().unwrap_or(""); + // 禁止私有 IP 地址访问(除非 base_url 本身就是私有地址,用于开发环境) + if let Some(host) = parsed_request.host_str() { + let base_host = parsed_base.host_str().unwrap_or(""); - // 如果 base_url 不是私有地址,则禁止访问私有IP - if !is_private_ip(base_host) && is_private_ip(host) { - return Err(AppError::localized( - "usage_script.private_ip_blocked", - "禁止访问私有 IP 地址", - "Access to private IP addresses is blocked", - )); + // 如果 base_url 不是私有地址,则禁止访问私有IP + if !is_private_ip(base_host) && is_private_ip(host) { + return Err(AppError::localized( + "usage_script.private_ip_blocked", + "禁止访问私有 IP 地址", + "Access to private IP addresses is blocked", + )); + } + } + } else { + // 自定义模板模式:没有 base_url,需要额外的安全检查 + // 禁止访问私有 IP 地址(SSRF 防护) + if let Some(host) = parsed_request.host_str() { + if is_private_ip(host) && !is_request_loopback { + return Err(AppError::localized( + "usage_script.private_ip_blocked", + "禁止访问私有 IP 地址(localhost 除外)", + "Access to private IP addresses is blocked (localhost allowed)", + )); + } } } @@ -843,7 +873,7 @@ mod tests { ]; for (base_url, request_url, should_match) in test_cases { - let result = validate_request_url(request_url, base_url); + let result = validate_request_url(request_url, base_url, false); if should_match { assert!( diff --git a/src/components/UsageScriptModal.tsx b/src/components/UsageScriptModal.tsx index c23dbb55..6bbbb90d 100644 --- a/src/components/UsageScriptModal.tsx +++ b/src/components/UsageScriptModal.tsx @@ -2,8 +2,10 @@ import React, { useState } from "react"; import { Play, Wand2, Eye, EyeOff, Save } from "lucide-react"; import { toast } from "sonner"; import { useTranslation } from "react-i18next"; +import { useQueryClient } from "@tanstack/react-query"; import { Provider, UsageScript, UsageData } from "@/types"; import { usageApi, type AppId } from "@/lib/api"; +import { extractCodexBaseUrl } from "@/utils/providerConfigUtils"; import JsonEditor from "./JsonEditor"; import * as prettier from "prettier/standalone"; import * as parserBabel from "prettier/parser-babel"; @@ -109,19 +111,67 @@ const UsageScriptModal: React.FC = ({ onSave, }) => { const { t } = useTranslation(); + const queryClient = useQueryClient(); // 生成带国际化的预设模板 const PRESET_TEMPLATES = generatePresetTemplates(t); - const [script, setScript] = useState(() => { - return ( - provider.meta?.usage_script || { - enabled: false, - language: "javascript", - code: PRESET_TEMPLATES[TEMPLATE_KEYS.GENERAL], - timeout: 10, + // 从 provider 的 settingsConfig 中提取 API Key 和 Base URL + const getProviderCredentials = (): { + apiKey: string | undefined; + baseUrl: string | undefined; + } => { + try { + const config = provider.settingsConfig; + if (!config) return { apiKey: undefined, baseUrl: undefined }; + + // 处理不同应用的配置格式 + if (appId === "claude") { + // Claude: { env: { ANTHROPIC_AUTH_TOKEN | ANTHROPIC_API_KEY, ANTHROPIC_BASE_URL } } + const env = (config as any).env || {}; + return { + apiKey: env.ANTHROPIC_AUTH_TOKEN || env.ANTHROPIC_API_KEY, + baseUrl: env.ANTHROPIC_BASE_URL, + }; + } else if (appId === "codex") { + // Codex: { auth: { OPENAI_API_KEY }, config: TOML string with base_url } + const auth = (config as any).auth || {}; + const configToml = (config as any).config || ""; + return { + apiKey: auth.OPENAI_API_KEY, + baseUrl: extractCodexBaseUrl(configToml), + }; + } else if (appId === "gemini") { + // Gemini: { env: { GEMINI_API_KEY, GOOGLE_GEMINI_BASE_URL } } + const env = (config as any).env || {}; + return { + apiKey: env.GEMINI_API_KEY, + baseUrl: env.GOOGLE_GEMINI_BASE_URL, + }; } - ); + return { apiKey: undefined, baseUrl: undefined }; + } catch (error) { + console.error("Failed to extract provider credentials:", error); + return { apiKey: undefined, baseUrl: undefined }; + } + }; + + const providerCredentials = getProviderCredentials(); + + const [script, setScript] = useState(() => { + const savedScript = provider.meta?.usage_script; + const defaultScript = { + enabled: false, + language: "javascript" as const, + code: PRESET_TEMPLATES[TEMPLATE_KEYS.GENERAL], + timeout: 10, + }; + + if (!savedScript) { + return defaultScript; + } + + return savedScript; }); const [testing, setTesting] = useState(false); @@ -176,6 +226,11 @@ const UsageScriptModal: React.FC = ({ const [selectedTemplate, setSelectedTemplate] = useState( () => { const existingScript = provider.meta?.usage_script; + // 优先使用保存的 templateType + if (existingScript?.templateType) { + return existingScript.templateType; + } + // 向后兼容:根据字段推断模板类型 // 检测 NEW_API 模板(有 accessToken 或 userId) if (existingScript?.accessToken || existingScript?.userId) { return TEMPLATE_KEYS.NEW_API; @@ -201,7 +256,16 @@ const UsageScriptModal: React.FC = ({ toast.error(t("usageScript.mustHaveReturn"), { duration: 5000 }); return; } - onSave(script); + // 保存时记录当前选择的模板类型 + const scriptWithTemplate = { + ...script, + templateType: selectedTemplate as + | "custom" + | "general" + | "newapi" + | undefined, + }; + onSave(scriptWithTemplate); onClose(); }; @@ -217,6 +281,7 @@ const UsageScriptModal: React.FC = ({ script.baseUrl, script.accessToken, script.userId, + selectedTemplate as "custom" | "general" | "newapi" | undefined, ); if (result.success && result.data && result.data.length > 0) { const summary = result.data @@ -229,6 +294,9 @@ const UsageScriptModal: React.FC = ({ duration: 3000, closeButton: true, }); + + // 🔧 测试成功后,更新主界面列表的用量查询缓存 + queryClient.setQueryData(["usage", provider.id, appId], result); } else { toast.error( `${t("usageScript.testFailed")}: ${result.error || t("endpointTest.noResult")}`, @@ -278,9 +346,13 @@ const UsageScriptModal: React.FC = ({ const preset = PRESET_TEMPLATES[presetName]; if (preset) { if (presetName === TEMPLATE_KEYS.CUSTOM) { + // 🔧 自定义模式:用户应该在脚本中直接写完整 URL 和凭证,而不是依赖变量替换 + // 这样可以避免同源检查导致的问题 + // 如果用户想使用变量,需要手动在配置中设置 baseUrl/apiKey setScript({ ...script, code: preset, + // 清除凭证,用户可选择手动输入或保持空 apiKey: undefined, baseUrl: undefined, accessToken: undefined, @@ -401,6 +473,74 @@ const UsageScriptModal: React.FC = ({ })} + {/* 自定义模式:变量提示和具体值 */} + {selectedTemplate === TEMPLATE_KEYS.CUSTOM && ( +
+

+ {t("usageScript.supportedVariables")} +

+
+ {/* baseUrl */} +
+ + {"{{baseUrl}}"} + + = + {providerCredentials.baseUrl ? ( + + {providerCredentials.baseUrl} + + ) : ( + + {t("common.notSet") || "未设置"} + + )} +
+ + {/* apiKey */} +
+ + {"{{apiKey}}"} + + = + {providerCredentials.apiKey ? ( + <> + {showApiKey ? ( + + {providerCredentials.apiKey} + + ) : ( + + •••••••• + + )} + + + ) : ( + + {t("common.notSet") || "未设置"} + + )} +
+
+
+ )} + {/* 凭证配置 */} {shouldShowCredentialsConfig && (
@@ -601,11 +741,13 @@ const UsageScriptModal: React.FC = ({ type="number" min={0} max={1440} - value={script.autoIntervalMinutes ?? 0} + value={ + script.autoQueryInterval ?? script.autoIntervalMinutes ?? 0 + } onChange={(e) => setScript({ ...script, - autoIntervalMinutes: validateAndClampInterval( + autoQueryInterval: validateAndClampInterval( e.target.value, ), }) @@ -613,7 +755,7 @@ const UsageScriptModal: React.FC = ({ onBlur={(e) => setScript({ ...script, - autoIntervalMinutes: validateAndClampInterval( + autoQueryInterval: validateAndClampInterval( e.target.value, ), }) diff --git a/src/hooks/useProviderActions.ts b/src/hooks/useProviderActions.ts index 18739ca5..a8322b4c 100644 --- a/src/hooks/useProviderActions.ts +++ b/src/hooks/useProviderActions.ts @@ -115,6 +115,11 @@ export function useProviderActions(activeApp: AppId) { await queryClient.invalidateQueries({ queryKey: ["providers", activeApp], }); + // 🔧 保存用量脚本后,也应该失效该 provider 的用量查询缓存 + // 这样主页列表会使用新配置重新查询,而不是使用测试时的缓存 + await queryClient.invalidateQueries({ + queryKey: ["usage", provider.id, activeApp], + }); toast.success( t("provider.usageSaved", { defaultValue: "用量查询配置已保存", diff --git a/src/i18n/locales/en.json b/src/i18n/locales/en.json index 017d34fa..d8cc43e0 100644 --- a/src/i18n/locales/en.json +++ b/src/i18n/locales/en.json @@ -584,6 +584,7 @@ "testFailed": "Test failed", "formatSuccess": "Format successful", "formatFailed": "Format failed", + "supportedVariables": "Supported Variables", "variablesHint": "Supported variables: {{apiKey}}, {{baseUrl}} | extractor function receives API response JSON object", "scriptConfig": "Request configuration", "extractorCode": "Extractor code", diff --git a/src/i18n/locales/ja.json b/src/i18n/locales/ja.json index ba3549f1..9b726f5a 100644 --- a/src/i18n/locales/ja.json +++ b/src/i18n/locales/ja.json @@ -584,6 +584,7 @@ "testFailed": "テストに失敗しました", "formatSuccess": "整形に成功しました", "formatFailed": "整形に失敗しました", + "supportedVariables": "使用可能な変数", "variablesHint": "使用可能な変数: {{apiKey}}, {{baseUrl}} | extractor 関数には API 応答の JSON オブジェクトが渡されます", "scriptConfig": "リクエスト設定", "extractorCode": "抽出コード", diff --git a/src/i18n/locales/zh.json b/src/i18n/locales/zh.json index ab68a021..21d4d580 100644 --- a/src/i18n/locales/zh.json +++ b/src/i18n/locales/zh.json @@ -584,6 +584,7 @@ "testFailed": "测试失败", "formatSuccess": "格式化成功", "formatFailed": "格式化失败", + "supportedVariables": "支持的变量", "variablesHint": "支持变量: {{apiKey}}, {{baseUrl}} | extractor 函数接收 API 响应的 JSON 对象", "scriptConfig": "请求配置", "extractorCode": "提取器代码", diff --git a/src/lib/api/usage.ts b/src/lib/api/usage.ts index 4224973f..5b94269a 100644 --- a/src/lib/api/usage.ts +++ b/src/lib/api/usage.ts @@ -28,6 +28,7 @@ export const usageApi = { baseUrl?: string, accessToken?: string, userId?: string, + templateType?: "custom" | "general" | "newapi", ): Promise => { return invoke("testUsageScript", { providerId, @@ -38,6 +39,7 @@ export const usageApi = { baseUrl, accessToken, userId, + templateType, }); }, diff --git a/src/types.ts b/src/types.ts index c234963c..a9234d0d 100644 --- a/src/types.ts +++ b/src/types.ts @@ -52,6 +52,7 @@ export interface UsageScript { language: "javascript"; // 脚本语言 code: string; // 脚本代码(JSON 格式配置) timeout?: number; // 超时时间(秒,默认 10) + templateType?: "custom" | "general" | "newapi"; // 模板类型(用于后端判断验证规则) apiKey?: string; // 用量查询专用的 API Key(通用模板使用) baseUrl?: string; // 用量查询专用的 Base URL(通用和 NewAPI 模板使用) accessToken?: string; // 访问令牌(NewAPI 模板使用)