diff --git a/src-tauri/src/codex_config.rs b/src-tauri/src/codex_config.rs index 642e4886..b84caab7 100644 --- a/src-tauri/src/codex_config.rs +++ b/src-tauri/src/codex_config.rs @@ -9,6 +9,7 @@ use crate::error::AppError; use serde_json::Value; use std::fs; use std::path::Path; +use toml_edit::DocumentMut; /// 获取 Codex 配置目录路径 pub fn get_codex_config_dir() -> PathBuf { @@ -135,3 +136,335 @@ pub fn read_and_validate_codex_config_text() -> Result { validate_config_toml(&s)?; Ok(s) } + +/// Update a field in Codex config.toml using toml_edit (syntax-preserving). +/// +/// Supported fields: +/// - `"base_url"`: writes to `[model_providers.].base_url` if `model_provider` exists, +/// otherwise falls back to top-level `base_url`. +/// - `"model"`: writes to top-level `model` field. +/// +/// Empty value removes the field. +pub fn update_codex_toml_field(toml_str: &str, field: &str, value: &str) -> Result { + let mut doc = toml_str + .parse::() + .map_err(|e| format!("TOML parse error: {e}"))?; + + let trimmed = value.trim(); + + match field { + "base_url" => { + let model_provider = doc + .get("model_provider") + .and_then(|item| item.as_str()) + .map(str::to_string); + + if let Some(provider_key) = model_provider { + // Ensure [model_providers] table exists + if doc.get("model_providers").is_none() { + doc["model_providers"] = toml_edit::table(); + } + + if let Some(model_providers) = doc["model_providers"].as_table_mut() { + // Ensure [model_providers.] table exists + if !model_providers.contains_key(&provider_key) { + model_providers[&provider_key] = toml_edit::table(); + } + + if let Some(provider_table) = model_providers[&provider_key].as_table_mut() { + if trimmed.is_empty() { + provider_table.remove("base_url"); + } else { + provider_table["base_url"] = toml_edit::value(trimmed); + } + return Ok(doc.to_string()); + } + } + } + + // Fallback: no model_provider or structure mismatch → top-level base_url + if trimmed.is_empty() { + doc.as_table_mut().remove("base_url"); + } else { + doc["base_url"] = toml_edit::value(trimmed); + } + } + "model" => { + if trimmed.is_empty() { + doc.as_table_mut().remove("model"); + } else { + doc["model"] = toml_edit::value(trimmed); + } + } + _ => return Err(format!("unsupported field: {field}")), + } + + Ok(doc.to_string()) +} + +/// Remove `base_url` from the active model_provider section only if it matches `predicate`. +/// Also removes top-level `base_url` if it matches. +/// Used by proxy cleanup to strip local proxy URLs without touching user-configured URLs. +pub fn remove_codex_toml_base_url_if(toml_str: &str, predicate: impl Fn(&str) -> bool) -> String { + let mut doc = match toml_str.parse::() { + Ok(doc) => doc, + Err(_) => return toml_str.to_string(), + }; + + let model_provider = doc + .get("model_provider") + .and_then(|item| item.as_str()) + .map(str::to_string); + + if let Some(provider_key) = model_provider { + if let Some(model_providers) = doc + .get_mut("model_providers") + .and_then(|v| v.as_table_mut()) + { + if let Some(provider_table) = model_providers + .get_mut(provider_key.as_str()) + .and_then(|v| v.as_table_mut()) + { + let should_remove = provider_table + .get("base_url") + .and_then(|item| item.as_str()) + .map(&predicate) + .unwrap_or(false); + if should_remove { + provider_table.remove("base_url"); + } + } + } + } + + // Fallback: also clean up top-level base_url if it matches + let should_remove_root = doc + .get("base_url") + .and_then(|item| item.as_str()) + .map(&predicate) + .unwrap_or(false); + if should_remove_root { + doc.as_table_mut().remove("base_url"); + } + + doc.to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn base_url_writes_into_correct_model_provider_section() { + let input = r#"model_provider = "any" +model = "gpt-5.1-codex" + +[model_providers.any] +name = "any" +wire_api = "responses" +"#; + + let result = update_codex_toml_field(input, "base_url", "https://example.com/v1").unwrap(); + let parsed: toml::Value = toml::from_str(&result).unwrap(); + + let base_url = parsed + .get("model_providers") + .and_then(|v| v.get("any")) + .and_then(|v| v.get("base_url")) + .and_then(|v| v.as_str()) + .expect("base_url should be in model_providers.any"); + assert_eq!(base_url, "https://example.com/v1"); + + // Should NOT have top-level base_url + assert!(parsed.get("base_url").is_none()); + + // wire_api preserved + let wire_api = parsed + .get("model_providers") + .and_then(|v| v.get("any")) + .and_then(|v| v.get("wire_api")) + .and_then(|v| v.as_str()); + assert_eq!(wire_api, Some("responses")); + } + + #[test] + fn base_url_creates_section_when_missing() { + let input = r#"model_provider = "custom" +model = "gpt-4" +"#; + + let result = update_codex_toml_field(input, "base_url", "https://custom.api/v1").unwrap(); + let parsed: toml::Value = toml::from_str(&result).unwrap(); + + let base_url = parsed + .get("model_providers") + .and_then(|v| v.get("custom")) + .and_then(|v| v.get("base_url")) + .and_then(|v| v.as_str()) + .expect("should create section and set base_url"); + assert_eq!(base_url, "https://custom.api/v1"); + } + + #[test] + fn base_url_falls_back_to_top_level_without_model_provider() { + let input = r#"model = "gpt-4" +"#; + + let result = update_codex_toml_field(input, "base_url", "https://fallback.api/v1").unwrap(); + let parsed: toml::Value = toml::from_str(&result).unwrap(); + + let base_url = parsed + .get("base_url") + .and_then(|v| v.as_str()) + .expect("should set top-level base_url"); + assert_eq!(base_url, "https://fallback.api/v1"); + } + + #[test] + fn clearing_base_url_removes_only_from_correct_section() { + let input = r#"model_provider = "any" + +[model_providers.any] +name = "any" +base_url = "https://old.api/v1" +wire_api = "responses" + +[mcp_servers.context7] +command = "npx" +"#; + + let result = update_codex_toml_field(input, "base_url", "").unwrap(); + let parsed: toml::Value = toml::from_str(&result).unwrap(); + + // base_url removed from model_providers.any + let any_section = parsed + .get("model_providers") + .and_then(|v| v.get("any")) + .expect("model_providers.any should exist"); + assert!(any_section.get("base_url").is_none()); + + // wire_api preserved + assert_eq!( + any_section.get("wire_api").and_then(|v| v.as_str()), + Some("responses") + ); + + // mcp_servers untouched + assert!(parsed.get("mcp_servers").is_some()); + } + + #[test] + fn model_field_operates_on_top_level() { + let input = r#"model_provider = "any" +model = "gpt-4" + +[model_providers.any] +name = "any" +"#; + + let result = update_codex_toml_field(input, "model", "gpt-5").unwrap(); + let parsed: toml::Value = toml::from_str(&result).unwrap(); + assert_eq!(parsed.get("model").and_then(|v| v.as_str()), Some("gpt-5")); + + // Clear model + let result2 = update_codex_toml_field(&result, "model", "").unwrap(); + let parsed2: toml::Value = toml::from_str(&result2).unwrap(); + assert!(parsed2.get("model").is_none()); + } + + #[test] + fn preserves_comments_and_whitespace() { + let input = r#"# My Codex config +model_provider = "any" +model = "gpt-4" + +# Provider section +[model_providers.any] +name = "any" +base_url = "https://old.api/v1" +"#; + + let result = update_codex_toml_field(input, "base_url", "https://new.api/v1").unwrap(); + + // Comments should be preserved + assert!(result.contains("# My Codex config")); + assert!(result.contains("# Provider section")); + } + + #[test] + fn does_not_misplace_when_profiles_section_follows() { + let input = r#"model_provider = "any" + +[model_providers.any] +name = "any" +base_url = "https://old.api/v1" + +[profiles.default] +model = "gpt-4" +"#; + + let result = update_codex_toml_field(input, "base_url", "https://new.api/v1").unwrap(); + let parsed: toml::Value = toml::from_str(&result).unwrap(); + + // base_url in correct section + let base_url = parsed + .get("model_providers") + .and_then(|v| v.get("any")) + .and_then(|v| v.get("base_url")) + .and_then(|v| v.as_str()); + assert_eq!(base_url, Some("https://new.api/v1")); + + // profiles section untouched + let profile_model = parsed + .get("profiles") + .and_then(|v| v.get("default")) + .and_then(|v| v.get("model")) + .and_then(|v| v.as_str()); + assert_eq!(profile_model, Some("gpt-4")); + } + + #[test] + fn remove_base_url_if_predicate() { + let input = r#"model_provider = "any" + +[model_providers.any] +name = "any" +base_url = "http://127.0.0.1:5000/v1" +wire_api = "responses" +"#; + + let result = + remove_codex_toml_base_url_if(input, |url| url.starts_with("http://127.0.0.1")); + let parsed: toml::Value = toml::from_str(&result).unwrap(); + + let any_section = parsed + .get("model_providers") + .and_then(|v| v.get("any")) + .unwrap(); + assert!(any_section.get("base_url").is_none()); + assert_eq!( + any_section.get("wire_api").and_then(|v| v.as_str()), + Some("responses") + ); + } + + #[test] + fn remove_base_url_if_keeps_non_matching() { + let input = r#"model_provider = "any" + +[model_providers.any] +base_url = "https://production.api/v1" +"#; + + let result = + remove_codex_toml_base_url_if(input, |url| url.starts_with("http://127.0.0.1")); + let parsed: toml::Value = toml::from_str(&result).unwrap(); + + let base_url = parsed + .get("model_providers") + .and_then(|v| v.get("any")) + .and_then(|v| v.get("base_url")) + .and_then(|v| v.as_str()); + assert_eq!(base_url, Some("https://production.api/v1")); + } +} diff --git a/src-tauri/src/services/proxy.rs b/src-tauri/src/services/proxy.rs index ccb755fc..02ce4a58 100644 --- a/src-tauri/src/services/proxy.rs +++ b/src-tauri/src/services/proxy.rs @@ -19,7 +19,7 @@ use tokio::sync::RwLock; /// 用于接管 Live 配置时的占位符(避免客户端提示缺少 key,同时不泄露真实 Token) const PROXY_TOKEN_PLACEHOLDER: &str = "PROXY_MANAGED"; -/// 代理接管模式下需要从 Claude Live 配置中移除的“模型覆盖”字段。 +/// 代理接管模式下需要从 Claude Live 配置中移除的"模型覆盖"字段。 /// /// 原因:接管模式切换供应商时不会写回 Live 配置,如果保留这些字段, /// Claude Code 会继续以旧模型名发起请求,导致新供应商不支持时失败。 @@ -52,7 +52,7 @@ impl ProxyService { /// 清理接管模式下 Claude Live 配置中的模型覆盖字段。 /// - /// 这可以避免“接管开启后切换供应商仍使用旧模型”的问题。 + /// 这可以避免"接管开启后切换供应商仍使用旧模型"的问题。 /// 注意:此方法不会修改 Token/Base URL 的接管占位符,仅移除模型字段。 pub fn cleanup_claude_model_overrides_in_live(&self) -> Result<(), String> { let mut config = self.read_claude_live()?; @@ -1162,7 +1162,7 @@ impl ProxyService { ) -> Result<(), String> { let app_type_str = app_type.as_str(); - // 1) 优先从 Live 备份恢复(这是“原始 Live”的唯一可靠来源) + // 1) 优先从 Live 备份恢复(这是"原始 Live"的唯一可靠来源) let backup = self .db .get_live_backup(app_type_str) @@ -1181,7 +1181,7 @@ impl ProxyService { return Ok(()); } - // 2.1) 优先从 SSOT(当前供应商)重建 Live(比“清理字段”更可用) + // 2.1) 优先从 SSOT(当前供应商)重建 Live(比"清理字段"更可用) match self.restore_live_from_ssot_for_app(app_type) { Ok(true) => { log::info!("{app_type_str} Live 配置已从 SSOT 恢复(无备份兜底)"); @@ -1358,51 +1358,9 @@ impl ProxyService { Ok(()) } + /// Remove local proxy base_url from TOML(委托给 codex_config 共享实现) fn remove_local_toml_base_url(toml_str: &str) -> String { - use toml_edit::DocumentMut; - - let mut doc = match toml_str.parse::() { - Ok(doc) => doc, - Err(_) => return toml_str.to_string(), - }; - - let model_provider = doc - .get("model_provider") - .and_then(|item| item.as_str()) - .map(str::to_string); - - if let Some(provider_key) = model_provider { - if let Some(model_providers) = doc - .get_mut("model_providers") - .and_then(|v| v.as_table_mut()) - { - if let Some(provider_table) = model_providers - .get_mut(provider_key.as_str()) - .and_then(|v| v.as_table_mut()) - { - let should_remove = provider_table - .get("base_url") - .and_then(|item| item.as_str()) - .map(Self::is_local_proxy_url) - .unwrap_or(false); - if should_remove { - provider_table.remove("base_url"); - } - } - } - } - - // 兜底:清理顶层 base_url(仅当它看起来像本地代理地址) - let should_remove_root = doc - .get("base_url") - .and_then(|item| item.as_str()) - .map(Self::is_local_proxy_url) - .unwrap_or(false); - if should_remove_root { - doc.as_table_mut().remove("base_url"); - } - - doc.to_string() + crate::codex_config::remove_codex_toml_base_url_if(toml_str, Self::is_local_proxy_url) } fn cleanup_gemini_takeover_placeholders_in_live(&self) -> Result<(), String> { @@ -1459,7 +1417,7 @@ impl ProxyService { Ok(()) } - /// 检测 Live 配置是否处于“被接管”的残留状态 + /// 检测 Live 配置是否处于"被接管"的残留状态 /// /// 用于兜底处理:当数据库备份缺失但 Live 文件已经写成代理占位符时, /// 启动流程可以据此触发恢复逻辑。 @@ -1675,49 +1633,10 @@ impl ProxyService { // ==================== Live 配置读写辅助方法 ==================== - /// 更新 TOML 字符串中的 base_url + /// 更新 TOML 字符串中的 base_url(委托给 codex_config 共享实现) fn update_toml_base_url(toml_str: &str, new_url: &str) -> String { - use toml_edit::DocumentMut; - - let mut doc = match toml_str.parse::() { - Ok(doc) => doc, - Err(_) => return toml_str.to_string(), - }; - - // Codex 的 config.toml 通常是: - // model_provider = "any" - // - // [model_providers.any] - // base_url = "https://.../v1" - // - // 所以接管时要“精准”修改当前 model_provider 对应的 model_providers..base_url, - // 避免写错位置导致 Codex 仍然走旧地址。 - let model_provider = doc - .get("model_provider") - .and_then(|item| item.as_str()) - .map(str::to_string); - - if let Some(provider_key) = model_provider { - if doc.get("model_providers").is_none() { - doc["model_providers"] = toml_edit::table(); - } - - if let Some(model_providers) = doc["model_providers"].as_table_mut() { - if !model_providers.contains_key(&provider_key) { - model_providers[&provider_key] = toml_edit::table(); - } - - if let Some(provider_table) = model_providers[&provider_key].as_table_mut() { - provider_table["base_url"] = toml_edit::value(new_url); - return doc.to_string(); - } - } - } - - // 兜底:如果没有 model_provider 或结构不符合预期,则退回修改顶层 base_url。 - doc["base_url"] = toml_edit::value(new_url); - - doc.to_string() + crate::codex_config::update_codex_toml_field(toml_str, "base_url", new_url) + .unwrap_or_else(|_| toml_str.to_string()) } fn read_claude_live(&self) -> Result { @@ -2228,7 +2147,7 @@ model = "gpt-5.1-codex" db.set_current_provider("claude", "a") .expect("set current provider"); - // 模拟“已接管”状态:存在 Live 备份(内容不重要,会被热切换更新) + // 模拟"已接管"状态:存在 Live 备份(内容不重要,会被热切换更新) db.save_live_backup("claude", "{\"env\":{}}") .await .expect("seed live backup"); diff --git a/src/components/providers/AddProviderDialog.tsx b/src/components/providers/AddProviderDialog.tsx index a40befdc..f8fc8e5b 100644 --- a/src/components/providers/AddProviderDialog.tsx +++ b/src/components/providers/AddProviderDialog.tsx @@ -17,6 +17,7 @@ import { UniversalProviderPanel } from "@/components/universal"; import { providerPresets } from "@/config/claudeProviderPresets"; import { codexProviderPresets } from "@/config/codexProviderPresets"; import { geminiProviderPresets } from "@/config/geminiProviderPresets"; +import { extractCodexBaseUrl } from "@/utils/providerConfigUtils"; import type { OpenClawSuggestedDefaults } from "@/config/openclawProviderPresets"; import type { UniversalProviderPreset } from "@/config/universalProviderPresets"; @@ -179,11 +180,9 @@ export function AddProviderDialog({ } else if (appId === "codex") { const config = parsedConfig.config as string | undefined; if (config) { - const baseUrlMatch = config.match( - /base_url\s*=\s*["']([^"']+)["']/, - ); - if (baseUrlMatch?.[1]) { - addUrl(baseUrlMatch[1]); + const extractedBaseUrl = extractCodexBaseUrl(config); + if (extractedBaseUrl) { + addUrl(extractedBaseUrl); } } } else if (appId === "gemini") { diff --git a/src/components/providers/ProviderCard.tsx b/src/components/providers/ProviderCard.tsx index 96d8b460..3301fbc9 100644 --- a/src/components/providers/ProviderCard.tsx +++ b/src/components/providers/ProviderCard.tsx @@ -13,6 +13,7 @@ import { ProviderIcon } from "@/components/ProviderIcon"; import UsageFooter from "@/components/UsageFooter"; import { ProviderHealthBadge } from "@/components/providers/ProviderHealthBadge"; import { FailoverPriorityBadge } from "@/components/providers/FailoverPriorityBadge"; +import { extractCodexBaseUrl } from "@/utils/providerConfigUtils"; import { useProviderHealth } from "@/lib/query/failover"; import { useUsageQuery } from "@/lib/query/queries"; @@ -76,9 +77,9 @@ const extractApiUrl = (provider: Provider, fallbackText: string) => { const baseUrl = (config as Record)?.config; if (typeof baseUrl === "string" && baseUrl.includes("base_url")) { - const match = baseUrl.match(/base_url\s*=\s*['"]([^'"]+)['"]/); - if (match?.[1]) { - return match[1]; + const extractedBaseUrl = extractCodexBaseUrl(baseUrl); + if (extractedBaseUrl) { + return extractedBaseUrl; } } } diff --git a/src/components/providers/forms/OpenClawFormFields.tsx b/src/components/providers/forms/OpenClawFormFields.tsx index 8c35e0df..314f7024 100644 --- a/src/components/providers/forms/OpenClawFormFields.tsx +++ b/src/components/providers/forms/OpenClawFormFields.tsx @@ -383,7 +383,9 @@ export function OpenClawFormFields({ className="flex items-center gap-1.5 cursor-pointer select-none" > { const current = model.input ?? ["text"]; const next = checked diff --git a/src/components/providers/forms/hooks/useCodexCommonConfig.ts b/src/components/providers/forms/hooks/useCodexCommonConfig.ts index fb65053d..228ab419 100644 --- a/src/components/providers/forms/hooks/useCodexCommonConfig.ts +++ b/src/components/providers/forms/hooks/useCodexCommonConfig.ts @@ -285,7 +285,13 @@ export function useCodexCommonConfig({ isUpdatingFromCommonConfig.current = false; }, 0); }, - [codexConfig, commonConfigSnippet, onConfigChange, parseCommonConfigSnippet, t], + [ + codexConfig, + commonConfigSnippet, + onConfigChange, + parseCommonConfigSnippet, + t, + ], ); // 处理通用配置片段变化 diff --git a/src/components/providers/forms/hooks/useOpenclawFormState.ts b/src/components/providers/forms/hooks/useOpenclawFormState.ts index b7652570..7f32424c 100644 --- a/src/components/providers/forms/hooks/useOpenclawFormState.ts +++ b/src/components/providers/forms/hooks/useOpenclawFormState.ts @@ -171,18 +171,15 @@ export function useOpenclawFormState({ [updateOpenclawConfig], ); - const resetOpenclawState = useCallback( - (config?: OpenClawProviderConfig) => { - setOpenclawProviderKey(""); - setOpenclawBaseUrl(config?.baseUrl || ""); - setOpenclawApiKey(config?.apiKey || ""); - setOpenclawApi(config?.api || "openai-completions"); - setOpenclawModels(config?.models || []); - const ua = config?.headers ? "User-Agent" in config.headers : false; - setOpenclawUserAgent(ua); - }, - [], - ); + const resetOpenclawState = useCallback((config?: OpenClawProviderConfig) => { + setOpenclawProviderKey(""); + setOpenclawBaseUrl(config?.baseUrl || ""); + setOpenclawApiKey(config?.apiKey || ""); + setOpenclawApi(config?.api || "openai-completions"); + setOpenclawModels(config?.models || []); + const ua = config?.headers ? "User-Agent" in config.headers : false; + setOpenclawUserAgent(ua); + }, []); return { openclawProviderKey, diff --git a/src/components/providers/forms/hooks/useSpeedTestEndpoints.ts b/src/components/providers/forms/hooks/useSpeedTestEndpoints.ts index ab9f54b5..c503289c 100644 --- a/src/components/providers/forms/hooks/useSpeedTestEndpoints.ts +++ b/src/components/providers/forms/hooks/useSpeedTestEndpoints.ts @@ -3,6 +3,7 @@ import type { AppId } from "@/lib/api"; import type { ProviderPreset } from "@/config/claudeProviderPresets"; import type { CodexProviderPreset } from "@/config/codexProviderPresets"; import type { ProviderMeta, EndpointCandidate } from "@/types"; +import { extractCodexBaseUrl } from "@/utils/providerConfigUtils"; type PresetEntry = { id: string; @@ -128,10 +129,9 @@ export function useSpeedTestEndpoints({ } | undefined; const configStr = initialCodexConfig?.config ?? ""; - // 从 TOML 中提取 base_url - const match = /base_url\s*=\s*["']([^"']+)["']/i.exec(configStr); - if (match?.[1]) { - add(match[1]); + const extractedBaseUrl = extractCodexBaseUrl(configStr); + if (extractedBaseUrl) { + add(extractedBaseUrl); } // 3. 预设中的 endpointCandidates @@ -141,11 +141,9 @@ export function useSpeedTestEndpoints({ const preset = entry.preset as CodexProviderPreset; // 添加预设自己的 baseUrl const presetConfig = preset.config || ""; - const presetMatch = /base_url\s*=\s*["']([^"']+)["']/i.exec( - presetConfig, - ); - if (presetMatch?.[1]) { - add(presetMatch[1]); + const presetBaseUrl = extractCodexBaseUrl(presetConfig); + if (presetBaseUrl) { + add(presetBaseUrl); } // 添加预设的候选端点 if (preset.endpointCandidates) { diff --git a/src/utils/providerConfigUtils.ts b/src/utils/providerConfigUtils.ts index b35dcabe..84f5e941 100644 --- a/src/utils/providerConfigUtils.ts +++ b/src/utils/providerConfigUtils.ts @@ -1,7 +1,7 @@ // 供应商配置处理工具函数 import type { TemplateValueConfig } from "../config/claudeProviderPresets"; -import { normalizeQuotes, normalizeTomlText } from "@/utils/textNormalization"; +import { normalizeTomlText } from "@/utils/textNormalization"; import { parse as parseToml, stringify as stringifyToml } from "smol-toml"; const isPlainObject = (value: unknown): value is Record => { @@ -414,17 +414,234 @@ export const hasTomlCommonConfigSnippet = ( // ========== Codex base_url utils ========== +const TOML_SECTION_HEADER_PATTERN = /^\s*\[([^\]\r\n]+)\]\s*$/; +const TOML_BASE_URL_PATTERN = + /^\s*base_url\s*=\s*(["'])([^"'\r\n]+)\1\s*(?:#.*)?$/; +const TOML_MODEL_PATTERN = /^\s*model\s*=\s*(["'])([^"'\r\n]+)\1\s*(?:#.*)?$/; +const TOML_MODEL_PROVIDER_LINE_PATTERN = + /^\s*model_provider\s*=\s*(["'])([^"'\r\n]+)\1\s*(?:#.*)?$/; +const TOML_MODEL_PROVIDER_PATTERN = + /^\s*model_provider\s*=\s*(["'])([^"'\r\n]+)\1\s*(?:#.*)?$/m; + +interface TomlSectionRange { + bodyEndIndex: number; + bodyStartIndex: number; +} + +interface TomlAssignmentMatch { + index: number; + sectionName?: string; + value: string; +} + +const finalizeTomlText = (lines: string[]): string => + lines + .join("\n") + .replace(/\n{3,}/g, "\n\n") + .replace(/^\n+/, ""); + +const getTomlSectionRange = ( + lines: string[], + sectionName: string, +): TomlSectionRange | undefined => { + let headerLineIndex = -1; + + for (let index = 0; index < lines.length; index += 1) { + const match = lines[index].match(TOML_SECTION_HEADER_PATTERN); + if (!match) { + continue; + } + + if (headerLineIndex === -1) { + if (match[1] === sectionName) { + headerLineIndex = index; + } + continue; + } + + return { + bodyStartIndex: headerLineIndex + 1, + bodyEndIndex: index, + }; + } + + if (headerLineIndex === -1) { + return undefined; + } + + return { + bodyStartIndex: headerLineIndex + 1, + bodyEndIndex: lines.length, + }; +}; + +const getTopLevelEndIndex = (lines: string[]): number => { + const firstSectionIndex = lines.findIndex((line) => + TOML_SECTION_HEADER_PATTERN.test(line), + ); + return firstSectionIndex === -1 ? lines.length : firstSectionIndex; +}; + +const getTomlSectionInsertIndex = ( + lines: string[], + sectionRange: TomlSectionRange, +): number => { + let insertIndex = sectionRange.bodyEndIndex; + while ( + insertIndex > sectionRange.bodyStartIndex && + lines[insertIndex - 1].trim() === "" + ) { + insertIndex -= 1; + } + return insertIndex; +}; + +const getCodexModelProviderName = (configText: string): string | undefined => { + const match = configText.match(TOML_MODEL_PROVIDER_PATTERN); + const providerName = match?.[2]?.trim(); + return providerName || undefined; +}; + +const getCodexProviderSectionName = ( + configText: string, +): string | undefined => { + const providerName = getCodexModelProviderName(configText); + return providerName ? `model_providers.${providerName}` : undefined; +}; + +const findTomlAssignmentInRange = ( + lines: string[], + pattern: RegExp, + startIndex: number, + endIndex: number, + sectionName?: string, +): TomlAssignmentMatch | undefined => { + for (let index = startIndex; index < endIndex; index += 1) { + const match = lines[index].match(pattern); + if (match?.[2]) { + return { + index, + sectionName, + value: match[2], + }; + } + } + + return undefined; +}; + +const findTomlAssignments = ( + lines: string[], + pattern: RegExp, +): TomlAssignmentMatch[] => { + const assignments: TomlAssignmentMatch[] = []; + let currentSectionName: string | undefined; + + lines.forEach((line, index) => { + const sectionMatch = line.match(TOML_SECTION_HEADER_PATTERN); + if (sectionMatch) { + currentSectionName = sectionMatch[1]; + return; + } + + const match = line.match(pattern); + if (!match?.[2]) { + return; + } + + assignments.push({ + index, + sectionName: currentSectionName, + value: match[2], + }); + }); + + return assignments; +}; + +const isMcpServerSection = (sectionName?: string): boolean => + sectionName === "mcp_servers" || + sectionName?.startsWith("mcp_servers.") === true; + +const isOtherProviderSection = ( + sectionName: string | undefined, + targetSectionName: string | undefined, +): boolean => + Boolean( + sectionName && + sectionName !== targetSectionName && + (sectionName === "model_providers" || + sectionName.startsWith("model_providers.")), + ); + +const getRecoverableBaseUrlAssignments = ( + assignments: TomlAssignmentMatch[], + targetSectionName: string | undefined, +): TomlAssignmentMatch[] => + assignments.filter( + ({ sectionName }) => + sectionName !== targetSectionName && + !isMcpServerSection(sectionName) && + !isOtherProviderSection(sectionName, targetSectionName), + ); + +const getTopLevelModelProviderLineIndex = (lines: string[]): number => { + const topLevelEndIndex = getTopLevelEndIndex(lines); + + for (let index = 0; index < topLevelEndIndex; index += 1) { + if (TOML_MODEL_PROVIDER_LINE_PATTERN.test(lines[index])) { + return index; + } + } + + return -1; +}; + // 从 Codex 的 TOML 配置文本中提取 base_url(支持单/双引号) export const extractCodexBaseUrl = ( configText: string | undefined | null, ): string | undefined => { try { const raw = typeof configText === "string" ? configText : ""; - // 归一化中文/全角引号,避免正则提取失败 - const text = normalizeQuotes(raw); + const text = normalizeTomlText(raw); if (!text) return undefined; - const m = text.match(/base_url\s*=\s*(['"])([^'\"]+)\1/); - return m && m[2] ? m[2] : undefined; + + const lines = text.split("\n"); + const targetSectionName = getCodexProviderSectionName(text); + + if (targetSectionName) { + const sectionRange = getTomlSectionRange(lines, targetSectionName); + if (sectionRange) { + const match = findTomlAssignmentInRange( + lines, + TOML_BASE_URL_PATTERN, + sectionRange.bodyStartIndex, + sectionRange.bodyEndIndex, + targetSectionName, + ); + if (match?.value) { + return match.value; + } + } + } + + const topLevelMatch = findTomlAssignmentInRange( + lines, + TOML_BASE_URL_PATTERN, + 0, + getTopLevelEndIndex(lines), + ); + if (topLevelMatch?.value) { + return topLevelMatch.value; + } + + const fallbackAssignments = getRecoverableBaseUrlAssignments( + findTomlAssignments(lines, TOML_BASE_URL_PATTERN), + targetSectionName, + ); + return fallbackAssignments.length === 1 + ? fallbackAssignments[0].value + : undefined; } catch { return undefined; } @@ -451,36 +668,107 @@ export const setCodexBaseUrl = ( baseUrl: string, ): string => { const trimmed = baseUrl.trim(); - // 归一化原文本中的引号(既能匹配,也能输出稳定格式) - const normalizedText = normalizeQuotes(configText); + const normalizedText = normalizeTomlText(configText); + const lines = normalizedText ? normalizedText.split("\n") : []; + const targetSectionName = getCodexProviderSectionName(normalizedText); + const allAssignments = findTomlAssignments(lines, TOML_BASE_URL_PATTERN); + const recoverableAssignments = getRecoverableBaseUrlAssignments( + allAssignments, + targetSectionName, + ); - // 允许清空:当 baseUrl 为空时,移除 base_url 行 if (!trimmed) { if (!normalizedText) return normalizedText; - const next = normalizedText - .split("\n") - .filter((line) => !/^\s*base_url\s*=/.test(line)) - .join("\n") - // 避免移除后留下过多空行 - .replace(/\n{3,}/g, "\n\n") - // 避免开头出现空行 - .replace(/^\n+/, ""); - return next; + + if (targetSectionName) { + const sectionRange = getTomlSectionRange(lines, targetSectionName); + const targetMatch = sectionRange + ? findTomlAssignmentInRange( + lines, + TOML_BASE_URL_PATTERN, + sectionRange.bodyStartIndex, + sectionRange.bodyEndIndex, + targetSectionName, + ) + : undefined; + + if (targetMatch) { + lines.splice(targetMatch.index, 1); + return finalizeTomlText(lines); + } + } + + if (recoverableAssignments.length === 1) { + lines.splice(recoverableAssignments[0].index, 1); + return finalizeTomlText(lines); + } + + return finalizeTomlText(lines); } const normalizedUrl = trimmed.replace(/\s+/g, ""); const replacementLine = `base_url = "${normalizedUrl}"`; - const pattern = /base_url\s*=\s*(["'])([^"']+)\1/; - if (pattern.test(normalizedText)) { - return normalizedText.replace(pattern, replacementLine); + if (targetSectionName) { + let targetSectionRange = getTomlSectionRange(lines, targetSectionName); + const targetMatch = targetSectionRange + ? findTomlAssignmentInRange( + lines, + TOML_BASE_URL_PATTERN, + targetSectionRange.bodyStartIndex, + targetSectionRange.bodyEndIndex, + targetSectionName, + ) + : undefined; + + if (targetMatch) { + lines[targetMatch.index] = replacementLine; + return finalizeTomlText(lines); + } + + if (recoverableAssignments.length === 1) { + lines.splice(recoverableAssignments[0].index, 1); + targetSectionRange = getTomlSectionRange(lines, targetSectionName); + } + + if (targetSectionRange) { + const insertIndex = getTomlSectionInsertIndex(lines, targetSectionRange); + lines.splice(insertIndex, 0, replacementLine); + return finalizeTomlText(lines); + } + + if (lines.length > 0 && lines[lines.length - 1].trim() !== "") { + lines.push(""); + } + lines.push(`[${targetSectionName}]`, replacementLine); + return finalizeTomlText(lines); } - const prefix = - normalizedText && !normalizedText.endsWith("\n") - ? `${normalizedText}\n` - : normalizedText; - return `${prefix}${replacementLine}\n`; + const topLevelEndIndex = getTopLevelEndIndex(lines); + const topLevelMatch = findTomlAssignmentInRange( + lines, + TOML_BASE_URL_PATTERN, + 0, + topLevelEndIndex, + ); + if (topLevelMatch) { + lines[topLevelMatch.index] = replacementLine; + return finalizeTomlText(lines); + } + + const modelProviderIndex = getTopLevelModelProviderLineIndex(lines); + if (modelProviderIndex !== -1) { + lines.splice(modelProviderIndex + 1, 0, replacementLine); + return finalizeTomlText(lines); + } + + if (lines.length === 0) { + return `${replacementLine}\n`; + } + + const insertIndex = topLevelEndIndex; + lines.splice(insertIndex, 0, replacementLine); + return finalizeTomlText(lines); }; // ========== Codex model name utils ========== @@ -491,13 +779,16 @@ export const extractCodexModelName = ( ): string | undefined => { try { const raw = typeof configText === "string" ? configText : ""; - // 归一化中文/全角引号,避免正则提取失败 - const text = normalizeQuotes(raw); + const text = normalizeTomlText(raw); if (!text) return undefined; - - // 匹配 model = "xxx" 或 model = 'xxx' - const m = text.match(/^model\s*=\s*(['"])([^'"]+)\1/m); - return m && m[2] ? m[2] : undefined; + const lines = text.split("\n"); + const topLevelMatch = findTomlAssignmentInRange( + lines, + TOML_MODEL_PATTERN, + 0, + getTopLevelEndIndex(lines), + ); + return topLevelMatch?.value; } catch { return undefined; } @@ -509,47 +800,40 @@ export const setCodexModelName = ( modelName: string, ): string => { const trimmed = modelName.trim(); - // 归一化原文本中的引号(既能匹配,也能输出稳定格式) - const normalizedText = normalizeQuotes(configText); + const normalizedText = normalizeTomlText(configText); + const lines = normalizedText ? normalizedText.split("\n") : []; + const topLevelEndIndex = getTopLevelEndIndex(lines); + const topLevelMatch = findTomlAssignmentInRange( + lines, + TOML_MODEL_PATTERN, + 0, + topLevelEndIndex, + ); - // 允许清空:当 modelName 为空时,移除 model 行 if (!trimmed) { if (!normalizedText) return normalizedText; - const next = normalizedText - .split("\n") - .filter((line) => !/^\s*model\s*=/.test(line)) - .join("\n") - .replace(/\n{3,}/g, "\n\n") - .replace(/^\n+/, ""); - return next; + if (topLevelMatch) { + lines.splice(topLevelMatch.index, 1); + } + return finalizeTomlText(lines); } const replacementLine = `model = "${trimmed}"`; - const pattern = /^model\s*=\s*["']([^"']+)["']/m; - - if (pattern.test(normalizedText)) { - return normalizedText.replace(pattern, replacementLine); + if (topLevelMatch) { + lines[topLevelMatch.index] = replacementLine; + return finalizeTomlText(lines); } - // 如果不存在 model 字段,尝试在 model_provider 之后插入 - // 如果 model_provider 也不存在,则插入到开头 - const providerPattern = /^model_provider\s*=\s*["'][^"']+["']/m; - const match = normalizedText.match(providerPattern); - - if (match && match.index !== undefined) { - // 在 model_provider 行之后插入 - const endOfLine = normalizedText.indexOf("\n", match.index); - if (endOfLine !== -1) { - return ( - normalizedText.slice(0, endOfLine + 1) + - replacementLine + - "\n" + - normalizedText.slice(endOfLine + 1) - ); - } + const modelProviderIndex = getTopLevelModelProviderLineIndex(lines); + if (modelProviderIndex !== -1) { + lines.splice(modelProviderIndex + 1, 0, replacementLine); + return finalizeTomlText(lines); } - // 在文件开头插入 - const lines = normalizedText.split("\n"); - return `${replacementLine}\n${lines.join("\n")}`; + if (lines.length === 0) { + return `${replacementLine}\n`; + } + + lines.splice(topLevelEndIndex, 0, replacementLine); + return finalizeTomlText(lines); }; diff --git a/tests/utils/providerConfigUtils.codex.test.ts b/tests/utils/providerConfigUtils.codex.test.ts index 13dfa2ae..28e741ab 100644 --- a/tests/utils/providerConfigUtils.codex.test.ts +++ b/tests/utils/providerConfigUtils.codex.test.ts @@ -22,17 +22,21 @@ describe("Codex TOML utils", () => { expect(extractCodexModelName(output)).toBe("gpt-5-codex"); }); - it("removes model line when set to empty", () => { + it("removes only the top-level model line when set to empty", () => { const input = [ 'model_provider = "openai"', 'base_url = "https://api.example.com/v1"', 'model = "gpt-5-codex"', "", + "[profiles.default]", + 'model = "profile-model"', + "", ].join("\n"); const output = setCodexModelName(input, ""); - expect(output).not.toMatch(/^\s*model\s*=/m); + expect(output).not.toMatch(/^model\s*=\s*"gpt-5-codex"$/m); + expect(output).toMatch(/^\[profiles\.default\]\nmodel = "profile-model"$/m); expect(extractCodexModelName(output)).toBeUndefined(); expect(extractCodexBaseUrl(output)).toBe("https://api.example.com/v1"); }); @@ -51,5 +55,97 @@ describe("Codex TOML utils", () => { const output2 = setCodexModelName(output1, " new-model \n"); expect(extractCodexModelName(output2)).toBe("new-model"); }); -}); + it("reads and writes base_url in the active provider section", () => { + const input = [ + 'model_provider = "custom"', + 'model = "gpt-5.4"', + "", + "[model_providers.custom]", + 'name = "custom"', + 'wire_api = "responses"', + "", + "[profiles.default]", + 'approval_policy = "never"', + "", + ].join("\n"); + + const output = setCodexBaseUrl(input, "https://api.example.com/v1"); + + expect(output).toContain( + '[model_providers.custom]\nname = "custom"\nwire_api = "responses"\nbase_url = "https://api.example.com/v1"', + ); + expect(extractCodexBaseUrl(output)).toBe("https://api.example.com/v1"); + }); + + it("recovers a single misplaced base_url from another section", () => { + const input = [ + 'model_provider = "custom"', + 'model = "gpt-5.4"', + "", + "[model_providers.custom]", + 'name = "custom"', + 'wire_api = "responses"', + "", + "[profiles.default]", + 'approval_policy = "never"', + 'base_url = "https://wrong.example/v1"', + "", + ].join("\n"); + + expect(extractCodexBaseUrl(input)).toBe("https://wrong.example/v1"); + + const output = setCodexBaseUrl(input, "https://fixed.example/v1"); + + expect(output).toContain( + '[model_providers.custom]\nname = "custom"\nwire_api = "responses"\nbase_url = "https://fixed.example/v1"', + ); + expect(output).not.toContain("https://wrong.example/v1"); + expect(output.match(/base_url\s*=/g)).toHaveLength(1); + }); + + it("does not treat mcp_servers base_url as provider base_url", () => { + const input = [ + 'model_provider = "azure"', + 'model = "gpt-4"', + "", + "[model_providers.azure]", + 'name = "Azure OpenAI"', + 'wire_api = "responses"', + "", + "[mcp_servers.my_server]", + 'base_url = "http://localhost:8080"', + "", + ].join("\n"); + + expect(extractCodexBaseUrl(input)).toBeUndefined(); + + const output = setCodexBaseUrl(input, "https://new.azure/v1"); + + expect(output).toContain( + '[model_providers.azure]\nname = "Azure OpenAI"\nwire_api = "responses"\nbase_url = "https://new.azure/v1"', + ); + expect(output).toContain( + '[mcp_servers.my_server]\nbase_url = "http://localhost:8080"', + ); + }); + + it("reads model only from the top-level config", () => { + const input = [ + 'model_provider = "custom"', + "", + "[profiles.default]", + 'model = "profile-model"', + "", + ].join("\n"); + + expect(extractCodexModelName(input)).toBeUndefined(); + }); + + it("handles single-quoted values", () => { + const input = "base_url = 'https://api.example.com/v1'\nmodel = 'gpt-5'\n"; + + expect(extractCodexBaseUrl(input)).toBe("https://api.example.com/v1"); + expect(extractCodexModelName(input)).toBe("gpt-5"); + }); +});