Files
cc-switch/src-tauri/src/services/provider/mod.rs
2025-12-11 12:13:27 +08:00

679 lines
24 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! Provider service module
//!
//! Handles provider CRUD operations, switching, and configuration management.
mod endpoints;
mod gemini_auth;
mod live;
mod usage;
use indexmap::IndexMap;
use regex::Regex;
use serde::Deserialize;
use serde_json::Value;
use crate::app_config::AppType;
use crate::error::AppError;
use crate::provider::{Provider, UsageResult};
use crate::services::mcp::McpService;
use crate::settings::CustomEndpoint;
use crate::store::AppState;
// Re-export sub-module functions for external access
pub use live::{import_default_config, read_live_settings, sync_current_to_live};
// Internal re-exports (pub(crate))
pub(crate) use live::write_live_snapshot;
// Internal re-exports
use live::write_gemini_live;
use usage::validate_usage_script;
/// Provider business logic service
pub struct ProviderService;
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn validate_provider_settings_rejects_missing_auth() {
let provider = Provider::with_id(
"codex".into(),
"Codex".into(),
json!({ "config": "base_url = \"https://example.com\"" }),
None,
);
let err = ProviderService::validate_provider_settings(&AppType::Codex, &provider)
.expect_err("missing auth should be rejected");
assert!(
err.to_string().contains("auth"),
"expected auth error, got {err:?}"
);
}
#[test]
fn extract_credentials_returns_expected_values() {
let provider = Provider::with_id(
"claude".into(),
"Claude".into(),
json!({
"env": {
"ANTHROPIC_AUTH_TOKEN": "token",
"ANTHROPIC_BASE_URL": "https://claude.example"
}
}),
None,
);
let (api_key, base_url) =
ProviderService::extract_credentials(&provider, &AppType::Claude).unwrap();
assert_eq!(api_key, "token");
assert_eq!(base_url, "https://claude.example");
}
}
impl ProviderService {
fn normalize_provider_if_claude(app_type: &AppType, provider: &mut Provider) {
if matches!(app_type, AppType::Claude) {
let mut v = provider.settings_config.clone();
if normalize_claude_models_in_value(&mut v) {
provider.settings_config = v;
}
}
}
/// List all providers for an app type
pub fn list(
state: &AppState,
app_type: AppType,
) -> Result<IndexMap<String, Provider>, AppError> {
state.db.get_all_providers(app_type.as_str())
}
/// Get current provider ID
///
/// 使用有效的当前供应商 ID验证过存在性
/// 优先从本地 settings 读取,验证后 fallback 到数据库的 is_current 字段。
/// 这确保了云同步场景下多设备可以独立选择供应商,且返回的 ID 一定有效。
pub fn current(state: &AppState, app_type: AppType) -> Result<String, AppError> {
crate::settings::get_effective_current_provider(&state.db, &app_type)
.map(|opt| opt.unwrap_or_default())
}
/// Add a new provider
pub fn add(state: &AppState, app_type: AppType, provider: Provider) -> Result<bool, AppError> {
let mut provider = provider;
// Normalize Claude model keys
Self::normalize_provider_if_claude(&app_type, &mut provider);
Self::validate_provider_settings(&app_type, &provider)?;
// Save to database
state.db.save_provider(app_type.as_str(), &provider)?;
// Check if sync is needed (if this is current provider, or no current provider)
let current = state.db.get_current_provider(app_type.as_str())?;
if current.is_none() {
// No current provider, set as current and sync
state
.db
.set_current_provider(app_type.as_str(), &provider.id)?;
write_live_snapshot(&app_type, &provider)?;
}
Ok(true)
}
/// Update a provider
pub fn update(
state: &AppState,
app_type: AppType,
provider: Provider,
) -> Result<bool, AppError> {
let mut provider = provider;
// Normalize Claude model keys
Self::normalize_provider_if_claude(&app_type, &mut provider);
Self::validate_provider_settings(&app_type, &provider)?;
// Check if this is current provider (use effective current, not just DB)
let effective_current =
crate::settings::get_effective_current_provider(&state.db, &app_type)?;
let is_current = effective_current.as_deref() == Some(provider.id.as_str());
// Save to database
state.db.save_provider(app_type.as_str(), &provider)?;
if is_current {
write_live_snapshot(&app_type, &provider)?;
// Sync MCP
McpService::sync_all_enabled(state)?;
}
Ok(true)
}
/// Delete a provider
///
/// 同时检查本地 settings 和数据库的当前供应商,防止删除任一端正在使用的供应商。
pub fn delete(state: &AppState, app_type: AppType, id: &str) -> Result<(), AppError> {
// Check both local settings and database
let local_current = crate::settings::get_current_provider(&app_type);
let db_current = state.db.get_current_provider(app_type.as_str())?;
if local_current.as_deref() == Some(id) || db_current.as_deref() == Some(id) {
return Err(AppError::Message(
"无法删除当前正在使用的供应商".to_string(),
));
}
state.db.delete_provider(app_type.as_str(), id)
}
/// Switch to a provider
///
/// Switch flow:
/// 1. Validate target provider exists
/// 2. Check if proxy takeover mode is active AND proxy server is running
/// 3. If takeover mode active: hot-switch proxy target only (no Live config write)
/// 4. If normal mode:
/// a. **Backfill mechanism**: Backfill current live config to current provider
/// b. Update local settings current_provider_xxx (device-level)
/// c. Update database is_current (as default for new devices)
/// d. Write target provider config to live files
/// e. Sync MCP configuration
pub fn switch(state: &AppState, app_type: AppType, id: &str) -> Result<(), AppError> {
// Check if provider exists
let providers = state.db.get_all_providers(app_type.as_str())?;
let _provider = providers
.get(id)
.ok_or_else(|| AppError::Message(format!("供应商 {id} 不存在")))?;
// Check if proxy takeover mode is active AND proxy server is actually running
// Both conditions must be true to use hot-switch mode
// Use blocking wait since this is a sync function
let is_takeover_flag =
futures::executor::block_on(state.db.is_live_takeover_active()).unwrap_or(false);
let is_proxy_running = futures::executor::block_on(state.proxy_service.is_running());
// Hot-switch only when BOTH: takeover flag is set AND proxy server is actually running
let should_hot_switch = is_takeover_flag && is_proxy_running;
if should_hot_switch {
// Proxy takeover mode: hot-switch only, don't write Live config
log::info!(
"代理接管模式:热切换 {} 的目标供应商为 {}",
app_type.as_str(),
id
);
// Update database is_current
state.db.set_current_provider(app_type.as_str(), id)?;
// 同时更新 is_proxy_target代理路由器使用此字段选择供应商
state.db.set_proxy_target_provider(app_type.as_str(), id)?;
// Update local settings for consistency
crate::settings::set_current_provider(&app_type, Some(id))?;
// Note: No Live config write, no MCP sync
// The proxy server will route requests to the new provider via is_proxy_target
return Ok(());
}
// Normal mode: full switch with Live config write
// Also clear stale takeover flag if proxy is not running but flag was set
if is_takeover_flag && !is_proxy_running {
log::warn!("检测到代理接管标志残留(代理已停止),清除标志并执行正常切换");
// Clear stale takeover flag
let _ = futures::executor::block_on(state.db.set_live_takeover_active(false));
}
Self::switch_normal(state, app_type, id, &providers)
}
/// Normal switch flow (non-proxy mode)
fn switch_normal(
state: &AppState,
app_type: AppType,
id: &str,
providers: &indexmap::IndexMap<String, Provider>,
) -> Result<(), AppError> {
let provider = providers
.get(id)
.ok_or_else(|| AppError::Message(format!("供应商 {id} 不存在")))?;
// Backfill: Backfill current live config to current provider
// Use effective current provider (validated existence) to ensure backfill targets valid provider
let current_id = crate::settings::get_effective_current_provider(&state.db, &app_type)?;
if let Some(current_id) = current_id {
if current_id != id {
// Only backfill when switching to a different provider
if let Ok(live_config) = read_live_settings(app_type.clone()) {
if let Some(mut current_provider) = providers.get(&current_id).cloned() {
current_provider.settings_config = live_config;
// Ignore backfill failure, don't affect switch flow
let _ = state.db.save_provider(app_type.as_str(), &current_provider);
}
}
}
}
// Update local settings (device-level, takes priority)
crate::settings::set_current_provider(&app_type, Some(id))?;
// Update database is_current (as default for new devices)
state.db.set_current_provider(app_type.as_str(), id)?;
// Sync to live (write_gemini_live handles security flag internally for Gemini)
write_live_snapshot(&app_type, provider)?;
// Sync MCP
McpService::sync_all_enabled(state)?;
Ok(())
}
/// Set proxy target provider
pub fn set_proxy_target(state: &AppState, app_type: AppType, id: &str) -> Result<(), AppError> {
// Check if provider exists
let providers = state.db.get_all_providers(app_type.as_str())?;
if !providers.contains_key(id) {
return Err(AppError::Message(format!("供应商 {id} 不存在")));
}
state.db.set_proxy_target_provider(app_type.as_str(), id)?;
Ok(())
}
/// Sync current provider to live configuration (re-export)
pub fn sync_current_to_live(state: &AppState) -> Result<(), AppError> {
sync_current_to_live(state)
}
/// Import default configuration from live files (re-export)
///
/// Returns `Ok(true)` if imported, `Ok(false)` if skipped.
pub fn import_default_config(state: &AppState, app_type: AppType) -> Result<bool, AppError> {
import_default_config(state, app_type)
}
/// Read current live settings (re-export)
pub fn read_live_settings(app_type: AppType) -> Result<Value, AppError> {
read_live_settings(app_type)
}
/// Get custom endpoints list (re-export)
pub fn get_custom_endpoints(
state: &AppState,
app_type: AppType,
provider_id: &str,
) -> Result<Vec<CustomEndpoint>, AppError> {
endpoints::get_custom_endpoints(state, app_type, provider_id)
}
/// Add custom endpoint (re-export)
pub fn add_custom_endpoint(
state: &AppState,
app_type: AppType,
provider_id: &str,
url: String,
) -> Result<(), AppError> {
endpoints::add_custom_endpoint(state, app_type, provider_id, url)
}
/// Remove custom endpoint (re-export)
pub fn remove_custom_endpoint(
state: &AppState,
app_type: AppType,
provider_id: &str,
url: String,
) -> Result<(), AppError> {
endpoints::remove_custom_endpoint(state, app_type, provider_id, url)
}
/// Update endpoint last used timestamp (re-export)
pub fn update_endpoint_last_used(
state: &AppState,
app_type: AppType,
provider_id: &str,
url: String,
) -> Result<(), AppError> {
endpoints::update_endpoint_last_used(state, app_type, provider_id, url)
}
/// Update provider sort order
pub fn update_sort_order(
state: &AppState,
app_type: AppType,
updates: Vec<ProviderSortUpdate>,
) -> Result<bool, AppError> {
let mut providers = state.db.get_all_providers(app_type.as_str())?;
for update in updates {
if let Some(provider) = providers.get_mut(&update.id) {
provider.sort_index = Some(update.sort_index);
state.db.save_provider(app_type.as_str(), provider)?;
}
}
Ok(true)
}
/// Query provider usage (re-export)
pub async fn query_usage(
state: &AppState,
app_type: AppType,
provider_id: &str,
) -> Result<UsageResult, AppError> {
usage::query_usage(state, app_type, provider_id).await
}
/// Test usage script (re-export)
#[allow(clippy::too_many_arguments)]
pub async fn test_usage_script(
state: &AppState,
app_type: AppType,
provider_id: &str,
script_code: &str,
timeout: u64,
api_key: Option<&str>,
base_url: Option<&str>,
access_token: Option<&str>,
user_id: Option<&str>,
) -> Result<UsageResult, AppError> {
usage::test_usage_script(
state,
app_type,
provider_id,
script_code,
timeout,
api_key,
base_url,
access_token,
user_id,
)
.await
}
pub(crate) fn write_gemini_live(provider: &Provider) -> Result<(), AppError> {
write_gemini_live(provider)
}
fn validate_provider_settings(app_type: &AppType, provider: &Provider) -> Result<(), AppError> {
match app_type {
AppType::Claude => {
if !provider.settings_config.is_object() {
return Err(AppError::localized(
"provider.claude.settings.not_object",
"Claude 配置必须是 JSON 对象",
"Claude configuration must be a JSON object",
));
}
}
AppType::Codex => {
let settings = provider.settings_config.as_object().ok_or_else(|| {
AppError::localized(
"provider.codex.settings.not_object",
"Codex 配置必须是 JSON 对象",
"Codex configuration must be a JSON object",
)
})?;
let auth = settings.get("auth").ok_or_else(|| {
AppError::localized(
"provider.codex.auth.missing",
format!("供应商 {} 缺少 auth 配置", provider.id),
format!("Provider {} is missing auth configuration", provider.id),
)
})?;
if !auth.is_object() {
return Err(AppError::localized(
"provider.codex.auth.not_object",
format!("供应商 {} 的 auth 配置必须是 JSON 对象", provider.id),
format!(
"Provider {} auth configuration must be a JSON object",
provider.id
),
));
}
if let Some(config_value) = settings.get("config") {
if !(config_value.is_string() || config_value.is_null()) {
return Err(AppError::localized(
"provider.codex.config.invalid_type",
"Codex config 字段必须是字符串",
"Codex config field must be a string",
));
}
if let Some(cfg_text) = config_value.as_str() {
crate::codex_config::validate_config_toml(cfg_text)?;
}
}
}
AppType::Gemini => {
use crate::gemini_config::validate_gemini_settings;
validate_gemini_settings(&provider.settings_config)?
}
}
// Validate and clean UsageScript configuration (common for all app types)
if let Some(meta) = &provider.meta {
if let Some(usage_script) = &meta.usage_script {
validate_usage_script(usage_script)?;
}
}
Ok(())
}
#[allow(dead_code)]
fn extract_credentials(
provider: &Provider,
app_type: &AppType,
) -> Result<(String, String), AppError> {
match app_type {
AppType::Claude => {
let env = provider
.settings_config
.get("env")
.and_then(|v| v.as_object())
.ok_or_else(|| {
AppError::localized(
"provider.claude.env.missing",
"配置格式错误: 缺少 env",
"Invalid configuration: missing env section",
)
})?;
let api_key = env
.get("ANTHROPIC_AUTH_TOKEN")
.or_else(|| env.get("ANTHROPIC_API_KEY"))
.and_then(|v| v.as_str())
.ok_or_else(|| {
AppError::localized(
"provider.claude.api_key.missing",
"缺少 API Key",
"API key is missing",
)
})?
.to_string();
let base_url = env
.get("ANTHROPIC_BASE_URL")
.and_then(|v| v.as_str())
.ok_or_else(|| {
AppError::localized(
"provider.claude.base_url.missing",
"缺少 ANTHROPIC_BASE_URL 配置",
"Missing ANTHROPIC_BASE_URL configuration",
)
})?
.to_string();
Ok((api_key, base_url))
}
AppType::Codex => {
let auth = provider
.settings_config
.get("auth")
.and_then(|v| v.as_object())
.ok_or_else(|| {
AppError::localized(
"provider.codex.auth.missing",
"配置格式错误: 缺少 auth",
"Invalid configuration: missing auth section",
)
})?;
let api_key = auth
.get("OPENAI_API_KEY")
.and_then(|v| v.as_str())
.ok_or_else(|| {
AppError::localized(
"provider.codex.api_key.missing",
"缺少 API Key",
"API key is missing",
)
})?
.to_string();
let config_toml = provider
.settings_config
.get("config")
.and_then(|v| v.as_str())
.unwrap_or("");
let base_url = if config_toml.contains("base_url") {
let re = Regex::new(r#"base_url\s*=\s*["']([^"']+)["']"#).map_err(|e| {
AppError::localized(
"provider.regex_init_failed",
format!("正则初始化失败: {e}"),
format!("Failed to initialize regex: {e}"),
)
})?;
re.captures(config_toml)
.and_then(|caps| caps.get(1))
.map(|m| m.as_str().to_string())
.ok_or_else(|| {
AppError::localized(
"provider.codex.base_url.invalid",
"config.toml 中 base_url 格式错误",
"base_url in config.toml has invalid format",
)
})?
} else {
return Err(AppError::localized(
"provider.codex.base_url.missing",
"config.toml 中缺少 base_url 配置",
"base_url is missing from config.toml",
));
};
Ok((api_key, base_url))
}
AppType::Gemini => {
use crate::gemini_config::json_to_env;
let env_map = json_to_env(&provider.settings_config)?;
let api_key = env_map.get("GEMINI_API_KEY").cloned().ok_or_else(|| {
AppError::localized(
"gemini.missing_api_key",
"缺少 GEMINI_API_KEY",
"Missing GEMINI_API_KEY",
)
})?;
let base_url = env_map
.get("GOOGLE_GEMINI_BASE_URL")
.cloned()
.unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_string());
Ok((api_key, base_url))
}
}
}
}
/// Normalize Claude model keys in a JSON value
///
/// Reads old key (ANTHROPIC_SMALL_FAST_MODEL), writes new keys (DEFAULT_*), and deletes old key.
pub(crate) fn normalize_claude_models_in_value(settings: &mut Value) -> bool {
let mut changed = false;
let env = match settings.get_mut("env").and_then(|v| v.as_object_mut()) {
Some(obj) => obj,
None => return changed,
};
let model = env
.get("ANTHROPIC_MODEL")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let small_fast = env
.get("ANTHROPIC_SMALL_FAST_MODEL")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let current_haiku = env
.get("ANTHROPIC_DEFAULT_HAIKU_MODEL")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let current_sonnet = env
.get("ANTHROPIC_DEFAULT_SONNET_MODEL")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let current_opus = env
.get("ANTHROPIC_DEFAULT_OPUS_MODEL")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let target_haiku = current_haiku
.or_else(|| small_fast.clone())
.or_else(|| model.clone());
let target_sonnet = current_sonnet
.or_else(|| model.clone())
.or_else(|| small_fast.clone());
let target_opus = current_opus
.or_else(|| model.clone())
.or_else(|| small_fast.clone());
if env.get("ANTHROPIC_DEFAULT_HAIKU_MODEL").is_none() {
if let Some(v) = target_haiku {
env.insert(
"ANTHROPIC_DEFAULT_HAIKU_MODEL".to_string(),
Value::String(v),
);
changed = true;
}
}
if env.get("ANTHROPIC_DEFAULT_SONNET_MODEL").is_none() {
if let Some(v) = target_sonnet {
env.insert(
"ANTHROPIC_DEFAULT_SONNET_MODEL".to_string(),
Value::String(v),
);
changed = true;
}
}
if env.get("ANTHROPIC_DEFAULT_OPUS_MODEL").is_none() {
if let Some(v) = target_opus {
env.insert("ANTHROPIC_DEFAULT_OPUS_MODEL".to_string(), Value::String(v));
changed = true;
}
}
if env.remove("ANTHROPIC_SMALL_FAST_MODEL").is_some() {
changed = true;
}
changed
}
#[derive(Debug, Clone, Deserialize)]
pub struct ProviderSortUpdate {
pub id: String,
#[serde(rename = "sortIndex")]
pub sort_index: usize,
}