feat(usage): improve custom template system with variable hints and validation fixes (#628)

* feat(usage): improve custom template with variables display and explicit type detection

Combine two feature improvements:
1. Display supported variables ({{baseUrl}}, {{apiKey}}) with actual values in custom template mode
2. Add explicit templateType field for accurate template mode detection

## Changes

### Frontend
- Display template variables with actual values extracted from provider settings
- Add templateType field to UsageScript for explicit mode detection
- Support template mode persistence across sessions

### Backend
- Add template_type field to UsageScript struct
- Improve validation logic based on explicit template type
- Maintain backward compatibility with type inference

### I18n
- Add "Supported Variables" section translation (zh/en/ja)

### Benefits
- More accurate template mode detection (no more guessing)
- Better user experience with variable hints
- Clearer validation rules per template type

* fix(usage): resolve custom template cache and validation issues

Combine three bug fixes to make custom template mode work correctly:

1. **Update cache after test**: Testing usage script successfully now updates the main list cache immediately
2. **Fix same-origin check**: Custom template mode can now access different domains (SSRF protection still active)
3. **Fix field naming**: Unified to use autoQueryInterval consistently between frontend and backend

## Problems Solved

- Main provider list showing "Query failed" after successful test
- Custom templates blocked by overly strict same-origin validation
- Auto-query intervals not saved correctly due to inconsistent naming

## Changes

### Frontend (UsageScriptModal)
- Import useQueryClient and update cache after successful test
- Invalidate usage cache when saving script configuration
- Use standardized autoQueryInterval field name

### Backend (usage_script.rs)
- Allow custom template mode to bypass same-origin checks
- Maintain SSRF protection for all modes

### Hooks (useProviderActions)
- Invalidate usage query cache when saving script

## Impact

Users can now use custom templates freely while security validations remain intact for general templates.

* fix(usage): correct provider credential field names

- Claude: support both ANTHROPIC_API_KEY and ANTHROPIC_AUTH_TOKEN
- Gemini: use GEMINI_API_KEY instead of GOOGLE_GEMINI_API_KEY
- Codex: use OPENAI_API_KEY and parse base_url from TOML config string

Addresses review feedback from PR #628

* style: format code

---------

Co-authored-by: Jason <farion1231@gmail.com>
This commit is contained in:
杨永安
2026-01-14 15:42:05 +08:00
committed by GitHub
parent f3343992f2
commit 07d022ba9f
13 changed files with 273 additions and 75 deletions

View File

@@ -133,6 +133,7 @@ pub async fn testUsageScript(
#[allow(non_snake_case)] baseUrl: Option<String>,
#[allow(non_snake_case)] accessToken: Option<String>,
#[allow(non_snake_case)] userId: Option<String>,
#[allow(non_snake_case)] templateType: Option<String>,
) -> Result<crate::provider::UsageResult, String> {
let app_type = AppType::from_str(&app).map_err(|e| e.to_string())?;
ProviderService::test_usage_script(
@@ -145,6 +146,7 @@ pub async fn testUsageScript(
baseUrl.as_deref(),
accessToken.as_deref(),
userId.as_deref(),
templateType.as_deref(),
)
.await
.map_err(|e| e.to_string())

View File

@@ -225,6 +225,7 @@ fn build_provider_meta(request: &DeepLinkImportRequest) -> Result<Option<Provide
}),
access_token: request.usage_access_token.clone(),
user_id: request.usage_user_id.clone(),
template_type: None, // Deeplink providers don't specify template type (will use backward compatibility logic)
auto_query_interval: request.usage_auto_interval,
};

View File

@@ -98,6 +98,10 @@ pub struct UsageScript {
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "userId")]
pub user_id: Option<String>,
/// 模板类型(用于后端判断验证规则)
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "templateType")]
pub template_type: Option<String>,
/// 自动查询间隔单位分钟0 表示禁用自动查询)
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "autoQueryInterval")]

View File

