feat(copilot): add GitHub Enterprise Server support (#2175)

* feat(copilot): add GitHub Enterprise Server support

* fix(copilot): address GHES PR review findings (P1 + 2×P2)

- P1: Use composite account ID (domain:user_id) for GHES to prevent
  cross-instance ID collisions; github.com keeps plain numeric ID for
  backward compatibilit
- P2-a: Use get_api_endpoint() for model list URL with automatic
  fallback to static URL when dynamic endpoint resolution fails
- P2-b: Add normalize_github_domain() as backend SSOT for domain
  normalization (lowercase, strip protocol/path/query, reject userinfo)

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
hotelbe
2026-04-19 20:29:46 +08:00
committed by GitHub
parent 9871d3d1eb
commit 87635e7fc6
12 changed files with 441 additions and 68 deletions
+9 -2
View File
@@ -18,6 +18,7 @@ pub struct ManagedAuthAccount {
pub avatar_url: Option<String>,
pub authenticated_at: i64,
pub is_default: bool,
pub github_domain: String,
}
#[derive(Debug, Clone, serde::Serialize)]
@@ -59,6 +60,7 @@ fn map_account(
login: account.login,
avatar_url: account.avatar_url,
authenticated_at: account.authenticated_at,
github_domain: account.github_domain,
}
}
@@ -79,6 +81,7 @@ fn map_device_code_response(
#[tauri::command(rename_all = "camelCase")]
pub async fn auth_start_login(
auth_provider: String,
github_domain: Option<String>,
copilot_state: State<'_, CopilotAuthState>,
codex_state: State<'_, CodexOAuthState>,
) -> Result<ManagedAuthDeviceCodeResponse, String> {
@@ -87,7 +90,7 @@ pub async fn auth_start_login(
AUTH_PROVIDER_GITHUB_COPILOT => {
let auth_manager = copilot_state.0.read().await;
let response = auth_manager
.start_device_flow()
.start_device_flow(github_domain.as_deref())
.await
.map_err(|e| e.to_string())?;
Ok(map_device_code_response(auth_provider, response))
@@ -108,6 +111,7 @@ pub async fn auth_start_login(
pub async fn auth_poll_for_account(
auth_provider: String,
device_code: String,
github_domain: Option<String>,
copilot_state: State<'_, CopilotAuthState>,
codex_state: State<'_, CodexOAuthState>,
) -> Result<Option<ManagedAuthAccount>, String> {
@@ -115,7 +119,10 @@ pub async fn auth_poll_for_account(
match auth_provider {
AUTH_PROVIDER_GITHUB_COPILOT => {
let auth_manager = copilot_state.0.write().await;
match auth_manager.poll_for_token(&device_code).await {
match auth_manager
.poll_for_token(&device_code, github_domain.as_deref())
.await
{
Ok(account) => {
let default_account_id = auth_manager.get_status().await.default_account_id;
Ok(account.map(|account| {
+12 -3
View File
@@ -20,11 +20,12 @@ pub struct CopilotAuthState(pub Arc<RwLock<CopilotAuthManager>>);
/// 返回设备码和用户码,用于 OAuth 认证
#[tauri::command]
pub async fn copilot_start_device_flow(
github_domain: Option<String>,
state: State<'_, CopilotAuthState>,
) -> Result<GitHubDeviceCodeResponse, String> {
let auth_manager = state.0.read().await;
auth_manager
.start_device_flow()
.start_device_flow(github_domain.as_deref())
.await
.map_err(|e| e.to_string())
}
@@ -36,10 +37,14 @@ pub async fn copilot_start_device_flow(
#[tauri::command(rename_all = "camelCase")]
pub async fn copilot_poll_for_auth(
device_code: String,
github_domain: Option<String>,
state: State<'_, CopilotAuthState>,
) -> Result<bool, String> {
let auth_manager = state.0.write().await;
match auth_manager.poll_for_token(&device_code).await {
match auth_manager
.poll_for_token(&device_code, github_domain.as_deref())
.await
{
Ok(Some(_account)) => {
log::info!("[CopilotAuth] 用户已授权");
Ok(true)
@@ -61,10 +66,14 @@ pub async fn copilot_poll_for_auth(
#[tauri::command(rename_all = "camelCase")]
pub async fn copilot_poll_for_account(
device_code: String,
github_domain: Option<String>,
state: State<'_, CopilotAuthState>,
) -> Result<Option<GitHubAccount>, String> {
let auth_manager = state.0.write().await;
match auth_manager.poll_for_token(&device_code).await {
match auth_manager
.poll_for_token(&device_code, github_domain.as_deref())
.await
{
Ok(account) => Ok(account),
Err(crate::proxy::providers::copilot_auth::CopilotAuthError::AuthorizationPending) => {
Ok(None)
@@ -203,6 +203,7 @@ impl From<&CodexAccountData> for GitHubAccount {
.unwrap_or_else(|| format!("ChatGPT ({})", &data.account_id)),
avatar_url: None,
authenticated_at: data.authenticated_at,
github_domain: "github.com".to_string(),
}
}
}
+328 -54
View File
@@ -24,26 +24,114 @@ use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
/// GitHub OAuth 客户端 IDVS Code 使用的 ID
/// GitHub OAuth 客户端 IDVS Code- 用于 github.com
const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
/// GitHub OAuth 客户端 ID(与 OpenCode 相同)- 在所有 GHES Copilot 实例上预注册
const GITHUB_CLIENT_ID_GHES: &str = "Ov23li8tweQw6odWQebz";
/// 默认 GitHub 域名
const DEFAULT_GITHUB_DOMAIN: &str = "github.com";
/// 根据域名选择 OAuth 客户端 ID
fn github_client_id(domain: &str) -> &'static str {
if domain == DEFAULT_GITHUB_DOMAIN {
GITHUB_CLIENT_ID
} else {
GITHUB_CLIENT_ID_GHES
}
}
fn default_github_domain() -> String {
DEFAULT_GITHUB_DOMAIN.to_string()
}
/// GitHub 设备码 URL
const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
fn github_device_code_url(domain: &str) -> String {
format!("https://{domain}/login/device/code")
}
/// GitHub OAuth Token URL
const GITHUB_OAUTH_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
fn github_oauth_token_url(domain: &str) -> String {
format!("https://{domain}/login/oauth/access_token")
}
/// GitHub API 基础 URLgithub.com 用 api.github.comGHES 用 {domain}/api/v3
fn github_api_base(domain: &str) -> String {
if domain == DEFAULT_GITHUB_DOMAIN {
"https://api.github.com".to_string()
} else {
format!("https://{domain}/api/v3")
}
}
/// Copilot Token URL
const COPILOT_TOKEN_URL: &str = "https://api.github.com/copilot_internal/v2/token";
fn copilot_token_url(domain: &str) -> String {
format!("{}/copilot_internal/v2/token", github_api_base(domain))
}
/// GitHub User API URL
const GITHUB_USER_URL: &str = "https://api.github.com/user";
fn github_user_url(domain: &str) -> String {
format!("{}/user", github_api_base(domain))
}
/// Copilot 使用量 API URL
fn copilot_usage_url(domain: &str) -> String {
format!("{}/copilot_internal/user", github_api_base(domain))
}
/// Copilot API 基础地址(github.com 用 api.githubcopilot.comGHES 用 copilot-api.{domain}
fn copilot_api_base(domain: &str) -> String {
if domain == DEFAULT_GITHUB_DOMAIN {
"https://api.githubcopilot.com".to_string()
} else {
format!("https://copilot-api.{domain}")
}
}
/// Token 刷新提前量(秒)
const TOKEN_REFRESH_BUFFER_SECONDS: i64 = 60;
/// Copilot API 端点
const COPILOT_MODELS_URL: &str = "https://api.githubcopilot.com/models";
/// 判断是否为 GitHub Enterprise Server(非 github.com
fn is_ghes(domain: &str) -> bool {
domain != DEFAULT_GITHUB_DOMAIN
}
/// 归一化 GitHub 域名(SSOT):
/// - 小写化
/// - 剥离协议(https:// http://
/// - 剥离尾斜杠、path、query、fragment
/// - 拒绝包含 userinfo@)的输入
/// - 保留端口号(如有)
fn normalize_github_domain(raw: &str) -> Result<String, CopilotAuthError> {
let s = raw.trim();
// 剥离协议
let s = s
.strip_prefix("https://")
.or_else(|| s.strip_prefix("http://"))
.unwrap_or(s);
// 取 host 部分(到第一个 / 或 ? 或 #)
let host = s.split(&['/', '?', '#'][..]).next().unwrap_or(s);
// 拒绝 userinfo
if host.contains('@') {
return Err(CopilotAuthError::InvalidDomain(raw.to_string()));
}
let normalized = host.to_lowercase();
if normalized.is_empty() {
return Err(CopilotAuthError::InvalidDomain(raw.to_string()));
}
Ok(normalized)
}
/// 生成复合账号 ID,确保不同 GHES 实例的 user ID 不会冲突。
/// github.com 账号保持原格式(向后兼容),GHES 账号使用 `domain:user_id` 格式。
fn composite_account_id(domain: &str, user_id: u64) -> String {
if domain == DEFAULT_GITHUB_DOMAIN {
user_id.to_string()
} else {
format!("{}:{}", domain, user_id)
}
}
/// Copilot API Header 常量
pub const COPILOT_EDITOR_VERSION: &str = "vscode/1.110.1";
@@ -52,12 +140,6 @@ pub const COPILOT_USER_AGENT: &str = "GitHubCopilotChat/0.38.2";
pub const COPILOT_API_VERSION: &str = "2025-10-01";
pub const COPILOT_INTEGRATION_ID: &str = "vscode-chat";
/// Copilot 使用量 API URL
const COPILOT_USAGE_URL: &str = "https://api.github.com/copilot_internal/user";
/// 默认 Copilot API 端点
const DEFAULT_COPILOT_API_ENDPOINT: &str = "https://api.githubcopilot.com";
/// Copilot 使用量响应
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CopilotUsageResponse {
@@ -169,6 +251,9 @@ pub enum CopilotAuthError {
#[error("账号不存在: {0}")]
AccountNotFound(String),
#[error("无效的 GitHub 域名: {0}")]
InvalidDomain(String),
}
impl From<reqwest::Error> for CopilotAuthError {
@@ -253,15 +338,19 @@ pub struct GitHubAccount {
pub avatar_url: Option<String>,
/// 认证时间戳
pub authenticated_at: i64,
/// GitHub 域名(github.com 或 GHES 域名)
#[serde(default = "default_github_domain")]
pub github_domain: String,
}
impl From<&GitHubAccountData> for GitHubAccount {
fn from(data: &GitHubAccountData) -> Self {
GitHubAccount {
id: data.user.id.to_string(),
id: composite_account_id(&data.github_domain, data.user.id),
login: data.user.login.clone(),
avatar_url: data.user.avatar_url.clone(),
authenticated_at: data.authenticated_at,
github_domain: data.github_domain.clone(),
}
}
}
@@ -295,6 +384,9 @@ struct GitHubAccountData {
pub user: GitHubUser,
/// 认证时间戳
pub authenticated_at: i64,
/// GitHub 域名(github.com 或 GHES 域名)
#[serde(default = "default_github_domain")]
pub github_domain: String,
}
/// 持久化存储结构(v3 多账号 + 默认账号格式)
@@ -437,14 +529,16 @@ impl CopilotAuthManager {
&self,
github_token: String,
user: GitHubUser,
github_domain: String,
) -> Result<GitHubAccount, CopilotAuthError> {
let account_id = user.id.to_string();
let account_id = composite_account_id(&github_domain, user.id);
let now = chrono::Utc::now().timestamp();
let account_data = GitHubAccountData {
github_token,
user: user.clone(),
authenticated_at: now,
github_domain: github_domain.clone(),
};
let account = GitHubAccount {
@@ -452,6 +546,7 @@ impl CopilotAuthManager {
login: user.login.clone(),
avatar_url: user.avatar_url.clone(),
authenticated_at: now,
github_domain,
};
{
@@ -497,15 +592,25 @@ impl CopilotAuthManager {
// ==================== 设备码流程 ====================
/// 启动设备码流程
pub async fn start_device_flow(&self) -> Result<GitHubDeviceCodeResponse, CopilotAuthError> {
log::info!("[CopilotAuth] 启动设备码流程");
pub async fn start_device_flow(
&self,
github_domain: Option<&str>,
) -> Result<GitHubDeviceCodeResponse, CopilotAuthError> {
let domain = match github_domain {
Some(d) => normalize_github_domain(d)?,
None => DEFAULT_GITHUB_DOMAIN.to_string(),
};
log::info!("[CopilotAuth] 启动设备码流程 (domain: {domain})");
let response = self
.http_client
.post(GITHUB_DEVICE_CODE_URL)
.post(github_device_code_url(&domain))
.header("Accept", "application/json")
.header("User-Agent", COPILOT_USER_AGENT)
.form(&[("client_id", GITHUB_CLIENT_ID), ("scope", "read:user")])
.form(&[
("client_id", github_client_id(&domain)),
("scope", "read:user"),
])
.send()
.await?;
@@ -534,16 +639,21 @@ impl CopilotAuthManager {
pub async fn poll_for_token(
&self,
device_code: &str,
github_domain: Option<&str>,
) -> Result<Option<GitHubAccount>, CopilotAuthError> {
log::debug!("[CopilotAuth] 轮询 OAuth Token");
let domain = match github_domain {
Some(d) => normalize_github_domain(d)?,
None => DEFAULT_GITHUB_DOMAIN.to_string(),
};
log::debug!("[CopilotAuth] 轮询 OAuth Token (domain: {domain})");
let response = self
.http_client
.post(GITHUB_OAUTH_TOKEN_URL)
.post(github_oauth_token_url(&domain))
.header("Accept", "application/json")
.header("User-Agent", COPILOT_USER_AGENT)
.form(&[
("client_id", GITHUB_CLIENT_ID),
("client_id", github_client_id(&domain)),
("device_code", device_code),
("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
])
@@ -578,14 +688,28 @@ impl CopilotAuthManager {
log::info!("[CopilotAuth] OAuth Token 获取成功");
// 获取用户信息
let user = self.fetch_user_info_with_token(&access_token).await?;
// 验证 Copilot 订阅(获取 Copilot Token
self.fetch_copilot_token_with_github_token(&access_token, &user.id.to_string())
let user = self
.fetch_user_info_with_token(&access_token, &domain)
.await?;
// GHES 无需换取 Copilot Token,直接使用 OAuth token 作为 Bearer
// 参考 OpenCode 的实现:GHE Copilot 直接用 OAuth token 调用 copilot-api.{domain}
if !is_ghes(&domain) {
// github.com:验证 Copilot 订阅(获取 Copilot Token
self.fetch_copilot_token_with_github_token(
&access_token,
&user.id.to_string(),
&domain,
)
.await?;
} else {
log::info!("[CopilotAuth] GHES 账号,跳过 Copilot Token 兑换,直接使用 OAuth token");
}
// 添加账号
let account = self.add_account_internal(access_token, user).await?;
let account = self
.add_account_internal(access_token, user, domain)
.await?;
Ok(Some(account))
}
@@ -600,6 +724,16 @@ impl CopilotAuthManager {
// 确保迁移完成
self.ensure_migration_complete().await?;
// GHES 账号直接使用 GitHub OAuth token,无需 Copilot token 交换
let domain = self.get_account_domain(account_id).await;
if is_ghes(&domain) {
let accounts = self.accounts.read().await;
return accounts
.get(account_id)
.map(|a| a.github_token.clone())
.ok_or_else(|| CopilotAuthError::AccountNotFound(account_id.to_string()));
}
// 检查缓存的 token
{
let tokens = self.copilot_tokens.read().await;
@@ -627,16 +761,16 @@ impl CopilotAuthManager {
}
// 获取账号的 GitHub token
let github_token = {
let (github_token, domain) = {
let accounts = self.accounts.read().await;
accounts
let account = accounts
.get(account_id)
.map(|a| a.github_token.clone())
.ok_or_else(|| CopilotAuthError::AccountNotFound(account_id.to_string()))?
.ok_or_else(|| CopilotAuthError::AccountNotFound(account_id.to_string()))?;
(account.github_token.clone(), account.github_domain.clone())
};
// 刷新 Copilot token
self.fetch_copilot_token_with_github_token(&github_token, account_id)
self.fetch_copilot_token_with_github_token(&github_token, account_id, &domain)
.await?;
// 返回新 token
@@ -687,11 +821,19 @@ impl CopilotAuthManager {
) -> Result<Vec<CopilotModel>, CopilotAuthError> {
let copilot_token = self.get_valid_token_for_account(account_id).await?;
// 使用 get_api_endpoint() 动态解析 Copilot API 基础 URL。
// 对于 github.com 账号,会查询 /copilot_internal/user 获取 endpoints.api 字段。
// 对于 GHES 账号,/copilot_internal/user 可能不返回 endpoints——此时
// get_api_endpoint() 会回退到 copilot_api_base(&domain),与之前的静态 URL
// 拼接结果一致。该回退行为是安全且符合预期的。
let api_base = self.get_api_endpoint(account_id).await;
let models_url = format!("{}/models", api_base);
log::info!("[CopilotAuth] 获取账号 {account_id} 的 Copilot 可用模型");
let response = self
.http_client
.get(COPILOT_MODELS_URL)
.get(&models_url)
.header("Authorization", format!("Bearer {copilot_token}"))
.header("Content-Type", "application/json")
.header("copilot-integration-id", "vscode-chat")
@@ -767,19 +909,19 @@ impl CopilotAuthManager {
&self,
account_id: &str,
) -> Result<CopilotUsageResponse, CopilotAuthError> {
let github_token = {
let (github_token, domain) = {
let accounts = self.accounts.read().await;
accounts
let account = accounts
.get(account_id)
.map(|a| a.github_token.clone())
.ok_or_else(|| CopilotAuthError::AccountNotFound(account_id.to_string()))?
.ok_or_else(|| CopilotAuthError::AccountNotFound(account_id.to_string()))?;
(account.github_token.clone(), account.github_domain.clone())
};
log::info!("[CopilotAuth] 获取账号 {account_id} 的 Copilot 使用量");
let response = self
.http_client
.get(COPILOT_USAGE_URL)
.get(copilot_usage_url(&domain))
.header("Authorization", format!("token {github_token}"))
.header("Content-Type", "application/json")
.header("editor-version", COPILOT_EDITOR_VERSION)
@@ -862,7 +1004,8 @@ impl CopilotAuthManager {
log::debug!(
"[CopilotAuth] 获取账号 {account_id} 动态 API 端点失败: {e},使用默认值"
);
DEFAULT_COPILOT_API_ENDPOINT.to_string()
let domain = self.get_account_domain(account_id).await;
copilot_api_base(&domain)
}
}
}
@@ -873,24 +1016,27 @@ impl CopilotAuthManager {
match self.resolve_default_account_id().await {
Some(id) => self.get_api_endpoint(&id).await,
None => DEFAULT_COPILOT_API_ENDPOINT.to_string(),
None => {
// 无账号时回退到 github.com 的默认端点
copilot_api_base(DEFAULT_GITHUB_DOMAIN)
}
}
}
async fn fetch_and_cache_endpoint(&self, account_id: &str) -> Result<String, CopilotAuthError> {
let github_token = {
let (github_token, domain) = {
let accounts = self.accounts.read().await;
accounts
let account = accounts
.get(account_id)
.map(|a| a.github_token.clone())
.ok_or_else(|| CopilotAuthError::AccountNotFound(account_id.to_string()))?
.ok_or_else(|| CopilotAuthError::AccountNotFound(account_id.to_string()))?;
(account.github_token.clone(), account.github_domain.clone())
};
log::debug!("[CopilotAuth] 为账号 {account_id} 惰性拉取动态 API 端点");
let response = self
.http_client
.get(COPILOT_USAGE_URL)
.get(copilot_usage_url(&domain))
.header("Authorization", format!("token {github_token}"))
.header("Content-Type", "application/json")
.header("editor-version", COPILOT_EDITOR_VERSION)
@@ -918,7 +1064,7 @@ impl CopilotAuthManager {
let endpoint = match usage.endpoints {
Some(endpoints) => endpoints.api.clone(),
None => DEFAULT_COPILOT_API_ENDPOINT.to_string(),
None => copilot_api_base(&domain),
};
// 缓存端点(包括默认值),避免重复请求
@@ -1075,6 +1221,15 @@ impl CopilotAuthManager {
Self::fallback_default_account_id(&accounts)
}
/// 获取指定账号的 GitHub 域名
async fn get_account_domain(&self, account_id: &str) -> String {
let accounts = self.accounts.read().await;
accounts
.get(account_id)
.map(|a| a.github_domain.clone())
.unwrap_or_else(|| DEFAULT_GITHUB_DOMAIN.to_string())
}
async fn get_refresh_lock(&self, account_id: &str) -> Arc<Mutex<()>> {
{
let refresh_locks = self.refresh_locks.read().await;
@@ -1155,10 +1310,11 @@ impl CopilotAuthManager {
async fn fetch_user_info_with_token(
&self,
github_token: &str,
domain: &str,
) -> Result<GitHubUser, CopilotAuthError> {
let response = self
.http_client
.get(GITHUB_USER_URL)
.get(github_user_url(domain))
.header("Authorization", format!("token {github_token}"))
.header("User-Agent", COPILOT_USER_AGENT)
.header("Editor-Version", COPILOT_EDITOR_VERSION)
@@ -1185,12 +1341,13 @@ impl CopilotAuthManager {
&self,
github_token: &str,
account_id: &str,
domain: &str,
) -> Result<(), CopilotAuthError> {
log::debug!("[CopilotAuth] 获取账号 {account_id} 的 Copilot Token");
log::debug!("[CopilotAuth] 获取账号 {account_id} 的 Copilot Token (domain: {domain})");
let response = self
.http_client
.get(COPILOT_TOKEN_URL)
.get(copilot_token_url(domain))
.header("Authorization", format!("token {github_token}"))
.header("User-Agent", COPILOT_USER_AGENT)
.header("Editor-Version", COPILOT_EDITOR_VERSION)
@@ -1284,20 +1441,32 @@ impl CopilotAuthManager {
log::info!("[CopilotAuth] 执行旧格式迁移");
// 获取用户信息
match self.fetch_user_info_with_token(&legacy_token).await {
match self
.fetch_user_info_with_token(&legacy_token, DEFAULT_GITHUB_DOMAIN)
.await
{
Ok(user) => {
let account_id = user.id.to_string();
let account_id = composite_account_id(DEFAULT_GITHUB_DOMAIN, user.id);
// 尝试获取 Copilot token 验证订阅
if let Err(e) = self
.fetch_copilot_token_with_github_token(&legacy_token, &account_id)
.fetch_copilot_token_with_github_token(
&legacy_token,
&account_id,
DEFAULT_GITHUB_DOMAIN,
)
.await
{
log::warn!("[CopilotAuth] 迁移时验证 Copilot 订阅失败: {e}");
}
// 添加账号
self.add_account_internal(legacy_token, user).await?;
self.add_account_internal(
legacy_token,
user,
DEFAULT_GITHUB_DOMAIN.to_string(),
)
.await?;
self.set_migration_error(None).await;
log::info!("[CopilotAuth] 旧格式迁移完成");
@@ -1387,6 +1556,7 @@ mod tests {
login: "testuser".to_string(),
avatar_url: Some("https://example.com/avatar.png".to_string()),
authenticated_at: 1234567890,
github_domain: DEFAULT_GITHUB_DOMAIN.to_string(),
}],
default_account_id: Some("12345".to_string()),
migration_error: None,
@@ -1420,6 +1590,7 @@ mod tests {
avatar_url: Some("https://example.com/alice.png".to_string()),
},
authenticated_at: 1700000000,
github_domain: DEFAULT_GITHUB_DOMAIN.to_string(),
},
);
accounts.insert(
@@ -1432,6 +1603,7 @@ mod tests {
avatar_url: None,
},
authenticated_at: 1700000001,
github_domain: DEFAULT_GITHUB_DOMAIN.to_string(),
},
);
@@ -1479,6 +1651,7 @@ mod tests {
avatar_url: Some("https://example.com/avatar.png".to_string()),
},
authenticated_at: 1700000000,
github_domain: DEFAULT_GITHUB_DOMAIN.to_string(),
};
let account = GitHubAccount::from(&data);
@@ -1504,6 +1677,7 @@ mod tests {
avatar_url: None,
},
authenticated_at: 1700000000,
github_domain: DEFAULT_GITHUB_DOMAIN.to_string(),
},
);
accounts.insert(
@@ -1516,6 +1690,7 @@ mod tests {
avatar_url: None,
},
authenticated_at: 1700000001,
github_domain: DEFAULT_GITHUB_DOMAIN.to_string(),
},
);
@@ -1546,6 +1721,7 @@ mod tests {
avatar_url: None,
},
authenticated_at: 1700000000,
github_domain: DEFAULT_GITHUB_DOMAIN.to_string(),
},
);
}
@@ -1630,6 +1806,7 @@ mod tests {
avatar_url: None,
},
authenticated_at: 1700000000,
github_domain: DEFAULT_GITHUB_DOMAIN.to_string(),
},
);
}
@@ -1664,6 +1841,7 @@ mod tests {
avatar_url: None,
},
authenticated_at: 1700000000,
github_domain: DEFAULT_GITHUB_DOMAIN.to_string(),
},
);
}
@@ -1746,6 +1924,7 @@ mod tests {
avatar_url: None,
},
authenticated_at: 1700000000,
github_domain: DEFAULT_GITHUB_DOMAIN.to_string(),
},
);
}
@@ -1801,7 +1980,7 @@ mod tests {
let manager = CopilotAuthManager::new(temp_dir.path().to_path_buf());
let endpoint = manager.get_api_endpoint("12345").await;
assert_eq!(endpoint, DEFAULT_COPILOT_API_ENDPOINT);
assert_eq!(endpoint, copilot_api_base(DEFAULT_GITHUB_DOMAIN));
}
#[tokio::test]
@@ -1817,4 +1996,99 @@ mod tests {
other => panic!("期望 AccountNotFound 错误,实际: {other:?}"),
}
}
#[test]
fn test_normalize_github_domain() {
// 基本用法
assert_eq!(normalize_github_domain("github.com").unwrap(), "github.com");
assert_eq!(
normalize_github_domain("company.ghe.com").unwrap(),
"company.ghe.com"
);
// 剥离协议
assert_eq!(
normalize_github_domain("https://company.ghe.com").unwrap(),
"company.ghe.com"
);
assert_eq!(
normalize_github_domain("http://company.ghe.com").unwrap(),
"company.ghe.com"
);
// 小写化
assert_eq!(normalize_github_domain("GitHub.COM").unwrap(), "github.com");
assert_eq!(
normalize_github_domain("Company.GHE.Com").unwrap(),
"company.ghe.com"
);
// 剥离尾斜杠和 path
assert_eq!(
normalize_github_domain("company.ghe.com/").unwrap(),
"company.ghe.com"
);
assert_eq!(
normalize_github_domain("company.ghe.com/api/v3").unwrap(),
"company.ghe.com"
);
// 剥离 query 和 fragment
assert_eq!(
normalize_github_domain("company.ghe.com?foo=bar").unwrap(),
"company.ghe.com"
);
assert_eq!(
normalize_github_domain("company.ghe.com#section").unwrap(),
"company.ghe.com"
);
// 保留端口
assert_eq!(
normalize_github_domain("company.ghe.com:8443").unwrap(),
"company.ghe.com:8443"
);
// 拒绝 userinfo
assert!(normalize_github_domain("user@company.ghe.com").is_err());
// 拒绝空输入
assert!(normalize_github_domain("").is_err());
assert!(normalize_github_domain(" ").is_err());
}
#[test]
fn test_composite_account_id() {
// github.com 保持原格式(向后兼容)
assert_eq!(composite_account_id("github.com", 12345), "12345");
// GHES 使用复合格式
assert_eq!(
composite_account_id("company.ghe.com", 12345),
"company.ghe.com:12345"
);
// 不同 GHES 实例,相同 user ID,不冲突
assert_ne!(
composite_account_id("a.ghe.com", 1),
composite_account_id("b.ghe.com", 1)
);
}
#[test]
fn test_github_account_from_data_ghes_uses_composite_id() {
let data = GitHubAccountData {
github_token: "gho_test".to_string(),
user: GitHubUser {
login: "testuser".to_string(),
id: 99999,
avatar_url: None,
},
authenticated_at: 1700000000,
github_domain: "company.ghe.com".to_string(),
};
let account = GitHubAccount::from(&data);
assert_eq!(account.id, "company.ghe.com:99999");
}
}
@@ -3,6 +3,7 @@ import { useTranslation } from "react-i18next";
import { Button } from "@/components/ui/button";
import { Badge } from "@/components/ui/badge";
import { Label } from "@/components/ui/label";
import { Input } from "@/components/ui/input";
import {
Select,
SelectContent,
@@ -45,6 +46,19 @@ export const CopilotAuthSection: React.FC<CopilotAuthSectionProps> = ({
}) => {
const { t } = useTranslation();
const [copied, setCopied] = React.useState(false);
const [deploymentType, setDeploymentType] = React.useState<
"github.com" | "enterprise"
>("github.com");
const [enterpriseDomain, setEnterpriseDomain] = React.useState("");
// 根据部署类型计算实际的 GitHub 域名
const effectiveGithubDomain =
deploymentType === "enterprise" && enterpriseDomain.trim()
? enterpriseDomain
.trim()
.replace(/^https?:\/\//, "")
.replace(/\/$/, "")
: undefined;
const {
accounts,
@@ -63,7 +77,7 @@ export const CopilotAuthSection: React.FC<CopilotAuthSectionProps> = ({
setDefaultAccount,
cancelAuth,
logout,
} = useCopilotAuth();
} = useCopilotAuth(effectiveGithubDomain);
// 复制用户码
const copyUserCode = async () => {
@@ -113,6 +127,41 @@ export const CopilotAuthSection: React.FC<CopilotAuthSectionProps> = ({
</Badge>
</div>
{/* GitHub 部署类型选择 */}
<div className="space-y-2">
<Label className="text-sm text-muted-foreground">
{t("copilot.deploymentType", "GitHub 部署类型")}
</Label>
<Select
value={deploymentType}
onValueChange={(v) =>
setDeploymentType(v as "github.com" | "enterprise")
}
>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="github.com">
{t("copilot.deploymentGitHubCom", "GitHub.com")}
</SelectItem>
<SelectItem value="enterprise">
{t("copilot.deploymentEnterprise", "GitHub Enterprise Server")}
</SelectItem>
</SelectContent>
</Select>
{deploymentType === "enterprise" && (
<Input
placeholder={t(
"copilot.enterpriseDomainPlaceholder",
"例如:company.ghe.com",
)}
value={enterpriseDomain}
onChange={(e) => setEnterpriseDomain(e.target.value)}
/>
)}
</div>
{migrationError && (
<p className="text-sm text-amber-600 dark:text-amber-400">
{t("copilot.migrationFailed", {
@@ -179,6 +228,12 @@ export const CopilotAuthSection: React.FC<CopilotAuthSectionProps> = ({
{t("copilot.defaultAccount", "默认")}
</Badge>
)}
{account.github_domain &&
account.github_domain !== "github.com" && (
<Badge variant="outline" className="text-xs">
{account.github_domain}
</Badge>
)}
{selectedAccountId === account.id && (
<Badge variant="outline" className="text-xs">
{t("copilot.selected", "已选中")}
@@ -223,6 +278,7 @@ export const CopilotAuthSection: React.FC<CopilotAuthSectionProps> = ({
onClick={addAccount}
className="w-full"
variant="outline"
disabled={deploymentType === "enterprise" && !enterpriseDomain.trim()}
>
<Github className="mr-2 h-4 w-4" />
{t("copilot.loginWithGitHub", "使用 GitHub 登录")}
@@ -236,7 +292,10 @@ export const CopilotAuthSection: React.FC<CopilotAuthSectionProps> = ({
onClick={addAccount}
className="w-full"
variant="outline"
disabled={isAddingAccount}
disabled={
isAddingAccount ||
(deploymentType === "enterprise" && !enterpriseDomain.trim())
}
>
<Plus className="mr-2 h-4 w-4" />
{t("copilot.addAnotherAccount", "添加其他账号")}
@@ -1,8 +1,8 @@
import type { GitHubAccount } from "@/lib/api";
import { useManagedAuth } from "./useManagedAuth";
export function useCopilotAuth() {
const managedAuth = useManagedAuth("github_copilot");
export function useCopilotAuth(githubDomain?: string) {
const managedAuth = useManagedAuth("github_copilot", githubDomain);
const defaultAccount =
managedAuth.accounts.find(
(account) => account.id === managedAuth.defaultAccountId,
@@ -10,7 +10,10 @@ import type {
type PollingState = "idle" | "polling" | "success" | "error";
export function useManagedAuth(authProvider: ManagedAuthProvider) {
export function useManagedAuth(
authProvider: ManagedAuthProvider,
githubDomain?: string,
) {
const queryClient = useQueryClient();
const queryKey = ["managed-auth-status", authProvider];
@@ -52,7 +55,7 @@ export function useManagedAuth(authProvider: ManagedAuthProvider) {
}, [stopPolling]);
const startLoginMutation = useMutation({
mutationFn: () => authApi.authStartLogin(authProvider),
mutationFn: () => authApi.authStartLogin(authProvider, githubDomain),
onSuccess: async (response) => {
setDeviceCode(response);
setPollingState("polling");
@@ -87,6 +90,7 @@ export function useManagedAuth(authProvider: ManagedAuthProvider) {
const newAccount = await authApi.authPollForAccount(
authProvider,
response.device_code,
githubDomain,
);
if (newAccount) {
stopPolling();
+5 -1
View File
@@ -889,7 +889,11 @@
"retry": "Retry",
"copyCode": "Copy code",
"migrationFailed": "Legacy auth migration failed: {{error}}",
"loadModelsFailed": "Failed to load Copilot models"
"loadModelsFailed": "Failed to load Copilot models",
"deploymentType": "GitHub Deployment Type",
"deploymentGitHubCom": "GitHub.com",
"deploymentEnterprise": "GitHub Enterprise Server",
"enterpriseDomainPlaceholder": "e.g. company.ghe.com"
},
"codexOauth": {
"authStatus": "Auth status",
+5 -1
View File
@@ -889,7 +889,11 @@
"retry": "再試行",
"copyCode": "コードをコピー",
"migrationFailed": "旧認証データの移行に失敗しました: {{error}}",
"loadModelsFailed": "Copilot モデル一覧の読み込みに失敗しました"
"loadModelsFailed": "Copilot モデル一覧の読み込みに失敗しました",
"deploymentType": "GitHub デプロイメントタイプ",
"deploymentGitHubCom": "GitHub.com",
"deploymentEnterprise": "GitHub Enterprise Server",
"enterpriseDomainPlaceholder": "例: company.ghe.com"
},
"codexOauth": {
"authStatus": "認証状態",
+5 -1
View File
@@ -890,7 +890,11 @@
"retry": "重试",
"copyCode": "复制代码",
"migrationFailed": "旧认证数据迁移失败:{{error}}",
"loadModelsFailed": "加载 Copilot 模型列表失败"
"loadModelsFailed": "加载 Copilot 模型列表失败",
"deploymentType": "GitHub 部署类型",
"deploymentGitHubCom": "GitHub.com",
"deploymentEnterprise": "GitHub Enterprise Server",
"enterpriseDomainPlaceholder": "例如:company.ghe.com"
},
"codexOauth": {
"authStatus": "认证状态",
+5
View File
@@ -9,6 +9,7 @@ export interface ManagedAuthAccount {
avatar_url: string | null;
authenticated_at: number;
is_default: boolean;
github_domain: string;
}
export interface ManagedAuthStatus {
@@ -30,19 +31,23 @@ export interface ManagedAuthDeviceCodeResponse {
export async function authStartLogin(
authProvider: ManagedAuthProvider,
githubDomain?: string,
): Promise<ManagedAuthDeviceCodeResponse> {
return invoke<ManagedAuthDeviceCodeResponse>("auth_start_login", {
authProvider,
githubDomain: githubDomain || null,
});
}
export async function authPollForAccount(
authProvider: ManagedAuthProvider,
deviceCode: string,
githubDomain?: string,
): Promise<ManagedAuthAccount | null> {
return invoke<ManagedAuthAccount | null>("auth_poll_for_account", {
authProvider,
deviceCode,
githubDomain: githubDomain || null,
});
}
+2
View File
@@ -30,6 +30,8 @@ export interface GitHubAccount {
avatar_url: string | null;
/** 认证时间戳(Unix 秒) */
authenticated_at: number;
/** GitHub 域名(github.com 或 GHES 域名) */
github_domain: string;
}
/**