feat: add official balance query for DeepSeek, StepFun, SiliconFlow, OpenRouter, Novita AI

Add a new "Official" (官方) template type in the usage query panel that
queries account balance via each provider's native API endpoint.
Follows the same zero-script pattern as Token Plan — Rust handles the
HTTP call, frontend auto-detects the provider from base URL.

Supported providers and endpoints:
- DeepSeek: GET /user/balance
- StepFun: GET /v1/accounts
- SiliconFlow: GET /v1/user/info (cn + com)
- OpenRouter: GET /api/v1/credits
- Novita AI: GET /v3/user/balance
This commit is contained in:
Jason
2026-04-05 16:54:51 +08:00
parent 24555275bb
commit f1fb3351c1
12 changed files with 561 additions and 2 deletions
+6
View File
@@ -0,0 +1,6 @@
use crate::provider::UsageResult;
#[tauri::command]
pub async fn get_balance(base_url: String, api_key: String) -> Result<UsageResult, String> {
crate::services::balance::get_balance(&base_url, &api_key).await
}
+2
View File
@@ -1,6 +1,7 @@
#![allow(non_snake_case)]
mod auth;
mod balance;
mod coding_plan;
mod config;
mod copilot;
@@ -31,6 +32,7 @@ mod webdav_sync;
mod workspace;
pub use auth::*;
pub use balance::*;
pub use coding_plan::*;
pub use config::*;
pub use copilot::*;
+25
View File
@@ -14,6 +14,7 @@ use std::str::FromStr;
// 常量定义
const TEMPLATE_TYPE_GITHUB_COPILOT: &str = "github_copilot";
const TEMPLATE_TYPE_TOKEN_PLAN: &str = "token_plan";
const TEMPLATE_TYPE_BALANCE: &str = "balance";
const COPILOT_UNIT_PREMIUM: &str = "requests";
/// 获取所有供应商
@@ -268,6 +269,30 @@ pub async fn queryProviderUsage(
});
}
// ── 官方余额查询路径 ──
if template_type == TEMPLATE_TYPE_BALANCE {
let settings_config = provider
.map(|p| &p.settings_config)
.cloned()
.unwrap_or_default();
let env = settings_config.get("env");
let base_url = env
.and_then(|e| e.get("ANTHROPIC_BASE_URL"))
.and_then(|v| v.as_str())
.unwrap_or("");
let api_key = env
.and_then(|e| {
e.get("ANTHROPIC_AUTH_TOKEN")
.or_else(|| e.get("ANTHROPIC_API_KEY"))
})
.and_then(|v| v.as_str())
.unwrap_or("");
return crate::services::balance::get_balance(base_url, api_key)
.await
.map_err(|e| format!("Failed to query balance: {e}"));
}
// ── 通用 JS 脚本路径 ──
ProviderService::query_usage(state.inner(), app_type, &providerId)
.await
+1
View File
@@ -893,6 +893,7 @@ pub fn run() {
// subscription quota
commands::get_subscription_quota,
commands::get_coding_plan_quota,
commands::get_balance,
// New MCP via config.json (SSOT)
commands::get_mcp_config,
commands::upsert_mcp_server_in_config,
+411
View File
@@ -0,0 +1,411 @@
//! 供应商余额查询服务
//!
//! 支持 DeepSeek、StepFun、SiliconFlow、OpenRouter、Novita AI 的账户余额查询。
//! 返回 UsageResult 格式,与现有用量系统无缝对接。
use crate::provider::{UsageData, UsageResult};
use std::time::Duration;
// ── 供应商检测 ──────────────────────────────────────────────
enum BalanceProvider {
DeepSeek,
StepFun,
SiliconFlow,
SiliconFlowEn,
OpenRouter,
NovitaAI,
}
fn detect_provider(base_url: &str) -> Option<BalanceProvider> {
let url = base_url.to_lowercase();
if url.contains("api.deepseek.com") {
Some(BalanceProvider::DeepSeek)
} else if url.contains("api.stepfun.ai") || url.contains("api.stepfun.com") {
Some(BalanceProvider::StepFun)
} else if url.contains("api.siliconflow.cn") {
Some(BalanceProvider::SiliconFlow)
} else if url.contains("api.siliconflow.com") {
Some(BalanceProvider::SiliconFlowEn)
} else if url.contains("openrouter.ai") {
Some(BalanceProvider::OpenRouter)
} else if url.contains("api.novita.ai") {
Some(BalanceProvider::NovitaAI)
} else {
None
}
}
fn make_error(msg: String) -> UsageResult {
UsageResult {
success: false,
data: None,
error: Some(msg),
}
}
fn make_auth_error(status: reqwest::StatusCode) -> UsageResult {
UsageResult {
success: false,
data: Some(vec![UsageData {
plan_name: None,
remaining: None,
total: None,
used: None,
unit: None,
is_valid: Some(false),
invalid_message: Some(format!("Authentication failed (HTTP {status})")),
extra: None,
}]),
error: Some(format!("Authentication failed (HTTP {status})")),
}
}
// ── DeepSeek ────────────────────────────────────────────────
// GET https://api.deepseek.com/user/balance
// Response: { balance_infos: [{ currency, total_balance, granted_balance, topped_up_balance }], is_available }
async fn query_deepseek(api_key: &str) -> UsageResult {
let client = crate::proxy::http_client::get();
let resp = client
.get("https://api.deepseek.com/user/balance")
.header("Authorization", format!("Bearer {api_key}"))
.header("Accept", "application/json")
.timeout(Duration::from_secs(10))
.send()
.await;
let resp = match resp {
Ok(r) => r,
Err(e) => return make_error(format!("Network error: {e}")),
};
let status = resp.status();
if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
return make_auth_error(status);
}
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
return make_error(format!("API error (HTTP {status}): {body}"));
}
let body: serde_json::Value = match resp.json().await {
Ok(v) => v,
Err(e) => return make_error(format!("Failed to parse response: {e}")),
};
let is_available = body
.get("is_available")
.and_then(|v| v.as_bool())
.unwrap_or(true);
let mut data = Vec::new();
if let Some(infos) = body.get("balance_infos").and_then(|v| v.as_array()) {
for info in infos {
let currency = info
.get("currency")
.and_then(|v| v.as_str())
.unwrap_or("CNY");
let total = parse_f64_field(info, "total_balance");
data.push(UsageData {
plan_name: Some(currency.to_string()),
remaining: total,
total: None,
used: None,
unit: Some(currency.to_string()),
is_valid: Some(is_available),
invalid_message: if !is_available {
Some("Insufficient balance".to_string())
} else {
None
},
extra: None,
});
}
}
UsageResult {
success: true,
data: if data.is_empty() { None } else { Some(data) },
error: None,
}
}
// ── StepFun ─────────────────────────────────────────────────
// GET https://api.stepfun.com/v1/accounts
// Response: { object, type, balance, total_cash_balance, total_voucher_balance }
async fn query_stepfun(api_key: &str) -> UsageResult {
let client = crate::proxy::http_client::get();
let resp = client
.get("https://api.stepfun.com/v1/accounts")
.header("Authorization", format!("Bearer {api_key}"))
.header("Accept", "application/json")
.timeout(Duration::from_secs(10))
.send()
.await;
let resp = match resp {
Ok(r) => r,
Err(e) => return make_error(format!("Network error: {e}")),
};
let status = resp.status();
if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
return make_auth_error(status);
}
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
return make_error(format!("API error (HTTP {status}): {body}"));
}
let body: serde_json::Value = match resp.json().await {
Ok(v) => v,
Err(e) => return make_error(format!("Failed to parse response: {e}")),
};
let balance = parse_f64_field(&body, "balance").unwrap_or(0.0);
UsageResult {
success: true,
data: Some(vec![UsageData {
plan_name: Some("StepFun".to_string()),
remaining: Some(balance),
total: None,
used: None,
unit: Some("CNY".to_string()),
is_valid: Some(true),
invalid_message: None,
extra: None,
}]),
error: None,
}
}
// ── SiliconFlow ─────────────────────────────────────────────
// GET https://api.siliconflow.cn/v1/user/info (or .com for EN)
// Response: { code, data: { balance, chargeBalance, totalBalance, status } }
async fn query_siliconflow(api_key: &str, is_cn: bool) -> UsageResult {
let client = crate::proxy::http_client::get();
let domain = if is_cn {
"api.siliconflow.cn"
} else {
"api.siliconflow.com"
};
let url = format!("https://{domain}/v1/user/info");
let resp = client
.get(&url)
.header("Authorization", format!("Bearer {api_key}"))
.header("Accept", "application/json")
.timeout(Duration::from_secs(10))
.send()
.await;
let resp = match resp {
Ok(r) => r,
Err(e) => return make_error(format!("Network error: {e}")),
};
let status = resp.status();
if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
return make_auth_error(status);
}
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
return make_error(format!("API error (HTTP {status}): {body}"));
}
let body: serde_json::Value = match resp.json().await {
Ok(v) => v,
Err(e) => return make_error(format!("Failed to parse response: {e}")),
};
let data = match body.get("data") {
Some(d) => d,
None => return make_error("Missing 'data' field in response".to_string()),
};
let total_balance = parse_f64_field(data, "totalBalance").unwrap_or(0.0);
UsageResult {
success: true,
data: Some(vec![UsageData {
plan_name: Some("SiliconFlow".to_string()),
remaining: Some(total_balance),
total: None,
used: None,
unit: Some("CNY".to_string()),
is_valid: Some(true),
invalid_message: None,
extra: None,
}]),
error: None,
}
}
// ── OpenRouter ──────────────────────────────────────────────
// GET https://openrouter.ai/api/v1/credits
// Response: { data: { total_credits, total_usage } }
async fn query_openrouter(api_key: &str) -> UsageResult {
let client = crate::proxy::http_client::get();
let resp = client
.get("https://openrouter.ai/api/v1/credits")
.header("Authorization", format!("Bearer {api_key}"))
.header("Accept", "application/json")
.timeout(Duration::from_secs(10))
.send()
.await;
let resp = match resp {
Ok(r) => r,
Err(e) => return make_error(format!("Network error: {e}")),
};
let status = resp.status();
if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
return make_auth_error(status);
}
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
return make_error(format!("API error (HTTP {status}): {body}"));
}
let body: serde_json::Value = match resp.json().await {
Ok(v) => v,
Err(e) => return make_error(format!("Failed to parse response: {e}")),
};
let data = body.get("data").unwrap_or(&body);
let total_credits = parse_f64_field(data, "total_credits").unwrap_or(0.0);
let total_usage = parse_f64_field(data, "total_usage").unwrap_or(0.0);
let remaining = total_credits - total_usage;
UsageResult {
success: true,
data: Some(vec![UsageData {
plan_name: Some("OpenRouter".to_string()),
remaining: Some(remaining),
total: Some(total_credits),
used: Some(total_usage),
unit: Some("USD".to_string()),
is_valid: Some(remaining > 0.0),
invalid_message: if remaining <= 0.0 {
Some("No credits remaining".to_string())
} else {
None
},
extra: None,
}]),
error: None,
}
}
// ── Novita AI ───────────────────────────────────────────────
// GET https://api.novita.ai/v3/user/balance
// Response: { availableBalance, cashBalance, creditLimit, outstandingInvoices }
// 金额单位:0.0001 USD
async fn query_novita(api_key: &str) -> UsageResult {
let client = crate::proxy::http_client::get();
let resp = client
.get("https://api.novita.ai/v3/user/balance")
.header("Authorization", format!("Bearer {api_key}"))
.header("Accept", "application/json")
.timeout(Duration::from_secs(10))
.send()
.await;
let resp = match resp {
Ok(r) => r,
Err(e) => return make_error(format!("Network error: {e}")),
};
let status = resp.status();
if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
return make_auth_error(status);
}
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
return make_error(format!("API error (HTTP {status}): {body}"));
}
let body: serde_json::Value = match resp.json().await {
Ok(v) => v,
Err(e) => return make_error(format!("Failed to parse response: {e}")),
};
// Novita 金额单位为 0.0001 USD,需除以 10000 转为 USD
let available = parse_f64_field(&body, "availableBalance").unwrap_or(0.0) / 10000.0;
UsageResult {
success: true,
data: Some(vec![UsageData {
plan_name: Some("Novita AI".to_string()),
remaining: Some(available),
total: None,
used: None,
unit: Some("USD".to_string()),
is_valid: Some(available > 0.0),
invalid_message: if available <= 0.0 {
Some("No balance remaining".to_string())
} else {
None
},
extra: None,
}]),
error: None,
}
}
// ── 工具函数 ────────────────────────────────────────────────
/// 解析 JSON 字段为 f64,兼容数字和字符串格式
fn parse_f64_field(obj: &serde_json::Value, field: &str) -> Option<f64> {
obj.get(field).and_then(|v| {
v.as_f64()
.or_else(|| v.as_str().and_then(|s| s.parse().ok()))
})
}
// ── 公开入口 ────────────────────────────────────────────────
pub async fn get_balance(base_url: &str, api_key: &str) -> Result<UsageResult, String> {
if api_key.trim().is_empty() {
return Ok(UsageResult {
success: false,
data: None,
error: Some("API key is empty".to_string()),
});
}
let provider = match detect_provider(base_url) {
Some(p) => p,
None => {
return Ok(UsageResult {
success: false,
data: None,
error: Some("Unknown balance provider".to_string()),
})
}
};
let result = match provider {
BalanceProvider::DeepSeek => query_deepseek(api_key).await,
BalanceProvider::StepFun => query_stepfun(api_key).await,
BalanceProvider::SiliconFlow => query_siliconflow(api_key, true).await,
BalanceProvider::SiliconFlowEn => query_siliconflow(api_key, false).await,
BalanceProvider::OpenRouter => query_openrouter(api_key).await,
BalanceProvider::NovitaAI => query_novita(api_key).await,
};
Ok(result)
}
+1
View File
@@ -1,3 +1,4 @@
pub mod balance;
pub mod coding_plan;
pub mod config;
pub mod env_checker;
+103 -2
View File
@@ -98,6 +98,9 @@ const generatePresetTemplates = (
// Coding Plan 模板不需要脚本,使用专用 Rust 查询
[TEMPLATE_TYPES.TOKEN_PLAN]: "",
// 官方余额查询模板不需要脚本,使用专用 Rust 查询
[TEMPLATE_TYPES.BALANCE]: "",
});
// 模板名称国际化键映射
@@ -107,6 +110,7 @@ const TEMPLATE_NAME_KEYS: Record<string, string> = {
[TEMPLATE_TYPES.NEW_API]: "usageScript.templateNewAPI",
[TEMPLATE_TYPES.GITHUB_COPILOT]: "usageScript.templateCopilot",
[TEMPLATE_TYPES.TOKEN_PLAN]: "usageScript.templateTokenPlan",
[TEMPLATE_TYPES.BALANCE]: "usageScript.templateBalance",
};
/** Coding Plan 供应商选项 */
@@ -124,6 +128,25 @@ const TOKEN_PLAN_PROVIDERS = [
},
] as const;
/** 官方余额查询供应商检测 */
const BALANCE_PROVIDERS = [
{ id: "deepseek", label: "DeepSeek", pattern: /api\.deepseek\.com/i },
{ id: "stepfun", label: "StepFun", pattern: /api\.stepfun\.(ai|com)/i },
{
id: "siliconflow",
label: "SiliconFlow",
pattern: /api\.siliconflow\.(cn|com)/i,
},
{ id: "openrouter", label: "OpenRouter", pattern: /openrouter\.ai/i },
{ id: "novita", label: "Novita AI", pattern: /api\.novita\.ai/i },
] as const;
/** 根据 Base URL 自动检测余额查询供应商 */
function detectBalanceProvider(baseUrl: string | undefined): boolean {
if (!baseUrl) return false;
return BALANCE_PROVIDERS.some((bp) => bp.pattern.test(baseUrl));
}
/** 根据 Base URL 自动检测 Coding Plan 供应商 */
function detectTokenPlanProvider(baseUrl: string | undefined): string | null {
if (!baseUrl) return null;
@@ -219,6 +242,16 @@ const UsageScriptModal: React.FC<UsageScriptModalProps> = ({
};
}
// 新配置:如果 URL 匹配官方余额查询供应商,自动初始化
if (detectBalanceProvider(providerCredentials.baseUrl)) {
return {
enabled: false,
language: "javascript" as const,
code: "",
timeout: 10,
};
}
return {
enabled: false,
language: "javascript" as const,
@@ -300,6 +333,10 @@ const UsageScriptModal: React.FC<UsageScriptModalProps> = ({
if (detectTokenPlanProvider(providerCredentials.baseUrl)) {
return TEMPLATE_TYPES.TOKEN_PLAN;
}
// 新配置:如果 URL 匹配官方余额查询供应商,自动选择 Balance 模板
if (detectBalanceProvider(providerCredentials.baseUrl)) {
return TEMPLATE_TYPES.BALANCE;
}
// 默认使用 GENERAL(与默认代码模板一致)
return TEMPLATE_TYPES.GENERAL;
},
@@ -331,10 +368,11 @@ const UsageScriptModal: React.FC<UsageScriptModalProps> = ({
};
const handleSave = () => {
// CopilotCoding Plan 模板不需要脚本验证
// CopilotCoding Plan、Balance 模板不需要脚本验证
if (
selectedTemplate !== TEMPLATE_TYPES.GITHUB_COPILOT &&
selectedTemplate !== TEMPLATE_TYPES.TOKEN_PLAN
selectedTemplate !== TEMPLATE_TYPES.TOKEN_PLAN &&
selectedTemplate !== TEMPLATE_TYPES.BALANCE
) {
if (script.enabled && !script.code.trim()) {
toast.error(t("usageScript.scriptEmpty"));
@@ -354,6 +392,7 @@ const UsageScriptModal: React.FC<UsageScriptModalProps> = ({
| "newapi"
| "github_copilot"
| "token_plan"
| "balance"
| undefined,
};
onSave(scriptWithTemplate);
@@ -363,6 +402,37 @@ const UsageScriptModal: React.FC<UsageScriptModalProps> = ({
const handleTest = async () => {
setTesting(true);
try {
// 官方余额查询模板使用专用 API
if (selectedTemplate === TEMPLATE_TYPES.BALANCE) {
const config = provider.settingsConfig as Record<string, any>;
const baseUrl: string = config?.env?.ANTHROPIC_BASE_URL ?? "";
const apiKey: string =
config?.env?.ANTHROPIC_AUTH_TOKEN ??
config?.env?.ANTHROPIC_API_KEY ??
"";
const { subscriptionApi } = await import("@/lib/api/subscription");
const result = await subscriptionApi.getBalance(baseUrl, apiKey);
if (result.success && result.data && result.data.length > 0) {
const summary = result.data
.map((d) => {
const name = d.planName ? `[${d.planName}] ` : "";
return `${name}${t("usage.remaining")} ${d.remaining?.toFixed(2)} ${d.unit || ""}`;
})
.join(", ");
toast.success(`${t("usageScript.testSuccess")}${summary}`, {
duration: 3000,
closeButton: true,
});
queryClient.setQueryData(["usage", provider.id, appId], result);
} else {
toast.error(
`${t("usageScript.testFailed")}: ${result.error || t("endpointTest.noResult")}`,
{ duration: 5000 },
);
}
return;
}
// Coding Plan 模板使用专用 API
if (selectedTemplate === TEMPLATE_TYPES.TOKEN_PLAN) {
const config = provider.settingsConfig as Record<string, any>;
@@ -558,6 +628,16 @@ const UsageScriptModal: React.FC<UsageScriptModalProps> = ({
codingPlanProvider:
script.codingPlanProvider || autoDetected || "kimi",
});
} else if (presetName === TEMPLATE_TYPES.BALANCE) {
// 官方余额查询模板不需要脚本,使用 Rust 原生查询
setScript({
...script,
code: "",
apiKey: undefined,
baseUrl: undefined,
accessToken: undefined,
userId: undefined,
});
}
setSelectedTemplate(presetName);
}
@@ -746,6 +826,27 @@ const UsageScriptModal: React.FC<UsageScriptModalProps> = ({
</div>
)}
{/* 官方余额查询模式:自动提示 */}
{selectedTemplate === TEMPLATE_TYPES.BALANCE && (
<div className="space-y-3 border-t border-white/10 pt-3">
<p className="text-sm text-muted-foreground">
{t("usageScript.balanceHint")}
</p>
<div className="flex gap-2 flex-wrap">
{BALANCE_PROVIDERS.filter((bp) =>
bp.pattern.test(providerCredentials.baseUrl || ""),
).map((bp) => (
<span
key={bp.id}
className="inline-flex items-center px-2.5 py-1 rounded-md bg-primary/10 text-primary text-xs font-medium"
>
{bp.label}
</span>
))}
</div>
</div>
)}
{/* Coding Plan 模式:供应商选择 */}
{selectedTemplate === TEMPLATE_TYPES.TOKEN_PLAN && (
<div className="space-y-3 border-t border-white/10 pt-3">
+1
View File
@@ -10,6 +10,7 @@ export const TEMPLATE_TYPES = {
NEW_API: "newapi",
GITHUB_COPILOT: "github_copilot",
TOKEN_PLAN: "token_plan",
BALANCE: "balance",
} as const;
export type TemplateType = (typeof TEMPLATE_TYPES)[keyof typeof TEMPLATE_TYPES];
+2
View File
@@ -1091,8 +1091,10 @@
"templateNewAPI": "NewAPI",
"templateCopilot": "GitHub Copilot",
"templateTokenPlan": "Token Plan",
"templateBalance": "Official",
"copilotAutoAuth": "Auto OAuth authentication, no manual credentials needed",
"tokenPlanHint": "Automatically uses the provider's API Key and Base URL to query Token Plan quota",
"balanceHint": "Automatically uses the provider's API Key to query account balance",
"resetDate": "Reset date",
"premiumRequests": "Premium Requests",
"credentialsConfig": "Credentials",
+2
View File
@@ -1091,8 +1091,10 @@
"templateNewAPI": "NewAPI",
"templateCopilot": "GitHub Copilot",
"templateTokenPlan": "Token Plan",
"templateBalance": "公式",
"copilotAutoAuth": "OAuth 認証を自動使用、手動設定不要",
"tokenPlanHint": "プロバイダーのAPI KeyとBase URLを使用してToken Planクォータを自動クエリ",
"balanceHint": "プロバイダーのAPI Keyを使用してアカウント残高を自動クエリ",
"resetDate": "リセット日",
"premiumRequests": "Premium リクエスト",
"credentialsConfig": "認証情報",
+2
View File
@@ -1091,8 +1091,10 @@
"templateNewAPI": "NewAPI",
"templateCopilot": "GitHub Copilot",
"templateTokenPlan": "Token Plan",
"templateBalance": "官方",
"copilotAutoAuth": "自动使用 OAuth 认证,无需手动配置凭证",
"tokenPlanHint": "自动使用供应商的 API Key 和 Base URL 查询 Token Plan 额度",
"balanceHint": "自动使用供应商的 API Key 查询账户余额",
"resetDate": "重置日期",
"premiumRequests": "Premium 请求",
"credentialsConfig": "凭证配置",
+5
View File
@@ -9,4 +9,9 @@ export const subscriptionApi = {
apiKey: string,
): Promise<SubscriptionQuota> =>
invoke("get_coding_plan_quota", { baseUrl, apiKey }),
getBalance: (
baseUrl: string,
apiKey: string,
): Promise<import("@/types").UsageResult> =>
invoke("get_balance", { baseUrl, apiKey }),
};