@@ -615,6 +615,7 @@ impl ProviderService {
base_url: Option<&str>,
access_token: Option<&str>,
user_id: Option<&str>,
template_type: Option<&str>,
) -> Result<UsageResult, AppError> {
usage::test_usage_script(
state,
@@ -626,6 +627,7 @@ impl ProviderService {
base_url,
access_token,
user_id,
template_type,
)
.await
}

View File

@@ -17,6 +17,7 @@ pub(crate) async fn execute_and_format_usage_result(
timeout: u64,
access_token: Option<&str>,
user_id: Option<&str>,
template_type: Option<&str>,
) -> Result<UsageResult, AppError> {
match usage_script::execute_usage_script(
script_code,
@@ -25,6 +26,7 @@ pub(crate) async fn execute_and_format_usage_result(
timeout,
access_token,
user_id,
template_type,
)
.await
{
@@ -113,7 +115,7 @@ pub async fn query_usage(
app_type: AppType,
provider_id: &str,
) -> Result<UsageResult, AppError> {
let (script_code, timeout, api_key, base_url, access_token, user_id) = {
let (script_code, timeout, api_key, base_url, access_token, user_id, template_type) = {
let providers = state.db.get_all_providers(app_type.as_str())?;
let provider = providers.get(provider_id).ok_or_else(|| {
AppError::localized(
@@ -164,6 +166,7 @@ pub async fn query_usage(
base_url,
usage_script.access_token.clone(),
usage_script.user_id.clone(),
usage_script.template_type.clone(),
)
};
@@ -174,6 +177,7 @@ pub async fn query_usage(
timeout,
access_token.as_deref(),
user_id.as_deref(),
template_type.as_deref(),
)
.await
}
@@ -190,6 +194,7 @@ pub async fn test_usage_script(
base_url: Option<&str>,
access_token: Option<&str>,
user_id: Option<&str>,
template_type: Option<&str>,
) -> Result<UsageResult, AppError> {
// Use provided credential parameters directly for testing
execute_and_format_usage_result(
@@ -199,6 +204,7 @@ pub async fn test_usage_script(
timeout,
access_token,
user_id,
template_type,
)
.await
}

View File

@@ -13,13 +13,21 @@ pub async fn execute_usage_script(
timeout_secs: u64,
access_token: Option<&str>,
user_id: Option<&str>,
template_type: Option<&str>,
) -> Result<Value, AppError> {
// 检测是否为自定义模板模式
// 优先使用前端传递的 template_type
let is_custom_template = template_type.map(|t| t == "custom").unwrap_or(false);
// 1. 替换模板变量,避免泄露敏感信息
let script_with_vars =
build_script_with_vars(script_code, api_key, base_url, access_token, user_id);
// 2. 验证 base_url 的安全性
validate_base_url(base_url)?;
// 2. 验证 base_url 的安全性(仅当提供了 base_url 时)
// 自定义模板模式下,用户可能不使用模板变量,而是直接在脚本中写完整 URL
if !base_url.is_empty() {
validate_base_url(base_url)?;
}
// 3. 在独立作用域中提取 request 配置(确保 Runtime/Context 在 await 前释放)
let request_config = {
@@ -97,7 +105,8 @@ pub async fn execute_usage_script(
})?;
// 5. 验证请求 URL 是否安全(防止 SSRF
validate_request_url(&request.url, base_url)?;
// 如果提供了 base_url则验证同源否则只做基本安全检查
validate_request_url(&request.url, base_url, is_custom_template)?;
// 6. 发送 HTTP 请求
let response_data = send_http_request(&request, timeout_secs).await?;
@@ -472,7 +481,11 @@ fn validate_base_url(base_url: &str) -> Result<(), AppError> {
}
/// 验证请求 URL 是否安全(防止 SSRF
fn validate_request_url(request_url: &str, base_url: &str) -> Result<(), AppError> {
fn validate_request_url(
request_url: &str,
base_url: &str,
is_custom_template: bool,
) -> Result<(), AppError> {
// 解析请求 URL
let parsed_request = Url::parse(request_url).map_err(|e| {
AppError::localized(
@@ -482,19 +495,11 @@ fn validate_request_url(request_url: &str, base_url: &str) -> Result<(), AppErro
)
})?;
// 解析 base URL
let parsed_base = Url::parse(base_url).map_err(|e| {
AppError::localized(
"usage_script.base_url_invalid",
format!("无效的 base_url: {e}"),
format!("Invalid base_url: {e}"),
)
})?;
let is_request_loopback = is_loopback_host(&parsed_request);
// 必须使用 HTTPS允许 localhost 用于开发)
if parsed_request.scheme() != "https" && !is_request_loopback {
// 自定义模板模式下,允许用户自行决定是否使用 HTTP用户需自行承担安全风险
if !is_custom_template && parsed_request.scheme() != "https" && !is_request_loopback {
return Err(AppError::localized(
"usage_script.request_https_required",
"请求 URL 必须使用 HTTPS 协议localhost 除外)",
@@ -502,60 +507,85 @@ fn validate_request_url(request_url: &str, base_url: &str) -> Result<(), AppErro
));
}
// 核心安全检查:必须与 base_url 同源(相同域名和端口)
if parsed_request.host_str() != parsed_base.host_str() {
return Err(AppError::localized(
"usage_script.request_host_mismatch",
format!(
"请求域名 {} 与 base_url 域名 {} 不匹配(必须是同源请求)",
parsed_request.host_str().unwrap_or("unknown"),
parsed_base.host_str().unwrap_or("unknown")
),
format!(
"Request host {} must match base_url host {} (same-origin required)",
parsed_request.host_str().unwrap_or("unknown"),
parsed_base.host_str().unwrap_or("unknown")
),
));
}
// 如果提供了 base_url非空则进行同源检查
// 🔧 自定义模板模式下,用户可以自由访问任意 HTTPS 域名,跳过同源检查
if !base_url.is_empty() && !is_custom_template {
// 解析 base URL
let parsed_base = Url::parse(base_url).map_err(|e| {
AppError::localized(
"usage_script.base_url_invalid",
format!("无效的 base_url: {e}"),
format!("Invalid base_url: {e}"),
)
})?;
// 检查端口是否匹配(考虑默认端口)
// 使用 port_or_known_default() 会自动处理默认端口http->80, https->443
match (
parsed_request.port_or_known_default(),
parsed_base.port_or_known_default(),
) {
(Some(request_port), Some(base_port)) if request_port == base_port => {
// 端口匹配,继续执行
}
(Some(request_port), Some(base_port)) => {
// 核心安全检查:必须与 base_url 同源(相同域名和端口)
if parsed_request.host_str() != parsed_base.host_str() {
return Err(AppError::localized(
"usage_script.request_port_mismatch",
format!("请求端口 {request_port} 必须与 base_url 端口 {base_port} 匹配"),
format!("Request port {request_port} must match base_url port {base_port}"),
"usage_script.request_host_mismatch",
format!(
"请求域名 {} 与 base_url 域名 {} 不匹配(必须是同源请求)",
parsed_request.host_str().unwrap_or("unknown"),
parsed_base.host_str().unwrap_or("unknown")
),
format!(
"Request host {} must match base_url host {} (same-origin required)",
parsed_request.host_str().unwrap_or("unknown"),
parsed_base.host_str().unwrap_or("unknown")
),
));
}
_ => {
// 理论上不会发生,因为 port_or_known_default() 应该总是返回 Some
return Err(AppError::localized(
"usage_script.request_port_unknown",
"无法确定端口号",
"Unable to determine port number",
));
// 检查端口是否匹配(考虑默认端口)
// 使用 port_or_known_default() 会自动处理默认端口http->80, https->443
match (
parsed_request.port_or_known_default(),
parsed_base.port_or_known_default(),
) {
(Some(request_port), Some(base_port)) if request_port == base_port => {
// 端口匹配,继续执行
}
(Some(request_port), Some(base_port)) => {
return Err(AppError::localized(
"usage_script.request_port_mismatch",
format!("请求端口 {request_port} 必须与 base_url 端口 {base_port} 匹配"),
format!("Request port {request_port} must match base_url port {base_port}"),
));
}
_ => {
// 理论上不会发生,因为 port_or_known_default() 应该总是返回 Some
return Err(AppError::localized(
"usage_script.request_port_unknown",
"无法确定端口号",
"Unable to determine port number",
));
}
}
}
// 禁止私有 IP 地址访问(除非 base_url 本身就是私有地址,用于开发环境)
if let Some(host) = parsed_request.host_str() {
let base_host = parsed_base.host_str().unwrap_or("");
// 禁止私有 IP 地址访问(除非 base_url 本身就是私有地址,用于开发环境)
if let Some(host) = parsed_request.host_str() {
let base_host = parsed_base.host_str().unwrap_or("");
// 如果 base_url 不是私有地址则禁止访问私有IP
if !is_private_ip(base_host) && is_private_ip(host) {
return Err(AppError::localized(
"usage_script.private_ip_blocked",
"禁止访问私有 IP 地址",
"Access to private IP addresses is blocked",
));
// 如果 base_url 不是私有地址则禁止访问私有IP
if !is_private_ip(base_host) && is_private_ip(host) {
return Err(AppError::localized(
"usage_script.private_ip_blocked",
"禁止访问私有 IP 地址",
"Access to private IP addresses is blocked",
));
}
}
} else {
// 自定义模板模式:没有 base_url需要额外的安全检查
// 禁止访问私有 IP 地址SSRF 防护)
if let Some(host) = parsed_request.host_str() {
if is_private_ip(host) && !is_request_loopback {
return Err(AppError::localized(
"usage_script.private_ip_blocked",
"禁止访问私有 IP 地址localhost 除外)",
"Access to private IP addresses is blocked (localhost allowed)",
));
}
}
}
@@ -843,7 +873,7 @@ mod tests {
];
for (base_url, request_url, should_match) in test_cases {
let result = validate_request_url(request_url, base_url);
let result = validate_request_url(request_url, base_url, false);
if should_match {
assert!(

View File

@@ -2,8 +2,10 @@ import React, { useState } from "react";
import { Play, Wand2, Eye, EyeOff, Save } from "lucide-react";
import { toast } from "sonner";
import { useTranslation } from "react-i18next";
import { useQueryClient } from "@tanstack/react-query";
import { Provider, UsageScript, UsageData } from "@/types";
import { usageApi, type AppId } from "@/lib/api";
import { extractCodexBaseUrl } from "@/utils/providerConfigUtils";
import JsonEditor from "./JsonEditor";
import * as prettier from "prettier/standalone";
import * as parserBabel from "prettier/parser-babel";
@@ -109,19 +111,67 @@ const UsageScriptModal: React.FC<UsageScriptModalProps> = ({
onSave,
}) => {
const { t } = useTranslation();
const queryClient = useQueryClient();
// 生成带国际化的预设模板
const PRESET_TEMPLATES = generatePresetTemplates(t);
const [script, setScript] = useState<UsageScript>(() => {
return (
provider.meta?.usage_script || {
enabled: false,
language: "javascript",
code: PRESET_TEMPLATES[TEMPLATE_KEYS.GENERAL],
timeout: 10,
// 从 provider 的 settingsConfig 中提取 API Key 和 Base URL
const getProviderCredentials = (): {
apiKey: string | undefined;
baseUrl: string | undefined;
} => {
try {
const config = provider.settingsConfig;
if (!config) return { apiKey: undefined, baseUrl: undefined };
// 处理不同应用的配置格式
if (appId === "claude") {
// Claude: { env: { ANTHROPIC_AUTH_TOKEN | ANTHROPIC_API_KEY, ANTHROPIC_BASE_URL } }
const env = (config as any).env || {};
return {
apiKey: env.ANTHROPIC_AUTH_TOKEN || env.ANTHROPIC_API_KEY,
baseUrl: env.ANTHROPIC_BASE_URL,
};
} else if (appId === "codex") {
// Codex: { auth: { OPENAI_API_KEY }, config: TOML string with base_url }
const auth = (config as any).auth || {};
const configToml = (config as any).config || "";
return {
apiKey: auth.OPENAI_API_KEY,
baseUrl: extractCodexBaseUrl(configToml),
};
} else if (appId === "gemini") {
// Gemini: { env: { GEMINI_API_KEY, GOOGLE_GEMINI_BASE_URL } }
const env = (config as any).env || {};
return {
apiKey: env.GEMINI_API_KEY,
baseUrl: env.GOOGLE_GEMINI_BASE_URL,
};
}
);
return { apiKey: undefined, baseUrl: undefined };
} catch (error) {
console.error("Failed to extract provider credentials:", error);
return { apiKey: undefined, baseUrl: undefined };
}
};
const providerCredentials = getProviderCredentials();
const [script, setScript] = useState<UsageScript>(() => {
const savedScript = provider.meta?.usage_script;
const defaultScript = {
enabled: false,
language: "javascript" as const,
code: PRESET_TEMPLATES[TEMPLATE_KEYS.GENERAL],
timeout: 10,
};
if (!savedScript) {
return defaultScript;
}
return savedScript;
});
const [testing, setTesting] = useState(false);
@@ -176,6 +226,11 @@ const UsageScriptModal: React.FC<UsageScriptModalProps> = ({
const [selectedTemplate, setSelectedTemplate] = useState<string | null>(
() => {
const existingScript = provider.meta?.usage_script;
// 优先使用保存的 templateType
if (existingScript?.templateType) {
return existingScript.templateType;
}
// 向后兼容:根据字段推断模板类型
// 检测 NEW_API 模板(有 accessToken 或 userId
if (existingScript?.accessToken || existingScript?.userId) {
return TEMPLATE_KEYS.NEW_API;
@@ -201,7 +256,16 @@ const UsageScriptModal: React.FC<UsageScriptModalProps> = ({
toast.error(t("usageScript.mustHaveReturn"), { duration: 5000 });
return;
}
onSave(script);
// 保存时记录当前选择的模板类型
const scriptWithTemplate = {
...script,
templateType: selectedTemplate as
| "custom"
| "general"
| "newapi"
| undefined,
};
onSave(scriptWithTemplate);
onClose();
};
@@ -217,6 +281,7 @@ const UsageScriptModal: React.FC<UsageScriptModalProps> = ({
script.baseUrl,
script.accessToken,
script.userId,
selectedTemplate as "custom" | "general" | "newapi" | undefined,
);
if (result.success && result.data && result.data.length > 0) {
const summary = result.data
@@ -229,6 +294,9 @@ const UsageScriptModal: React.FC<UsageScriptModalProps> = ({
duration: 3000,
closeButton: true,
});
// 🔧 测试成功后,更新主界面列表的用量查询缓存
queryClient.setQueryData(["usage", provider.id, appId], result);
} else {
toast.error(
`${t("usageScript.testFailed")}: ${result.error || t("endpointTest.noResult")}`,
@@ -278,9 +346,13 @@ const UsageScriptModal: React.FC<UsageScriptModalProps> = ({
const preset = PRESET_TEMPLATES[presetName];
if (preset) {
if (presetName === TEMPLATE_KEYS.CUSTOM) {
// 🔧 自定义模式:用户应该在脚本中直接写完整 URL 和凭证,而不是依赖变量替换
// 这样可以避免同源检查导致的问题
// 如果用户想使用变量,需要手动在配置中设置 baseUrl/apiKey
setScript({
...script,
code: preset,
// 清除凭证,用户可选择手动输入或保持空
apiKey: undefined,
baseUrl: undefined,
accessToken: undefined,
@@ -401,6 +473,74 @@ const UsageScriptModal: React.FC<UsageScriptModalProps> = ({
})}
</div>
{/* 自定义模式:变量提示和具体值 */}
{selectedTemplate === TEMPLATE_KEYS.CUSTOM && (
<div className="space-y-2 border-t border-white/10 pt-3">
<h4 className="text-sm font-medium text-foreground">
{t("usageScript.supportedVariables")}
</h4>
<div className="space-y-1 text-xs">
{/* baseUrl */}
<div className="flex items-center gap-2 py-1">
<code className="text-emerald-500 dark:text-emerald-400 font-mono shrink-0">
{"{{baseUrl}}"}
</code>
<span className="text-muted-foreground/50">=</span>
{providerCredentials.baseUrl ? (
<code className="text-foreground/70 break-all font-mono">
{providerCredentials.baseUrl}
</code>
) : (
<span className="text-muted-foreground/50 italic">
{t("common.notSet") || "未设置"}
</span>
)}
</div>
{/* apiKey */}
<div className="flex items-center gap-2 py-1">
<code className="text-emerald-500 dark:text-emerald-400 font-mono shrink-0">
{"{{apiKey}}"}
</code>
<span className="text-muted-foreground/50">=</span>
{providerCredentials.apiKey ? (
<>
{showApiKey ? (
<code className="text-foreground/70 break-all font-mono">
{providerCredentials.apiKey}
</code>
) : (
<code className="text-foreground/70 font-mono">
</code>
)}
<button
type="button"
onClick={() => setShowApiKey(!showApiKey)}
className="text-muted-foreground hover:text-foreground transition-colors ml-1"
aria-label={
showApiKey
? t("apiKeyInput.hide")
: t("apiKeyInput.show")
}
>
{showApiKey ? (
<EyeOff size={12} />
) : (
<Eye size={12} />
)}
</button>
</>
) : (
<span className="text-muted-foreground/50 italic">
{t("common.notSet") || "未设置"}
</span>
)}
</div>
</div>
</div>
)}
{/* 凭证配置 */}
{shouldShowCredentialsConfig && (
<div className="space-y-4">
@@ -601,11 +741,13 @@ const UsageScriptModal: React.FC<UsageScriptModalProps> = ({
type="number"
min={0}
max={1440}
value={script.autoIntervalMinutes ?? 0}
value={
script.autoQueryInterval ?? script.autoIntervalMinutes ?? 0
}
onChange={(e) =>
setScript({
...script,
autoIntervalMinutes: validateAndClampInterval(
autoQueryInterval: validateAndClampInterval(
e.target.value,
),
})
@@ -613,7 +755,7 @@ const UsageScriptModal: React.FC<UsageScriptModalProps> = ({
onBlur={(e) =>
setScript({
...script,
autoIntervalMinutes: validateAndClampInterval(
autoQueryInterval: validateAndClampInterval(
e.target.value,
),
})

View File

@@ -115,6 +115,11 @@ export function useProviderActions(activeApp: AppId) {
await queryClient.invalidateQueries({
queryKey: ["providers", activeApp],
});
// 🔧 保存用量脚本后,也应该失效该 provider 的用量查询缓存
// 这样主页列表会使用新配置重新查询,而不是使用测试时的缓存
await queryClient.invalidateQueries({
queryKey: ["usage", provider.id, activeApp],
});
toast.success(
t("provider.usageSaved", {
defaultValue: "用量查询配置已保存",

View File

@@ -584,6 +584,7 @@
"testFailed": "Test failed",
"formatSuccess": "Format successful",
"formatFailed": "Format failed",
"supportedVariables": "Supported Variables",
"variablesHint": "Supported variables: {{apiKey}}, {{baseUrl}} | extractor function receives API response JSON object",
"scriptConfig": "Request configuration",
"extractorCode": "Extractor code",

View File

@@ -584,6 +584,7 @@
"testFailed": "テストに失敗しました",
"formatSuccess": "整形に成功しました",
"formatFailed": "整形に失敗しました",
"supportedVariables": "使用可能な変数",
"variablesHint": "使用可能な変数: {{apiKey}}, {{baseUrl}} | extractor 関数には API 応答の JSON オブジェクトが渡されます",
"scriptConfig": "リクエスト設定",
"extractorCode": "抽出コード",

View File

@@ -584,6 +584,7 @@
"testFailed": "测试失败",
"formatSuccess": "格式化成功",
"formatFailed": "格式化失败",
"supportedVariables": "支持的变量",
"variablesHint": "支持变量: {{apiKey}}, {{baseUrl}} | extractor 函数接收 API 响应的 JSON 对象",
"scriptConfig": "请求配置",
"extractorCode": "提取器代码",

View File

@@ -28,6 +28,7 @@ export const usageApi = {
baseUrl?: string,
accessToken?: string,
userId?: string,
templateType?: "custom" | "general" | "newapi",
): Promise<UsageResult> => {
return invoke("testUsageScript", {
providerId,
@@ -38,6 +39,7 @@ export const usageApi = {
baseUrl,
accessToken,
userId,
templateType,
});
},

View File

@@ -52,6 +52,7 @@ export interface UsageScript {
language: "javascript"; // 脚本语言
code: string; // 脚本代码JSON 格式配置)
timeout?: number; // 超时时间(秒,默认 10
templateType?: "custom" | "general" | "newapi"; // 模板类型(用于后端判断验证规则)
apiKey?: string; // 用量查询专用的 API Key通用模板使用
baseUrl?: string; // 用量查询专用的 Base URL通用和 NewAPI 模板使用)
accessToken?: string; // 访问令牌NewAPI 模板使用)