feat(copilot): add GitHub Copilot reverse proxy support (#930)

* refactor(toolsearch): replace binary patch with ENABLE_TOOL_SEARCH env var toggle

- Remove toolsearch_patch.rs binary patching mechanism (~590 lines)
  - Delete `toolsearch_patch.rs` and `commands/toolsearch.rs`
  - Remove auto-patch startup logic and command registration from lib.rs
  - Remove `tool_search_bypass` field from settings.rs
  - Remove frontend settings ToggleRow, useSettings hook sync logic, and API methods
  - Clean up zh/en/ja i18n keys (notifications + settings)

- Add ENABLE_TOOL_SEARCH toggle to Claude provider form
  - Add checkbox in CommonConfigEditor.tsx (alongside teammates toggle)
  - When enabled, writes `"env": { "ENABLE_TOOL_SEARCH": "true" }`
  - When disabled, removes the key; takes effect on provider switch
  - Add zh/en/ja i18n key: `claudeConfig.enableToolSearch`

Claude Code 2.1.76+ natively supports this env var, eliminating the need for binary patching.

* feat(claude): add effortLevel high toggle to provider form

- Add "high-effort thinking" checkbox to Claude provider config form
- When checked, writes `"effortLevel": "high"`; when unchecked, removes the field
- Add zh/en/ja i18n translations

* refactor(claude): remove deprecated alwaysThinking toggle

- Claude Code now enables extended thinking by default; alwaysThinkingEnabled is a no-op
- Thinking control is now handled via effortLevel (added in prior commit)
- Remove state, switch case, and checkbox UI from CommonConfigEditor
- Clean up alwaysThinking i18n keys across zh/en/ja locales

* feat(opencode): add setCacheKey: true to all provider presets

- Add setCacheKey: true to options in all 33 regular presets
- Add setCacheKey: true to OPENCODE_DEFAULT_CONFIG for custom providers
- Exclude 2 OMO presets (Oh My OpenCode / Slim) which have their own config mechanism

Closes #1523

* fix(codex): resolve 1M context window toggle causing MCP editor flicker

- Add localValueRef to short-circuit duplicate CodeMirror updateListener callbacks,
  breaking the React state → CodeMirror → stale onChange → React state feedback loop
- Use localValueRef.current in handleContextWindowToggle and handleCompactLimitChange
  to avoid stale closure reads
- Change compact limit input from type="number" to type="text" with inputMode="numeric"
  to remove unnecessary spinner buttons

* feat(codex): add 1M context window toggle utilities and i18n keys

- Add extractCodexTopLevelInt, setCodexTopLevelInt, removeCodexTopLevelField
  TOML helpers in providerConfigUtils.ts
- Add i18n keys for contextWindow1M, autoCompactLimit in zh/en/ja locales

* feat(claude): collapse model mapping fields by default

- Wrap 5 model mapping inputs in a Collapsible, collapsed by default
- Auto-expand when any model value is present (including preset-filled)
- Show hint text when collapsed explaining most users need no config
- Add zh/en/ja i18n keys for toggle label and collapsed hint
- Use variant={null} to avoid ghost button hover style clash in dark mode

* feat(claude): merge advanced fields into single collapsible section

- Merge API format, auth field, and model mapping into a unified "Advanced Options" collapsible
- Extend smart-expand logic to detect non-default values across all advanced fields
- Preserve model mapping sub-header and hint with a separator line
- Update zh/en/ja i18n keys (advancedOptionsToggle, advancedOptionsHint, modelMappingLabel, modelMappingHint)

* feat(copilot): add GitHub Copilot reverse proxy support

Add GitHub Copilot as a Claude provider variant with OAuth device code
authentication and Anthropic ↔ OpenAI format transformation.

Backend:
- Add CopilotAuthManager for GitHub OAuth device code flow
- Implement Copilot token auto-refresh (60s before expiry)
- Persist GitHub token to ~/.cc-switch/copilot_auth.json
- Add ProviderType::GitHubCopilot and AuthStrategy::GitHubCopilot
- Modify forwarder to use /chat/completions for Copilot
- Add Copilot-specific headers (Editor-Version, Editor-Plugin-Version)

Frontend:
- Add CopilotAuthSection component for OAuth UI
- Add useCopilotAuth hook for OAuth state management
- Auto-copy user code to clipboard and open browser
- Use 8-second polling interval to avoid GitHub rate limits
- Skip API Key validation for Copilot providers
- Add GitHub Copilot preset with claude-sonnet-4 model

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>

* fix(copilot): remove is_expired() calls from tests

Remove references to deleted is_expired() method in test code.
Only is_expiring_soon() is needed for token refresh logic.

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>

* feat(copilot): add real-time model listing from Copilot API

- Add fetch_models() to CopilotAuthManager calling GET /models endpoint
- Add copilot_get_models Tauri command
- Add copilotGetModels() frontend API wrapper
- Modify ClaudeFormFields to show model dropdown for Copilot providers
  - Fetches available models on component mount when isCopilotPreset
  - Groups models by vendor (Anthropic, OpenAI, Google, etc.)
  - Input + dropdown button combo allows both manual entry and selection
  - Non-Copilot providers keep original plain Input behavior

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* feat(copilot): add usage query integration

- Add Copilot usage API integration (fetch_usage method)
- Add copilot_get_usage Tauri command
- Add GitHub Copilot template in usage query modal
- Unify naming: copilot → github_copilot
- Add constants management (TEMPLATE_TYPES, PROVIDER_TYPES)
- Improve error handling with detailed error messages
- Add database migration (v5 → v6) for template type update
- Add i18n translations (zh, en, ja)
- Improve type safety with TemplateType
- Apply code formatting (cargo fmt, prettier)

* 修复github 登录和注销问题 ,模型选择问题

* feat(copilot): add multi-account support for GitHub Copilot

- Add multi-account storage structure with v1 to v2 migration
- Add per-account token caching and auto-refresh
- Add new Tauri commands for account management
- Integrate account selection in Proxy forwarder
- Add account selection UI in CopilotAuthSection
- Save githubAccountId to ProviderMeta
- Add i18n translations for multi-account features (zh/en/ja)

* 修复用量查询Reset字段出现多余字符

* refactor(auth-binding): introduce generic provider auth binding primitives

- add shared authBinding types in Rust and TypeScript while keeping githubAccountId as a compatibility field\n- resolve Copilot token, models, and usage through provider-bound account lookup instead of only the implicit default account\n- fix the Unix build regression in settings.rs by restoring std::io::Write for write_all()\n- remove the accidental .github ignore entry and drop leftover Copilot form debug logs\n- keep the first migration step non-breaking by writing both authBinding and the legacy githubAccountId field from the form

* refactor(auth-service): add managed auth command surface and explicit default account state

- introduce generic managed auth commands and frontend auth API wrappers for provider-scoped login, status, account listing, removal, logout, and default-account selection\n- store an explicit Copilot default_account_id instead of relying on HashMap iteration order, and use it consistently for fallback token/model/usage resolution\n- sort managed accounts deterministically and surface default-account state to the UI\n- refactor the Copilot form hook to wrap a generic useManagedAuth implementation while preserving the existing component contract\n- add default-account controls to the Copilot auth section and extend Copilot auth status serialization/tests for the new state

* feat(auth-center): add a dedicated settings entrypoint for managed OAuth accounts

- add an Auth Center tab to Settings so managed OAuth accounts are no longer hidden inside individual provider forms\n- introduce a first AuthCenterPanel that hosts GitHub Copilot account management as the initial managed auth provider\n- keep the provider form experience intact while establishing a global account-management surface for future providers such as OpenAI\n- validate that the new settings tab works cleanly with the generic managed auth hook and existing Copilot account controls

* feat(add-provider): expose managed OAuth sources alongside universal providers

- add an OAuth tab to the Add Provider flow so managed auth sources sit beside app-specific and universal providers\n- reuse the new Auth Center panel inside the dialog, keeping account management discoverable during provider creation\n- make the dialog footer adapt to the OAuth tab so account setup does not pretend to create a provider directly\n- align the add-provider UX with the new architecture where OAuth accounts are global assets and providers bind to them later

* fix(auth-reliability): harden managed auth persistence and refresh behavior

- replace direct Copilot auth store writes with private temp-file writes and atomic rename semantics, and document the local token storage limitation\n- add per-account refresh locks plus a double-check path so concurrent requests do not stampede GitHub token refresh\n- surface legacy migration failures through auth status, expose them in the UI, and add translated copy for the new account-state labels\n- stop writing the legacy githubAccountId field from the provider form while keeping compatibility reads in place\n- add logout error recovery and Copilot model-load toasts so auth failures are no longer silently swallowed

* refactor(copilot-detection): prefer provider type before URL fallbacks

- update forwarder endpoint rewriting to treat providerType as the primary GitHub Copilot signal\n- keep githubcopilot.com string matching only as a compatibility fallback for older provider records without providerType\n- reduce one more path where Copilot behavior depended purely on URL heuristics

* fix(copilot-auth): add cancel button to error state in CopilotAuthSection

- 错误状态下仅有"重试"按钮,用户无法退出(如不可恢复的 403 未订阅错误)
- 新增"取消"按钮,复用已有的 cancelAuth 逻辑重置为 idle 状态

* 修复打包后github账号头像显示异常

* 修复github copilot 来源的模型测试报错

* feat(copilot-preset): add default model presets for GitHub Copilot

- 补充 Copilot 预设的默认模型配置,用户选完预设即可直接使用
- ANTHROPIC_MODEL: claude-opus-4.6
- ANTHROPIC_DEFAULT_HAIKU_MODEL: claude-haiku-4.5
- ANTHROPIC_DEFAULT_SONNET_MODEL: claude-sonnet-4.6
- ANTHROPIC_DEFAULT_OPUS_MODEL: claude-opus-4.6

---------

Co-authored-by: Jason <farion1231@gmail.com>
Co-authored-by: 周梦泽 <mengze.zhou@dafeng-tech.com>
Co-authored-by: Claude Haiku 4.5 <noreply@anthropic.com>
This commit is contained in:
Zhou Mengze
2026-03-17 23:57:58 +08:00
committed by GitHub
parent 36bbdc36f5
commit 8ccfbd36d6
50 changed files with 4555 additions and 1062 deletions
+182
View File
@@ -0,0 +1,182 @@
use tauri::State;
use crate::commands::copilot::CopilotAuthState;
use crate::proxy::providers::copilot_auth::{GitHubAccount, GitHubDeviceCodeResponse};
const AUTH_PROVIDER_GITHUB_COPILOT: &str = "github_copilot";
#[derive(Debug, Clone, serde::Serialize)]
pub struct ManagedAuthAccount {
pub id: String,
pub provider: String,
pub login: String,
pub avatar_url: Option<String>,
pub authenticated_at: i64,
pub is_default: bool,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ManagedAuthStatus {
pub provider: String,
pub authenticated: bool,
pub default_account_id: Option<String>,
pub migration_error: Option<String>,
pub accounts: Vec<ManagedAuthAccount>,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ManagedAuthDeviceCodeResponse {
pub provider: String,
pub device_code: String,
pub user_code: String,
pub verification_uri: String,
pub expires_in: u64,
pub interval: u64,
}
fn ensure_auth_provider(auth_provider: &str) -> Result<&str, String> {
match auth_provider {
AUTH_PROVIDER_GITHUB_COPILOT => Ok(AUTH_PROVIDER_GITHUB_COPILOT),
_ => Err(format!("Unsupported auth provider: {auth_provider}")),
}
}
fn map_account(
provider: &str,
account: GitHubAccount,
default_account_id: Option<&str>,
) -> ManagedAuthAccount {
ManagedAuthAccount {
is_default: default_account_id == Some(account.id.as_str()),
id: account.id,
provider: provider.to_string(),
login: account.login,
avatar_url: account.avatar_url,
authenticated_at: account.authenticated_at,
}
}
fn map_device_code_response(
provider: &str,
response: GitHubDeviceCodeResponse,
) -> ManagedAuthDeviceCodeResponse {
ManagedAuthDeviceCodeResponse {
provider: provider.to_string(),
device_code: response.device_code,
user_code: response.user_code,
verification_uri: response.verification_uri,
expires_in: response.expires_in,
interval: response.interval,
}
}
#[tauri::command(rename_all = "camelCase")]
pub async fn auth_start_login(
auth_provider: String,
state: State<'_, CopilotAuthState>,
) -> Result<ManagedAuthDeviceCodeResponse, String> {
let auth_provider = ensure_auth_provider(&auth_provider)?;
let auth_manager = state.0.read().await;
let response = auth_manager
.start_device_flow()
.await
.map_err(|e| e.to_string())?;
Ok(map_device_code_response(auth_provider, response))
}
#[tauri::command(rename_all = "camelCase")]
pub async fn auth_poll_for_account(
auth_provider: String,
device_code: String,
state: State<'_, CopilotAuthState>,
) -> Result<Option<ManagedAuthAccount>, String> {
let auth_provider = ensure_auth_provider(&auth_provider)?;
let auth_manager = state.0.write().await;
match auth_manager.poll_for_token(&device_code).await {
Ok(account) => {
let default_account_id = auth_manager.get_status().await.default_account_id;
Ok(account
.map(|account| map_account(auth_provider, account, default_account_id.as_deref())))
}
Err(crate::proxy::providers::copilot_auth::CopilotAuthError::AuthorizationPending) => {
Ok(None)
}
Err(e) => Err(e.to_string()),
}
}
#[tauri::command(rename_all = "camelCase")]
pub async fn auth_list_accounts(
auth_provider: String,
state: State<'_, CopilotAuthState>,
) -> Result<Vec<ManagedAuthAccount>, String> {
let auth_provider = ensure_auth_provider(&auth_provider)?;
let auth_manager = state.0.read().await;
let status = auth_manager.get_status().await;
let default_account_id = status.default_account_id.clone();
Ok(status
.accounts
.into_iter()
.map(|account| map_account(auth_provider, account, default_account_id.as_deref()))
.collect())
}
#[tauri::command(rename_all = "camelCase")]
pub async fn auth_get_status(
auth_provider: String,
state: State<'_, CopilotAuthState>,
) -> Result<ManagedAuthStatus, String> {
let auth_provider = ensure_auth_provider(&auth_provider)?;
let auth_manager = state.0.read().await;
let status = auth_manager.get_status().await;
let default_account_id = status.default_account_id.clone();
Ok(ManagedAuthStatus {
provider: auth_provider.to_string(),
authenticated: status.authenticated,
default_account_id: default_account_id.clone(),
migration_error: status.migration_error,
accounts: status
.accounts
.into_iter()
.map(|account| map_account(auth_provider, account, default_account_id.as_deref()))
.collect(),
})
}
#[tauri::command(rename_all = "camelCase")]
pub async fn auth_remove_account(
auth_provider: String,
account_id: String,
state: State<'_, CopilotAuthState>,
) -> Result<(), String> {
ensure_auth_provider(&auth_provider)?;
let auth_manager = state.0.write().await;
auth_manager
.remove_account(&account_id)
.await
.map_err(|e| e.to_string())
}
#[tauri::command(rename_all = "camelCase")]
pub async fn auth_set_default_account(
auth_provider: String,
account_id: String,
state: State<'_, CopilotAuthState>,
) -> Result<(), String> {
ensure_auth_provider(&auth_provider)?;
let auth_manager = state.0.write().await;
auth_manager
.set_default_account(&account_id)
.await
.map_err(|e| e.to_string())
}
#[tauri::command(rename_all = "camelCase")]
pub async fn auth_logout(
auth_provider: String,
state: State<'_, CopilotAuthState>,
) -> Result<(), String> {
ensure_auth_provider(&auth_provider)?;
let auth_manager = state.0.write().await;
auth_manager.clear_auth().await.map_err(|e| e.to_string())
}
+212
View File
@@ -0,0 +1,212 @@
//! GitHub Copilot Tauri Commands
//!
//! 提供 Copilot OAuth 认证相关的 Tauri 命令,支持多账号管理。
use crate::proxy::providers::copilot_auth::{
CopilotAuthManager, CopilotAuthStatus, CopilotModel, CopilotUsageResponse, GitHubAccount,
GitHubDeviceCodeResponse,
};
use std::sync::Arc;
use tauri::State;
use tokio::sync::RwLock;
/// Copilot 认证状态
pub struct CopilotAuthState(pub Arc<RwLock<CopilotAuthManager>>);
// ==================== 设备码流程 ====================
/// 启动设备码流程
///
/// 返回设备码和用户码,用于 OAuth 认证
#[tauri::command]
pub async fn copilot_start_device_flow(
state: State<'_, CopilotAuthState>,
) -> Result<GitHubDeviceCodeResponse, String> {
let auth_manager = state.0.read().await;
auth_manager
.start_device_flow()
.await
.map_err(|e| e.to_string())
}
/// 轮询 OAuth Token(向后兼容)
///
/// 使用设备码轮询 GitHub,等待用户完成授权
/// 返回 true 表示授权成功,false 表示等待中
#[tauri::command(rename_all = "camelCase")]
pub async fn copilot_poll_for_auth(
device_code: String,
state: State<'_, CopilotAuthState>,
) -> Result<bool, String> {
let auth_manager = state.0.write().await;
match auth_manager.poll_for_token(&device_code).await {
Ok(Some(_account)) => {
log::info!("[CopilotAuth] 用户已授权");
Ok(true)
}
Ok(None) => Ok(false),
Err(crate::proxy::providers::copilot_auth::CopilotAuthError::AuthorizationPending) => {
Ok(false)
}
Err(e) => {
log::error!("[CopilotAuth] 轮询失败: {}", e);
Err(e.to_string())
}
}
}
/// 轮询 OAuth Token(多账号版本)
///
/// 返回新添加的账号信息,如果授权成功
#[tauri::command(rename_all = "camelCase")]
pub async fn copilot_poll_for_account(
device_code: String,
state: State<'_, CopilotAuthState>,
) -> Result<Option<GitHubAccount>, String> {
let auth_manager = state.0.write().await;
match auth_manager.poll_for_token(&device_code).await {
Ok(account) => Ok(account),
Err(crate::proxy::providers::copilot_auth::CopilotAuthError::AuthorizationPending) => {
Ok(None)
}
Err(e) => {
log::error!("[CopilotAuth] 轮询失败: {}", e);
Err(e.to_string())
}
}
}
// ==================== 多账号管理 ====================
/// 列出所有已认证的账号
#[tauri::command]
pub async fn copilot_list_accounts(
state: State<'_, CopilotAuthState>,
) -> Result<Vec<GitHubAccount>, String> {
let auth_manager = state.0.read().await;
Ok(auth_manager.list_accounts().await)
}
/// 移除指定账号
#[tauri::command(rename_all = "camelCase")]
pub async fn copilot_remove_account(
account_id: String,
state: State<'_, CopilotAuthState>,
) -> Result<(), String> {
let auth_manager = state.0.write().await;
auth_manager
.remove_account(&account_id)
.await
.map_err(|e| e.to_string())
}
/// 设置默认账号
#[tauri::command(rename_all = "camelCase")]
pub async fn copilot_set_default_account(
account_id: String,
state: State<'_, CopilotAuthState>,
) -> Result<(), String> {
let auth_manager = state.0.write().await;
auth_manager
.set_default_account(&account_id)
.await
.map_err(|e| e.to_string())
}
// ==================== 状态查询 ====================
/// 获取认证状态(包含所有账号)
#[tauri::command]
pub async fn copilot_get_auth_status(
state: State<'_, CopilotAuthState>,
) -> Result<CopilotAuthStatus, String> {
let auth_manager = state.0.read().await;
Ok(auth_manager.get_status().await)
}
/// 检查是否已认证(有任意账号)
#[tauri::command]
pub async fn copilot_is_authenticated(state: State<'_, CopilotAuthState>) -> Result<bool, String> {
let auth_manager = state.0.read().await;
Ok(auth_manager.is_authenticated().await)
}
/// 注销所有 Copilot 认证
#[tauri::command]
pub async fn copilot_logout(state: State<'_, CopilotAuthState>) -> Result<(), String> {
let auth_manager = state.0.write().await;
auth_manager.clear_auth().await.map_err(|e| e.to_string())
}
// ==================== Token 获取 ====================
/// 获取有效的 Copilot Token(向后兼容:使用第一个账号)
///
/// 内部使用,用于代理请求
#[tauri::command]
pub async fn copilot_get_token(state: State<'_, CopilotAuthState>) -> Result<String, String> {
let auth_manager = state.0.read().await;
auth_manager
.get_valid_token()
.await
.map_err(|e| e.to_string())
}
/// 获取指定账号的有效 Copilot Token
#[tauri::command(rename_all = "camelCase")]
pub async fn copilot_get_token_for_account(
account_id: String,
state: State<'_, CopilotAuthState>,
) -> Result<String, String> {
let auth_manager = state.0.read().await;
auth_manager
.get_valid_token_for_account(&account_id)
.await
.map_err(|e| e.to_string())
}
// ==================== 模型和使用量 ====================
/// 获取 Copilot 可用模型列表(向后兼容:使用第一个账号)
#[tauri::command]
pub async fn copilot_get_models(
state: State<'_, CopilotAuthState>,
) -> Result<Vec<CopilotModel>, String> {
let auth_manager = state.0.read().await;
auth_manager.fetch_models().await.map_err(|e| e.to_string())
}
/// 获取指定账号的 Copilot 可用模型列表
#[tauri::command(rename_all = "camelCase")]
pub async fn copilot_get_models_for_account(
account_id: String,
state: State<'_, CopilotAuthState>,
) -> Result<Vec<CopilotModel>, String> {
let auth_manager = state.0.read().await;
auth_manager
.fetch_models_for_account(&account_id)
.await
.map_err(|e| e.to_string())
}
/// 获取 Copilot 使用量信息(向后兼容:使用第一个账号)
#[tauri::command]
pub async fn copilot_get_usage(
state: State<'_, CopilotAuthState>,
) -> Result<CopilotUsageResponse, String> {
let auth_manager = state.0.read().await;
auth_manager.fetch_usage().await.map_err(|e| e.to_string())
}
/// 获取指定账号的 Copilot 使用量信息
#[tauri::command(rename_all = "camelCase")]
pub async fn copilot_get_usage_for_account(
account_id: String,
state: State<'_, CopilotAuthState>,
) -> Result<CopilotUsageResponse, String> {
let auth_manager = state.0.read().await;
auth_manager
.fetch_usage_for_account(&account_id)
.await
.map_err(|e| e.to_string())
}
+6 -2
View File
@@ -1,6 +1,8 @@
#![allow(non_snake_case)]
mod auth;
mod config;
mod copilot;
mod deeplink;
mod env;
mod failover;
@@ -19,12 +21,14 @@ mod settings;
pub mod skill;
mod stream_check;
mod sync_support;
mod toolsearch;
mod usage;
mod webdav_sync;
mod workspace;
pub use auth::*;
pub use config::*;
pub use copilot::*;
pub use deeplink::*;
pub use env::*;
pub use failover::*;
@@ -42,7 +46,7 @@ pub use session_manager::*;
pub use settings::*;
pub use skill::*;
pub use stream_check::*;
pub use toolsearch::*;
pub use usage::*;
pub use webdav_sync::*;
pub use workspace::*;
+61
View File
@@ -2,6 +2,7 @@ use indexmap::IndexMap;
use tauri::State;
use crate::app_config::AppType;
use crate::commands::copilot::CopilotAuthState;
use crate::error::AppError;
use crate::provider::Provider;
use crate::services::{
@@ -10,6 +11,11 @@ use crate::services::{
use crate::store::AppState;
use std::str::FromStr;
// 常量定义
const TEMPLATE_TYPE_GITHUB_COPILOT: &str = "github_copilot";
const COPILOT_UNIT_PREMIUM: &str = "requests";
/// 获取所有供应商
#[tauri::command]
pub fn get_providers(
state: State<'_, AppState>,
@@ -142,10 +148,65 @@ pub fn import_default_config(state: State<'_, AppState>, app: String) -> Result<
#[tauri::command]
pub async fn queryProviderUsage(
state: State<'_, AppState>,
copilot_state: State<'_, CopilotAuthState>,
#[allow(non_snake_case)] providerId: String, // 使用 camelCase 匹配前端
app: String,
) -> Result<crate::provider::UsageResult, String> {
let app_type = AppType::from_str(&app).map_err(|e| e.to_string())?;
// 检查是否为 GitHub Copilot 模板类型,并解析绑定账号
let (is_copilot_template, copilot_account_id) = {
let providers = state
.db
.get_all_providers(app_type.as_str())
.map_err(|e| format!("Failed to get providers: {}", e))?;
let provider = providers.get(&providerId);
let is_copilot = provider
.and_then(|p| p.meta.as_ref())
.and_then(|m| m.usage_script.as_ref())
.and_then(|s| s.template_type.as_ref())
.map(|t| t == TEMPLATE_TYPE_GITHUB_COPILOT)
.unwrap_or(false);
let account_id = provider
.and_then(|p| p.meta.as_ref())
.and_then(|m| m.managed_account_id_for(TEMPLATE_TYPE_GITHUB_COPILOT));
(is_copilot, account_id)
};
if is_copilot_template {
// 使用 Copilot 专用 API
let auth_manager = copilot_state.0.read().await;
let usage = match copilot_account_id.as_deref() {
Some(account_id) => auth_manager
.fetch_usage_for_account(account_id)
.await
.map_err(|e| format!("Failed to fetch Copilot usage: {}", e))?,
None => auth_manager
.fetch_usage()
.await
.map_err(|e| format!("Failed to fetch Copilot usage: {}", e))?,
};
let premium = &usage.quota_snapshots.premium_interactions;
let used = premium.entitlement - premium.remaining;
return Ok(crate::provider::UsageResult {
success: true,
data: Some(vec![crate::provider::UsageData {
plan_name: Some(usage.copilot_plan),
remaining: Some(premium.remaining as f64),
total: Some(premium.entitlement as f64),
used: Some(used as f64),
unit: Some(COPILOT_UNIT_PREMIUM.to_string()),
is_valid: Some(true),
invalid_message: None,
extra: Some(format!("Reset: {}", usage.quota_reset_date)),
}]),
error: None,
});
}
ProviderService::query_usage(state.inner(), app_type, &providerId)
.await
.map_err(|e| e.to_string())
+63 -13
View File
@@ -1,6 +1,7 @@
//! 流式健康检查命令
use crate::app_config::AppType;
use crate::commands::copilot::CopilotAuthState;
use crate::error::AppError;
use crate::services::stream_check::{
HealthStatus, StreamCheckConfig, StreamCheckResult, StreamCheckService,
@@ -13,6 +14,7 @@ use tauri::State;
#[tauri::command]
pub async fn stream_check_provider(
state: State<'_, AppState>,
copilot_state: State<'_, CopilotAuthState>,
app_type: AppType,
provider_id: String,
) -> Result<StreamCheckResult, AppError> {
@@ -23,7 +25,9 @@ pub async fn stream_check_provider(
.get(&provider_id)
.ok_or_else(|| AppError::Message(format!("供应商 {provider_id} 不存在")))?;
let result = StreamCheckService::check_with_retry(&app_type, provider, &config).await?;
let auth_override = resolve_copilot_auth_override(provider, &copilot_state).await?;
let result =
StreamCheckService::check_with_retry(&app_type, provider, &config, auth_override).await?;
// 记录日志
let _ =
@@ -38,6 +42,7 @@ pub async fn stream_check_provider(
#[tauri::command]
pub async fn stream_check_all_providers(
state: State<'_, AppState>,
copilot_state: State<'_, CopilotAuthState>,
app_type: AppType,
proxy_targets_only: bool,
) -> Result<Vec<(String, StreamCheckResult)>, AppError> {
@@ -67,18 +72,20 @@ pub async fn stream_check_all_providers(
}
}
let result = StreamCheckService::check_with_retry(&app_type, &provider, &config)
.await
.unwrap_or_else(|e| StreamCheckResult {
status: HealthStatus::Failed,
success: false,
message: e.to_string(),
response_time_ms: None,
http_status: None,
model_used: String::new(),
tested_at: chrono::Utc::now().timestamp(),
retry_count: 0,
});
let auth_override = resolve_copilot_auth_override(&provider, &copilot_state).await?;
let result =
StreamCheckService::check_with_retry(&app_type, &provider, &config, auth_override)
.await
.unwrap_or_else(|e| StreamCheckResult {
status: HealthStatus::Failed,
success: false,
message: e.to_string(),
response_time_ms: None,
http_status: None,
model_used: String::new(),
tested_at: chrono::Utc::now().timestamp(),
retry_count: 0,
});
let _ = state
.db
@@ -104,3 +111,46 @@ pub fn save_stream_check_config(
) -> Result<(), AppError> {
state.db.save_stream_check_config(&config)
}
async fn resolve_copilot_auth_override(
provider: &crate::provider::Provider,
copilot_state: &State<'_, CopilotAuthState>,
) -> Result<Option<crate::proxy::providers::AuthInfo>, AppError> {
let is_copilot = provider
.meta
.as_ref()
.and_then(|meta| meta.provider_type.as_deref())
== Some("github_copilot")
|| provider
.settings_config
.pointer("/env/ANTHROPIC_BASE_URL")
.and_then(|value| value.as_str())
.map(|url| url.contains("githubcopilot.com"))
.unwrap_or(false);
if !is_copilot {
return Ok(None);
}
let auth_manager = copilot_state.0.read().await;
let account_id = provider
.meta
.as_ref()
.and_then(|meta| meta.github_account_id.clone());
let token = match account_id.as_deref() {
Some(id) => auth_manager
.get_valid_token_for_account(id)
.await
.map_err(|e| AppError::Message(format!("GitHub Copilot 认证失败: {e}")))?,
None => auth_manager
.get_valid_token()
.await
.map_err(|e| AppError::Message(format!("GitHub Copilot 认证失败: {e}")))?,
};
Ok(Some(crate::proxy::providers::AuthInfo::new(
token,
crate::proxy::providers::AuthStrategy::GitHubCopilot,
)))
}
-21
View File
@@ -1,21 +0,0 @@
#![allow(non_snake_case)]
/// Check Tool Search patch status for the active Claude Code installation
#[tauri::command]
pub async fn check_toolsearch_status() -> Result<crate::toolsearch_patch::ToolSearchStatus, String>
{
crate::toolsearch_patch::check_toolsearch_status().map_err(|e| e.to_string())
}
/// Apply Tool Search patch (bypass domain restriction) to the active installation
#[tauri::command]
pub async fn apply_toolsearch_patch() -> Result<Vec<crate::toolsearch_patch::PatchResult>, String> {
crate::toolsearch_patch::apply_toolsearch_patch().map_err(|e| e.to_string())
}
/// Restore Tool Search patch (re-enable domain restriction) for the active installation
#[tauri::command]
pub async fn restore_toolsearch_patch() -> Result<Vec<crate::toolsearch_patch::PatchResult>, String>
{
crate::toolsearch_patch::restore_toolsearch_patch().map_err(|e| e.to_string())
}
+53 -4
View File
@@ -4,7 +4,7 @@
use super::{lock_conn, Database, SCHEMA_VERSION};
use crate::error::AppError;
use rusqlite::Connection;
use rusqlite::{params, Connection};
use serde::Serialize;
#[derive(Serialize)]
@@ -389,7 +389,7 @@ impl Database {
Self::set_user_version(conn, 5)?;
}
5 => {
log::info!("迁移数据库从 v5 到 v6(使用量聚合表)");
log::info!("迁移数据库从 v5 到 v6(使用量聚合表 + Copilot 模板类型统一");
Self::migrate_v5_to_v6(conn)?;
Self::set_user_version(conn, 6)?;
}
@@ -970,8 +970,9 @@ impl Database {
Ok(())
}
/// v5 -> v6 迁移:添加使用量日聚合表
/// v5 -> v6 迁移:添加使用量日聚合表 + 统一 Copilot 模板类型
fn migrate_v5_to_v6(conn: &Connection) -> Result<(), AppError> {
// 1. 添加使用量日聚合表
conn.execute(
"CREATE TABLE IF NOT EXISTS usage_daily_rollups (
date TEXT NOT NULL,
@@ -992,7 +993,55 @@ impl Database {
)
.map_err(|e| AppError::Database(format!("创建 usage_daily_rollups 表失败: {e}")))?;
log::info!("v5 -> v6 迁移完成:已添加使用量日聚合表");
// 2. 统一 Copilot 模板类型为 github_copilot
let mut stmt = conn
.prepare("SELECT id, app_type, meta FROM providers")
.map_err(|e| AppError::Database(e.to_string()))?;
let rows = stmt
.query_map([], |row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
))
})
.map_err(|e| AppError::Database(e.to_string()))?;
let mut updates = Vec::new();
for row in rows {
let (id, app_type, meta_str) = row.map_err(|e| AppError::Database(e.to_string()))?;
if let Ok(mut meta) = serde_json::from_str::<serde_json::Value>(&meta_str) {
let mut updated = false;
if let Some(usage_script) = meta.get_mut("usage_script") {
if let Some(template_type) = usage_script.get_mut("template_type") {
if template_type == "copilot" {
*template_type =
serde_json::Value::String("github_copilot".to_string());
updated = true;
}
}
}
if updated {
let new_meta_str = serde_json::to_string(&meta)
.map_err(|e| AppError::Database(e.to_string()))?;
updates.push((id, app_type, new_meta_str));
}
}
}
for (id, app_type, new_meta) in updates {
conn.execute(
"UPDATE providers SET meta = ?1 WHERE id = ?2 AND app_type = ?3",
params![new_meta, id, app_type],
)
.map_err(|e| AppError::Database(e.to_string()))?;
}
log::info!("v5 -> v6 迁移完成:已添加使用量日聚合表,统一 copilot 模板类型");
Ok(())
}
+38 -25
View File
@@ -25,7 +25,7 @@ mod services;
mod session_manager;
mod settings;
mod store;
mod toolsearch_patch;
mod tray;
mod usage_script;
@@ -690,6 +690,18 @@ pub fn run() {
let skill_service = SkillService::new();
app.manage(commands::skill::SkillServiceState(Arc::new(skill_service)));
// 初始化 CopilotAuthManager
{
use crate::proxy::providers::copilot_auth::CopilotAuthManager;
use commands::CopilotAuthState;
use tokio::sync::RwLock;
let app_config_dir = crate::config::get_app_config_dir();
let copilot_auth_manager = CopilotAuthManager::new(app_config_dir);
app.manage(CopilotAuthState(Arc::new(RwLock::new(copilot_auth_manager))));
log::info!("✓ CopilotAuthManager initialized");
}
// 初始化全局出站代理 HTTP 客户端
{
let db = &app.state::<AppState>().db;
@@ -806,26 +818,6 @@ pub fn run() {
}
}
// Tool Search bypass: auto-apply patch on startup if enabled
if settings.tool_search_bypass {
match crate::toolsearch_patch::apply_toolsearch_patch() {
Ok(results) => {
let success = results.iter().filter(|r| r.success).count();
let total = results.len();
if success > 0 {
log::info!("✓ Tool Search patch auto-applied ({success}/{total})");
}
for r in results.iter().filter(|r| !r.success) {
if let Some(err) = &r.error {
log::warn!("✗ Tool Search patch failed for {}: {err}", r.path);
}
}
}
Err(e) => {
log::warn!("✗ Tool Search auto-patch skipped: {e}");
}
}
}
Ok(())
})
@@ -873,10 +865,6 @@ pub fn run() {
commands::is_claude_plugin_applied,
commands::apply_claude_onboarding_skip,
commands::clear_claude_onboarding_skip,
// Tool Search patch
commands::check_toolsearch_status,
commands::apply_toolsearch_patch,
commands::restore_toolsearch_patch,
// Claude MCP management
commands::get_claude_mcp_status,
commands::read_claude_mcp_config,
@@ -1056,6 +1044,31 @@ pub fn run() {
commands::scan_local_proxies,
// Window theme control
commands::set_window_theme,
// Generic managed auth commands
commands::auth_start_login,
commands::auth_poll_for_account,
commands::auth_list_accounts,
commands::auth_get_status,
commands::auth_remove_account,
commands::auth_set_default_account,
commands::auth_logout,
// Copilot OAuth commands (multi-account support)
commands::copilot_start_device_flow,
commands::copilot_poll_for_auth,
commands::copilot_poll_for_account,
commands::copilot_list_accounts,
commands::copilot_remove_account,
commands::copilot_set_default_account,
commands::copilot_get_auth_status,
commands::copilot_logout,
commands::copilot_is_authenticated,
commands::copilot_get_token,
commands::copilot_get_token_for_account,
commands::copilot_get_models,
commands::copilot_get_models_for_account,
commands::copilot_get_usage,
commands::copilot_get_usage_for_account,
// OMO commands
commands::read_omo_local_file,
commands::get_current_omo_provider_id,
commands::disable_current_omo,
+60
View File
@@ -191,6 +191,31 @@ pub struct ProviderProxyConfig {
pub proxy_password: Option<String>,
}
/// 认证绑定来源
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum AuthBindingSource {
/// 从 provider 自身配置读取认证信息(默认)
#[default]
ProviderConfig,
/// 使用托管账号认证(如 GitHub Copilot OAuth
ManagedAccount,
}
/// 通用认证绑定
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AuthBinding {
/// 认证来源
#[serde(default)]
pub source: AuthBindingSource,
/// 托管认证供应商标识(如 github_copilot
#[serde(rename = "authProvider", skip_serializing_if = "Option::is_none")]
pub auth_provider: Option<String>,
/// 托管账号 ID;为空表示跟随该认证供应商的默认账号
#[serde(rename = "accountId", skip_serializing_if = "Option::is_none")]
pub account_id: Option<String>,
}
/// 供应商元数据
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ProviderMeta {
@@ -242,14 +267,49 @@ pub struct ProviderMeta {
/// - "openai_responses": OpenAI Responses API 格式,需要转换
#[serde(rename = "apiFormat", skip_serializing_if = "Option::is_none")]
pub api_format: Option<String>,
/// 通用认证绑定(provider_config / managed_account
///
/// 新代码应只写入该字段;githubAccountId 仅保留兼容读取。
#[serde(rename = "authBinding", skip_serializing_if = "Option::is_none")]
pub auth_binding: Option<AuthBinding>,
/// Claude 认证字段名("ANTHROPIC_AUTH_TOKEN" 或 "ANTHROPIC_API_KEY"
#[serde(rename = "apiKeyField", skip_serializing_if = "Option::is_none")]
pub api_key_field: Option<String>,
/// Prompt cache key for OpenAI-compatible endpoints.
/// When set, injected into converted requests to improve cache hit rate.
/// If not set, provider ID is used automatically during format conversion.
#[serde(rename = "promptCacheKey", skip_serializing_if = "Option::is_none")]
pub prompt_cache_key: Option<String>,
/// 供应商类型标识(用于特殊供应商检测)
/// - "github_copilot": GitHub Copilot 供应商
#[serde(rename = "providerType", skip_serializing_if = "Option::is_none")]
pub provider_type: Option<String>,
/// GitHub Copilot 关联账号 ID(仅 github_copilot 供应商使用)
/// 用于多账号支持,关联到特定的 GitHub 账号
#[serde(rename = "githubAccountId", skip_serializing_if = "Option::is_none")]
pub github_account_id: Option<String>,
}
impl ProviderMeta {
/// 解析指定托管认证供应商绑定的账号 ID。
///
/// 新版优先读取 authBinding,旧版继续兼容 githubAccountId。
pub fn managed_account_id_for(&self, auth_provider: &str) -> Option<String> {
if let Some(binding) = self.auth_binding.as_ref() {
if binding.source == AuthBindingSource::ManagedAccount
&& binding.auth_provider.as_deref() == Some(auth_provider)
{
return binding.account_id.clone();
}
}
if auth_provider == "github_copilot" {
return self.github_account_id.clone();
}
None
}
}
impl ProviderManager {
+73 -7
View File
@@ -8,7 +8,7 @@ use super::{
failover_switch::FailoverSwitchManager,
log_codes::fwd as log_fwd,
provider_router::ProviderRouter,
providers::{get_adapter, ProviderAdapter, ProviderType},
providers::{get_adapter, AuthInfo, AuthStrategy, ProviderAdapter, ProviderType},
thinking_budget_rectifier::{rectify_thinking_budget, should_rectify_thinking_budget},
thinking_rectifier::{
normalize_thinking_type, rectify_anthropic_request, should_rectify_thinking_signature,
@@ -16,10 +16,13 @@ use super::{
types::{OptimizerConfig, ProxyStatus, RectifierConfig},
ProxyError,
};
use crate::commands::CopilotAuthState;
use crate::proxy::providers::copilot_auth::CopilotAuthManager;
use crate::{app_config::AppType, provider::Provider};
use reqwest::Response;
use serde_json::Value;
use std::sync::Arc;
use tauri::Manager;
use tokio::sync::RwLock;
/// Headers 黑名单 - 不透传到上游的 Headers
@@ -792,14 +795,27 @@ impl RequestForwarder {
// 检查是否需要格式转换
let needs_transform = adapter.needs_transform(provider);
// 确定有效端点
// GitHub Copilot API 使用 /chat/completions(无 /v1 前缀)
let is_copilot = provider
.meta
.as_ref()
.and_then(|m| m.provider_type.as_deref())
== Some("github_copilot")
|| base_url.contains("githubcopilot.com");
let effective_endpoint =
if needs_transform && adapter.name() == "Claude" && endpoint == "/v1/messages" {
// 根据 api_format 选择目标端点
let api_format = super::providers::get_claude_api_format(provider);
if api_format == "openai_responses" {
"/v1/responses"
if is_copilot {
// GitHub Copilot uses /chat/completions without /v1 prefix
"/chat/completions"
} else {
"/v1/chat/completions"
// 根据 api_format 选择目标端点
let api_format = super::providers::get_claude_api_format(provider);
if api_format == "openai_responses" {
"/v1/responses"
} else {
"/v1/chat/completions"
}
}
} else {
endpoint
@@ -892,7 +908,57 @@ impl RequestForwarder {
}
// 使用适配器添加认证头
if let Some(auth) = adapter.extract_auth(provider) {
if let Some(mut auth) = adapter.extract_auth(provider) {
// GitHub Copilot 特殊处理:从 CopilotAuthManager 获取真实 token
if auth.strategy == AuthStrategy::GitHubCopilot {
if let Some(app_handle) = &self.app_handle {
let copilot_state = app_handle.state::<CopilotAuthState>();
let copilot_auth: tokio::sync::RwLockReadGuard<'_, CopilotAuthManager> =
copilot_state.0.read().await;
// 从 provider.meta 获取关联的 GitHub 账号 ID(多账号支持)
let account_id = provider
.meta
.as_ref()
.and_then(|m| m.managed_account_id_for("github_copilot"));
// 根据账号 ID 获取对应 token(向后兼容:无账号 ID 时使用第一个账号)
let token_result = match &account_id {
Some(id) => {
log::debug!("[Copilot] 使用指定账号 {id} 获取 token");
copilot_auth.get_valid_token_for_account(id).await
}
None => {
log::debug!("[Copilot] 使用默认账号获取 token");
copilot_auth.get_valid_token().await
}
};
match token_result {
Ok(token) => {
auth = AuthInfo::new(token, AuthStrategy::GitHubCopilot);
log::debug!(
"[Copilot] 成功获取 Copilot token (account={})",
account_id.as_deref().unwrap_or("default")
);
}
Err(e) => {
log::error!(
"[Copilot] 获取 Copilot token 失败 (account={}): {e}",
account_id.as_deref().unwrap_or("default")
);
return Err(ProxyError::AuthError(format!(
"GitHub Copilot 认证失败: {e}"
)));
}
}
} else {
log::error!("[Copilot] AppHandle 不可用");
return Err(ProxyError::AuthError(
"GitHub Copilot 认证不可用(无 AppHandle".to_string(),
));
}
}
request = adapter.add_auth_headers(request, &auth);
}
+8
View File
@@ -112,6 +112,13 @@ pub enum AuthStrategy {
///
/// 用于 Gemini CLI 等需要 OAuth 的场景
GoogleOAuth,
/// GitHub Copilot 认证方式
///
/// - Header: `Authorization: Bearer <copilot_token>`
///
/// 使用动态获取的 Copilot Token(通过 GitHub OAuth 设备码流程获取)
GitHubCopilot,
}
#[cfg(test)]
@@ -226,6 +233,7 @@ mod tests {
AuthStrategy::Bearer,
AuthStrategy::Google,
AuthStrategy::GoogleOAuth,
AuthStrategy::GitHubCopilot,
];
for (i, s1) in strategies.iter().enumerate() {
+116
View File
@@ -11,6 +11,7 @@
//! - **Claude**: Anthropic 官方 API (x-api-key + anthropic-version)
//! - **ClaudeAuth**: 中转服务 (仅 Bearer 认证,无 x-api-key)
//! - **OpenRouter**: 已支持 Claude Code 兼容接口,默认透传
//! - **GitHubCopilot**: GitHub Copilot (OAuth + Copilot Token)
use super::{AuthInfo, AuthStrategy, ProviderAdapter, ProviderType};
use crate::provider::Provider;
@@ -76,10 +77,16 @@ impl ClaudeAdapter {
/// 获取供应商类型
///
/// 根据 base_url 和 auth_mode 检测具体的供应商类型:
/// - GitHubCopilot: meta.provider_type 为 github_copilot 或 base_url 包含 githubcopilot.com
/// - OpenRouter: base_url 包含 openrouter.ai
/// - ClaudeAuth: auth_mode 为 bearer_only
/// - Claude: 默认 Anthropic 官方
pub fn provider_type(&self, provider: &Provider) -> ProviderType {
// 检测 GitHub Copilot
if self.is_github_copilot(provider) {
return ProviderType::GitHubCopilot;
}
// 检测 OpenRouter
if self.is_openrouter(provider) {
return ProviderType::OpenRouter;
@@ -93,6 +100,25 @@ impl ClaudeAdapter {
ProviderType::Claude
}
/// 检测是否为 GitHub Copilot 供应商
fn is_github_copilot(&self, provider: &Provider) -> bool {
// 方式1: 检查 meta.provider_type
if let Some(meta) = provider.meta.as_ref() {
if meta.provider_type.as_deref() == Some("github_copilot") {
return true;
}
}
// 方式2: 检查 base_url(兼容旧数据的 fallback,后续应优先依赖 providerType
if let Ok(base_url) = self.extract_base_url(provider) {
if base_url.contains("githubcopilot.com") {
return true;
}
}
false
}
/// 检测是否使用 OpenRouter
fn is_openrouter(&self, provider: &Provider) -> bool {
if let Ok(base_url) = self.extract_base_url(provider) {
@@ -244,6 +270,17 @@ impl ProviderAdapter for ClaudeAdapter {
fn extract_auth(&self, provider: &Provider) -> Option<AuthInfo> {
let provider_type = self.provider_type(provider);
// GitHub Copilot 使用特殊的认证策略
// 实际的 token 会在代理请求时动态获取
if provider_type == ProviderType::GitHubCopilot {
// 返回一个占位符,实际 token 由 CopilotAuthManager 动态提供
return Some(AuthInfo::new(
"copilot_placeholder".to_string(),
AuthStrategy::GitHubCopilot,
));
}
let strategy = match provider_type {
ProviderType::OpenRouter => AuthStrategy::Bearer,
ProviderType::ClaudeAuth => AuthStrategy::ClaudeAuth,
@@ -273,6 +310,11 @@ impl ProviderAdapter for ClaudeAdapter {
base = base.replace("/v1/v1", "/v1");
}
// GitHub Copilot 不需要 ?beta=true 参数
if base_url.contains("githubcopilot.com") {
return base;
}
// 为 Claude 原生 /v1/messages 端点添加 ?beta=true 参数
// 这是某些上游服务(如 DuckCoding)验证请求来源的关键参数
// 注意:不要为 OpenAI Chat Completions (/v1/chat/completions) 添加此参数
@@ -304,11 +346,22 @@ impl ProviderAdapter for ClaudeAdapter {
AuthStrategy::Bearer => {
request.header("Authorization", format!("Bearer {}", auth.api_key))
}
// GitHub Copilot: Bearer + 特定的 Editor headers
AuthStrategy::GitHubCopilot => request
.header("Authorization", format!("Bearer {}", auth.api_key))
.header("Editor-Version", "vscode/1.85.0")
.header("Editor-Plugin-Version", "copilot/1.150.0")
.header("Copilot-Integration-Id", "vscode-chat"),
_ => request,
}
}
fn needs_transform(&self, provider: &Provider) -> bool {
// GitHub Copilot 总是需要格式转换 (Anthropic → OpenAI)
if self.is_github_copilot(provider) {
return true;
}
// 根据 api_format 配置决定是否需要格式转换
// - "anthropic" (默认): 直接透传,无需转换
// - "openai_chat": 需要 Anthropic ↔ OpenAI Chat Completions 格式转换
@@ -678,4 +731,67 @@ mod tests {
);
assert!(!adapter.needs_transform(&unknown_format));
}
#[test]
fn test_github_copilot_detection_by_url() {
let adapter = ClaudeAdapter::new();
// GitHub Copilot by base_url
let copilot = create_provider(json!({
"env": {
"ANTHROPIC_BASE_URL": "https://api.githubcopilot.com"
}
}));
assert_eq!(adapter.provider_type(&copilot), ProviderType::GitHubCopilot);
}
#[test]
fn test_github_copilot_detection_by_meta() {
let adapter = ClaudeAdapter::new();
// GitHub Copilot by meta.provider_type
let copilot_meta = create_provider_with_meta(
json!({
"env": {
"ANTHROPIC_BASE_URL": "https://api.example.com"
}
}),
ProviderMeta {
provider_type: Some("github_copilot".to_string()),
..Default::default()
},
);
assert_eq!(
adapter.provider_type(&copilot_meta),
ProviderType::GitHubCopilot
);
}
#[test]
fn test_github_copilot_auth() {
let adapter = ClaudeAdapter::new();
let copilot = create_provider(json!({
"env": {
"ANTHROPIC_BASE_URL": "https://api.githubcopilot.com"
}
}));
let auth = adapter.extract_auth(&copilot).unwrap();
assert_eq!(auth.strategy, AuthStrategy::GitHubCopilot);
}
#[test]
fn test_github_copilot_needs_transform() {
let adapter = ClaudeAdapter::new();
let copilot = create_provider(json!({
"env": {
"ANTHROPIC_BASE_URL": "https://api.githubcopilot.com"
}
}));
// GitHub Copilot always needs transform
assert!(adapter.needs_transform(&copilot));
}
}
File diff suppressed because it is too large Load Diff
+47 -4
View File
@@ -15,6 +15,7 @@ mod adapter;
mod auth;
mod claude;
mod codex;
pub mod copilot_auth;
mod gemini;
pub mod models;
pub mod streaming;
@@ -52,6 +53,8 @@ pub enum ProviderType {
GeminiCli,
/// OpenRouter(已支持 Claude Code 兼容接口,默认透传;保留旧转换逻辑备用)
OpenRouter,
/// GitHub Copilot (OAuth + Copilot Token,需要 Anthropic ↔ OpenAI 转换)
GitHubCopilot,
}
impl ProviderType {
@@ -59,9 +62,11 @@ impl ProviderType {
///
/// 过去 OpenRouter 需要将 Anthropic 格式转换为 OpenAI 格式;
/// 现在默认关闭转换(因为 OpenRouter 已支持 Claude Code 兼容接口)。
/// GitHub Copilot 需要转换(Anthropic → OpenAI 格式)。
#[allow(dead_code)]
pub fn needs_transform(&self) -> bool {
match self {
ProviderType::GitHubCopilot => true,
ProviderType::OpenRouter => false,
_ => false,
}
@@ -77,6 +82,7 @@ impl ProviderType {
"https://generativelanguage.googleapis.com"
}
ProviderType::OpenRouter => "https://openrouter.ai/api",
ProviderType::GitHubCopilot => "https://api.githubcopilot.com",
}
}
@@ -87,9 +93,20 @@ impl ProviderType {
pub fn from_app_type_and_config(app_type: &AppType, provider: &Provider) -> Self {
match app_type {
AppType::Claude => {
// 检测是否为 OpenRouter
// 检测是否为 GitHub Copilot
if let Some(meta) = provider.meta.as_ref() {
if meta.provider_type.as_deref() == Some("github_copilot") {
return ProviderType::GitHubCopilot;
}
}
// 检测 base_url 是否为 GitHub Copilot
let adapter = ClaudeAdapter::new();
if let Ok(base_url) = adapter.extract_base_url(provider) {
if base_url.contains("githubcopilot.com") {
return ProviderType::GitHubCopilot;
}
// 检测是否为 OpenRouter
if base_url.contains("openrouter.ai") {
return ProviderType::OpenRouter;
}
@@ -154,6 +171,7 @@ impl ProviderType {
ProviderType::Gemini => "gemini",
ProviderType::GeminiCli => "gemini_cli",
ProviderType::OpenRouter => "openrouter",
ProviderType::GitHubCopilot => "github_copilot",
}
}
}
@@ -175,6 +193,9 @@ impl std::str::FromStr for ProviderType {
"gemini" => Ok(ProviderType::Gemini),
"gemini_cli" | "gemini-cli" => Ok(ProviderType::GeminiCli),
"openrouter" => Ok(ProviderType::OpenRouter),
"github_copilot" | "github-copilot" | "githubcopilot" => {
Ok(ProviderType::GitHubCopilot)
}
_ => Err(format!("Invalid provider type: {s}")),
}
}
@@ -201,9 +222,10 @@ pub fn get_adapter(app_type: &AppType) -> Box<dyn ProviderAdapter> {
#[allow(dead_code)]
pub fn get_adapter_for_provider_type(provider_type: &ProviderType) -> Box<dyn ProviderAdapter> {
match provider_type {
ProviderType::Claude | ProviderType::ClaudeAuth | ProviderType::OpenRouter => {
Box::new(ClaudeAdapter::new())
}
ProviderType::Claude
| ProviderType::ClaudeAuth
| ProviderType::OpenRouter
| ProviderType::GitHubCopilot => Box::new(ClaudeAdapter::new()),
ProviderType::Codex => Box::new(CodexAdapter::new()),
ProviderType::Gemini | ProviderType::GeminiCli => Box::new(GeminiAdapter::new()),
}
@@ -239,6 +261,7 @@ mod tests {
assert!(!ProviderType::Gemini.needs_transform());
assert!(!ProviderType::GeminiCli.needs_transform());
assert!(!ProviderType::OpenRouter.needs_transform());
assert!(ProviderType::GitHubCopilot.needs_transform());
}
#[test]
@@ -267,6 +290,10 @@ mod tests {
ProviderType::OpenRouter.default_endpoint(),
"https://openrouter.ai/api"
);
assert_eq!(
ProviderType::GitHubCopilot.default_endpoint(),
"https://api.githubcopilot.com"
);
}
#[test]
@@ -303,6 +330,18 @@ mod tests {
"openrouter".parse::<ProviderType>().unwrap(),
ProviderType::OpenRouter
);
assert_eq!(
"github_copilot".parse::<ProviderType>().unwrap(),
ProviderType::GitHubCopilot
);
assert_eq!(
"github-copilot".parse::<ProviderType>().unwrap(),
ProviderType::GitHubCopilot
);
assert_eq!(
"githubcopilot".parse::<ProviderType>().unwrap(),
ProviderType::GitHubCopilot
);
assert!("invalid".parse::<ProviderType>().is_err());
}
@@ -314,6 +353,7 @@ mod tests {
assert_eq!(ProviderType::Gemini.as_str(), "gemini");
assert_eq!(ProviderType::GeminiCli.as_str(), "gemini_cli");
assert_eq!(ProviderType::OpenRouter.as_str(), "openrouter");
assert_eq!(ProviderType::GitHubCopilot.as_str(), "github_copilot");
}
#[test]
@@ -434,6 +474,9 @@ mod tests {
let adapter = get_adapter_for_provider_type(&ProviderType::OpenRouter);
assert_eq!(adapter.name(), "Claude");
let adapter = get_adapter_for_provider_type(&ProviderType::GitHubCopilot);
assert_eq!(adapter.name(), "Claude");
let adapter = get_adapter_for_provider_type(&ProviderType::Codex);
assert_eq!(adapter.name(), "Codex");
+95 -10
View File
@@ -12,6 +12,7 @@ use std::time::Instant;
use crate::app_config::AppType;
use crate::error::AppError;
use crate::provider::Provider;
use crate::proxy::providers::transform::anthropic_to_openai;
use crate::proxy::providers::{get_adapter, AuthInfo, AuthStrategy};
/// 健康状态枚举
@@ -84,13 +85,16 @@ impl StreamCheckService {
app_type: &AppType,
provider: &Provider,
config: &StreamCheckConfig,
auth_override: Option<AuthInfo>,
) -> Result<StreamCheckResult, AppError> {
// 合并供应商单独配置和全局配置
let effective_config = Self::merge_provider_config(provider, config);
let mut last_result = None;
for attempt in 0..=effective_config.max_retries {
let result = Self::check_once(app_type, provider, &effective_config).await;
let result =
Self::check_once(app_type, provider, &effective_config, auth_override.clone())
.await;
match &result {
Ok(r) if r.success => {
@@ -178,6 +182,7 @@ impl StreamCheckService {
app_type: &AppType,
provider: &Provider,
config: &StreamCheckConfig,
auth_override: Option<AuthInfo>,
) -> Result<StreamCheckResult, AppError> {
let start = Instant::now();
let adapter = get_adapter(app_type);
@@ -186,8 +191,8 @@ impl StreamCheckService {
.extract_base_url(provider)
.map_err(|e| AppError::Message(format!("Failed to extract base_url: {e}")))?;
let auth = adapter
.extract_auth(provider)
let auth = auth_override
.or_else(|| adapter.extract_auth(provider))
.ok_or_else(|| AppError::Message("API Key not found".to_string()))?;
// 获取 HTTP 客户端:优先使用供应商单独代理配置,否则使用全局客户端
@@ -297,6 +302,7 @@ impl StreamCheckService {
provider: &Provider,
) -> Result<(u16, String), AppError> {
let base = base_url.trim_end_matches('/');
let is_github_copilot = auth.strategy == AuthStrategy::GitHubCopilot;
// Detect api_format: meta.api_format > settings_config.api_format > default "anthropic"
let api_format = provider
@@ -311,10 +317,15 @@ impl StreamCheckService {
})
.unwrap_or("anthropic");
let is_openai_chat = api_format == "openai_chat";
let is_openai_chat = is_github_copilot || api_format == "openai_chat";
// URL: /v1/chat/completions for openai_chat, /v1/messages?beta=true for anthropic
let url = if is_openai_chat {
// URL:
// - GitHub Copilot: /chat/completions (no /v1 prefix)
// - OpenAI-compatible: /v1/chat/completions
// - Anthropic native: /v1/messages?beta=true
let url = if is_github_copilot {
format!("{base}/chat/completions")
} else if is_openai_chat {
if base.ends_with("/v1") {
format!("{base}/chat/completions")
} else {
@@ -329,22 +340,38 @@ impl StreamCheckService {
}
};
// Body: identical structure for minimal test (both APIs accept messages array)
let body = json!({
// Build from Anthropic-native shape first, then convert for OpenAI-compatible targets.
let anthropic_body = json!({
"model": model,
"max_tokens": 1,
"messages": [{ "role": "user", "content": test_prompt }],
"stream": true
});
let body = if is_openai_chat {
anthropic_to_openai(anthropic_body, Some(&provider.id))
.map_err(|e| AppError::Message(format!("Failed to build test request: {e}")))?
} else {
anthropic_body
};
let mut request_builder = client.post(&url);
if is_openai_chat {
if is_github_copilot {
request_builder = request_builder
.header("authorization", format!("Bearer {}", auth.api_key))
.header("content-type", "application/json")
.header("accept", "text/event-stream")
.header("accept-encoding", "identity")
.header("editor-version", "vscode/1.85.0")
.header("editor-plugin-version", "copilot/1.150.0")
.header("copilot-integration-id", "vscode-chat");
} else if is_openai_chat {
// OpenAI-compatible: Bearer auth + standard headers only
request_builder = request_builder
.header("authorization", format!("Bearer {}", auth.api_key))
.header("content-type", "application/json")
.header("accept", "application/json");
.header("accept", "text/event-stream")
.header("accept-encoding", "identity");
} else {
// Anthropic native: full Claude CLI headers
let os_name = Self::get_os_name();
@@ -692,6 +719,31 @@ impl StreamCheckService {
other => other,
}
}
#[cfg(test)]
fn resolve_claude_stream_url(
base_url: &str,
auth_strategy: AuthStrategy,
api_format: &str,
) -> String {
let base = base_url.trim_end_matches('/');
let is_github_copilot = auth_strategy == AuthStrategy::GitHubCopilot;
let is_openai_chat = is_github_copilot || api_format == "openai_chat";
if is_github_copilot {
format!("{base}/chat/completions")
} else if is_openai_chat {
if base.ends_with("/v1") {
format!("{base}/chat/completions")
} else {
format!("{base}/v1/chat/completions")
}
} else if base.ends_with("/v1") {
format!("{base}/messages?beta=true")
} else {
format!("{base}/v1/messages?beta=true")
}
}
}
#[cfg(test)]
@@ -794,4 +846,37 @@ mod tests {
assert_eq!(claude_auth, AuthStrategy::ClaudeAuth);
assert_eq!(bearer, AuthStrategy::Bearer);
}
#[test]
fn test_resolve_claude_stream_url_for_github_copilot() {
let url = StreamCheckService::resolve_claude_stream_url(
"https://api.githubcopilot.com",
AuthStrategy::GitHubCopilot,
"anthropic",
);
assert_eq!(url, "https://api.githubcopilot.com/chat/completions");
}
#[test]
fn test_resolve_claude_stream_url_for_openai_chat() {
let url = StreamCheckService::resolve_claude_stream_url(
"https://example.com/v1",
AuthStrategy::Bearer,
"openai_chat",
);
assert_eq!(url, "https://example.com/v1/chat/completions");
}
#[test]
fn test_resolve_claude_stream_url_for_anthropic() {
let url = StreamCheckService::resolve_claude_stream_url(
"https://api.anthropic.com",
AuthStrategy::Anthropic,
"anthropic",
);
assert_eq!(url, "https://api.anthropic.com/v1/messages?beta=true");
}
}
-4
View File
@@ -181,9 +181,6 @@ pub struct AppSettings {
/// 是否跳过 Claude Code 初次安装确认
#[serde(default)]
pub skip_claude_onboarding: bool,
/// 是否解除 Tool Search 域名限制
#[serde(default)]
pub tool_search_bypass: bool,
/// 是否开机自启
#[serde(default)]
pub launch_on_startup: bool,
@@ -289,7 +286,6 @@ impl Default for AppSettings {
minimize_to_tray_on_close: true,
enable_claude_plugin_integration: false,
skip_claude_onboarding: false,
tool_search_bypass: false,
launch_on_startup: false,
silent_startup: false,
enable_local_proxy: false,
-569
View File
@@ -1,569 +0,0 @@
//! Tool Search domain restriction bypass patch for Claude Code.
//!
//! Resolves the current active `claude` command from PATH and patches the
//! domain whitelist check
//! `return["api.anthropic.com"].includes(x)}catch{return!1}`
//! to always return true via equal-length byte replacement.
use regex::bytes::Regex;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::process::Command;
use sha2::{Digest, Sha256};
use crate::error::AppError;
const BACKUP_SUFFIX: &str = ".toolsearch-bak";
/// Encode bytes as lowercase hex string (avoids adding `hex` crate dependency).
fn to_hex(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{b:02x}")).collect()
}
// Regex matching the domain whitelist check with any JS identifier as variable name
const PATCH_TARGET_PATTERN: &str =
r#"return\["api\.anthropic\.com"\]\.includes\([A-Za-z_$][A-Za-z0-9_$]*\)\}catch\{return!1\}"#;
// Regex matching already-patched code
const PATCHED_PATTERN: &str = r#"return!0/\* *\*/\}catch\{return!0\}"#;
/// Single Claude Code installation info
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ClaudeInstallation {
pub path: String,
pub source: String,
pub patched: bool,
pub has_backup: bool,
}
/// Result of a patch/restore operation on one installation
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PatchResult {
pub path: String,
pub success: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
/// Overall Tool Search patch status
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolSearchStatus {
pub installations: Vec<ClaudeInstallation>,
pub all_patched: bool,
pub any_found: bool,
}
// ── Patch status detection ──────────────────────────────────────────
fn get_patch_status(data: &[u8]) -> &'static str {
let target_re = Regex::new(PATCH_TARGET_PATTERN).unwrap();
let patched_re = Regex::new(PATCHED_PATTERN).unwrap();
if target_re.is_match(data) {
"unpatched"
} else if patched_re.is_match(data) {
"patched"
} else {
"unknown"
}
}
/// Build equal-length replacement bytes: `return!0/* */}catch{return!0}`
fn build_patched_bytes(original_len: usize) -> Result<Vec<u8>, AppError> {
let prefix = b"return!0/*";
let suffix = b"*/}catch{return!0}";
let padding = original_len
.checked_sub(prefix.len() + suffix.len())
.ok_or_else(|| AppError::Config("Patch template too long for match".into()))?;
let mut out = Vec::with_capacity(original_len);
out.extend_from_slice(prefix);
out.extend(std::iter::repeat_n(b' ', padding));
out.extend_from_slice(suffix);
Ok(out)
}
/// Apply byte-level patch to file data, returns (patched_data, replacement_count)
fn patch_bytes(data: &[u8]) -> Result<(Vec<u8>, usize), AppError> {
let re = Regex::new(PATCH_TARGET_PATTERN).unwrap();
let mut count = 0usize;
let result = re.replace_all(data, |caps: &regex::bytes::Captures| {
count += 1;
build_patched_bytes(caps[0].len()).unwrap_or_else(|_| caps[0].to_vec())
});
Ok((result.into_owned(), count))
}
// ── Installation detection ──────────────────────────────────────────
/// Run a command and return stdout, or empty string on failure
fn run_cmd(cmd: &str, args: &[&str]) -> String {
Command::new(cmd)
.args(args)
.output()
.ok()
.filter(|o| o.status.success())
.and_then(|o| String::from_utf8(o.stdout).ok())
.map(|s| s.trim().to_string())
.unwrap_or_default()
}
/// Search a package directory for JS files containing the domain check
fn find_patch_target_in_pkg(pkg_dir: &Path) -> Option<PathBuf> {
let marker = b"api.anthropic.com";
// Check cli.js first (most common)
let cli_js = pkg_dir.join("cli.js");
if cli_js.is_file() {
if let Ok(data) = std::fs::read(&cli_js) {
if data.windows(marker.len()).any(|w| w == marker) {
return Some(cli_js);
}
}
}
// Search other JS files
find_js_with_marker(pkg_dir)
}
fn find_js_with_marker(dir: &Path) -> Option<PathBuf> {
let marker = b"api.anthropic.com";
let entries = match std::fs::read_dir(dir) {
Ok(e) => e,
Err(_) => return None,
};
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
if let Some(found) = find_js_with_marker(&path) {
return Some(found);
}
} else if path.extension().and_then(|e| e.to_str()) == Some("js") {
if let Ok(meta) = path.metadata() {
if meta.len() < 1000 {
continue;
}
}
if let Ok(data) = std::fs::read(&path) {
if data.windows(marker.len()).any(|w| w == marker) {
return Some(path);
}
}
}
}
None
}
/// Resolve symlinks to actual file path
fn resolve_target(path: &Path) -> PathBuf {
std::fs::canonicalize(path).unwrap_or_else(|_| path.to_path_buf())
}
fn get_patch_status_for_path(path: &Path) -> Option<&'static str> {
let data = std::fs::read(path).ok()?;
Some(get_patch_status(&data))
}
fn package_dir_from_ancestors(path: &Path) -> Option<PathBuf> {
for ancestor in path.ancestors() {
if ancestor.file_name().and_then(|v| v.to_str()) == Some("claude-code")
&& ancestor
.parent()
.and_then(|v| v.file_name())
.and_then(|v| v.to_str())
== Some("@anthropic-ai")
{
return Some(ancestor.to_path_buf());
}
}
None
}
fn push_candidate_package_dir(
candidates: &mut Vec<PathBuf>,
seen: &mut std::collections::HashSet<PathBuf>,
path: PathBuf,
) {
if path.is_dir() && seen.insert(path.clone()) {
candidates.push(path);
}
}
fn resolve_active_patch_target(command_path: &Path) -> Option<PathBuf> {
let resolved_command = resolve_target(command_path);
if matches!(
get_patch_status_for_path(&resolved_command),
Some("patched" | "unpatched")
) {
return Some(resolved_command);
}
let mut candidates = Vec::new();
let mut seen = std::collections::HashSet::new();
for path in [command_path, resolved_command.as_path()] {
if let Some(pkg_dir) = package_dir_from_ancestors(path) {
push_candidate_package_dir(&mut candidates, &mut seen, pkg_dir);
}
if let Some(bin_dir) = path.parent() {
if let Some(prefix) = bin_dir.parent() {
push_candidate_package_dir(
&mut candidates,
&mut seen,
prefix
.join("lib")
.join("node_modules")
.join("@anthropic-ai")
.join("claude-code"),
);
push_candidate_package_dir(
&mut candidates,
&mut seen,
prefix
.join("node_modules")
.join("@anthropic-ai")
.join("claude-code"),
);
}
}
}
candidates
.into_iter()
.find_map(|pkg_dir| find_patch_target_in_pkg(&pkg_dir).map(|p| resolve_target(&p)))
}
#[cfg(target_os = "windows")]
fn find_active_command_path() -> Option<PathBuf> {
run_cmd("where.exe", &["claude"])
.lines()
.next()
.map(PathBuf::from)
.filter(|path| path.is_file())
}
#[cfg(not(target_os = "windows"))]
fn find_active_command_path() -> Option<PathBuf> {
run_cmd("which", &["claude"])
.lines()
.next()
.map(PathBuf::from)
.filter(|path| path.is_file())
}
fn find_active_installation() -> Option<(PathBuf, String)> {
let command_path = find_active_command_path()?;
let patch_target = resolve_active_patch_target(&command_path)?;
Some((
patch_target,
format!("active claude ({})", command_path.display()),
))
}
fn require_active_installation() -> Result<(PathBuf, String), AppError> {
find_active_installation()
.ok_or_else(|| AppError::Config("No active Claude Code installation found in PATH".into()))
}
// ── macOS codesign ──────────────────────────────────────────────────
#[cfg(target_os = "macos")]
fn codesign_adhoc(path: &Path) -> Result<(), AppError> {
let output = Command::new("codesign")
.args(["--force", "--sign", "-"])
.arg(path)
.output()
.map_err(|e| AppError::IoContext {
context: format!("Failed to run codesign for {}", path.display()),
source: e,
})?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(AppError::Config(format!(
"codesign failed for {}: {}",
path.display(),
stderr.trim()
)));
}
Ok(())
}
#[cfg(not(target_os = "macos"))]
fn codesign_adhoc(_path: &Path) -> Result<(), AppError> {
Ok(())
}
// ── Backup directory helpers ─────────────────────────────────────────
/// Get the centralized backup directory: `~/.cc-switch/toolsearch-backups/`
fn get_backup_dir() -> Result<PathBuf, AppError> {
let dir = crate::config::get_app_config_dir().join("toolsearch-backups");
if !dir.exists() {
std::fs::create_dir_all(&dir).map_err(|e| AppError::io(&dir, e))?;
}
Ok(dir)
}
/// Derive a stable backup filename from the original path using SHA-256.
fn backup_name_for(path: &Path) -> String {
let mut hasher = Sha256::new();
hasher.update(path.to_string_lossy().as_bytes());
to_hex(&hasher.finalize())
}
/// Get the backup file path for a given target file.
fn get_backup_path(path: &Path) -> Result<PathBuf, AppError> {
let dir = get_backup_dir()?;
let name = backup_name_for(path);
Ok(dir.join(format!("{name}.bak")))
}
/// Get the metadata file path (records original path for debugging).
fn get_meta_path(path: &Path) -> Result<PathBuf, AppError> {
let dir = get_backup_dir()?;
let name = backup_name_for(path);
Ok(dir.join(format!("{name}.meta")))
}
// ── Patch / Restore single file ─────────────────────────────────────
fn patch_single_file(path: &Path) -> Result<(), AppError> {
let data = std::fs::read(path).map_err(|e| AppError::io(path, e))?;
let status = get_patch_status(&data);
if status == "patched" {
return Ok(()); // Already patched
}
if status == "unknown" {
return Err(AppError::Config(format!(
"Target pattern not found in {}, possibly incompatible version",
path.display()
)));
}
let (patched_data, count) = patch_bytes(&data)?;
if count == 0 {
return Err(AppError::Config(format!(
"No replacements made in {}",
path.display()
)));
}
// Create backup in centralized directory
let backup_path = get_backup_path(path)?;
std::fs::copy(path, &backup_path).map_err(|e| AppError::io(&backup_path, e))?;
// Write metadata file for debugging
let meta_path = get_meta_path(path)?;
let _ = std::fs::write(&meta_path, path.to_string_lossy().as_bytes());
// Write patched data
if let Err(e) = std::fs::write(path, &patched_data) {
// Try rename trick on Windows
#[cfg(target_os = "windows")]
{
if let Ok(()) = write_via_rename(path, &patched_data) {
codesign_adhoc(path)?;
return Ok(());
}
}
return Err(AppError::io(path, e));
}
codesign_adhoc(path)?;
Ok(())
}
#[cfg(target_os = "windows")]
fn write_via_rename(target: &Path, data: &[u8]) -> Result<(), AppError> {
let tmp_path = target.with_extension("tmp");
let old_path = target.with_extension("old");
let _ = std::fs::remove_file(&tmp_path);
let _ = std::fs::remove_file(&old_path);
std::fs::write(&tmp_path, data).map_err(|e| AppError::io(&tmp_path, e))?;
std::fs::rename(target, &old_path).map_err(|e| AppError::io(target, e))?;
std::fs::rename(&tmp_path, target).map_err(|e| AppError::io(target, e))?;
let _ = std::fs::remove_file(&old_path);
Ok(())
}
fn restore_single_file(path: &Path) -> Result<(), AppError> {
let current_status = get_patch_status_for_path(path);
if matches!(current_status, Some("unpatched")) {
return Ok(());
}
// Try centralized backup directory first
let backup_path = get_backup_path(path)?;
// Fallback: legacy backup path (adjacent `.toolsearch-bak` file)
let legacy_backup = PathBuf::from(format!("{}{}", path.display(), BACKUP_SUFFIX));
let actual_backup = if backup_path.is_file() {
&backup_path
} else if legacy_backup.is_file() {
&legacy_backup
} else {
return Err(AppError::Config(format!(
"No backup found for {}",
path.display()
)));
};
let backup_data = std::fs::read(actual_backup).map_err(|e| AppError::io(actual_backup, e))?;
if let Err(e) = std::fs::write(path, &backup_data) {
#[cfg(target_os = "windows")]
{
if let Ok(()) = write_via_rename(path, &backup_data) {
codesign_adhoc(path)?;
return Ok(());
}
}
return Err(AppError::io(path, e));
}
codesign_adhoc(path)?;
Ok(())
}
// ── Public API ──────────────────────────────────────────────────────
/// Check Tool Search patch status for the current active Claude Code installation.
pub fn check_toolsearch_status() -> Result<ToolSearchStatus, AppError> {
let installations = find_active_installation()
.into_iter()
.map(|(path, source)| {
let data = std::fs::read(&path).unwrap_or_default();
let status = get_patch_status(&data);
// Check centralized backup first, then legacy
let has_backup = get_backup_path(&path).map(|p| p.is_file()).unwrap_or(false)
|| PathBuf::from(format!("{}{}", path.display(), BACKUP_SUFFIX)).is_file();
ClaudeInstallation {
path: path.to_string_lossy().to_string(),
source: source.clone(),
patched: status == "patched",
has_backup,
}
})
.collect::<Vec<_>>();
let any_found = !installations.is_empty();
let all_patched = any_found && installations.iter().all(|i| i.patched);
Ok(ToolSearchStatus {
installations,
all_patched,
any_found,
})
}
/// Apply the Tool Search patch to the current active Claude Code installation.
pub fn apply_toolsearch_patch() -> Result<Vec<PatchResult>, AppError> {
let (path, _) = require_active_installation()?;
Ok(vec![match patch_single_file(&path) {
Ok(()) => PatchResult {
path: path.to_string_lossy().to_string(),
success: true,
error: None,
},
Err(e) => PatchResult {
path: path.to_string_lossy().to_string(),
success: false,
error: Some(e.to_string()),
},
}])
}
/// Restore the current active Claude Code installation from backup.
pub fn restore_toolsearch_patch() -> Result<Vec<PatchResult>, AppError> {
let (path, _) = require_active_installation()?;
Ok(vec![match restore_single_file(&path) {
Ok(()) => PatchResult {
path: path.to_string_lossy().to_string(),
success: true,
error: None,
},
Err(e) => PatchResult {
path: path.to_string_lossy().to_string(),
success: false,
error: Some(e.to_string()),
},
}])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_patch_bytes_replaces_correctly() {
let input = br#"return["api.anthropic.com"].includes(x)}catch{return!1}"#;
let (patched, count) = patch_bytes(input).unwrap();
assert_eq!(count, 1);
assert_eq!(patched.len(), input.len());
assert!(patched.starts_with(b"return!0/*"));
assert!(patched.ends_with(b"*/}catch{return!0}"));
}
#[test]
fn test_patch_status_detection() {
let unpatched = br#"return["api.anthropic.com"].includes(x)}catch{return!1}"#;
assert_eq!(get_patch_status(unpatched), "unpatched");
let (patched, _) = patch_bytes(unpatched).unwrap();
assert_eq!(get_patch_status(&patched), "patched");
assert_eq!(get_patch_status(b"some random data"), "unknown");
}
#[test]
fn test_build_patched_bytes_length() {
for len in 50..70 {
let result = build_patched_bytes(len).unwrap();
assert_eq!(result.len(), len);
}
}
#[test]
fn test_resolve_active_patch_target_from_npm_style_bin_path() {
let tmp = tempfile::tempdir().unwrap();
let prefix = tmp.path().join("prefix");
let bin_dir = prefix.join("bin");
let pkg_dir = prefix
.join("lib")
.join("node_modules")
.join("@anthropic-ai")
.join("claude-code");
std::fs::create_dir_all(&bin_dir).unwrap();
std::fs::create_dir_all(&pkg_dir).unwrap();
let command_path = bin_dir.join("claude");
std::fs::write(&command_path, b"#!/usr/bin/env node\n").unwrap();
let cli_path = pkg_dir.join("cli.js");
std::fs::write(
&cli_path,
br#"return["api.anthropic.com"].includes(x)}catch{return!1}"#,
)
.unwrap();
assert_eq!(
resolve_active_patch_target(&command_path),
Some(resolve_target(&cli_path))
);
}
#[test]
fn test_restore_single_file_is_noop_when_current_target_is_unpatched() {
let tmp = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
tmp.path(),
br#"return["api.anthropic.com"].includes(x)}catch{return!1}"#,
)
.unwrap();
assert!(restore_single_file(tmp.path()).is_ok());
}
}
+1 -1
View File
@@ -26,7 +26,7 @@
}
],
"security": {
"csp": "default-src 'self'; img-src 'self' data:; script-src 'self'; style-src 'self' 'unsafe-inline'; connect-src 'self' ipc: http://ipc.localhost https: http:",
"csp": "default-src 'self'; img-src 'self' data: https: http:; script-src 'self'; style-src 'self' 'unsafe-inline'; connect-src 'self' ipc: http://ipc.localhost https: http:",
"assetProtocol": {
"enable": true,
"scope": []