diff --git a/src-tauri/src/commands/auth.rs b/src-tauri/src/commands/auth.rs index e95c9b234..c3036023a 100644 --- a/src-tauri/src/commands/auth.rs +++ b/src-tauri/src/commands/auth.rs @@ -18,6 +18,7 @@ pub struct ManagedAuthAccount { pub avatar_url: Option, pub authenticated_at: i64, pub is_default: bool, + pub github_domain: String, } #[derive(Debug, Clone, serde::Serialize)] @@ -59,6 +60,7 @@ fn map_account( login: account.login, avatar_url: account.avatar_url, authenticated_at: account.authenticated_at, + github_domain: account.github_domain, } } @@ -79,6 +81,7 @@ fn map_device_code_response( #[tauri::command(rename_all = "camelCase")] pub async fn auth_start_login( auth_provider: String, + github_domain: Option, copilot_state: State<'_, CopilotAuthState>, codex_state: State<'_, CodexOAuthState>, ) -> Result { @@ -87,7 +90,7 @@ pub async fn auth_start_login( AUTH_PROVIDER_GITHUB_COPILOT => { let auth_manager = copilot_state.0.read().await; let response = auth_manager - .start_device_flow() + .start_device_flow(github_domain.as_deref()) .await .map_err(|e| e.to_string())?; Ok(map_device_code_response(auth_provider, response)) @@ -108,6 +111,7 @@ pub async fn auth_start_login( pub async fn auth_poll_for_account( auth_provider: String, device_code: String, + github_domain: Option, copilot_state: State<'_, CopilotAuthState>, codex_state: State<'_, CodexOAuthState>, ) -> Result, String> { @@ -115,7 +119,10 @@ pub async fn auth_poll_for_account( match auth_provider { AUTH_PROVIDER_GITHUB_COPILOT => { let auth_manager = copilot_state.0.write().await; - match auth_manager.poll_for_token(&device_code).await { + match auth_manager + .poll_for_token(&device_code, github_domain.as_deref()) + .await + { Ok(account) => { let default_account_id = auth_manager.get_status().await.default_account_id; Ok(account.map(|account| { diff --git a/src-tauri/src/commands/copilot.rs b/src-tauri/src/commands/copilot.rs index 7e36e8460..fb7104aa8 100644 --- a/src-tauri/src/commands/copilot.rs +++ b/src-tauri/src/commands/copilot.rs @@ -20,11 +20,12 @@ pub struct CopilotAuthState(pub Arc>); /// 返回设备码和用户码,用于 OAuth 认证 #[tauri::command] pub async fn copilot_start_device_flow( + github_domain: Option, state: State<'_, CopilotAuthState>, ) -> Result { let auth_manager = state.0.read().await; auth_manager - .start_device_flow() + .start_device_flow(github_domain.as_deref()) .await .map_err(|e| e.to_string()) } @@ -36,10 +37,14 @@ pub async fn copilot_start_device_flow( #[tauri::command(rename_all = "camelCase")] pub async fn copilot_poll_for_auth( device_code: String, + github_domain: Option, state: State<'_, CopilotAuthState>, ) -> Result { let auth_manager = state.0.write().await; - match auth_manager.poll_for_token(&device_code).await { + match auth_manager + .poll_for_token(&device_code, github_domain.as_deref()) + .await + { Ok(Some(_account)) => { log::info!("[CopilotAuth] 用户已授权"); Ok(true) @@ -61,10 +66,14 @@ pub async fn copilot_poll_for_auth( #[tauri::command(rename_all = "camelCase")] pub async fn copilot_poll_for_account( device_code: String, + github_domain: Option, state: State<'_, CopilotAuthState>, ) -> Result, String> { let auth_manager = state.0.write().await; - match auth_manager.poll_for_token(&device_code).await { + match auth_manager + .poll_for_token(&device_code, github_domain.as_deref()) + .await + { Ok(account) => Ok(account), Err(crate::proxy::providers::copilot_auth::CopilotAuthError::AuthorizationPending) => { Ok(None) diff --git a/src-tauri/src/proxy/providers/codex_oauth_auth.rs b/src-tauri/src/proxy/providers/codex_oauth_auth.rs index 508e7cd4a..945473f01 100644 --- a/src-tauri/src/proxy/providers/codex_oauth_auth.rs +++ b/src-tauri/src/proxy/providers/codex_oauth_auth.rs @@ -203,6 +203,7 @@ impl From<&CodexAccountData> for GitHubAccount { .unwrap_or_else(|| format!("ChatGPT ({})", &data.account_id)), avatar_url: None, authenticated_at: data.authenticated_at, + github_domain: "github.com".to_string(), } } } diff --git a/src-tauri/src/proxy/providers/copilot_auth.rs b/src-tauri/src/proxy/providers/copilot_auth.rs index 44a232712..c82011fc6 100644 --- a/src-tauri/src/proxy/providers/copilot_auth.rs +++ b/src-tauri/src/proxy/providers/copilot_auth.rs @@ -24,26 +24,114 @@ use std::path::PathBuf; use std::sync::Arc; use tokio::sync::{Mutex, RwLock}; -/// GitHub OAuth 客户端 ID(VS Code 使用的 ID) +/// GitHub OAuth 客户端 ID(VS Code)- 用于 github.com const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98"; +/// GitHub OAuth 客户端 ID(与 OpenCode 相同)- 在所有 GHES Copilot 实例上预注册 +const GITHUB_CLIENT_ID_GHES: &str = "Ov23li8tweQw6odWQebz"; + +/// 默认 GitHub 域名 +const DEFAULT_GITHUB_DOMAIN: &str = "github.com"; + +/// 根据域名选择 OAuth 客户端 ID +fn github_client_id(domain: &str) -> &'static str { + if domain == DEFAULT_GITHUB_DOMAIN { + GITHUB_CLIENT_ID + } else { + GITHUB_CLIENT_ID_GHES + } +} + +fn default_github_domain() -> String { + DEFAULT_GITHUB_DOMAIN.to_string() +} + /// GitHub 设备码 URL -const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code"; +fn github_device_code_url(domain: &str) -> String { + format!("https://{domain}/login/device/code") +} /// GitHub OAuth Token URL -const GITHUB_OAUTH_TOKEN_URL: &str = "https://github.com/login/oauth/access_token"; +fn github_oauth_token_url(domain: &str) -> String { + format!("https://{domain}/login/oauth/access_token") +} + +/// GitHub API 基础 URL(github.com 用 api.github.com,GHES 用 {domain}/api/v3) +fn github_api_base(domain: &str) -> String { + if domain == DEFAULT_GITHUB_DOMAIN { + "https://api.github.com".to_string() + } else { + format!("https://{domain}/api/v3") + } +} /// Copilot Token URL -const COPILOT_TOKEN_URL: &str = "https://api.github.com/copilot_internal/v2/token"; +fn copilot_token_url(domain: &str) -> String { + format!("{}/copilot_internal/v2/token", github_api_base(domain)) +} /// GitHub User API URL -const GITHUB_USER_URL: &str = "https://api.github.com/user"; +fn github_user_url(domain: &str) -> String { + format!("{}/user", github_api_base(domain)) +} + +/// Copilot 使用量 API URL +fn copilot_usage_url(domain: &str) -> String { + format!("{}/copilot_internal/user", github_api_base(domain)) +} + +/// Copilot API 基础地址(github.com 用 api.githubcopilot.com,GHES 用 copilot-api.{domain}) +fn copilot_api_base(domain: &str) -> String { + if domain == DEFAULT_GITHUB_DOMAIN { + "https://api.githubcopilot.com".to_string() + } else { + format!("https://copilot-api.{domain}") + } +} /// Token 刷新提前量(秒) const TOKEN_REFRESH_BUFFER_SECONDS: i64 = 60; -/// Copilot API 端点 -const COPILOT_MODELS_URL: &str = "https://api.githubcopilot.com/models"; +/// 判断是否为 GitHub Enterprise Server(非 github.com) +fn is_ghes(domain: &str) -> bool { + domain != DEFAULT_GITHUB_DOMAIN +} + +/// 归一化 GitHub 域名(SSOT): +/// - 小写化 +/// - 剥离协议(https:// http://) +/// - 剥离尾斜杠、path、query、fragment +/// - 拒绝包含 userinfo(@)的输入 +/// - 保留端口号(如有) +fn normalize_github_domain(raw: &str) -> Result { + let s = raw.trim(); + // 剥离协议 + let s = s + .strip_prefix("https://") + .or_else(|| s.strip_prefix("http://")) + .unwrap_or(s); + // 取 host 部分(到第一个 / 或 ? 或 #) + let host = s.split(&['/', '?', '#'][..]).next().unwrap_or(s); + // 拒绝 userinfo + if host.contains('@') { + return Err(CopilotAuthError::InvalidDomain(raw.to_string())); + } + let normalized = host.to_lowercase(); + if normalized.is_empty() { + return Err(CopilotAuthError::InvalidDomain(raw.to_string())); + } + Ok(normalized) +} + +/// 生成复合账号 ID,确保不同 GHES 实例的 user ID 不会冲突。 +/// github.com 账号保持原格式(向后兼容),GHES 账号使用 `domain:user_id` 格式。 +fn composite_account_id(domain: &str, user_id: u64) -> String { + if domain == DEFAULT_GITHUB_DOMAIN { + user_id.to_string() + } else { + format!("{}:{}", domain, user_id) + } +} /// Copilot API Header 常量 pub const COPILOT_EDITOR_VERSION: &str = "vscode/1.110.1"; @@ -52,12 +140,6 @@ pub const COPILOT_USER_AGENT: &str = "GitHubCopilotChat/0.38.2"; pub const COPILOT_API_VERSION: &str = "2025-10-01"; pub const COPILOT_INTEGRATION_ID: &str = "vscode-chat"; -/// Copilot 使用量 API URL -const COPILOT_USAGE_URL: &str = "https://api.github.com/copilot_internal/user"; - -/// 默认 Copilot API 端点 -const DEFAULT_COPILOT_API_ENDPOINT: &str = "https://api.githubcopilot.com"; - /// Copilot 使用量响应 #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CopilotUsageResponse { @@ -169,6 +251,9 @@ pub enum CopilotAuthError { #[error("账号不存在: {0}")] AccountNotFound(String), + + #[error("无效的 GitHub 域名: {0}")] + InvalidDomain(String), } impl From for CopilotAuthError { @@ -253,15 +338,19 @@ pub struct GitHubAccount { pub avatar_url: Option, /// 认证时间戳 pub authenticated_at: i64, + /// GitHub 域名(github.com 或 GHES 域名) + #[serde(default = "default_github_domain")] + pub github_domain: String, } impl From<&GitHubAccountData> for GitHubAccount { fn from(data: &GitHubAccountData) -> Self { GitHubAccount { - id: data.user.id.to_string(), + id: composite_account_id(&data.github_domain, data.user.id), login: data.user.login.clone(), avatar_url: data.user.avatar_url.clone(), authenticated_at: data.authenticated_at, + github_domain: data.github_domain.clone(), } } } @@ -295,6 +384,9 @@ struct GitHubAccountData { pub user: GitHubUser, /// 认证时间戳 pub authenticated_at: i64, + /// GitHub 域名(github.com 或 GHES 域名) + #[serde(default = "default_github_domain")] + pub github_domain: String, } /// 持久化存储结构(v3 多账号 + 默认账号格式) @@ -437,14 +529,16 @@ impl CopilotAuthManager { &self, github_token: String, user: GitHubUser, + github_domain: String, ) -> Result { - let account_id = user.id.to_string(); + let account_id = composite_account_id(&github_domain, user.id); let now = chrono::Utc::now().timestamp(); let account_data = GitHubAccountData { github_token, user: user.clone(), authenticated_at: now, + github_domain: github_domain.clone(), }; let account = GitHubAccount { @@ -452,6 +546,7 @@ impl CopilotAuthManager { login: user.login.clone(), avatar_url: user.avatar_url.clone(), authenticated_at: now, + github_domain, }; { @@ -497,15 +592,25 @@ impl CopilotAuthManager { // ==================== 设备码流程 ==================== /// 启动设备码流程 - pub async fn start_device_flow(&self) -> Result { - log::info!("[CopilotAuth] 启动设备码流程"); + pub async fn start_device_flow( + &self, + github_domain: Option<&str>, + ) -> Result { + let domain = match github_domain { + Some(d) => normalize_github_domain(d)?, + None => DEFAULT_GITHUB_DOMAIN.to_string(), + }; + log::info!("[CopilotAuth] 启动设备码流程 (domain: {domain})"); let response = self .http_client - .post(GITHUB_DEVICE_CODE_URL) + .post(github_device_code_url(&domain)) .header("Accept", "application/json") .header("User-Agent", COPILOT_USER_AGENT) - .form(&[("client_id", GITHUB_CLIENT_ID), ("scope", "read:user")]) + .form(&[ + ("client_id", github_client_id(&domain)), + ("scope", "read:user"), + ]) .send() .await?; @@ -534,16 +639,21 @@ impl CopilotAuthManager { pub async fn poll_for_token( &self, device_code: &str, + github_domain: Option<&str>, ) -> Result, CopilotAuthError> { - log::debug!("[CopilotAuth] 轮询 OAuth Token"); + let domain = match github_domain { + Some(d) => normalize_github_domain(d)?, + None => DEFAULT_GITHUB_DOMAIN.to_string(), + }; + log::debug!("[CopilotAuth] 轮询 OAuth Token (domain: {domain})"); let response = self .http_client - .post(GITHUB_OAUTH_TOKEN_URL) + .post(github_oauth_token_url(&domain)) .header("Accept", "application/json") .header("User-Agent", COPILOT_USER_AGENT) .form(&[ - ("client_id", GITHUB_CLIENT_ID), + ("client_id", github_client_id(&domain)), ("device_code", device_code), ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"), ]) @@ -578,14 +688,28 @@ impl CopilotAuthManager { log::info!("[CopilotAuth] OAuth Token 获取成功"); // 获取用户信息 - let user = self.fetch_user_info_with_token(&access_token).await?; - - // 验证 Copilot 订阅(获取 Copilot Token) - self.fetch_copilot_token_with_github_token(&access_token, &user.id.to_string()) + let user = self + .fetch_user_info_with_token(&access_token, &domain) .await?; + // GHES 无需换取 Copilot Token,直接使用 OAuth token 作为 Bearer + // 参考 OpenCode 的实现:GHE Copilot 直接用 OAuth token 调用 copilot-api.{domain} + if !is_ghes(&domain) { + // github.com:验证 Copilot 订阅(获取 Copilot Token) + self.fetch_copilot_token_with_github_token( + &access_token, + &user.id.to_string(), + &domain, + ) + .await?; + } else { + log::info!("[CopilotAuth] GHES 账号,跳过 Copilot Token 兑换,直接使用 OAuth token"); + } + // 添加账号 - let account = self.add_account_internal(access_token, user).await?; + let account = self + .add_account_internal(access_token, user, domain) + .await?; Ok(Some(account)) } @@ -600,6 +724,16 @@ impl CopilotAuthManager { // 确保迁移完成 self.ensure_migration_complete().await?; + // GHES 账号直接使用 GitHub OAuth token,无需 Copilot token 交换 + let domain = self.get_account_domain(account_id).await; + if is_ghes(&domain) { + let accounts = self.accounts.read().await; + return accounts + .get(account_id) + .map(|a| a.github_token.clone()) + .ok_or_else(|| CopilotAuthError::AccountNotFound(account_id.to_string())); + } + // 检查缓存的 token { let tokens = self.copilot_tokens.read().await; @@ -627,16 +761,16 @@ impl CopilotAuthManager { } // 获取账号的 GitHub token - let github_token = { + let (github_token, domain) = { let accounts = self.accounts.read().await; - accounts + let account = accounts .get(account_id) - .map(|a| a.github_token.clone()) - .ok_or_else(|| CopilotAuthError::AccountNotFound(account_id.to_string()))? + .ok_or_else(|| CopilotAuthError::AccountNotFound(account_id.to_string()))?; + (account.github_token.clone(), account.github_domain.clone()) }; // 刷新 Copilot token - self.fetch_copilot_token_with_github_token(&github_token, account_id) + self.fetch_copilot_token_with_github_token(&github_token, account_id, &domain) .await?; // 返回新 token @@ -687,11 +821,19 @@ impl CopilotAuthManager { ) -> Result, CopilotAuthError> { let copilot_token = self.get_valid_token_for_account(account_id).await?; + // 使用 get_api_endpoint() 动态解析 Copilot API 基础 URL。 + // 对于 github.com 账号,会查询 /copilot_internal/user 获取 endpoints.api 字段。 + // 对于 GHES 账号,/copilot_internal/user 可能不返回 endpoints——此时 + // get_api_endpoint() 会回退到 copilot_api_base(&domain),与之前的静态 URL + // 拼接结果一致。该回退行为是安全且符合预期的。 + let api_base = self.get_api_endpoint(account_id).await; + let models_url = format!("{}/models", api_base); + log::info!("[CopilotAuth] 获取账号 {account_id} 的 Copilot 可用模型"); let response = self .http_client - .get(COPILOT_MODELS_URL) + .get(&models_url) .header("Authorization", format!("Bearer {copilot_token}")) .header("Content-Type", "application/json") .header("copilot-integration-id", "vscode-chat") @@ -767,19 +909,19 @@ impl CopilotAuthManager { &self, account_id: &str, ) -> Result { - let github_token = { + let (github_token, domain) = { let accounts = self.accounts.read().await; - accounts + let account = accounts .get(account_id) - .map(|a| a.github_token.clone()) - .ok_or_else(|| CopilotAuthError::AccountNotFound(account_id.to_string()))? + .ok_or_else(|| CopilotAuthError::AccountNotFound(account_id.to_string()))?; + (account.github_token.clone(), account.github_domain.clone()) }; log::info!("[CopilotAuth] 获取账号 {account_id} 的 Copilot 使用量"); let response = self .http_client - .get(COPILOT_USAGE_URL) + .get(copilot_usage_url(&domain)) .header("Authorization", format!("token {github_token}")) .header("Content-Type", "application/json") .header("editor-version", COPILOT_EDITOR_VERSION) @@ -862,7 +1004,8 @@ impl CopilotAuthManager { log::debug!( "[CopilotAuth] 获取账号 {account_id} 动态 API 端点失败: {e},使用默认值" ); - DEFAULT_COPILOT_API_ENDPOINT.to_string() + let domain = self.get_account_domain(account_id).await; + copilot_api_base(&domain) } } } @@ -873,24 +1016,27 @@ impl CopilotAuthManager { match self.resolve_default_account_id().await { Some(id) => self.get_api_endpoint(&id).await, - None => DEFAULT_COPILOT_API_ENDPOINT.to_string(), + None => { + // 无账号时回退到 github.com 的默认端点 + copilot_api_base(DEFAULT_GITHUB_DOMAIN) + } } } async fn fetch_and_cache_endpoint(&self, account_id: &str) -> Result { - let github_token = { + let (github_token, domain) = { let accounts = self.accounts.read().await; - accounts + let account = accounts .get(account_id) - .map(|a| a.github_token.clone()) - .ok_or_else(|| CopilotAuthError::AccountNotFound(account_id.to_string()))? + .ok_or_else(|| CopilotAuthError::AccountNotFound(account_id.to_string()))?; + (account.github_token.clone(), account.github_domain.clone()) }; log::debug!("[CopilotAuth] 为账号 {account_id} 惰性拉取动态 API 端点"); let response = self .http_client - .get(COPILOT_USAGE_URL) + .get(copilot_usage_url(&domain)) .header("Authorization", format!("token {github_token}")) .header("Content-Type", "application/json") .header("editor-version", COPILOT_EDITOR_VERSION) @@ -918,7 +1064,7 @@ impl CopilotAuthManager { let endpoint = match usage.endpoints { Some(endpoints) => endpoints.api.clone(), - None => DEFAULT_COPILOT_API_ENDPOINT.to_string(), + None => copilot_api_base(&domain), }; // 缓存端点(包括默认值),避免重复请求 @@ -1075,6 +1221,15 @@ impl CopilotAuthManager { Self::fallback_default_account_id(&accounts) } + /// 获取指定账号的 GitHub 域名 + async fn get_account_domain(&self, account_id: &str) -> String { + let accounts = self.accounts.read().await; + accounts + .get(account_id) + .map(|a| a.github_domain.clone()) + .unwrap_or_else(|| DEFAULT_GITHUB_DOMAIN.to_string()) + } + async fn get_refresh_lock(&self, account_id: &str) -> Arc> { { let refresh_locks = self.refresh_locks.read().await; @@ -1155,10 +1310,11 @@ impl CopilotAuthManager { async fn fetch_user_info_with_token( &self, github_token: &str, + domain: &str, ) -> Result { let response = self .http_client - .get(GITHUB_USER_URL) + .get(github_user_url(domain)) .header("Authorization", format!("token {github_token}")) .header("User-Agent", COPILOT_USER_AGENT) .header("Editor-Version", COPILOT_EDITOR_VERSION) @@ -1185,12 +1341,13 @@ impl CopilotAuthManager { &self, github_token: &str, account_id: &str, + domain: &str, ) -> Result<(), CopilotAuthError> { - log::debug!("[CopilotAuth] 获取账号 {account_id} 的 Copilot Token"); + log::debug!("[CopilotAuth] 获取账号 {account_id} 的 Copilot Token (domain: {domain})"); let response = self .http_client - .get(COPILOT_TOKEN_URL) + .get(copilot_token_url(domain)) .header("Authorization", format!("token {github_token}")) .header("User-Agent", COPILOT_USER_AGENT) .header("Editor-Version", COPILOT_EDITOR_VERSION) @@ -1284,20 +1441,32 @@ impl CopilotAuthManager { log::info!("[CopilotAuth] 执行旧格式迁移"); // 获取用户信息 - match self.fetch_user_info_with_token(&legacy_token).await { + match self + .fetch_user_info_with_token(&legacy_token, DEFAULT_GITHUB_DOMAIN) + .await + { Ok(user) => { - let account_id = user.id.to_string(); + let account_id = composite_account_id(DEFAULT_GITHUB_DOMAIN, user.id); // 尝试获取 Copilot token 验证订阅 if let Err(e) = self - .fetch_copilot_token_with_github_token(&legacy_token, &account_id) + .fetch_copilot_token_with_github_token( + &legacy_token, + &account_id, + DEFAULT_GITHUB_DOMAIN, + ) .await { log::warn!("[CopilotAuth] 迁移时验证 Copilot 订阅失败: {e}"); } // 添加账号 - self.add_account_internal(legacy_token, user).await?; + self.add_account_internal( + legacy_token, + user, + DEFAULT_GITHUB_DOMAIN.to_string(), + ) + .await?; self.set_migration_error(None).await; log::info!("[CopilotAuth] 旧格式迁移完成"); @@ -1387,6 +1556,7 @@ mod tests { login: "testuser".to_string(), avatar_url: Some("https://example.com/avatar.png".to_string()), authenticated_at: 1234567890, + github_domain: DEFAULT_GITHUB_DOMAIN.to_string(), }], default_account_id: Some("12345".to_string()), migration_error: None, @@ -1420,6 +1590,7 @@ mod tests { avatar_url: Some("https://example.com/alice.png".to_string()), }, authenticated_at: 1700000000, + github_domain: DEFAULT_GITHUB_DOMAIN.to_string(), }, ); accounts.insert( @@ -1432,6 +1603,7 @@ mod tests { avatar_url: None, }, authenticated_at: 1700000001, + github_domain: DEFAULT_GITHUB_DOMAIN.to_string(), }, ); @@ -1479,6 +1651,7 @@ mod tests { avatar_url: Some("https://example.com/avatar.png".to_string()), }, authenticated_at: 1700000000, + github_domain: DEFAULT_GITHUB_DOMAIN.to_string(), }; let account = GitHubAccount::from(&data); @@ -1504,6 +1677,7 @@ mod tests { avatar_url: None, }, authenticated_at: 1700000000, + github_domain: DEFAULT_GITHUB_DOMAIN.to_string(), }, ); accounts.insert( @@ -1516,6 +1690,7 @@ mod tests { avatar_url: None, }, authenticated_at: 1700000001, + github_domain: DEFAULT_GITHUB_DOMAIN.to_string(), }, ); @@ -1546,6 +1721,7 @@ mod tests { avatar_url: None, }, authenticated_at: 1700000000, + github_domain: DEFAULT_GITHUB_DOMAIN.to_string(), }, ); } @@ -1630,6 +1806,7 @@ mod tests { avatar_url: None, }, authenticated_at: 1700000000, + github_domain: DEFAULT_GITHUB_DOMAIN.to_string(), }, ); } @@ -1664,6 +1841,7 @@ mod tests { avatar_url: None, }, authenticated_at: 1700000000, + github_domain: DEFAULT_GITHUB_DOMAIN.to_string(), }, ); } @@ -1746,6 +1924,7 @@ mod tests { avatar_url: None, }, authenticated_at: 1700000000, + github_domain: DEFAULT_GITHUB_DOMAIN.to_string(), }, ); } @@ -1801,7 +1980,7 @@ mod tests { let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf()); let endpoint = manager.get_api_endpoint("12345").await; - assert_eq!(endpoint, DEFAULT_COPILOT_API_ENDPOINT); + assert_eq!(endpoint, copilot_api_base(DEFAULT_GITHUB_DOMAIN)); } #[tokio::test] @@ -1817,4 +1996,99 @@ mod tests { other => panic!("期望 AccountNotFound 错误,实际: {other:?}"), } } + + #[test] + fn test_normalize_github_domain() { + // 基本用法 + assert_eq!(normalize_github_domain("github.com").unwrap(), "github.com"); + assert_eq!( + normalize_github_domain("company.ghe.com").unwrap(), + "company.ghe.com" + ); + + // 剥离协议 + assert_eq!( + normalize_github_domain("https://company.ghe.com").unwrap(), + "company.ghe.com" + ); + assert_eq!( + normalize_github_domain("http://company.ghe.com").unwrap(), + "company.ghe.com" + ); + + // 小写化 + assert_eq!(normalize_github_domain("GitHub.COM").unwrap(), "github.com"); + assert_eq!( + normalize_github_domain("Company.GHE.Com").unwrap(), + "company.ghe.com" + ); + + // 剥离尾斜杠和 path + assert_eq!( + normalize_github_domain("company.ghe.com/").unwrap(), + "company.ghe.com" + ); + assert_eq!( + normalize_github_domain("company.ghe.com/api/v3").unwrap(), + "company.ghe.com" + ); + + // 剥离 query 和 fragment + assert_eq!( + normalize_github_domain("company.ghe.com?foo=bar").unwrap(), + "company.ghe.com" + ); + assert_eq!( + normalize_github_domain("company.ghe.com#section").unwrap(), + "company.ghe.com" + ); + + // 保留端口 + assert_eq!( + normalize_github_domain("company.ghe.com:8443").unwrap(), + "company.ghe.com:8443" + ); + + // 拒绝 userinfo + assert!(normalize_github_domain("user@company.ghe.com").is_err()); + + // 拒绝空输入 + assert!(normalize_github_domain("").is_err()); + assert!(normalize_github_domain(" ").is_err()); + } + + #[test] + fn test_composite_account_id() { + // github.com 保持原格式(向后兼容) + assert_eq!(composite_account_id("github.com", 12345), "12345"); + + // GHES 使用复合格式 + assert_eq!( + composite_account_id("company.ghe.com", 12345), + "company.ghe.com:12345" + ); + + // 不同 GHES 实例,相同 user ID,不冲突 + assert_ne!( + composite_account_id("a.ghe.com", 1), + composite_account_id("b.ghe.com", 1) + ); + } + + #[test] + fn test_github_account_from_data_ghes_uses_composite_id() { + let data = GitHubAccountData { + github_token: "gho_test".to_string(), + user: GitHubUser { + login: "testuser".to_string(), + id: 99999, + avatar_url: None, + }, + authenticated_at: 1700000000, + github_domain: "company.ghe.com".to_string(), + }; + + let account = GitHubAccount::from(&data); + assert_eq!(account.id, "company.ghe.com:99999"); + } } diff --git a/src/components/providers/forms/CopilotAuthSection.tsx b/src/components/providers/forms/CopilotAuthSection.tsx index 608eb6d0f..c328c3fc0 100644 --- a/src/components/providers/forms/CopilotAuthSection.tsx +++ b/src/components/providers/forms/CopilotAuthSection.tsx @@ -3,6 +3,7 @@ import { useTranslation } from "react-i18next"; import { Button } from "@/components/ui/button"; import { Badge } from "@/components/ui/badge"; import { Label } from "@/components/ui/label"; +import { Input } from "@/components/ui/input"; import { Select, SelectContent, @@ -45,6 +46,19 @@ export const CopilotAuthSection: React.FC = ({ }) => { const { t } = useTranslation(); const [copied, setCopied] = React.useState(false); + const [deploymentType, setDeploymentType] = React.useState< + "github.com" | "enterprise" + >("github.com"); + const [enterpriseDomain, setEnterpriseDomain] = React.useState(""); + + // 根据部署类型计算实际的 GitHub 域名 + const effectiveGithubDomain = + deploymentType === "enterprise" && enterpriseDomain.trim() + ? enterpriseDomain + .trim() + .replace(/^https?:\/\//, "") + .replace(/\/$/, "") + : undefined; const { accounts, @@ -63,7 +77,7 @@ export const CopilotAuthSection: React.FC = ({ setDefaultAccount, cancelAuth, logout, - } = useCopilotAuth(); + } = useCopilotAuth(effectiveGithubDomain); // 复制用户码 const copyUserCode = async () => { @@ -113,6 +127,41 @@ export const CopilotAuthSection: React.FC = ({ + {/* GitHub 部署类型选择 */} +
+ + + {deploymentType === "enterprise" && ( + setEnterpriseDomain(e.target.value)} + /> + )} +
+ {migrationError && (

{t("copilot.migrationFailed", { @@ -179,6 +228,12 @@ export const CopilotAuthSection: React.FC = ({ {t("copilot.defaultAccount", "默认")} )} + {account.github_domain && + account.github_domain !== "github.com" && ( + + {account.github_domain} + + )} {selectedAccountId === account.id && ( {t("copilot.selected", "已选中")} @@ -223,6 +278,7 @@ export const CopilotAuthSection: React.FC = ({ onClick={addAccount} className="w-full" variant="outline" + disabled={deploymentType === "enterprise" && !enterpriseDomain.trim()} > {t("copilot.loginWithGitHub", "使用 GitHub 登录")} @@ -236,7 +292,10 @@ export const CopilotAuthSection: React.FC = ({ onClick={addAccount} className="w-full" variant="outline" - disabled={isAddingAccount} + disabled={ + isAddingAccount || + (deploymentType === "enterprise" && !enterpriseDomain.trim()) + } > {t("copilot.addAnotherAccount", "添加其他账号")} diff --git a/src/components/providers/forms/hooks/useCopilotAuth.ts b/src/components/providers/forms/hooks/useCopilotAuth.ts index 7b2c45010..8600d7d65 100644 --- a/src/components/providers/forms/hooks/useCopilotAuth.ts +++ b/src/components/providers/forms/hooks/useCopilotAuth.ts @@ -1,8 +1,8 @@ import type { GitHubAccount } from "@/lib/api"; import { useManagedAuth } from "./useManagedAuth"; -export function useCopilotAuth() { - const managedAuth = useManagedAuth("github_copilot"); +export function useCopilotAuth(githubDomain?: string) { + const managedAuth = useManagedAuth("github_copilot", githubDomain); const defaultAccount = managedAuth.accounts.find( (account) => account.id === managedAuth.defaultAccountId, diff --git a/src/components/providers/forms/hooks/useManagedAuth.ts b/src/components/providers/forms/hooks/useManagedAuth.ts index 138f3f726..86360b397 100644 --- a/src/components/providers/forms/hooks/useManagedAuth.ts +++ b/src/components/providers/forms/hooks/useManagedAuth.ts @@ -10,7 +10,10 @@ import type { type PollingState = "idle" | "polling" | "success" | "error"; -export function useManagedAuth(authProvider: ManagedAuthProvider) { +export function useManagedAuth( + authProvider: ManagedAuthProvider, + githubDomain?: string, +) { const queryClient = useQueryClient(); const queryKey = ["managed-auth-status", authProvider]; @@ -52,7 +55,7 @@ export function useManagedAuth(authProvider: ManagedAuthProvider) { }, [stopPolling]); const startLoginMutation = useMutation({ - mutationFn: () => authApi.authStartLogin(authProvider), + mutationFn: () => authApi.authStartLogin(authProvider, githubDomain), onSuccess: async (response) => { setDeviceCode(response); setPollingState("polling"); @@ -87,6 +90,7 @@ export function useManagedAuth(authProvider: ManagedAuthProvider) { const newAccount = await authApi.authPollForAccount( authProvider, response.device_code, + githubDomain, ); if (newAccount) { stopPolling(); diff --git a/src/i18n/locales/en.json b/src/i18n/locales/en.json index 0872a4e93..b7e4aae87 100644 --- a/src/i18n/locales/en.json +++ b/src/i18n/locales/en.json @@ -889,7 +889,11 @@ "retry": "Retry", "copyCode": "Copy code", "migrationFailed": "Legacy auth migration failed: {{error}}", - "loadModelsFailed": "Failed to load Copilot models" + "loadModelsFailed": "Failed to load Copilot models", + "deploymentType": "GitHub Deployment Type", + "deploymentGitHubCom": "GitHub.com", + "deploymentEnterprise": "GitHub Enterprise Server", + "enterpriseDomainPlaceholder": "e.g. company.ghe.com" }, "codexOauth": { "authStatus": "Auth status", diff --git a/src/i18n/locales/ja.json b/src/i18n/locales/ja.json index d7299c95c..c76362a6d 100644 --- a/src/i18n/locales/ja.json +++ b/src/i18n/locales/ja.json @@ -889,7 +889,11 @@ "retry": "再試行", "copyCode": "コードをコピー", "migrationFailed": "旧認証データの移行に失敗しました: {{error}}", - "loadModelsFailed": "Copilot モデル一覧の読み込みに失敗しました" + "loadModelsFailed": "Copilot モデル一覧の読み込みに失敗しました", + "deploymentType": "GitHub デプロイメントタイプ", + "deploymentGitHubCom": "GitHub.com", + "deploymentEnterprise": "GitHub Enterprise Server", + "enterpriseDomainPlaceholder": "例: company.ghe.com" }, "codexOauth": { "authStatus": "認証状態", diff --git a/src/i18n/locales/zh.json b/src/i18n/locales/zh.json index 93899f663..f3cbbeb17 100644 --- a/src/i18n/locales/zh.json +++ b/src/i18n/locales/zh.json @@ -890,7 +890,11 @@ "retry": "重试", "copyCode": "复制代码", "migrationFailed": "旧认证数据迁移失败:{{error}}", - "loadModelsFailed": "加载 Copilot 模型列表失败" + "loadModelsFailed": "加载 Copilot 模型列表失败", + "deploymentType": "GitHub 部署类型", + "deploymentGitHubCom": "GitHub.com", + "deploymentEnterprise": "GitHub Enterprise Server", + "enterpriseDomainPlaceholder": "例如:company.ghe.com" }, "codexOauth": { "authStatus": "认证状态", diff --git a/src/lib/api/auth.ts b/src/lib/api/auth.ts index a4d840ffa..294661802 100644 --- a/src/lib/api/auth.ts +++ b/src/lib/api/auth.ts @@ -9,6 +9,7 @@ export interface ManagedAuthAccount { avatar_url: string | null; authenticated_at: number; is_default: boolean; + github_domain: string; } export interface ManagedAuthStatus { @@ -30,19 +31,23 @@ export interface ManagedAuthDeviceCodeResponse { export async function authStartLogin( authProvider: ManagedAuthProvider, + githubDomain?: string, ): Promise { return invoke("auth_start_login", { authProvider, + githubDomain: githubDomain || null, }); } export async function authPollForAccount( authProvider: ManagedAuthProvider, deviceCode: string, + githubDomain?: string, ): Promise { return invoke("auth_poll_for_account", { authProvider, deviceCode, + githubDomain: githubDomain || null, }); } diff --git a/src/lib/api/copilot.ts b/src/lib/api/copilot.ts index 09eb55b00..39089574a 100644 --- a/src/lib/api/copilot.ts +++ b/src/lib/api/copilot.ts @@ -30,6 +30,8 @@ export interface GitHubAccount { avatar_url: string | null; /** 认证时间戳(Unix 秒) */ authenticated_at: number; + /** GitHub 域名(github.com 或 GHES 域名) */ + github_domain: string; } /**