mirror of
https://github.com/farion1231/cc-switch.git
synced 2026-04-23 17:45:28 +08:00
Fix URL construction when base_url already contains the endpoint path (e.g., base_url="https://api.example.com/v1/chat/completions" with endpoint="/v1/chat/completions" should not result in double path).
307 lines
9.9 KiB
Rust
307 lines
9.9 KiB
Rust
//! Codex (OpenAI) Provider Adapter
|
||
//!
|
||
//! 仅透传模式,支持直连 OpenAI API
|
||
//!
|
||
//! ## 客户端检测
|
||
//! 支持检测官方 Codex 客户端 (codex_vscode, codex_cli_rs)
|
||
|
||
use super::{AuthInfo, AuthStrategy, ProviderAdapter};
|
||
use crate::provider::Provider;
|
||
use crate::proxy::error::ProxyError;
|
||
use regex::Regex;
|
||
use reqwest::RequestBuilder;
|
||
use std::sync::LazyLock;
|
||
|
||
/// 官方 Codex 客户端 User-Agent 正则
|
||
#[allow(dead_code)]
|
||
static CODEX_CLIENT_REGEX: LazyLock<Regex> =
|
||
LazyLock::new(|| Regex::new(r"^(codex_vscode|codex_cli_rs)/[\d.]+").unwrap());
|
||
|
||
/// Codex 适配器
|
||
pub struct CodexAdapter;
|
||
|
||
impl CodexAdapter {
|
||
pub fn new() -> Self {
|
||
Self
|
||
}
|
||
|
||
/// 检测是否为官方 Codex 客户端
|
||
///
|
||
/// 匹配 User-Agent 模式: `^(codex_vscode|codex_cli_rs)/[\d.]+`
|
||
#[allow(dead_code)]
|
||
pub fn is_official_client(user_agent: &str) -> bool {
|
||
CODEX_CLIENT_REGEX.is_match(user_agent)
|
||
}
|
||
|
||
/// 从 Provider 配置中提取 API Key
|
||
fn extract_key(&self, provider: &Provider) -> Option<String> {
|
||
// 1. 尝试从 env 中获取
|
||
if let Some(env) = provider.settings_config.get("env") {
|
||
if let Some(key) = env.get("OPENAI_API_KEY").and_then(|v| v.as_str()) {
|
||
return Some(key.to_string());
|
||
}
|
||
}
|
||
|
||
// 2. 尝试从 auth 中获取 (Codex CLI 格式)
|
||
if let Some(auth) = provider.settings_config.get("auth") {
|
||
if let Some(key) = auth.get("OPENAI_API_KEY").and_then(|v| v.as_str()) {
|
||
return Some(key.to_string());
|
||
}
|
||
}
|
||
|
||
// 3. 尝试直接获取
|
||
if let Some(key) = provider
|
||
.settings_config
|
||
.get("apiKey")
|
||
.or_else(|| provider.settings_config.get("api_key"))
|
||
.and_then(|v| v.as_str())
|
||
{
|
||
return Some(key.to_string());
|
||
}
|
||
|
||
// 4. 尝试从 config 对象中获取
|
||
if let Some(config) = provider.settings_config.get("config") {
|
||
if let Some(key) = config
|
||
.get("api_key")
|
||
.or_else(|| config.get("apiKey"))
|
||
.and_then(|v| v.as_str())
|
||
{
|
||
return Some(key.to_string());
|
||
}
|
||
}
|
||
|
||
None
|
||
}
|
||
}
|
||
|
||
impl Default for CodexAdapter {
|
||
fn default() -> Self {
|
||
Self::new()
|
||
}
|
||
}
|
||
|
||
impl ProviderAdapter for CodexAdapter {
|
||
fn name(&self) -> &'static str {
|
||
"Codex"
|
||
}
|
||
|
||
fn extract_base_url(&self, provider: &Provider) -> Result<String, ProxyError> {
|
||
// 1. 尝试直接获取 base_url 字段
|
||
if let Some(url) = provider
|
||
.settings_config
|
||
.get("base_url")
|
||
.and_then(|v| v.as_str())
|
||
{
|
||
return Ok(url.trim_end_matches('/').to_string());
|
||
}
|
||
|
||
// 2. 尝试 baseURL
|
||
if let Some(url) = provider
|
||
.settings_config
|
||
.get("baseURL")
|
||
.and_then(|v| v.as_str())
|
||
{
|
||
return Ok(url.trim_end_matches('/').to_string());
|
||
}
|
||
|
||
// 3. 尝试从 config 对象中获取
|
||
if let Some(config) = provider.settings_config.get("config") {
|
||
if let Some(url) = config.get("base_url").and_then(|v| v.as_str()) {
|
||
return Ok(url.trim_end_matches('/').to_string());
|
||
}
|
||
|
||
// 尝试解析 TOML 字符串格式
|
||
if let Some(config_str) = config.as_str() {
|
||
if let Some(start) = config_str.find("base_url = \"") {
|
||
let rest = &config_str[start + 12..];
|
||
if let Some(end) = rest.find('"') {
|
||
return Ok(rest[..end].trim_end_matches('/').to_string());
|
||
}
|
||
}
|
||
if let Some(start) = config_str.find("base_url = '") {
|
||
let rest = &config_str[start + 12..];
|
||
if let Some(end) = rest.find('\'') {
|
||
return Ok(rest[..end].trim_end_matches('/').to_string());
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
Err(ProxyError::ConfigError(
|
||
"Codex Provider 缺少 base_url 配置".to_string(),
|
||
))
|
||
}
|
||
|
||
fn extract_auth(&self, provider: &Provider) -> Option<AuthInfo> {
|
||
self.extract_key(provider)
|
||
.map(|key| AuthInfo::new(key, AuthStrategy::Bearer))
|
||
}
|
||
|
||
fn build_url(&self, base_url: &str, endpoint: &str) -> String {
|
||
let base_trimmed = base_url.trim_end_matches('/');
|
||
let endpoint_trimmed = endpoint.trim_start_matches('/');
|
||
|
||
// 检查 base_url 是否已包含 endpoint 的核心路径
|
||
// 例如:base_url = "https://api.example.com/v1/chat/completions"
|
||
// endpoint = "/v1/chat/completions"
|
||
// 此时不应再拼接,直接返回 base_url
|
||
let endpoint_core = endpoint_trimmed
|
||
.trim_start_matches("v1/")
|
||
.trim_start_matches("v1");
|
||
let endpoint_core = endpoint_core.trim_start_matches('/');
|
||
|
||
// 如果 base_url 已经以 endpoint 核心路径结尾,直接返回 base_url
|
||
if !endpoint_core.is_empty() && base_trimmed.ends_with(endpoint_core) {
|
||
return base_trimmed.to_string();
|
||
}
|
||
|
||
let mut url = format!("{base_trimmed}/{endpoint_trimmed}");
|
||
|
||
// 去除重复的 /v1/v1
|
||
if url.contains("/v1/v1") {
|
||
url = url.replace("/v1/v1", "/v1");
|
||
}
|
||
|
||
url
|
||
}
|
||
|
||
fn add_auth_headers(&self, request: RequestBuilder, auth: &AuthInfo) -> RequestBuilder {
|
||
request.header("Authorization", format!("Bearer {}", auth.api_key))
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use serde_json::json;
|
||
|
||
fn create_provider(config: serde_json::Value) -> Provider {
|
||
Provider {
|
||
id: "test".to_string(),
|
||
name: "Test Codex".to_string(),
|
||
settings_config: config,
|
||
website_url: None,
|
||
category: Some("codex".to_string()),
|
||
created_at: None,
|
||
sort_index: None,
|
||
notes: None,
|
||
meta: None,
|
||
icon: None,
|
||
icon_color: None,
|
||
in_failover_queue: false,
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn test_extract_base_url_direct() {
|
||
let adapter = CodexAdapter::new();
|
||
let provider = create_provider(json!({
|
||
"base_url": "https://api.openai.com/v1"
|
||
}));
|
||
|
||
let url = adapter.extract_base_url(&provider).unwrap();
|
||
assert_eq!(url, "https://api.openai.com/v1");
|
||
}
|
||
|
||
#[test]
|
||
fn test_extract_auth_from_auth_field() {
|
||
let adapter = CodexAdapter::new();
|
||
let provider = create_provider(json!({
|
||
"auth": {
|
||
"OPENAI_API_KEY": "sk-test-key-12345678"
|
||
}
|
||
}));
|
||
|
||
let auth = adapter.extract_auth(&provider).unwrap();
|
||
assert_eq!(auth.api_key, "sk-test-key-12345678");
|
||
assert_eq!(auth.strategy, AuthStrategy::Bearer);
|
||
}
|
||
|
||
#[test]
|
||
fn test_extract_auth_from_env() {
|
||
let adapter = CodexAdapter::new();
|
||
let provider = create_provider(json!({
|
||
"env": {
|
||
"OPENAI_API_KEY": "sk-env-key-12345678"
|
||
}
|
||
}));
|
||
|
||
let auth = adapter.extract_auth(&provider).unwrap();
|
||
assert_eq!(auth.api_key, "sk-env-key-12345678");
|
||
}
|
||
|
||
#[test]
|
||
fn test_build_url() {
|
||
let adapter = CodexAdapter::new();
|
||
let url = adapter.build_url("https://api.openai.com/v1", "/responses");
|
||
assert_eq!(url, "https://api.openai.com/v1/responses");
|
||
}
|
||
|
||
#[test]
|
||
fn test_build_url_dedup_v1() {
|
||
let adapter = CodexAdapter::new();
|
||
// base_url 已包含 /v1,endpoint 也包含 /v1
|
||
let url = adapter.build_url("https://www.packyapi.com/v1", "/v1/responses");
|
||
assert_eq!(url, "https://www.packyapi.com/v1/responses");
|
||
}
|
||
|
||
#[test]
|
||
fn test_build_url_base_already_has_chat_completions() {
|
||
let adapter = CodexAdapter::new();
|
||
// base_url 已包含 chat/completions,不应再拼接
|
||
let url = adapter.build_url(
|
||
"https://api.example.com/v1/chat/completions",
|
||
"/v1/chat/completions",
|
||
);
|
||
assert_eq!(url, "https://api.example.com/v1/chat/completions");
|
||
}
|
||
|
||
#[test]
|
||
fn test_build_url_base_already_has_responses() {
|
||
let adapter = CodexAdapter::new();
|
||
// base_url 已包含 responses,不应再拼接
|
||
let url = adapter.build_url("https://api.example.com/v1/responses", "/v1/responses");
|
||
assert_eq!(url, "https://api.example.com/v1/responses");
|
||
}
|
||
|
||
#[test]
|
||
fn test_build_url_base_without_endpoint() {
|
||
let adapter = CodexAdapter::new();
|
||
// base_url 不包含 endpoint,应正常拼接
|
||
let url = adapter.build_url("https://api.example.com/v1", "/v1/chat/completions");
|
||
assert_eq!(url, "https://api.example.com/v1/chat/completions");
|
||
}
|
||
|
||
// 官方客户端检测测试
|
||
#[test]
|
||
fn test_is_official_client_vscode() {
|
||
assert!(CodexAdapter::is_official_client("codex_vscode/1.0.0"));
|
||
assert!(CodexAdapter::is_official_client("codex_vscode/2.3.4"));
|
||
assert!(CodexAdapter::is_official_client("codex_vscode/0.1"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_is_official_client_cli() {
|
||
assert!(CodexAdapter::is_official_client("codex_cli_rs/1.0.0"));
|
||
assert!(CodexAdapter::is_official_client("codex_cli_rs/0.5.2"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_is_not_official_client() {
|
||
assert!(!CodexAdapter::is_official_client("Mozilla/5.0"));
|
||
assert!(!CodexAdapter::is_official_client("curl/7.68.0"));
|
||
assert!(!CodexAdapter::is_official_client("python-requests/2.25.1"));
|
||
assert!(!CodexAdapter::is_official_client("codex_other/1.0.0"));
|
||
assert!(!CodexAdapter::is_official_client(""));
|
||
}
|
||
|
||
#[test]
|
||
fn test_is_official_client_partial_match() {
|
||
// 必须从开头匹配
|
||
assert!(!CodexAdapter::is_official_client("some codex_vscode/1.0.0"));
|
||
assert!(!CodexAdapter::is_official_client(
|
||
"prefix_codex_cli_rs/1.0.0"
|
||
));
|
||
}
|
||
}
|