mirror of
https://github.com/farion1231/cc-switch.git
synced 2026-05-16 01:28:55 +08:00
feat(copilot): add GitHub Copilot reverse proxy support (#930)
* refactor(toolsearch): replace binary patch with ENABLE_TOOL_SEARCH env var toggle
- Remove toolsearch_patch.rs binary patching mechanism (~590 lines)
- Delete `toolsearch_patch.rs` and `commands/toolsearch.rs`
- Remove auto-patch startup logic and command registration from lib.rs
- Remove `tool_search_bypass` field from settings.rs
- Remove frontend settings ToggleRow, useSettings hook sync logic, and API methods
- Clean up zh/en/ja i18n keys (notifications + settings)
- Add ENABLE_TOOL_SEARCH toggle to Claude provider form
- Add checkbox in CommonConfigEditor.tsx (alongside teammates toggle)
- When enabled, writes `"env": { "ENABLE_TOOL_SEARCH": "true" }`
- When disabled, removes the key; takes effect on provider switch
- Add zh/en/ja i18n key: `claudeConfig.enableToolSearch`
Claude Code 2.1.76+ natively supports this env var, eliminating the need for binary patching.
* feat(claude): add effortLevel high toggle to provider form
- Add "high-effort thinking" checkbox to Claude provider config form
- When checked, writes `"effortLevel": "high"`; when unchecked, removes the field
- Add zh/en/ja i18n translations
* refactor(claude): remove deprecated alwaysThinking toggle
- Claude Code now enables extended thinking by default; alwaysThinkingEnabled is a no-op
- Thinking control is now handled via effortLevel (added in prior commit)
- Remove state, switch case, and checkbox UI from CommonConfigEditor
- Clean up alwaysThinking i18n keys across zh/en/ja locales
* feat(opencode): add setCacheKey: true to all provider presets
- Add setCacheKey: true to options in all 33 regular presets
- Add setCacheKey: true to OPENCODE_DEFAULT_CONFIG for custom providers
- Exclude 2 OMO presets (Oh My OpenCode / Slim) which have their own config mechanism
Closes #1523
* fix(codex): resolve 1M context window toggle causing MCP editor flicker
- Add localValueRef to short-circuit duplicate CodeMirror updateListener callbacks,
breaking the React state → CodeMirror → stale onChange → React state feedback loop
- Use localValueRef.current in handleContextWindowToggle and handleCompactLimitChange
to avoid stale closure reads
- Change compact limit input from type="number" to type="text" with inputMode="numeric"
to remove unnecessary spinner buttons
* feat(codex): add 1M context window toggle utilities and i18n keys
- Add extractCodexTopLevelInt, setCodexTopLevelInt, removeCodexTopLevelField
TOML helpers in providerConfigUtils.ts
- Add i18n keys for contextWindow1M, autoCompactLimit in zh/en/ja locales
* feat(claude): collapse model mapping fields by default
- Wrap 5 model mapping inputs in a Collapsible, collapsed by default
- Auto-expand when any model value is present (including preset-filled)
- Show hint text when collapsed explaining most users need no config
- Add zh/en/ja i18n keys for toggle label and collapsed hint
- Use variant={null} to avoid ghost button hover style clash in dark mode
* feat(claude): merge advanced fields into single collapsible section
- Merge API format, auth field, and model mapping into a unified "Advanced Options" collapsible
- Extend smart-expand logic to detect non-default values across all advanced fields
- Preserve model mapping sub-header and hint with a separator line
- Update zh/en/ja i18n keys (advancedOptionsToggle, advancedOptionsHint, modelMappingLabel, modelMappingHint)
* feat(copilot): add GitHub Copilot reverse proxy support
Add GitHub Copilot as a Claude provider variant with OAuth device code
authentication and Anthropic ↔ OpenAI format transformation.
Backend:
- Add CopilotAuthManager for GitHub OAuth device code flow
- Implement Copilot token auto-refresh (60s before expiry)
- Persist GitHub token to ~/.cc-switch/copilot_auth.json
- Add ProviderType::GitHubCopilot and AuthStrategy::GitHubCopilot
- Modify forwarder to use /chat/completions for Copilot
- Add Copilot-specific headers (Editor-Version, Editor-Plugin-Version)
Frontend:
- Add CopilotAuthSection component for OAuth UI
- Add useCopilotAuth hook for OAuth state management
- Auto-copy user code to clipboard and open browser
- Use 8-second polling interval to avoid GitHub rate limits
- Skip API Key validation for Copilot providers
- Add GitHub Copilot preset with claude-sonnet-4 model
Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
* fix(copilot): remove is_expired() calls from tests
Remove references to deleted is_expired() method in test code.
Only is_expiring_soon() is needed for token refresh logic.
Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
* feat(copilot): add real-time model listing from Copilot API
- Add fetch_models() to CopilotAuthManager calling GET /models endpoint
- Add copilot_get_models Tauri command
- Add copilotGetModels() frontend API wrapper
- Modify ClaudeFormFields to show model dropdown for Copilot providers
- Fetches available models on component mount when isCopilotPreset
- Groups models by vendor (Anthropic, OpenAI, Google, etc.)
- Input + dropdown button combo allows both manual entry and selection
- Non-Copilot providers keep original plain Input behavior
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* feat(copilot): add usage query integration
- Add Copilot usage API integration (fetch_usage method)
- Add copilot_get_usage Tauri command
- Add GitHub Copilot template in usage query modal
- Unify naming: copilot → github_copilot
- Add constants management (TEMPLATE_TYPES, PROVIDER_TYPES)
- Improve error handling with detailed error messages
- Add database migration (v5 → v6) for template type update
- Add i18n translations (zh, en, ja)
- Improve type safety with TemplateType
- Apply code formatting (cargo fmt, prettier)
* 修复github 登录和注销问题 ,模型选择问题
* feat(copilot): add multi-account support for GitHub Copilot
- Add multi-account storage structure with v1 to v2 migration
- Add per-account token caching and auto-refresh
- Add new Tauri commands for account management
- Integrate account selection in Proxy forwarder
- Add account selection UI in CopilotAuthSection
- Save githubAccountId to ProviderMeta
- Add i18n translations for multi-account features (zh/en/ja)
* 修复用量查询Reset字段出现多余字符
* refactor(auth-binding): introduce generic provider auth binding primitives
- add shared authBinding types in Rust and TypeScript while keeping githubAccountId as a compatibility field\n- resolve Copilot token, models, and usage through provider-bound account lookup instead of only the implicit default account\n- fix the Unix build regression in settings.rs by restoring std::io::Write for write_all()\n- remove the accidental .github ignore entry and drop leftover Copilot form debug logs\n- keep the first migration step non-breaking by writing both authBinding and the legacy githubAccountId field from the form
* refactor(auth-service): add managed auth command surface and explicit default account state
- introduce generic managed auth commands and frontend auth API wrappers for provider-scoped login, status, account listing, removal, logout, and default-account selection\n- store an explicit Copilot default_account_id instead of relying on HashMap iteration order, and use it consistently for fallback token/model/usage resolution\n- sort managed accounts deterministically and surface default-account state to the UI\n- refactor the Copilot form hook to wrap a generic useManagedAuth implementation while preserving the existing component contract\n- add default-account controls to the Copilot auth section and extend Copilot auth status serialization/tests for the new state
* feat(auth-center): add a dedicated settings entrypoint for managed OAuth accounts
- add an Auth Center tab to Settings so managed OAuth accounts are no longer hidden inside individual provider forms\n- introduce a first AuthCenterPanel that hosts GitHub Copilot account management as the initial managed auth provider\n- keep the provider form experience intact while establishing a global account-management surface for future providers such as OpenAI\n- validate that the new settings tab works cleanly with the generic managed auth hook and existing Copilot account controls
* feat(add-provider): expose managed OAuth sources alongside universal providers
- add an OAuth tab to the Add Provider flow so managed auth sources sit beside app-specific and universal providers\n- reuse the new Auth Center panel inside the dialog, keeping account management discoverable during provider creation\n- make the dialog footer adapt to the OAuth tab so account setup does not pretend to create a provider directly\n- align the add-provider UX with the new architecture where OAuth accounts are global assets and providers bind to them later
* fix(auth-reliability): harden managed auth persistence and refresh behavior
- replace direct Copilot auth store writes with private temp-file writes and atomic rename semantics, and document the local token storage limitation\n- add per-account refresh locks plus a double-check path so concurrent requests do not stampede GitHub token refresh\n- surface legacy migration failures through auth status, expose them in the UI, and add translated copy for the new account-state labels\n- stop writing the legacy githubAccountId field from the provider form while keeping compatibility reads in place\n- add logout error recovery and Copilot model-load toasts so auth failures are no longer silently swallowed
* refactor(copilot-detection): prefer provider type before URL fallbacks
- update forwarder endpoint rewriting to treat providerType as the primary GitHub Copilot signal\n- keep githubcopilot.com string matching only as a compatibility fallback for older provider records without providerType\n- reduce one more path where Copilot behavior depended purely on URL heuristics
* fix(copilot-auth): add cancel button to error state in CopilotAuthSection
- 错误状态下仅有"重试"按钮,用户无法退出(如不可恢复的 403 未订阅错误)
- 新增"取消"按钮,复用已有的 cancelAuth 逻辑重置为 idle 状态
* 修复打包后github账号头像显示异常
* 修复github copilot 来源的模型测试报错
* feat(copilot-preset): add default model presets for GitHub Copilot
- 补充 Copilot 预设的默认模型配置,用户选完预设即可直接使用
- ANTHROPIC_MODEL: claude-opus-4.6
- ANTHROPIC_DEFAULT_HAIKU_MODEL: claude-haiku-4.5
- ANTHROPIC_DEFAULT_SONNET_MODEL: claude-sonnet-4.6
- ANTHROPIC_DEFAULT_OPUS_MODEL: claude-opus-4.6
---------
Co-authored-by: Jason <farion1231@gmail.com>
Co-authored-by: 周梦泽 <mengze.zhou@dafeng-tech.com>
Co-authored-by: Claude Haiku 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,182 @@
|
||||
use tauri::State;
|
||||
|
||||
use crate::commands::copilot::CopilotAuthState;
|
||||
use crate::proxy::providers::copilot_auth::{GitHubAccount, GitHubDeviceCodeResponse};
|
||||
|
||||
const AUTH_PROVIDER_GITHUB_COPILOT: &str = "github_copilot";
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
pub struct ManagedAuthAccount {
|
||||
pub id: String,
|
||||
pub provider: String,
|
||||
pub login: String,
|
||||
pub avatar_url: Option<String>,
|
||||
pub authenticated_at: i64,
|
||||
pub is_default: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
pub struct ManagedAuthStatus {
|
||||
pub provider: String,
|
||||
pub authenticated: bool,
|
||||
pub default_account_id: Option<String>,
|
||||
pub migration_error: Option<String>,
|
||||
pub accounts: Vec<ManagedAuthAccount>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
pub struct ManagedAuthDeviceCodeResponse {
|
||||
pub provider: String,
|
||||
pub device_code: String,
|
||||
pub user_code: String,
|
||||
pub verification_uri: String,
|
||||
pub expires_in: u64,
|
||||
pub interval: u64,
|
||||
}
|
||||
|
||||
fn ensure_auth_provider(auth_provider: &str) -> Result<&str, String> {
|
||||
match auth_provider {
|
||||
AUTH_PROVIDER_GITHUB_COPILOT => Ok(AUTH_PROVIDER_GITHUB_COPILOT),
|
||||
_ => Err(format!("Unsupported auth provider: {auth_provider}")),
|
||||
}
|
||||
}
|
||||
|
||||
fn map_account(
|
||||
provider: &str,
|
||||
account: GitHubAccount,
|
||||
default_account_id: Option<&str>,
|
||||
) -> ManagedAuthAccount {
|
||||
ManagedAuthAccount {
|
||||
is_default: default_account_id == Some(account.id.as_str()),
|
||||
id: account.id,
|
||||
provider: provider.to_string(),
|
||||
login: account.login,
|
||||
avatar_url: account.avatar_url,
|
||||
authenticated_at: account.authenticated_at,
|
||||
}
|
||||
}
|
||||
|
||||
fn map_device_code_response(
|
||||
provider: &str,
|
||||
response: GitHubDeviceCodeResponse,
|
||||
) -> ManagedAuthDeviceCodeResponse {
|
||||
ManagedAuthDeviceCodeResponse {
|
||||
provider: provider.to_string(),
|
||||
device_code: response.device_code,
|
||||
user_code: response.user_code,
|
||||
verification_uri: response.verification_uri,
|
||||
expires_in: response.expires_in,
|
||||
interval: response.interval,
|
||||
}
|
||||
}
|
||||
|
||||
#[tauri::command(rename_all = "camelCase")]
|
||||
pub async fn auth_start_login(
|
||||
auth_provider: String,
|
||||
state: State<'_, CopilotAuthState>,
|
||||
) -> Result<ManagedAuthDeviceCodeResponse, String> {
|
||||
let auth_provider = ensure_auth_provider(&auth_provider)?;
|
||||
let auth_manager = state.0.read().await;
|
||||
let response = auth_manager
|
||||
.start_device_flow()
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
Ok(map_device_code_response(auth_provider, response))
|
||||
}
|
||||
|
||||
#[tauri::command(rename_all = "camelCase")]
|
||||
pub async fn auth_poll_for_account(
|
||||
auth_provider: String,
|
||||
device_code: String,
|
||||
state: State<'_, CopilotAuthState>,
|
||||
) -> Result<Option<ManagedAuthAccount>, String> {
|
||||
let auth_provider = ensure_auth_provider(&auth_provider)?;
|
||||
let auth_manager = state.0.write().await;
|
||||
match auth_manager.poll_for_token(&device_code).await {
|
||||
Ok(account) => {
|
||||
let default_account_id = auth_manager.get_status().await.default_account_id;
|
||||
Ok(account
|
||||
.map(|account| map_account(auth_provider, account, default_account_id.as_deref())))
|
||||
}
|
||||
Err(crate::proxy::providers::copilot_auth::CopilotAuthError::AuthorizationPending) => {
|
||||
Ok(None)
|
||||
}
|
||||
Err(e) => Err(e.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
#[tauri::command(rename_all = "camelCase")]
|
||||
pub async fn auth_list_accounts(
|
||||
auth_provider: String,
|
||||
state: State<'_, CopilotAuthState>,
|
||||
) -> Result<Vec<ManagedAuthAccount>, String> {
|
||||
let auth_provider = ensure_auth_provider(&auth_provider)?;
|
||||
let auth_manager = state.0.read().await;
|
||||
let status = auth_manager.get_status().await;
|
||||
let default_account_id = status.default_account_id.clone();
|
||||
Ok(status
|
||||
.accounts
|
||||
.into_iter()
|
||||
.map(|account| map_account(auth_provider, account, default_account_id.as_deref()))
|
||||
.collect())
|
||||
}
|
||||
|
||||
#[tauri::command(rename_all = "camelCase")]
|
||||
pub async fn auth_get_status(
|
||||
auth_provider: String,
|
||||
state: State<'_, CopilotAuthState>,
|
||||
) -> Result<ManagedAuthStatus, String> {
|
||||
let auth_provider = ensure_auth_provider(&auth_provider)?;
|
||||
let auth_manager = state.0.read().await;
|
||||
let status = auth_manager.get_status().await;
|
||||
let default_account_id = status.default_account_id.clone();
|
||||
Ok(ManagedAuthStatus {
|
||||
provider: auth_provider.to_string(),
|
||||
authenticated: status.authenticated,
|
||||
default_account_id: default_account_id.clone(),
|
||||
migration_error: status.migration_error,
|
||||
accounts: status
|
||||
.accounts
|
||||
.into_iter()
|
||||
.map(|account| map_account(auth_provider, account, default_account_id.as_deref()))
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
#[tauri::command(rename_all = "camelCase")]
|
||||
pub async fn auth_remove_account(
|
||||
auth_provider: String,
|
||||
account_id: String,
|
||||
state: State<'_, CopilotAuthState>,
|
||||
) -> Result<(), String> {
|
||||
ensure_auth_provider(&auth_provider)?;
|
||||
let auth_manager = state.0.write().await;
|
||||
auth_manager
|
||||
.remove_account(&account_id)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
#[tauri::command(rename_all = "camelCase")]
|
||||
pub async fn auth_set_default_account(
|
||||
auth_provider: String,
|
||||
account_id: String,
|
||||
state: State<'_, CopilotAuthState>,
|
||||
) -> Result<(), String> {
|
||||
ensure_auth_provider(&auth_provider)?;
|
||||
let auth_manager = state.0.write().await;
|
||||
auth_manager
|
||||
.set_default_account(&account_id)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
#[tauri::command(rename_all = "camelCase")]
|
||||
pub async fn auth_logout(
|
||||
auth_provider: String,
|
||||
state: State<'_, CopilotAuthState>,
|
||||
) -> Result<(), String> {
|
||||
ensure_auth_provider(&auth_provider)?;
|
||||
let auth_manager = state.0.write().await;
|
||||
auth_manager.clear_auth().await.map_err(|e| e.to_string())
|
||||
}
|
||||
@@ -0,0 +1,212 @@
|
||||
//! GitHub Copilot Tauri Commands
|
||||
//!
|
||||
//! 提供 Copilot OAuth 认证相关的 Tauri 命令,支持多账号管理。
|
||||
|
||||
use crate::proxy::providers::copilot_auth::{
|
||||
CopilotAuthManager, CopilotAuthStatus, CopilotModel, CopilotUsageResponse, GitHubAccount,
|
||||
GitHubDeviceCodeResponse,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use tauri::State;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Copilot 认证状态
|
||||
pub struct CopilotAuthState(pub Arc<RwLock<CopilotAuthManager>>);
|
||||
|
||||
// ==================== 设备码流程 ====================
|
||||
|
||||
/// 启动设备码流程
|
||||
///
|
||||
/// 返回设备码和用户码,用于 OAuth 认证
|
||||
#[tauri::command]
|
||||
pub async fn copilot_start_device_flow(
|
||||
state: State<'_, CopilotAuthState>,
|
||||
) -> Result<GitHubDeviceCodeResponse, String> {
|
||||
let auth_manager = state.0.read().await;
|
||||
auth_manager
|
||||
.start_device_flow()
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
/// 轮询 OAuth Token(向后兼容)
|
||||
///
|
||||
/// 使用设备码轮询 GitHub,等待用户完成授权
|
||||
/// 返回 true 表示授权成功,false 表示等待中
|
||||
#[tauri::command(rename_all = "camelCase")]
|
||||
pub async fn copilot_poll_for_auth(
|
||||
device_code: String,
|
||||
state: State<'_, CopilotAuthState>,
|
||||
) -> Result<bool, String> {
|
||||
let auth_manager = state.0.write().await;
|
||||
match auth_manager.poll_for_token(&device_code).await {
|
||||
Ok(Some(_account)) => {
|
||||
log::info!("[CopilotAuth] 用户已授权");
|
||||
Ok(true)
|
||||
}
|
||||
Ok(None) => Ok(false),
|
||||
Err(crate::proxy::providers::copilot_auth::CopilotAuthError::AuthorizationPending) => {
|
||||
Ok(false)
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("[CopilotAuth] 轮询失败: {}", e);
|
||||
Err(e.to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 轮询 OAuth Token(多账号版本)
|
||||
///
|
||||
/// 返回新添加的账号信息,如果授权成功
|
||||
#[tauri::command(rename_all = "camelCase")]
|
||||
pub async fn copilot_poll_for_account(
|
||||
device_code: String,
|
||||
state: State<'_, CopilotAuthState>,
|
||||
) -> Result<Option<GitHubAccount>, String> {
|
||||
let auth_manager = state.0.write().await;
|
||||
match auth_manager.poll_for_token(&device_code).await {
|
||||
Ok(account) => Ok(account),
|
||||
Err(crate::proxy::providers::copilot_auth::CopilotAuthError::AuthorizationPending) => {
|
||||
Ok(None)
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("[CopilotAuth] 轮询失败: {}", e);
|
||||
Err(e.to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 多账号管理 ====================
|
||||
|
||||
/// 列出所有已认证的账号
|
||||
#[tauri::command]
|
||||
pub async fn copilot_list_accounts(
|
||||
state: State<'_, CopilotAuthState>,
|
||||
) -> Result<Vec<GitHubAccount>, String> {
|
||||
let auth_manager = state.0.read().await;
|
||||
Ok(auth_manager.list_accounts().await)
|
||||
}
|
||||
|
||||
/// 移除指定账号
|
||||
#[tauri::command(rename_all = "camelCase")]
|
||||
pub async fn copilot_remove_account(
|
||||
account_id: String,
|
||||
state: State<'_, CopilotAuthState>,
|
||||
) -> Result<(), String> {
|
||||
let auth_manager = state.0.write().await;
|
||||
auth_manager
|
||||
.remove_account(&account_id)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
/// 设置默认账号
|
||||
#[tauri::command(rename_all = "camelCase")]
|
||||
pub async fn copilot_set_default_account(
|
||||
account_id: String,
|
||||
state: State<'_, CopilotAuthState>,
|
||||
) -> Result<(), String> {
|
||||
let auth_manager = state.0.write().await;
|
||||
auth_manager
|
||||
.set_default_account(&account_id)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
// ==================== 状态查询 ====================
|
||||
|
||||
/// 获取认证状态(包含所有账号)
|
||||
#[tauri::command]
|
||||
pub async fn copilot_get_auth_status(
|
||||
state: State<'_, CopilotAuthState>,
|
||||
) -> Result<CopilotAuthStatus, String> {
|
||||
let auth_manager = state.0.read().await;
|
||||
Ok(auth_manager.get_status().await)
|
||||
}
|
||||
|
||||
/// 检查是否已认证(有任意账号)
|
||||
#[tauri::command]
|
||||
pub async fn copilot_is_authenticated(state: State<'_, CopilotAuthState>) -> Result<bool, String> {
|
||||
let auth_manager = state.0.read().await;
|
||||
Ok(auth_manager.is_authenticated().await)
|
||||
}
|
||||
|
||||
/// 注销所有 Copilot 认证
|
||||
#[tauri::command]
|
||||
pub async fn copilot_logout(state: State<'_, CopilotAuthState>) -> Result<(), String> {
|
||||
let auth_manager = state.0.write().await;
|
||||
auth_manager.clear_auth().await.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
// ==================== Token 获取 ====================
|
||||
|
||||
/// 获取有效的 Copilot Token(向后兼容:使用第一个账号)
|
||||
///
|
||||
/// 内部使用,用于代理请求
|
||||
#[tauri::command]
|
||||
pub async fn copilot_get_token(state: State<'_, CopilotAuthState>) -> Result<String, String> {
|
||||
let auth_manager = state.0.read().await;
|
||||
auth_manager
|
||||
.get_valid_token()
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
/// 获取指定账号的有效 Copilot Token
|
||||
#[tauri::command(rename_all = "camelCase")]
|
||||
pub async fn copilot_get_token_for_account(
|
||||
account_id: String,
|
||||
state: State<'_, CopilotAuthState>,
|
||||
) -> Result<String, String> {
|
||||
let auth_manager = state.0.read().await;
|
||||
auth_manager
|
||||
.get_valid_token_for_account(&account_id)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
// ==================== 模型和使用量 ====================
|
||||
|
||||
/// 获取 Copilot 可用模型列表(向后兼容:使用第一个账号)
|
||||
#[tauri::command]
|
||||
pub async fn copilot_get_models(
|
||||
state: State<'_, CopilotAuthState>,
|
||||
) -> Result<Vec<CopilotModel>, String> {
|
||||
let auth_manager = state.0.read().await;
|
||||
auth_manager.fetch_models().await.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
/// 获取指定账号的 Copilot 可用模型列表
|
||||
#[tauri::command(rename_all = "camelCase")]
|
||||
pub async fn copilot_get_models_for_account(
|
||||
account_id: String,
|
||||
state: State<'_, CopilotAuthState>,
|
||||
) -> Result<Vec<CopilotModel>, String> {
|
||||
let auth_manager = state.0.read().await;
|
||||
auth_manager
|
||||
.fetch_models_for_account(&account_id)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
/// 获取 Copilot 使用量信息(向后兼容:使用第一个账号)
|
||||
#[tauri::command]
|
||||
pub async fn copilot_get_usage(
|
||||
state: State<'_, CopilotAuthState>,
|
||||
) -> Result<CopilotUsageResponse, String> {
|
||||
let auth_manager = state.0.read().await;
|
||||
auth_manager.fetch_usage().await.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
/// 获取指定账号的 Copilot 使用量信息
|
||||
#[tauri::command(rename_all = "camelCase")]
|
||||
pub async fn copilot_get_usage_for_account(
|
||||
account_id: String,
|
||||
state: State<'_, CopilotAuthState>,
|
||||
) -> Result<CopilotUsageResponse, String> {
|
||||
let auth_manager = state.0.read().await;
|
||||
auth_manager
|
||||
.fetch_usage_for_account(&account_id)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
mod auth;
|
||||
mod config;
|
||||
mod copilot;
|
||||
mod deeplink;
|
||||
mod env;
|
||||
mod failover;
|
||||
@@ -19,12 +21,14 @@ mod settings;
|
||||
pub mod skill;
|
||||
mod stream_check;
|
||||
mod sync_support;
|
||||
mod toolsearch;
|
||||
|
||||
mod usage;
|
||||
mod webdav_sync;
|
||||
mod workspace;
|
||||
|
||||
pub use auth::*;
|
||||
pub use config::*;
|
||||
pub use copilot::*;
|
||||
pub use deeplink::*;
|
||||
pub use env::*;
|
||||
pub use failover::*;
|
||||
@@ -42,7 +46,7 @@ pub use session_manager::*;
|
||||
pub use settings::*;
|
||||
pub use skill::*;
|
||||
pub use stream_check::*;
|
||||
pub use toolsearch::*;
|
||||
|
||||
pub use usage::*;
|
||||
pub use webdav_sync::*;
|
||||
pub use workspace::*;
|
||||
|
||||
@@ -2,6 +2,7 @@ use indexmap::IndexMap;
|
||||
use tauri::State;
|
||||
|
||||
use crate::app_config::AppType;
|
||||
use crate::commands::copilot::CopilotAuthState;
|
||||
use crate::error::AppError;
|
||||
use crate::provider::Provider;
|
||||
use crate::services::{
|
||||
@@ -10,6 +11,11 @@ use crate::services::{
|
||||
use crate::store::AppState;
|
||||
use std::str::FromStr;
|
||||
|
||||
// 常量定义
|
||||
const TEMPLATE_TYPE_GITHUB_COPILOT: &str = "github_copilot";
|
||||
const COPILOT_UNIT_PREMIUM: &str = "requests";
|
||||
|
||||
/// 获取所有供应商
|
||||
#[tauri::command]
|
||||
pub fn get_providers(
|
||||
state: State<'_, AppState>,
|
||||
@@ -142,10 +148,65 @@ pub fn import_default_config(state: State<'_, AppState>, app: String) -> Result<
|
||||
#[tauri::command]
|
||||
pub async fn queryProviderUsage(
|
||||
state: State<'_, AppState>,
|
||||
copilot_state: State<'_, CopilotAuthState>,
|
||||
#[allow(non_snake_case)] providerId: String, // 使用 camelCase 匹配前端
|
||||
app: String,
|
||||
) -> Result<crate::provider::UsageResult, String> {
|
||||
let app_type = AppType::from_str(&app).map_err(|e| e.to_string())?;
|
||||
|
||||
// 检查是否为 GitHub Copilot 模板类型,并解析绑定账号
|
||||
let (is_copilot_template, copilot_account_id) = {
|
||||
let providers = state
|
||||
.db
|
||||
.get_all_providers(app_type.as_str())
|
||||
.map_err(|e| format!("Failed to get providers: {}", e))?;
|
||||
|
||||
let provider = providers.get(&providerId);
|
||||
let is_copilot = provider
|
||||
.and_then(|p| p.meta.as_ref())
|
||||
.and_then(|m| m.usage_script.as_ref())
|
||||
.and_then(|s| s.template_type.as_ref())
|
||||
.map(|t| t == TEMPLATE_TYPE_GITHUB_COPILOT)
|
||||
.unwrap_or(false);
|
||||
let account_id = provider
|
||||
.and_then(|p| p.meta.as_ref())
|
||||
.and_then(|m| m.managed_account_id_for(TEMPLATE_TYPE_GITHUB_COPILOT));
|
||||
|
||||
(is_copilot, account_id)
|
||||
};
|
||||
|
||||
if is_copilot_template {
|
||||
// 使用 Copilot 专用 API
|
||||
let auth_manager = copilot_state.0.read().await;
|
||||
let usage = match copilot_account_id.as_deref() {
|
||||
Some(account_id) => auth_manager
|
||||
.fetch_usage_for_account(account_id)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to fetch Copilot usage: {}", e))?,
|
||||
None => auth_manager
|
||||
.fetch_usage()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to fetch Copilot usage: {}", e))?,
|
||||
};
|
||||
let premium = &usage.quota_snapshots.premium_interactions;
|
||||
let used = premium.entitlement - premium.remaining;
|
||||
|
||||
return Ok(crate::provider::UsageResult {
|
||||
success: true,
|
||||
data: Some(vec![crate::provider::UsageData {
|
||||
plan_name: Some(usage.copilot_plan),
|
||||
remaining: Some(premium.remaining as f64),
|
||||
total: Some(premium.entitlement as f64),
|
||||
used: Some(used as f64),
|
||||
unit: Some(COPILOT_UNIT_PREMIUM.to_string()),
|
||||
is_valid: Some(true),
|
||||
invalid_message: None,
|
||||
extra: Some(format!("Reset: {}", usage.quota_reset_date)),
|
||||
}]),
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
|
||||
ProviderService::query_usage(state.inner(), app_type, &providerId)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
//! 流式健康检查命令
|
||||
|
||||
use crate::app_config::AppType;
|
||||
use crate::commands::copilot::CopilotAuthState;
|
||||
use crate::error::AppError;
|
||||
use crate::services::stream_check::{
|
||||
HealthStatus, StreamCheckConfig, StreamCheckResult, StreamCheckService,
|
||||
@@ -13,6 +14,7 @@ use tauri::State;
|
||||
#[tauri::command]
|
||||
pub async fn stream_check_provider(
|
||||
state: State<'_, AppState>,
|
||||
copilot_state: State<'_, CopilotAuthState>,
|
||||
app_type: AppType,
|
||||
provider_id: String,
|
||||
) -> Result<StreamCheckResult, AppError> {
|
||||
@@ -23,7 +25,9 @@ pub async fn stream_check_provider(
|
||||
.get(&provider_id)
|
||||
.ok_or_else(|| AppError::Message(format!("供应商 {provider_id} 不存在")))?;
|
||||
|
||||
let result = StreamCheckService::check_with_retry(&app_type, provider, &config).await?;
|
||||
let auth_override = resolve_copilot_auth_override(provider, &copilot_state).await?;
|
||||
let result =
|
||||
StreamCheckService::check_with_retry(&app_type, provider, &config, auth_override).await?;
|
||||
|
||||
// 记录日志
|
||||
let _ =
|
||||
@@ -38,6 +42,7 @@ pub async fn stream_check_provider(
|
||||
#[tauri::command]
|
||||
pub async fn stream_check_all_providers(
|
||||
state: State<'_, AppState>,
|
||||
copilot_state: State<'_, CopilotAuthState>,
|
||||
app_type: AppType,
|
||||
proxy_targets_only: bool,
|
||||
) -> Result<Vec<(String, StreamCheckResult)>, AppError> {
|
||||
@@ -67,18 +72,20 @@ pub async fn stream_check_all_providers(
|
||||
}
|
||||
}
|
||||
|
||||
let result = StreamCheckService::check_with_retry(&app_type, &provider, &config)
|
||||
.await
|
||||
.unwrap_or_else(|e| StreamCheckResult {
|
||||
status: HealthStatus::Failed,
|
||||
success: false,
|
||||
message: e.to_string(),
|
||||
response_time_ms: None,
|
||||
http_status: None,
|
||||
model_used: String::new(),
|
||||
tested_at: chrono::Utc::now().timestamp(),
|
||||
retry_count: 0,
|
||||
});
|
||||
let auth_override = resolve_copilot_auth_override(&provider, &copilot_state).await?;
|
||||
let result =
|
||||
StreamCheckService::check_with_retry(&app_type, &provider, &config, auth_override)
|
||||
.await
|
||||
.unwrap_or_else(|e| StreamCheckResult {
|
||||
status: HealthStatus::Failed,
|
||||
success: false,
|
||||
message: e.to_string(),
|
||||
response_time_ms: None,
|
||||
http_status: None,
|
||||
model_used: String::new(),
|
||||
tested_at: chrono::Utc::now().timestamp(),
|
||||
retry_count: 0,
|
||||
});
|
||||
|
||||
let _ = state
|
||||
.db
|
||||
@@ -104,3 +111,46 @@ pub fn save_stream_check_config(
|
||||
) -> Result<(), AppError> {
|
||||
state.db.save_stream_check_config(&config)
|
||||
}
|
||||
|
||||
async fn resolve_copilot_auth_override(
|
||||
provider: &crate::provider::Provider,
|
||||
copilot_state: &State<'_, CopilotAuthState>,
|
||||
) -> Result<Option<crate::proxy::providers::AuthInfo>, AppError> {
|
||||
let is_copilot = provider
|
||||
.meta
|
||||
.as_ref()
|
||||
.and_then(|meta| meta.provider_type.as_deref())
|
||||
== Some("github_copilot")
|
||||
|| provider
|
||||
.settings_config
|
||||
.pointer("/env/ANTHROPIC_BASE_URL")
|
||||
.and_then(|value| value.as_str())
|
||||
.map(|url| url.contains("githubcopilot.com"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if !is_copilot {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let auth_manager = copilot_state.0.read().await;
|
||||
let account_id = provider
|
||||
.meta
|
||||
.as_ref()
|
||||
.and_then(|meta| meta.github_account_id.clone());
|
||||
|
||||
let token = match account_id.as_deref() {
|
||||
Some(id) => auth_manager
|
||||
.get_valid_token_for_account(id)
|
||||
.await
|
||||
.map_err(|e| AppError::Message(format!("GitHub Copilot 认证失败: {e}")))?,
|
||||
None => auth_manager
|
||||
.get_valid_token()
|
||||
.await
|
||||
.map_err(|e| AppError::Message(format!("GitHub Copilot 认证失败: {e}")))?,
|
||||
};
|
||||
|
||||
Ok(Some(crate::proxy::providers::AuthInfo::new(
|
||||
token,
|
||||
crate::proxy::providers::AuthStrategy::GitHubCopilot,
|
||||
)))
|
||||
}
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
/// Check Tool Search patch status for the active Claude Code installation
|
||||
#[tauri::command]
|
||||
pub async fn check_toolsearch_status() -> Result<crate::toolsearch_patch::ToolSearchStatus, String>
|
||||
{
|
||||
crate::toolsearch_patch::check_toolsearch_status().map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
/// Apply Tool Search patch (bypass domain restriction) to the active installation
|
||||
#[tauri::command]
|
||||
pub async fn apply_toolsearch_patch() -> Result<Vec<crate::toolsearch_patch::PatchResult>, String> {
|
||||
crate::toolsearch_patch::apply_toolsearch_patch().map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
/// Restore Tool Search patch (re-enable domain restriction) for the active installation
|
||||
#[tauri::command]
|
||||
pub async fn restore_toolsearch_patch() -> Result<Vec<crate::toolsearch_patch::PatchResult>, String>
|
||||
{
|
||||
crate::toolsearch_patch::restore_toolsearch_patch().map_err(|e| e.to_string())
|
||||
}
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
use super::{lock_conn, Database, SCHEMA_VERSION};
|
||||
use crate::error::AppError;
|
||||
use rusqlite::Connection;
|
||||
use rusqlite::{params, Connection};
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -389,7 +389,7 @@ impl Database {
|
||||
Self::set_user_version(conn, 5)?;
|
||||
}
|
||||
5 => {
|
||||
log::info!("迁移数据库从 v5 到 v6(使用量聚合表)");
|
||||
log::info!("迁移数据库从 v5 到 v6(使用量聚合表 + Copilot 模板类型统一)");
|
||||
Self::migrate_v5_to_v6(conn)?;
|
||||
Self::set_user_version(conn, 6)?;
|
||||
}
|
||||
@@ -970,8 +970,9 @@ impl Database {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// v5 -> v6 迁移:添加使用量日聚合表
|
||||
/// v5 -> v6 迁移:添加使用量日聚合表 + 统一 Copilot 模板类型
|
||||
fn migrate_v5_to_v6(conn: &Connection) -> Result<(), AppError> {
|
||||
// 1. 添加使用量日聚合表
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS usage_daily_rollups (
|
||||
date TEXT NOT NULL,
|
||||
@@ -992,7 +993,55 @@ impl Database {
|
||||
)
|
||||
.map_err(|e| AppError::Database(format!("创建 usage_daily_rollups 表失败: {e}")))?;
|
||||
|
||||
log::info!("v5 -> v6 迁移完成:已添加使用量日聚合表");
|
||||
// 2. 统一 Copilot 模板类型为 github_copilot
|
||||
let mut stmt = conn
|
||||
.prepare("SELECT id, app_type, meta FROM providers")
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
|
||||
let rows = stmt
|
||||
.query_map([], |row| {
|
||||
Ok((
|
||||
row.get::<_, String>(0)?,
|
||||
row.get::<_, String>(1)?,
|
||||
row.get::<_, String>(2)?,
|
||||
))
|
||||
})
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
|
||||
let mut updates = Vec::new();
|
||||
for row in rows {
|
||||
let (id, app_type, meta_str) = row.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
|
||||
if let Ok(mut meta) = serde_json::from_str::<serde_json::Value>(&meta_str) {
|
||||
let mut updated = false;
|
||||
|
||||
if let Some(usage_script) = meta.get_mut("usage_script") {
|
||||
if let Some(template_type) = usage_script.get_mut("template_type") {
|
||||
if template_type == "copilot" {
|
||||
*template_type =
|
||||
serde_json::Value::String("github_copilot".to_string());
|
||||
updated = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if updated {
|
||||
let new_meta_str = serde_json::to_string(&meta)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
updates.push((id, app_type, new_meta_str));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (id, app_type, new_meta) in updates {
|
||||
conn.execute(
|
||||
"UPDATE providers SET meta = ?1 WHERE id = ?2 AND app_type = ?3",
|
||||
params![new_meta, id, app_type],
|
||||
)
|
||||
.map_err(|e| AppError::Database(e.to_string()))?;
|
||||
}
|
||||
|
||||
log::info!("v5 -> v6 迁移完成:已添加使用量日聚合表,统一 copilot 模板类型");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
+38
-25
@@ -25,7 +25,7 @@ mod services;
|
||||
mod session_manager;
|
||||
mod settings;
|
||||
mod store;
|
||||
mod toolsearch_patch;
|
||||
|
||||
mod tray;
|
||||
mod usage_script;
|
||||
|
||||
@@ -690,6 +690,18 @@ pub fn run() {
|
||||
let skill_service = SkillService::new();
|
||||
app.manage(commands::skill::SkillServiceState(Arc::new(skill_service)));
|
||||
|
||||
// 初始化 CopilotAuthManager
|
||||
{
|
||||
use crate::proxy::providers::copilot_auth::CopilotAuthManager;
|
||||
use commands::CopilotAuthState;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
let app_config_dir = crate::config::get_app_config_dir();
|
||||
let copilot_auth_manager = CopilotAuthManager::new(app_config_dir);
|
||||
app.manage(CopilotAuthState(Arc::new(RwLock::new(copilot_auth_manager))));
|
||||
log::info!("✓ CopilotAuthManager initialized");
|
||||
}
|
||||
|
||||
// 初始化全局出站代理 HTTP 客户端
|
||||
{
|
||||
let db = &app.state::<AppState>().db;
|
||||
@@ -806,26 +818,6 @@ pub fn run() {
|
||||
}
|
||||
}
|
||||
|
||||
// Tool Search bypass: auto-apply patch on startup if enabled
|
||||
if settings.tool_search_bypass {
|
||||
match crate::toolsearch_patch::apply_toolsearch_patch() {
|
||||
Ok(results) => {
|
||||
let success = results.iter().filter(|r| r.success).count();
|
||||
let total = results.len();
|
||||
if success > 0 {
|
||||
log::info!("✓ Tool Search patch auto-applied ({success}/{total})");
|
||||
}
|
||||
for r in results.iter().filter(|r| !r.success) {
|
||||
if let Some(err) = &r.error {
|
||||
log::warn!("✗ Tool Search patch failed for {}: {err}", r.path);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("✗ Tool Search auto-patch skipped: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
@@ -873,10 +865,6 @@ pub fn run() {
|
||||
commands::is_claude_plugin_applied,
|
||||
commands::apply_claude_onboarding_skip,
|
||||
commands::clear_claude_onboarding_skip,
|
||||
// Tool Search patch
|
||||
commands::check_toolsearch_status,
|
||||
commands::apply_toolsearch_patch,
|
||||
commands::restore_toolsearch_patch,
|
||||
// Claude MCP management
|
||||
commands::get_claude_mcp_status,
|
||||
commands::read_claude_mcp_config,
|
||||
@@ -1056,6 +1044,31 @@ pub fn run() {
|
||||
commands::scan_local_proxies,
|
||||
// Window theme control
|
||||
commands::set_window_theme,
|
||||
// Generic managed auth commands
|
||||
commands::auth_start_login,
|
||||
commands::auth_poll_for_account,
|
||||
commands::auth_list_accounts,
|
||||
commands::auth_get_status,
|
||||
commands::auth_remove_account,
|
||||
commands::auth_set_default_account,
|
||||
commands::auth_logout,
|
||||
// Copilot OAuth commands (multi-account support)
|
||||
commands::copilot_start_device_flow,
|
||||
commands::copilot_poll_for_auth,
|
||||
commands::copilot_poll_for_account,
|
||||
commands::copilot_list_accounts,
|
||||
commands::copilot_remove_account,
|
||||
commands::copilot_set_default_account,
|
||||
commands::copilot_get_auth_status,
|
||||
commands::copilot_logout,
|
||||
commands::copilot_is_authenticated,
|
||||
commands::copilot_get_token,
|
||||
commands::copilot_get_token_for_account,
|
||||
commands::copilot_get_models,
|
||||
commands::copilot_get_models_for_account,
|
||||
commands::copilot_get_usage,
|
||||
commands::copilot_get_usage_for_account,
|
||||
// OMO commands
|
||||
commands::read_omo_local_file,
|
||||
commands::get_current_omo_provider_id,
|
||||
commands::disable_current_omo,
|
||||
|
||||
@@ -191,6 +191,31 @@ pub struct ProviderProxyConfig {
|
||||
pub proxy_password: Option<String>,
|
||||
}
|
||||
|
||||
/// 认证绑定来源
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AuthBindingSource {
|
||||
/// 从 provider 自身配置读取认证信息(默认)
|
||||
#[default]
|
||||
ProviderConfig,
|
||||
/// 使用托管账号认证(如 GitHub Copilot OAuth)
|
||||
ManagedAccount,
|
||||
}
|
||||
|
||||
/// 通用认证绑定
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct AuthBinding {
|
||||
/// 认证来源
|
||||
#[serde(default)]
|
||||
pub source: AuthBindingSource,
|
||||
/// 托管认证供应商标识(如 github_copilot)
|
||||
#[serde(rename = "authProvider", skip_serializing_if = "Option::is_none")]
|
||||
pub auth_provider: Option<String>,
|
||||
/// 托管账号 ID;为空表示跟随该认证供应商的默认账号
|
||||
#[serde(rename = "accountId", skip_serializing_if = "Option::is_none")]
|
||||
pub account_id: Option<String>,
|
||||
}
|
||||
|
||||
/// 供应商元数据
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ProviderMeta {
|
||||
@@ -242,14 +267,49 @@ pub struct ProviderMeta {
|
||||
/// - "openai_responses": OpenAI Responses API 格式,需要转换
|
||||
#[serde(rename = "apiFormat", skip_serializing_if = "Option::is_none")]
|
||||
pub api_format: Option<String>,
|
||||
/// 通用认证绑定(provider_config / managed_account)
|
||||
///
|
||||
/// 新代码应只写入该字段;githubAccountId 仅保留兼容读取。
|
||||
#[serde(rename = "authBinding", skip_serializing_if = "Option::is_none")]
|
||||
pub auth_binding: Option<AuthBinding>,
|
||||
/// Claude 认证字段名("ANTHROPIC_AUTH_TOKEN" 或 "ANTHROPIC_API_KEY")
|
||||
#[serde(rename = "apiKeyField", skip_serializing_if = "Option::is_none")]
|
||||
pub api_key_field: Option<String>,
|
||||
|
||||
/// Prompt cache key for OpenAI-compatible endpoints.
|
||||
/// When set, injected into converted requests to improve cache hit rate.
|
||||
/// If not set, provider ID is used automatically during format conversion.
|
||||
#[serde(rename = "promptCacheKey", skip_serializing_if = "Option::is_none")]
|
||||
pub prompt_cache_key: Option<String>,
|
||||
/// 供应商类型标识(用于特殊供应商检测)
|
||||
/// - "github_copilot": GitHub Copilot 供应商
|
||||
#[serde(rename = "providerType", skip_serializing_if = "Option::is_none")]
|
||||
pub provider_type: Option<String>,
|
||||
/// GitHub Copilot 关联账号 ID(仅 github_copilot 供应商使用)
|
||||
/// 用于多账号支持,关联到特定的 GitHub 账号
|
||||
#[serde(rename = "githubAccountId", skip_serializing_if = "Option::is_none")]
|
||||
pub github_account_id: Option<String>,
|
||||
}
|
||||
|
||||
impl ProviderMeta {
|
||||
/// 解析指定托管认证供应商绑定的账号 ID。
|
||||
///
|
||||
/// 新版优先读取 authBinding,旧版继续兼容 githubAccountId。
|
||||
pub fn managed_account_id_for(&self, auth_provider: &str) -> Option<String> {
|
||||
if let Some(binding) = self.auth_binding.as_ref() {
|
||||
if binding.source == AuthBindingSource::ManagedAccount
|
||||
&& binding.auth_provider.as_deref() == Some(auth_provider)
|
||||
{
|
||||
return binding.account_id.clone();
|
||||
}
|
||||
}
|
||||
|
||||
if auth_provider == "github_copilot" {
|
||||
return self.github_account_id.clone();
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderManager {
|
||||
|
||||
@@ -8,7 +8,7 @@ use super::{
|
||||
failover_switch::FailoverSwitchManager,
|
||||
log_codes::fwd as log_fwd,
|
||||
provider_router::ProviderRouter,
|
||||
providers::{get_adapter, ProviderAdapter, ProviderType},
|
||||
providers::{get_adapter, AuthInfo, AuthStrategy, ProviderAdapter, ProviderType},
|
||||
thinking_budget_rectifier::{rectify_thinking_budget, should_rectify_thinking_budget},
|
||||
thinking_rectifier::{
|
||||
normalize_thinking_type, rectify_anthropic_request, should_rectify_thinking_signature,
|
||||
@@ -16,10 +16,13 @@ use super::{
|
||||
types::{OptimizerConfig, ProxyStatus, RectifierConfig},
|
||||
ProxyError,
|
||||
};
|
||||
use crate::commands::CopilotAuthState;
|
||||
use crate::proxy::providers::copilot_auth::CopilotAuthManager;
|
||||
use crate::{app_config::AppType, provider::Provider};
|
||||
use reqwest::Response;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use tauri::Manager;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Headers 黑名单 - 不透传到上游的 Headers
|
||||
@@ -792,14 +795,27 @@ impl RequestForwarder {
|
||||
// 检查是否需要格式转换
|
||||
let needs_transform = adapter.needs_transform(provider);
|
||||
|
||||
// 确定有效端点
|
||||
// GitHub Copilot API 使用 /chat/completions(无 /v1 前缀)
|
||||
let is_copilot = provider
|
||||
.meta
|
||||
.as_ref()
|
||||
.and_then(|m| m.provider_type.as_deref())
|
||||
== Some("github_copilot")
|
||||
|| base_url.contains("githubcopilot.com");
|
||||
let effective_endpoint =
|
||||
if needs_transform && adapter.name() == "Claude" && endpoint == "/v1/messages" {
|
||||
// 根据 api_format 选择目标端点
|
||||
let api_format = super::providers::get_claude_api_format(provider);
|
||||
if api_format == "openai_responses" {
|
||||
"/v1/responses"
|
||||
if is_copilot {
|
||||
// GitHub Copilot uses /chat/completions without /v1 prefix
|
||||
"/chat/completions"
|
||||
} else {
|
||||
"/v1/chat/completions"
|
||||
// 根据 api_format 选择目标端点
|
||||
let api_format = super::providers::get_claude_api_format(provider);
|
||||
if api_format == "openai_responses" {
|
||||
"/v1/responses"
|
||||
} else {
|
||||
"/v1/chat/completions"
|
||||
}
|
||||
}
|
||||
} else {
|
||||
endpoint
|
||||
@@ -892,7 +908,57 @@ impl RequestForwarder {
|
||||
}
|
||||
|
||||
// 使用适配器添加认证头
|
||||
if let Some(auth) = adapter.extract_auth(provider) {
|
||||
if let Some(mut auth) = adapter.extract_auth(provider) {
|
||||
// GitHub Copilot 特殊处理:从 CopilotAuthManager 获取真实 token
|
||||
if auth.strategy == AuthStrategy::GitHubCopilot {
|
||||
if let Some(app_handle) = &self.app_handle {
|
||||
let copilot_state = app_handle.state::<CopilotAuthState>();
|
||||
let copilot_auth: tokio::sync::RwLockReadGuard<'_, CopilotAuthManager> =
|
||||
copilot_state.0.read().await;
|
||||
|
||||
// 从 provider.meta 获取关联的 GitHub 账号 ID(多账号支持)
|
||||
let account_id = provider
|
||||
.meta
|
||||
.as_ref()
|
||||
.and_then(|m| m.managed_account_id_for("github_copilot"));
|
||||
|
||||
// 根据账号 ID 获取对应 token(向后兼容:无账号 ID 时使用第一个账号)
|
||||
let token_result = match &account_id {
|
||||
Some(id) => {
|
||||
log::debug!("[Copilot] 使用指定账号 {id} 获取 token");
|
||||
copilot_auth.get_valid_token_for_account(id).await
|
||||
}
|
||||
None => {
|
||||
log::debug!("[Copilot] 使用默认账号获取 token");
|
||||
copilot_auth.get_valid_token().await
|
||||
}
|
||||
};
|
||||
|
||||
match token_result {
|
||||
Ok(token) => {
|
||||
auth = AuthInfo::new(token, AuthStrategy::GitHubCopilot);
|
||||
log::debug!(
|
||||
"[Copilot] 成功获取 Copilot token (account={})",
|
||||
account_id.as_deref().unwrap_or("default")
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!(
|
||||
"[Copilot] 获取 Copilot token 失败 (account={}): {e}",
|
||||
account_id.as_deref().unwrap_or("default")
|
||||
);
|
||||
return Err(ProxyError::AuthError(format!(
|
||||
"GitHub Copilot 认证失败: {e}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log::error!("[Copilot] AppHandle 不可用");
|
||||
return Err(ProxyError::AuthError(
|
||||
"GitHub Copilot 认证不可用(无 AppHandle)".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
request = adapter.add_auth_headers(request, &auth);
|
||||
}
|
||||
|
||||
|
||||
@@ -112,6 +112,13 @@ pub enum AuthStrategy {
|
||||
///
|
||||
/// 用于 Gemini CLI 等需要 OAuth 的场景
|
||||
GoogleOAuth,
|
||||
|
||||
/// GitHub Copilot 认证方式
|
||||
///
|
||||
/// - Header: `Authorization: Bearer <copilot_token>`
|
||||
///
|
||||
/// 使用动态获取的 Copilot Token(通过 GitHub OAuth 设备码流程获取)
|
||||
GitHubCopilot,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -226,6 +233,7 @@ mod tests {
|
||||
AuthStrategy::Bearer,
|
||||
AuthStrategy::Google,
|
||||
AuthStrategy::GoogleOAuth,
|
||||
AuthStrategy::GitHubCopilot,
|
||||
];
|
||||
|
||||
for (i, s1) in strategies.iter().enumerate() {
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
//! - **Claude**: Anthropic 官方 API (x-api-key + anthropic-version)
|
||||
//! - **ClaudeAuth**: 中转服务 (仅 Bearer 认证,无 x-api-key)
|
||||
//! - **OpenRouter**: 已支持 Claude Code 兼容接口,默认透传
|
||||
//! - **GitHubCopilot**: GitHub Copilot (OAuth + Copilot Token)
|
||||
|
||||
use super::{AuthInfo, AuthStrategy, ProviderAdapter, ProviderType};
|
||||
use crate::provider::Provider;
|
||||
@@ -76,10 +77,16 @@ impl ClaudeAdapter {
|
||||
/// 获取供应商类型
|
||||
///
|
||||
/// 根据 base_url 和 auth_mode 检测具体的供应商类型:
|
||||
/// - GitHubCopilot: meta.provider_type 为 github_copilot 或 base_url 包含 githubcopilot.com
|
||||
/// - OpenRouter: base_url 包含 openrouter.ai
|
||||
/// - ClaudeAuth: auth_mode 为 bearer_only
|
||||
/// - Claude: 默认 Anthropic 官方
|
||||
pub fn provider_type(&self, provider: &Provider) -> ProviderType {
|
||||
// 检测 GitHub Copilot
|
||||
if self.is_github_copilot(provider) {
|
||||
return ProviderType::GitHubCopilot;
|
||||
}
|
||||
|
||||
// 检测 OpenRouter
|
||||
if self.is_openrouter(provider) {
|
||||
return ProviderType::OpenRouter;
|
||||
@@ -93,6 +100,25 @@ impl ClaudeAdapter {
|
||||
ProviderType::Claude
|
||||
}
|
||||
|
||||
/// 检测是否为 GitHub Copilot 供应商
|
||||
fn is_github_copilot(&self, provider: &Provider) -> bool {
|
||||
// 方式1: 检查 meta.provider_type
|
||||
if let Some(meta) = provider.meta.as_ref() {
|
||||
if meta.provider_type.as_deref() == Some("github_copilot") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// 方式2: 检查 base_url(兼容旧数据的 fallback,后续应优先依赖 providerType)
|
||||
if let Ok(base_url) = self.extract_base_url(provider) {
|
||||
if base_url.contains("githubcopilot.com") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// 检测是否使用 OpenRouter
|
||||
fn is_openrouter(&self, provider: &Provider) -> bool {
|
||||
if let Ok(base_url) = self.extract_base_url(provider) {
|
||||
@@ -244,6 +270,17 @@ impl ProviderAdapter for ClaudeAdapter {
|
||||
|
||||
fn extract_auth(&self, provider: &Provider) -> Option<AuthInfo> {
|
||||
let provider_type = self.provider_type(provider);
|
||||
|
||||
// GitHub Copilot 使用特殊的认证策略
|
||||
// 实际的 token 会在代理请求时动态获取
|
||||
if provider_type == ProviderType::GitHubCopilot {
|
||||
// 返回一个占位符,实际 token 由 CopilotAuthManager 动态提供
|
||||
return Some(AuthInfo::new(
|
||||
"copilot_placeholder".to_string(),
|
||||
AuthStrategy::GitHubCopilot,
|
||||
));
|
||||
}
|
||||
|
||||
let strategy = match provider_type {
|
||||
ProviderType::OpenRouter => AuthStrategy::Bearer,
|
||||
ProviderType::ClaudeAuth => AuthStrategy::ClaudeAuth,
|
||||
@@ -273,6 +310,11 @@ impl ProviderAdapter for ClaudeAdapter {
|
||||
base = base.replace("/v1/v1", "/v1");
|
||||
}
|
||||
|
||||
// GitHub Copilot 不需要 ?beta=true 参数
|
||||
if base_url.contains("githubcopilot.com") {
|
||||
return base;
|
||||
}
|
||||
|
||||
// 为 Claude 原生 /v1/messages 端点添加 ?beta=true 参数
|
||||
// 这是某些上游服务(如 DuckCoding)验证请求来源的关键参数
|
||||
// 注意:不要为 OpenAI Chat Completions (/v1/chat/completions) 添加此参数
|
||||
@@ -304,11 +346,22 @@ impl ProviderAdapter for ClaudeAdapter {
|
||||
AuthStrategy::Bearer => {
|
||||
request.header("Authorization", format!("Bearer {}", auth.api_key))
|
||||
}
|
||||
// GitHub Copilot: Bearer + 特定的 Editor headers
|
||||
AuthStrategy::GitHubCopilot => request
|
||||
.header("Authorization", format!("Bearer {}", auth.api_key))
|
||||
.header("Editor-Version", "vscode/1.85.0")
|
||||
.header("Editor-Plugin-Version", "copilot/1.150.0")
|
||||
.header("Copilot-Integration-Id", "vscode-chat"),
|
||||
_ => request,
|
||||
}
|
||||
}
|
||||
|
||||
fn needs_transform(&self, provider: &Provider) -> bool {
|
||||
// GitHub Copilot 总是需要格式转换 (Anthropic → OpenAI)
|
||||
if self.is_github_copilot(provider) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// 根据 api_format 配置决定是否需要格式转换
|
||||
// - "anthropic" (默认): 直接透传,无需转换
|
||||
// - "openai_chat": 需要 Anthropic ↔ OpenAI Chat Completions 格式转换
|
||||
@@ -678,4 +731,67 @@ mod tests {
|
||||
);
|
||||
assert!(!adapter.needs_transform(&unknown_format));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_github_copilot_detection_by_url() {
|
||||
let adapter = ClaudeAdapter::new();
|
||||
|
||||
// GitHub Copilot by base_url
|
||||
let copilot = create_provider(json!({
|
||||
"env": {
|
||||
"ANTHROPIC_BASE_URL": "https://api.githubcopilot.com"
|
||||
}
|
||||
}));
|
||||
assert_eq!(adapter.provider_type(&copilot), ProviderType::GitHubCopilot);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_github_copilot_detection_by_meta() {
|
||||
let adapter = ClaudeAdapter::new();
|
||||
|
||||
// GitHub Copilot by meta.provider_type
|
||||
let copilot_meta = create_provider_with_meta(
|
||||
json!({
|
||||
"env": {
|
||||
"ANTHROPIC_BASE_URL": "https://api.example.com"
|
||||
}
|
||||
}),
|
||||
ProviderMeta {
|
||||
provider_type: Some("github_copilot".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
assert_eq!(
|
||||
adapter.provider_type(&copilot_meta),
|
||||
ProviderType::GitHubCopilot
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_github_copilot_auth() {
|
||||
let adapter = ClaudeAdapter::new();
|
||||
|
||||
let copilot = create_provider(json!({
|
||||
"env": {
|
||||
"ANTHROPIC_BASE_URL": "https://api.githubcopilot.com"
|
||||
}
|
||||
}));
|
||||
|
||||
let auth = adapter.extract_auth(&copilot).unwrap();
|
||||
assert_eq!(auth.strategy, AuthStrategy::GitHubCopilot);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_github_copilot_needs_transform() {
|
||||
let adapter = ClaudeAdapter::new();
|
||||
|
||||
let copilot = create_provider(json!({
|
||||
"env": {
|
||||
"ANTHROPIC_BASE_URL": "https://api.githubcopilot.com"
|
||||
}
|
||||
}));
|
||||
|
||||
// GitHub Copilot always needs transform
|
||||
assert!(adapter.needs_transform(&copilot));
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -15,6 +15,7 @@ mod adapter;
|
||||
mod auth;
|
||||
mod claude;
|
||||
mod codex;
|
||||
pub mod copilot_auth;
|
||||
mod gemini;
|
||||
pub mod models;
|
||||
pub mod streaming;
|
||||
@@ -52,6 +53,8 @@ pub enum ProviderType {
|
||||
GeminiCli,
|
||||
/// OpenRouter(已支持 Claude Code 兼容接口,默认透传;保留旧转换逻辑备用)
|
||||
OpenRouter,
|
||||
/// GitHub Copilot (OAuth + Copilot Token,需要 Anthropic ↔ OpenAI 转换)
|
||||
GitHubCopilot,
|
||||
}
|
||||
|
||||
impl ProviderType {
|
||||
@@ -59,9 +62,11 @@ impl ProviderType {
|
||||
///
|
||||
/// 过去 OpenRouter 需要将 Anthropic 格式转换为 OpenAI 格式;
|
||||
/// 现在默认关闭转换(因为 OpenRouter 已支持 Claude Code 兼容接口)。
|
||||
/// GitHub Copilot 需要转换(Anthropic → OpenAI 格式)。
|
||||
#[allow(dead_code)]
|
||||
pub fn needs_transform(&self) -> bool {
|
||||
match self {
|
||||
ProviderType::GitHubCopilot => true,
|
||||
ProviderType::OpenRouter => false,
|
||||
_ => false,
|
||||
}
|
||||
@@ -77,6 +82,7 @@ impl ProviderType {
|
||||
"https://generativelanguage.googleapis.com"
|
||||
}
|
||||
ProviderType::OpenRouter => "https://openrouter.ai/api",
|
||||
ProviderType::GitHubCopilot => "https://api.githubcopilot.com",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,9 +93,20 @@ impl ProviderType {
|
||||
pub fn from_app_type_and_config(app_type: &AppType, provider: &Provider) -> Self {
|
||||
match app_type {
|
||||
AppType::Claude => {
|
||||
// 检测是否为 OpenRouter
|
||||
// 检测是否为 GitHub Copilot
|
||||
if let Some(meta) = provider.meta.as_ref() {
|
||||
if meta.provider_type.as_deref() == Some("github_copilot") {
|
||||
return ProviderType::GitHubCopilot;
|
||||
}
|
||||
}
|
||||
|
||||
// 检测 base_url 是否为 GitHub Copilot
|
||||
let adapter = ClaudeAdapter::new();
|
||||
if let Ok(base_url) = adapter.extract_base_url(provider) {
|
||||
if base_url.contains("githubcopilot.com") {
|
||||
return ProviderType::GitHubCopilot;
|
||||
}
|
||||
// 检测是否为 OpenRouter
|
||||
if base_url.contains("openrouter.ai") {
|
||||
return ProviderType::OpenRouter;
|
||||
}
|
||||
@@ -154,6 +171,7 @@ impl ProviderType {
|
||||
ProviderType::Gemini => "gemini",
|
||||
ProviderType::GeminiCli => "gemini_cli",
|
||||
ProviderType::OpenRouter => "openrouter",
|
||||
ProviderType::GitHubCopilot => "github_copilot",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -175,6 +193,9 @@ impl std::str::FromStr for ProviderType {
|
||||
"gemini" => Ok(ProviderType::Gemini),
|
||||
"gemini_cli" | "gemini-cli" => Ok(ProviderType::GeminiCli),
|
||||
"openrouter" => Ok(ProviderType::OpenRouter),
|
||||
"github_copilot" | "github-copilot" | "githubcopilot" => {
|
||||
Ok(ProviderType::GitHubCopilot)
|
||||
}
|
||||
_ => Err(format!("Invalid provider type: {s}")),
|
||||
}
|
||||
}
|
||||
@@ -201,9 +222,10 @@ pub fn get_adapter(app_type: &AppType) -> Box<dyn ProviderAdapter> {
|
||||
#[allow(dead_code)]
|
||||
pub fn get_adapter_for_provider_type(provider_type: &ProviderType) -> Box<dyn ProviderAdapter> {
|
||||
match provider_type {
|
||||
ProviderType::Claude | ProviderType::ClaudeAuth | ProviderType::OpenRouter => {
|
||||
Box::new(ClaudeAdapter::new())
|
||||
}
|
||||
ProviderType::Claude
|
||||
| ProviderType::ClaudeAuth
|
||||
| ProviderType::OpenRouter
|
||||
| ProviderType::GitHubCopilot => Box::new(ClaudeAdapter::new()),
|
||||
ProviderType::Codex => Box::new(CodexAdapter::new()),
|
||||
ProviderType::Gemini | ProviderType::GeminiCli => Box::new(GeminiAdapter::new()),
|
||||
}
|
||||
@@ -239,6 +261,7 @@ mod tests {
|
||||
assert!(!ProviderType::Gemini.needs_transform());
|
||||
assert!(!ProviderType::GeminiCli.needs_transform());
|
||||
assert!(!ProviderType::OpenRouter.needs_transform());
|
||||
assert!(ProviderType::GitHubCopilot.needs_transform());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -267,6 +290,10 @@ mod tests {
|
||||
ProviderType::OpenRouter.default_endpoint(),
|
||||
"https://openrouter.ai/api"
|
||||
);
|
||||
assert_eq!(
|
||||
ProviderType::GitHubCopilot.default_endpoint(),
|
||||
"https://api.githubcopilot.com"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -303,6 +330,18 @@ mod tests {
|
||||
"openrouter".parse::<ProviderType>().unwrap(),
|
||||
ProviderType::OpenRouter
|
||||
);
|
||||
assert_eq!(
|
||||
"github_copilot".parse::<ProviderType>().unwrap(),
|
||||
ProviderType::GitHubCopilot
|
||||
);
|
||||
assert_eq!(
|
||||
"github-copilot".parse::<ProviderType>().unwrap(),
|
||||
ProviderType::GitHubCopilot
|
||||
);
|
||||
assert_eq!(
|
||||
"githubcopilot".parse::<ProviderType>().unwrap(),
|
||||
ProviderType::GitHubCopilot
|
||||
);
|
||||
assert!("invalid".parse::<ProviderType>().is_err());
|
||||
}
|
||||
|
||||
@@ -314,6 +353,7 @@ mod tests {
|
||||
assert_eq!(ProviderType::Gemini.as_str(), "gemini");
|
||||
assert_eq!(ProviderType::GeminiCli.as_str(), "gemini_cli");
|
||||
assert_eq!(ProviderType::OpenRouter.as_str(), "openrouter");
|
||||
assert_eq!(ProviderType::GitHubCopilot.as_str(), "github_copilot");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -434,6 +474,9 @@ mod tests {
|
||||
let adapter = get_adapter_for_provider_type(&ProviderType::OpenRouter);
|
||||
assert_eq!(adapter.name(), "Claude");
|
||||
|
||||
let adapter = get_adapter_for_provider_type(&ProviderType::GitHubCopilot);
|
||||
assert_eq!(adapter.name(), "Claude");
|
||||
|
||||
let adapter = get_adapter_for_provider_type(&ProviderType::Codex);
|
||||
assert_eq!(adapter.name(), "Codex");
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ use std::time::Instant;
|
||||
use crate::app_config::AppType;
|
||||
use crate::error::AppError;
|
||||
use crate::provider::Provider;
|
||||
use crate::proxy::providers::transform::anthropic_to_openai;
|
||||
use crate::proxy::providers::{get_adapter, AuthInfo, AuthStrategy};
|
||||
|
||||
/// 健康状态枚举
|
||||
@@ -84,13 +85,16 @@ impl StreamCheckService {
|
||||
app_type: &AppType,
|
||||
provider: &Provider,
|
||||
config: &StreamCheckConfig,
|
||||
auth_override: Option<AuthInfo>,
|
||||
) -> Result<StreamCheckResult, AppError> {
|
||||
// 合并供应商单独配置和全局配置
|
||||
let effective_config = Self::merge_provider_config(provider, config);
|
||||
let mut last_result = None;
|
||||
|
||||
for attempt in 0..=effective_config.max_retries {
|
||||
let result = Self::check_once(app_type, provider, &effective_config).await;
|
||||
let result =
|
||||
Self::check_once(app_type, provider, &effective_config, auth_override.clone())
|
||||
.await;
|
||||
|
||||
match &result {
|
||||
Ok(r) if r.success => {
|
||||
@@ -178,6 +182,7 @@ impl StreamCheckService {
|
||||
app_type: &AppType,
|
||||
provider: &Provider,
|
||||
config: &StreamCheckConfig,
|
||||
auth_override: Option<AuthInfo>,
|
||||
) -> Result<StreamCheckResult, AppError> {
|
||||
let start = Instant::now();
|
||||
let adapter = get_adapter(app_type);
|
||||
@@ -186,8 +191,8 @@ impl StreamCheckService {
|
||||
.extract_base_url(provider)
|
||||
.map_err(|e| AppError::Message(format!("Failed to extract base_url: {e}")))?;
|
||||
|
||||
let auth = adapter
|
||||
.extract_auth(provider)
|
||||
let auth = auth_override
|
||||
.or_else(|| adapter.extract_auth(provider))
|
||||
.ok_or_else(|| AppError::Message("API Key not found".to_string()))?;
|
||||
|
||||
// 获取 HTTP 客户端:优先使用供应商单独代理配置,否则使用全局客户端
|
||||
@@ -297,6 +302,7 @@ impl StreamCheckService {
|
||||
provider: &Provider,
|
||||
) -> Result<(u16, String), AppError> {
|
||||
let base = base_url.trim_end_matches('/');
|
||||
let is_github_copilot = auth.strategy == AuthStrategy::GitHubCopilot;
|
||||
|
||||
// Detect api_format: meta.api_format > settings_config.api_format > default "anthropic"
|
||||
let api_format = provider
|
||||
@@ -311,10 +317,15 @@ impl StreamCheckService {
|
||||
})
|
||||
.unwrap_or("anthropic");
|
||||
|
||||
let is_openai_chat = api_format == "openai_chat";
|
||||
let is_openai_chat = is_github_copilot || api_format == "openai_chat";
|
||||
|
||||
// URL: /v1/chat/completions for openai_chat, /v1/messages?beta=true for anthropic
|
||||
let url = if is_openai_chat {
|
||||
// URL:
|
||||
// - GitHub Copilot: /chat/completions (no /v1 prefix)
|
||||
// - OpenAI-compatible: /v1/chat/completions
|
||||
// - Anthropic native: /v1/messages?beta=true
|
||||
let url = if is_github_copilot {
|
||||
format!("{base}/chat/completions")
|
||||
} else if is_openai_chat {
|
||||
if base.ends_with("/v1") {
|
||||
format!("{base}/chat/completions")
|
||||
} else {
|
||||
@@ -329,22 +340,38 @@ impl StreamCheckService {
|
||||
}
|
||||
};
|
||||
|
||||
// Body: identical structure for minimal test (both APIs accept messages array)
|
||||
let body = json!({
|
||||
// Build from Anthropic-native shape first, then convert for OpenAI-compatible targets.
|
||||
let anthropic_body = json!({
|
||||
"model": model,
|
||||
"max_tokens": 1,
|
||||
"messages": [{ "role": "user", "content": test_prompt }],
|
||||
"stream": true
|
||||
});
|
||||
let body = if is_openai_chat {
|
||||
anthropic_to_openai(anthropic_body, Some(&provider.id))
|
||||
.map_err(|e| AppError::Message(format!("Failed to build test request: {e}")))?
|
||||
} else {
|
||||
anthropic_body
|
||||
};
|
||||
|
||||
let mut request_builder = client.post(&url);
|
||||
|
||||
if is_openai_chat {
|
||||
if is_github_copilot {
|
||||
request_builder = request_builder
|
||||
.header("authorization", format!("Bearer {}", auth.api_key))
|
||||
.header("content-type", "application/json")
|
||||
.header("accept", "text/event-stream")
|
||||
.header("accept-encoding", "identity")
|
||||
.header("editor-version", "vscode/1.85.0")
|
||||
.header("editor-plugin-version", "copilot/1.150.0")
|
||||
.header("copilot-integration-id", "vscode-chat");
|
||||
} else if is_openai_chat {
|
||||
// OpenAI-compatible: Bearer auth + standard headers only
|
||||
request_builder = request_builder
|
||||
.header("authorization", format!("Bearer {}", auth.api_key))
|
||||
.header("content-type", "application/json")
|
||||
.header("accept", "application/json");
|
||||
.header("accept", "text/event-stream")
|
||||
.header("accept-encoding", "identity");
|
||||
} else {
|
||||
// Anthropic native: full Claude CLI headers
|
||||
let os_name = Self::get_os_name();
|
||||
@@ -692,6 +719,31 @@ impl StreamCheckService {
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn resolve_claude_stream_url(
|
||||
base_url: &str,
|
||||
auth_strategy: AuthStrategy,
|
||||
api_format: &str,
|
||||
) -> String {
|
||||
let base = base_url.trim_end_matches('/');
|
||||
let is_github_copilot = auth_strategy == AuthStrategy::GitHubCopilot;
|
||||
let is_openai_chat = is_github_copilot || api_format == "openai_chat";
|
||||
|
||||
if is_github_copilot {
|
||||
format!("{base}/chat/completions")
|
||||
} else if is_openai_chat {
|
||||
if base.ends_with("/v1") {
|
||||
format!("{base}/chat/completions")
|
||||
} else {
|
||||
format!("{base}/v1/chat/completions")
|
||||
}
|
||||
} else if base.ends_with("/v1") {
|
||||
format!("{base}/messages?beta=true")
|
||||
} else {
|
||||
format!("{base}/v1/messages?beta=true")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -794,4 +846,37 @@ mod tests {
|
||||
assert_eq!(claude_auth, AuthStrategy::ClaudeAuth);
|
||||
assert_eq!(bearer, AuthStrategy::Bearer);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_claude_stream_url_for_github_copilot() {
|
||||
let url = StreamCheckService::resolve_claude_stream_url(
|
||||
"https://api.githubcopilot.com",
|
||||
AuthStrategy::GitHubCopilot,
|
||||
"anthropic",
|
||||
);
|
||||
|
||||
assert_eq!(url, "https://api.githubcopilot.com/chat/completions");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_claude_stream_url_for_openai_chat() {
|
||||
let url = StreamCheckService::resolve_claude_stream_url(
|
||||
"https://example.com/v1",
|
||||
AuthStrategy::Bearer,
|
||||
"openai_chat",
|
||||
);
|
||||
|
||||
assert_eq!(url, "https://example.com/v1/chat/completions");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_claude_stream_url_for_anthropic() {
|
||||
let url = StreamCheckService::resolve_claude_stream_url(
|
||||
"https://api.anthropic.com",
|
||||
AuthStrategy::Anthropic,
|
||||
"anthropic",
|
||||
);
|
||||
|
||||
assert_eq!(url, "https://api.anthropic.com/v1/messages?beta=true");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,9 +181,6 @@ pub struct AppSettings {
|
||||
/// 是否跳过 Claude Code 初次安装确认
|
||||
#[serde(default)]
|
||||
pub skip_claude_onboarding: bool,
|
||||
/// 是否解除 Tool Search 域名限制
|
||||
#[serde(default)]
|
||||
pub tool_search_bypass: bool,
|
||||
/// 是否开机自启
|
||||
#[serde(default)]
|
||||
pub launch_on_startup: bool,
|
||||
@@ -289,7 +286,6 @@ impl Default for AppSettings {
|
||||
minimize_to_tray_on_close: true,
|
||||
enable_claude_plugin_integration: false,
|
||||
skip_claude_onboarding: false,
|
||||
tool_search_bypass: false,
|
||||
launch_on_startup: false,
|
||||
silent_startup: false,
|
||||
enable_local_proxy: false,
|
||||
|
||||
@@ -1,569 +0,0 @@
|
||||
//! Tool Search domain restriction bypass patch for Claude Code.
|
||||
//!
|
||||
//! Resolves the current active `claude` command from PATH and patches the
|
||||
//! domain whitelist check
|
||||
//! `return["api.anthropic.com"].includes(x)}catch{return!1}`
|
||||
//! to always return true via equal-length byte replacement.
|
||||
|
||||
use regex::bytes::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
use crate::error::AppError;
|
||||
|
||||
const BACKUP_SUFFIX: &str = ".toolsearch-bak";
|
||||
|
||||
/// Encode bytes as lowercase hex string (avoids adding `hex` crate dependency).
|
||||
fn to_hex(bytes: &[u8]) -> String {
|
||||
bytes.iter().map(|b| format!("{b:02x}")).collect()
|
||||
}
|
||||
|
||||
// Regex matching the domain whitelist check with any JS identifier as variable name
|
||||
const PATCH_TARGET_PATTERN: &str =
|
||||
r#"return\["api\.anthropic\.com"\]\.includes\([A-Za-z_$][A-Za-z0-9_$]*\)\}catch\{return!1\}"#;
|
||||
|
||||
// Regex matching already-patched code
|
||||
const PATCHED_PATTERN: &str = r#"return!0/\* *\*/\}catch\{return!0\}"#;
|
||||
|
||||
/// Single Claude Code installation info
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClaudeInstallation {
|
||||
pub path: String,
|
||||
pub source: String,
|
||||
pub patched: bool,
|
||||
pub has_backup: bool,
|
||||
}
|
||||
|
||||
/// Result of a patch/restore operation on one installation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PatchResult {
|
||||
pub path: String,
|
||||
pub success: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// Overall Tool Search patch status
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolSearchStatus {
|
||||
pub installations: Vec<ClaudeInstallation>,
|
||||
pub all_patched: bool,
|
||||
pub any_found: bool,
|
||||
}
|
||||
|
||||
// ── Patch status detection ──────────────────────────────────────────
|
||||
|
||||
fn get_patch_status(data: &[u8]) -> &'static str {
|
||||
let target_re = Regex::new(PATCH_TARGET_PATTERN).unwrap();
|
||||
let patched_re = Regex::new(PATCHED_PATTERN).unwrap();
|
||||
if target_re.is_match(data) {
|
||||
"unpatched"
|
||||
} else if patched_re.is_match(data) {
|
||||
"patched"
|
||||
} else {
|
||||
"unknown"
|
||||
}
|
||||
}
|
||||
|
||||
/// Build equal-length replacement bytes: `return!0/* */}catch{return!0}`
|
||||
fn build_patched_bytes(original_len: usize) -> Result<Vec<u8>, AppError> {
|
||||
let prefix = b"return!0/*";
|
||||
let suffix = b"*/}catch{return!0}";
|
||||
let padding = original_len
|
||||
.checked_sub(prefix.len() + suffix.len())
|
||||
.ok_or_else(|| AppError::Config("Patch template too long for match".into()))?;
|
||||
let mut out = Vec::with_capacity(original_len);
|
||||
out.extend_from_slice(prefix);
|
||||
out.extend(std::iter::repeat_n(b' ', padding));
|
||||
out.extend_from_slice(suffix);
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
/// Apply byte-level patch to file data, returns (patched_data, replacement_count)
|
||||
fn patch_bytes(data: &[u8]) -> Result<(Vec<u8>, usize), AppError> {
|
||||
let re = Regex::new(PATCH_TARGET_PATTERN).unwrap();
|
||||
let mut count = 0usize;
|
||||
let result = re.replace_all(data, |caps: ®ex::bytes::Captures| {
|
||||
count += 1;
|
||||
build_patched_bytes(caps[0].len()).unwrap_or_else(|_| caps[0].to_vec())
|
||||
});
|
||||
Ok((result.into_owned(), count))
|
||||
}
|
||||
|
||||
// ── Installation detection ──────────────────────────────────────────
|
||||
|
||||
/// Run a command and return stdout, or empty string on failure
|
||||
fn run_cmd(cmd: &str, args: &[&str]) -> String {
|
||||
Command::new(cmd)
|
||||
.args(args)
|
||||
.output()
|
||||
.ok()
|
||||
.filter(|o| o.status.success())
|
||||
.and_then(|o| String::from_utf8(o.stdout).ok())
|
||||
.map(|s| s.trim().to_string())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Search a package directory for JS files containing the domain check
|
||||
fn find_patch_target_in_pkg(pkg_dir: &Path) -> Option<PathBuf> {
|
||||
let marker = b"api.anthropic.com";
|
||||
// Check cli.js first (most common)
|
||||
let cli_js = pkg_dir.join("cli.js");
|
||||
if cli_js.is_file() {
|
||||
if let Ok(data) = std::fs::read(&cli_js) {
|
||||
if data.windows(marker.len()).any(|w| w == marker) {
|
||||
return Some(cli_js);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Search other JS files
|
||||
find_js_with_marker(pkg_dir)
|
||||
}
|
||||
|
||||
fn find_js_with_marker(dir: &Path) -> Option<PathBuf> {
|
||||
let marker = b"api.anthropic.com";
|
||||
let entries = match std::fs::read_dir(dir) {
|
||||
Ok(e) => e,
|
||||
Err(_) => return None,
|
||||
};
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
if let Some(found) = find_js_with_marker(&path) {
|
||||
return Some(found);
|
||||
}
|
||||
} else if path.extension().and_then(|e| e.to_str()) == Some("js") {
|
||||
if let Ok(meta) = path.metadata() {
|
||||
if meta.len() < 1000 {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if let Ok(data) = std::fs::read(&path) {
|
||||
if data.windows(marker.len()).any(|w| w == marker) {
|
||||
return Some(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Resolve symlinks to actual file path
|
||||
fn resolve_target(path: &Path) -> PathBuf {
|
||||
std::fs::canonicalize(path).unwrap_or_else(|_| path.to_path_buf())
|
||||
}
|
||||
|
||||
fn get_patch_status_for_path(path: &Path) -> Option<&'static str> {
|
||||
let data = std::fs::read(path).ok()?;
|
||||
Some(get_patch_status(&data))
|
||||
}
|
||||
|
||||
fn package_dir_from_ancestors(path: &Path) -> Option<PathBuf> {
|
||||
for ancestor in path.ancestors() {
|
||||
if ancestor.file_name().and_then(|v| v.to_str()) == Some("claude-code")
|
||||
&& ancestor
|
||||
.parent()
|
||||
.and_then(|v| v.file_name())
|
||||
.and_then(|v| v.to_str())
|
||||
== Some("@anthropic-ai")
|
||||
{
|
||||
return Some(ancestor.to_path_buf());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn push_candidate_package_dir(
|
||||
candidates: &mut Vec<PathBuf>,
|
||||
seen: &mut std::collections::HashSet<PathBuf>,
|
||||
path: PathBuf,
|
||||
) {
|
||||
if path.is_dir() && seen.insert(path.clone()) {
|
||||
candidates.push(path);
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_active_patch_target(command_path: &Path) -> Option<PathBuf> {
|
||||
let resolved_command = resolve_target(command_path);
|
||||
if matches!(
|
||||
get_patch_status_for_path(&resolved_command),
|
||||
Some("patched" | "unpatched")
|
||||
) {
|
||||
return Some(resolved_command);
|
||||
}
|
||||
|
||||
let mut candidates = Vec::new();
|
||||
let mut seen = std::collections::HashSet::new();
|
||||
|
||||
for path in [command_path, resolved_command.as_path()] {
|
||||
if let Some(pkg_dir) = package_dir_from_ancestors(path) {
|
||||
push_candidate_package_dir(&mut candidates, &mut seen, pkg_dir);
|
||||
}
|
||||
|
||||
if let Some(bin_dir) = path.parent() {
|
||||
if let Some(prefix) = bin_dir.parent() {
|
||||
push_candidate_package_dir(
|
||||
&mut candidates,
|
||||
&mut seen,
|
||||
prefix
|
||||
.join("lib")
|
||||
.join("node_modules")
|
||||
.join("@anthropic-ai")
|
||||
.join("claude-code"),
|
||||
);
|
||||
push_candidate_package_dir(
|
||||
&mut candidates,
|
||||
&mut seen,
|
||||
prefix
|
||||
.join("node_modules")
|
||||
.join("@anthropic-ai")
|
||||
.join("claude-code"),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
candidates
|
||||
.into_iter()
|
||||
.find_map(|pkg_dir| find_patch_target_in_pkg(&pkg_dir).map(|p| resolve_target(&p)))
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn find_active_command_path() -> Option<PathBuf> {
|
||||
run_cmd("where.exe", &["claude"])
|
||||
.lines()
|
||||
.next()
|
||||
.map(PathBuf::from)
|
||||
.filter(|path| path.is_file())
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
fn find_active_command_path() -> Option<PathBuf> {
|
||||
run_cmd("which", &["claude"])
|
||||
.lines()
|
||||
.next()
|
||||
.map(PathBuf::from)
|
||||
.filter(|path| path.is_file())
|
||||
}
|
||||
|
||||
fn find_active_installation() -> Option<(PathBuf, String)> {
|
||||
let command_path = find_active_command_path()?;
|
||||
let patch_target = resolve_active_patch_target(&command_path)?;
|
||||
Some((
|
||||
patch_target,
|
||||
format!("active claude ({})", command_path.display()),
|
||||
))
|
||||
}
|
||||
|
||||
fn require_active_installation() -> Result<(PathBuf, String), AppError> {
|
||||
find_active_installation()
|
||||
.ok_or_else(|| AppError::Config("No active Claude Code installation found in PATH".into()))
|
||||
}
|
||||
|
||||
// ── macOS codesign ──────────────────────────────────────────────────
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
fn codesign_adhoc(path: &Path) -> Result<(), AppError> {
|
||||
let output = Command::new("codesign")
|
||||
.args(["--force", "--sign", "-"])
|
||||
.arg(path)
|
||||
.output()
|
||||
.map_err(|e| AppError::IoContext {
|
||||
context: format!("Failed to run codesign for {}", path.display()),
|
||||
source: e,
|
||||
})?;
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
return Err(AppError::Config(format!(
|
||||
"codesign failed for {}: {}",
|
||||
path.display(),
|
||||
stderr.trim()
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
fn codesign_adhoc(_path: &Path) -> Result<(), AppError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Backup directory helpers ─────────────────────────────────────────
|
||||
|
||||
/// Get the centralized backup directory: `~/.cc-switch/toolsearch-backups/`
|
||||
fn get_backup_dir() -> Result<PathBuf, AppError> {
|
||||
let dir = crate::config::get_app_config_dir().join("toolsearch-backups");
|
||||
if !dir.exists() {
|
||||
std::fs::create_dir_all(&dir).map_err(|e| AppError::io(&dir, e))?;
|
||||
}
|
||||
Ok(dir)
|
||||
}
|
||||
|
||||
/// Derive a stable backup filename from the original path using SHA-256.
|
||||
fn backup_name_for(path: &Path) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(path.to_string_lossy().as_bytes());
|
||||
to_hex(&hasher.finalize())
|
||||
}
|
||||
|
||||
/// Get the backup file path for a given target file.
|
||||
fn get_backup_path(path: &Path) -> Result<PathBuf, AppError> {
|
||||
let dir = get_backup_dir()?;
|
||||
let name = backup_name_for(path);
|
||||
Ok(dir.join(format!("{name}.bak")))
|
||||
}
|
||||
|
||||
/// Get the metadata file path (records original path for debugging).
|
||||
fn get_meta_path(path: &Path) -> Result<PathBuf, AppError> {
|
||||
let dir = get_backup_dir()?;
|
||||
let name = backup_name_for(path);
|
||||
Ok(dir.join(format!("{name}.meta")))
|
||||
}
|
||||
|
||||
// ── Patch / Restore single file ─────────────────────────────────────
|
||||
|
||||
fn patch_single_file(path: &Path) -> Result<(), AppError> {
|
||||
let data = std::fs::read(path).map_err(|e| AppError::io(path, e))?;
|
||||
let status = get_patch_status(&data);
|
||||
|
||||
if status == "patched" {
|
||||
return Ok(()); // Already patched
|
||||
}
|
||||
if status == "unknown" {
|
||||
return Err(AppError::Config(format!(
|
||||
"Target pattern not found in {}, possibly incompatible version",
|
||||
path.display()
|
||||
)));
|
||||
}
|
||||
|
||||
let (patched_data, count) = patch_bytes(&data)?;
|
||||
if count == 0 {
|
||||
return Err(AppError::Config(format!(
|
||||
"No replacements made in {}",
|
||||
path.display()
|
||||
)));
|
||||
}
|
||||
|
||||
// Create backup in centralized directory
|
||||
let backup_path = get_backup_path(path)?;
|
||||
std::fs::copy(path, &backup_path).map_err(|e| AppError::io(&backup_path, e))?;
|
||||
|
||||
// Write metadata file for debugging
|
||||
let meta_path = get_meta_path(path)?;
|
||||
let _ = std::fs::write(&meta_path, path.to_string_lossy().as_bytes());
|
||||
|
||||
// Write patched data
|
||||
if let Err(e) = std::fs::write(path, &patched_data) {
|
||||
// Try rename trick on Windows
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
if let Ok(()) = write_via_rename(path, &patched_data) {
|
||||
codesign_adhoc(path)?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
return Err(AppError::io(path, e));
|
||||
}
|
||||
|
||||
codesign_adhoc(path)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn write_via_rename(target: &Path, data: &[u8]) -> Result<(), AppError> {
|
||||
let tmp_path = target.with_extension("tmp");
|
||||
let old_path = target.with_extension("old");
|
||||
|
||||
let _ = std::fs::remove_file(&tmp_path);
|
||||
let _ = std::fs::remove_file(&old_path);
|
||||
|
||||
std::fs::write(&tmp_path, data).map_err(|e| AppError::io(&tmp_path, e))?;
|
||||
std::fs::rename(target, &old_path).map_err(|e| AppError::io(target, e))?;
|
||||
std::fs::rename(&tmp_path, target).map_err(|e| AppError::io(target, e))?;
|
||||
let _ = std::fs::remove_file(&old_path);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn restore_single_file(path: &Path) -> Result<(), AppError> {
|
||||
let current_status = get_patch_status_for_path(path);
|
||||
if matches!(current_status, Some("unpatched")) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Try centralized backup directory first
|
||||
let backup_path = get_backup_path(path)?;
|
||||
// Fallback: legacy backup path (adjacent `.toolsearch-bak` file)
|
||||
let legacy_backup = PathBuf::from(format!("{}{}", path.display(), BACKUP_SUFFIX));
|
||||
|
||||
let actual_backup = if backup_path.is_file() {
|
||||
&backup_path
|
||||
} else if legacy_backup.is_file() {
|
||||
&legacy_backup
|
||||
} else {
|
||||
return Err(AppError::Config(format!(
|
||||
"No backup found for {}",
|
||||
path.display()
|
||||
)));
|
||||
};
|
||||
|
||||
let backup_data = std::fs::read(actual_backup).map_err(|e| AppError::io(actual_backup, e))?;
|
||||
|
||||
if let Err(e) = std::fs::write(path, &backup_data) {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
if let Ok(()) = write_via_rename(path, &backup_data) {
|
||||
codesign_adhoc(path)?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
return Err(AppError::io(path, e));
|
||||
}
|
||||
|
||||
codesign_adhoc(path)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Public API ──────────────────────────────────────────────────────
|
||||
|
||||
/// Check Tool Search patch status for the current active Claude Code installation.
|
||||
pub fn check_toolsearch_status() -> Result<ToolSearchStatus, AppError> {
|
||||
let installations = find_active_installation()
|
||||
.into_iter()
|
||||
.map(|(path, source)| {
|
||||
let data = std::fs::read(&path).unwrap_or_default();
|
||||
let status = get_patch_status(&data);
|
||||
// Check centralized backup first, then legacy
|
||||
let has_backup = get_backup_path(&path).map(|p| p.is_file()).unwrap_or(false)
|
||||
|| PathBuf::from(format!("{}{}", path.display(), BACKUP_SUFFIX)).is_file();
|
||||
ClaudeInstallation {
|
||||
path: path.to_string_lossy().to_string(),
|
||||
source: source.clone(),
|
||||
patched: status == "patched",
|
||||
has_backup,
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let any_found = !installations.is_empty();
|
||||
let all_patched = any_found && installations.iter().all(|i| i.patched);
|
||||
|
||||
Ok(ToolSearchStatus {
|
||||
installations,
|
||||
all_patched,
|
||||
any_found,
|
||||
})
|
||||
}
|
||||
|
||||
/// Apply the Tool Search patch to the current active Claude Code installation.
|
||||
pub fn apply_toolsearch_patch() -> Result<Vec<PatchResult>, AppError> {
|
||||
let (path, _) = require_active_installation()?;
|
||||
Ok(vec![match patch_single_file(&path) {
|
||||
Ok(()) => PatchResult {
|
||||
path: path.to_string_lossy().to_string(),
|
||||
success: true,
|
||||
error: None,
|
||||
},
|
||||
Err(e) => PatchResult {
|
||||
path: path.to_string_lossy().to_string(),
|
||||
success: false,
|
||||
error: Some(e.to_string()),
|
||||
},
|
||||
}])
|
||||
}
|
||||
|
||||
/// Restore the current active Claude Code installation from backup.
|
||||
pub fn restore_toolsearch_patch() -> Result<Vec<PatchResult>, AppError> {
|
||||
let (path, _) = require_active_installation()?;
|
||||
Ok(vec![match restore_single_file(&path) {
|
||||
Ok(()) => PatchResult {
|
||||
path: path.to_string_lossy().to_string(),
|
||||
success: true,
|
||||
error: None,
|
||||
},
|
||||
Err(e) => PatchResult {
|
||||
path: path.to_string_lossy().to_string(),
|
||||
success: false,
|
||||
error: Some(e.to_string()),
|
||||
},
|
||||
}])
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_patch_bytes_replaces_correctly() {
|
||||
let input = br#"return["api.anthropic.com"].includes(x)}catch{return!1}"#;
|
||||
let (patched, count) = patch_bytes(input).unwrap();
|
||||
assert_eq!(count, 1);
|
||||
assert_eq!(patched.len(), input.len());
|
||||
assert!(patched.starts_with(b"return!0/*"));
|
||||
assert!(patched.ends_with(b"*/}catch{return!0}"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_patch_status_detection() {
|
||||
let unpatched = br#"return["api.anthropic.com"].includes(x)}catch{return!1}"#;
|
||||
assert_eq!(get_patch_status(unpatched), "unpatched");
|
||||
|
||||
let (patched, _) = patch_bytes(unpatched).unwrap();
|
||||
assert_eq!(get_patch_status(&patched), "patched");
|
||||
|
||||
assert_eq!(get_patch_status(b"some random data"), "unknown");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_patched_bytes_length() {
|
||||
for len in 50..70 {
|
||||
let result = build_patched_bytes(len).unwrap();
|
||||
assert_eq!(result.len(), len);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_active_patch_target_from_npm_style_bin_path() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let prefix = tmp.path().join("prefix");
|
||||
let bin_dir = prefix.join("bin");
|
||||
let pkg_dir = prefix
|
||||
.join("lib")
|
||||
.join("node_modules")
|
||||
.join("@anthropic-ai")
|
||||
.join("claude-code");
|
||||
std::fs::create_dir_all(&bin_dir).unwrap();
|
||||
std::fs::create_dir_all(&pkg_dir).unwrap();
|
||||
|
||||
let command_path = bin_dir.join("claude");
|
||||
std::fs::write(&command_path, b"#!/usr/bin/env node\n").unwrap();
|
||||
let cli_path = pkg_dir.join("cli.js");
|
||||
std::fs::write(
|
||||
&cli_path,
|
||||
br#"return["api.anthropic.com"].includes(x)}catch{return!1}"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
resolve_active_patch_target(&command_path),
|
||||
Some(resolve_target(&cli_path))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_restore_single_file_is_noop_when_current_target_is_unpatched() {
|
||||
let tmp = tempfile::NamedTempFile::new().unwrap();
|
||||
std::fs::write(
|
||||
tmp.path(),
|
||||
br#"return["api.anthropic.com"].includes(x)}catch{return!1}"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(restore_single_file(tmp.path()).is_ok());
|
||||
}
|
||||
}
|
||||
@@ -26,7 +26,7 @@
|
||||
}
|
||||
],
|
||||
"security": {
|
||||
"csp": "default-src 'self'; img-src 'self' data:; script-src 'self'; style-src 'self' 'unsafe-inline'; connect-src 'self' ipc: http://ipc.localhost https: http:",
|
||||
"csp": "default-src 'self'; img-src 'self' data: https: http:; script-src 'self'; style-src 'self' 'unsafe-inline'; connect-src 'self' ipc: http://ipc.localhost https: http:",
|
||||
"assetProtocol": {
|
||||
"enable": true,
|
||||
"scope": []
|
||||
|
||||
Reference in New Issue
Block a user