diff --git a/src-tauri/src/database/migration.rs b/src-tauri/src/database/migration.rs index 51c60e41..c96eb4e2 100644 --- a/src-tauri/src/database/migration.rs +++ b/src-tauri/src/database/migration.rs @@ -153,12 +153,14 @@ impl Database { tx: &rusqlite::Transaction<'_>, config: &MultiAppConfig, ) -> Result<(), AppError> { - let migrate_app_prompts = - |prompts_map: &std::collections::HashMap, - app_type: &str| - -> Result<(), AppError> { - for (id, prompt) in prompts_map { - tx.execute( + let migrate_app_prompts = |prompts_map: &std::collections::HashMap< + String, + crate::prompt::Prompt, + >, + app_type: &str| + -> Result<(), AppError> { + for (id, prompt) in prompts_map { + tx.execute( "INSERT OR REPLACE INTO prompts ( id, app_type, name, content, description, enabled, created_at, updated_at ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", @@ -174,9 +176,9 @@ impl Database { ], ) .map_err(|e| AppError::Database(format!("Migrate prompt failed: {e}")))?; - } - Ok(()) - }; + } + Ok(()) + }; migrate_app_prompts(&config.prompts.claude.prompts, "claude")?; migrate_app_prompts(&config.prompts.codex.prompts, "codex")?; diff --git a/src-tauri/src/database/schema.rs b/src-tauri/src/database/schema.rs index f0750aed..fd75818b 100644 --- a/src-tauri/src/database/schema.rs +++ b/src-tauri/src/database/schema.rs @@ -226,13 +226,13 @@ impl Database { Self::add_column_if_missing(conn, "skills", "installed_at", "INTEGER NOT NULL DEFAULT 0")?; // skill_repos 表 - Self::add_column_if_missing(conn, "skill_repos", "branch", "TEXT NOT NULL DEFAULT 'main'")?; Self::add_column_if_missing( conn, "skill_repos", - "enabled", - "BOOLEAN NOT NULL DEFAULT 1", + "branch", + "TEXT NOT NULL DEFAULT 'main'", )?; + Self::add_column_if_missing(conn, "skill_repos", "enabled", "BOOLEAN NOT NULL DEFAULT 1")?; Self::add_column_if_missing(conn, "skill_repos", "skills_path", "TEXT")?; Ok(()) @@ -247,9 +247,7 @@ impl Database { pub(crate) fn set_user_version(conn: &Connection, version: i32) -> Result<(), AppError> { if version < 0 { - return Err(AppError::Database( - "user_version 不能为负数".to_string(), - )); + return Err(AppError::Database("user_version 不能为负数".to_string())); } let sql = format!("PRAGMA user_version = {version};"); conn.execute(&sql, []) @@ -261,10 +259,7 @@ impl Database { if s.is_empty() { return Err(AppError::Database(format!("{kind} 不能为空"))); } - if !s - .chars() - .all(|c| c.is_ascii_alphanumeric() || c == '_') - { + if !s.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') { return Err(AppError::Database(format!( "非法{kind}: {s},仅允许字母、数字和下划线" ))); @@ -292,7 +287,11 @@ impl Database { Ok(false) } - pub(crate) fn has_column(conn: &Connection, table: &str, column: &str) -> Result { + pub(crate) fn has_column( + conn: &Connection, + table: &str, + column: &str, + ) -> Result { Self::validate_identifier(table, "表名")?; Self::validate_identifier(column, "列名")?; diff --git a/src-tauri/src/database/tests.rs b/src-tauri/src/database/tests.rs index e7676847..a7684cf2 100644 --- a/src-tauri/src/database/tests.rs +++ b/src-tauri/src/database/tests.rs @@ -108,8 +108,8 @@ fn migration_rejects_future_version() { Database::create_tables_on_conn(&conn).expect("create tables"); Database::set_user_version(&conn, SCHEMA_VERSION + 1).expect("set future version"); - let err = Database::apply_schema_migrations_on_conn(&conn) - .expect_err("should reject higher version"); + let err = + Database::apply_schema_migrations_on_conn(&conn).expect_err("should reject higher version"); assert!( err.to_string().contains("数据库版本过新"), "unexpected error: {err}" @@ -168,10 +168,7 @@ fn migration_aligns_column_defaults_and_types() { let is_current = get_column_info(&conn, "providers", "is_current"); assert_eq!(is_current.r#type, "BOOLEAN"); assert_eq!(is_current.notnull, 1); - assert_eq!( - normalize_default(&is_current.default).as_deref(), - Some("0") - ); + assert_eq!(normalize_default(&is_current.default).as_deref(), Some("0")); let tags = get_column_info(&conn, "mcp_servers", "tags"); assert_eq!(tags.r#type, "TEXT"); @@ -181,10 +178,7 @@ fn migration_aligns_column_defaults_and_types() { let enabled = get_column_info(&conn, "prompts", "enabled"); assert_eq!(enabled.r#type, "BOOLEAN"); assert_eq!(enabled.notnull, 1); - assert_eq!( - normalize_default(&enabled.default).as_deref(), - Some("1") - ); + assert_eq!(normalize_default(&enabled.default).as_deref(), Some("1")); let installed_at = get_column_info(&conn, "skills", "installed_at"); assert_eq!(installed_at.r#type, "INTEGER"); diff --git a/src-tauri/src/deeplink/tests.rs b/src-tauri/src/deeplink/tests.rs index f77ea0f5..40568154 100644 --- a/src-tauri/src/deeplink/tests.rs +++ b/src-tauri/src/deeplink/tests.rs @@ -307,8 +307,7 @@ fn test_import_prompt_allows_space_in_base64_content() { let db = Arc::new(Database::memory().expect("create memory db")); let state = AppState::new(db.clone()); - let prompt_id = - import_prompt_from_deeplink(&state, request.clone()).expect("import prompt"); + let prompt_id = import_prompt_from_deeplink(&state, request.clone()).expect("import prompt"); let prompts = state.db.get_prompts("codex").expect("get prompts"); let prompt = prompts.get(&prompt_id).expect("prompt saved"); diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 6a2e0f09..311879f4 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -223,11 +223,9 @@ fn create_tray_menu( let providers = app_state.db.get_all_providers(app_type_str)?; // 使用有效的当前供应商 ID(验证存在性,自动清理失效 ID) - let current_id = crate::settings::get_effective_current_provider( - &app_state.db, - §ion.app_type, - )? - .unwrap_or_default(); + let current_id = + crate::settings::get_effective_current_provider(&app_state.db, §ion.app_type)? + .unwrap_or_default(); let manager = crate::provider::ProviderManager { providers, @@ -1033,21 +1031,19 @@ fn show_migration_error_dialog(app: &tauri::AppHandle, error: &str) -> bool { let message = if is_chinese_locale() { format!( - "从旧版本迁移配置时发生错误:\n\n{}\n\n\ + "从旧版本迁移配置时发生错误:\n\n{error}\n\n\ 您的数据尚未丢失,旧配置文件仍然保留。\n\ 建议回退到旧版本 CC Switch 以保护数据。\n\n\ 点击「重试」重新尝试迁移\n\ - 点击「退出」关闭程序(可回退版本后重新打开)", - error + 点击「退出」关闭程序(可回退版本后重新打开)" ) } else { format!( - "An error occurred while migrating configuration:\n\n{}\n\n\ + "An error occurred while migrating configuration:\n\n{error}\n\n\ Your data is NOT lost - the old config file is still preserved.\n\ Consider rolling back to an older CC Switch version.\n\n\ Click 'Retry' to attempt migration again\n\ - Click 'Exit' to close the program", - error + Click 'Exit' to close the program" ) }; diff --git a/src-tauri/src/services/provider/live.rs b/src-tauri/src/services/provider/live.rs index 2450f764..b4db7221 100644 --- a/src-tauri/src/services/provider/live.rs +++ b/src-tauri/src/services/provider/live.rs @@ -100,12 +100,13 @@ pub(crate) fn write_live_snapshot(app_type: &AppType, provider: &Provider) -> Re write_json_file(&path, &provider.settings_config)?; } AppType::Codex => { - let obj = provider.settings_config.as_object().ok_or_else(|| { - AppError::Config("Codex 供应商配置必须是 JSON 对象".to_string()) - })?; - let auth = obj.get("auth").ok_or_else(|| { - AppError::Config("Codex 供应商配置缺少 'auth' 字段".to_string()) - })?; + let obj = provider + .settings_config + .as_object() + .ok_or_else(|| AppError::Config("Codex 供应商配置必须是 JSON 对象".to_string()))?; + let auth = obj + .get("auth") + .ok_or_else(|| AppError::Config("Codex 供应商配置缺少 'auth' 字段".to_string()))?; let config_str = obj.get("config").and_then(|v| v.as_str()).ok_or_else(|| { AppError::Config("Codex 供应商配置缺少 'config' 字段或不是字符串".to_string()) })?; @@ -113,8 +114,7 @@ pub(crate) fn write_live_snapshot(app_type: &AppType, provider: &Provider) -> Re let auth_path = get_codex_auth_path(); write_json_file(&auth_path, auth)?; let config_path = get_codex_config_path(); - std::fs::write(&config_path, config_str) - .map_err(|e| AppError::io(&config_path, e))?; + std::fs::write(&config_path, config_str).map_err(|e| AppError::io(&config_path, e))?; } AppType::Gemini => { // Delegate to write_gemini_live which handles env file writing correctly @@ -132,11 +132,11 @@ pub(crate) fn write_live_snapshot(app_type: &AppType, provider: &Provider) -> Re pub fn sync_current_to_live(state: &AppState) -> Result<(), AppError> { for app_type in [AppType::Claude, AppType::Codex, AppType::Gemini] { // Use validated effective current provider - let current_id = match crate::settings::get_effective_current_provider(&state.db, &app_type)? - { - Some(id) => id, - None => continue, - }; + let current_id = + match crate::settings::get_effective_current_provider(&state.db, &app_type)? { + Some(id) => id, + None => continue, + }; let providers = state.db.get_all_providers(app_type.as_str())?; if let Some(provider) = providers.get(¤t_id) { diff --git a/src-tauri/src/services/provider/usage.rs b/src-tauri/src/services/provider/usage.rs index 2f28db3f..ee95c00c 100644 --- a/src-tauri/src/services/provider/usage.rs +++ b/src-tauri/src/services/provider/usage.rs @@ -168,9 +168,7 @@ pub(crate) fn validate_usage_script(script: &UsageScript) -> Result<(), AppError if interval > 1440 { return Err(AppError::localized( "usage_script.interval_too_large", - format!( - "自动查询间隔不能超过 1440 分钟(24小时),当前值: {interval}" - ), + format!("自动查询间隔不能超过 1440 分钟(24小时),当前值: {interval}"), format!( "Auto query interval cannot exceed 1440 minutes (24 hours), current: {interval}" ), diff --git a/src-tauri/src/services/skill.rs b/src-tauri/src/services/skill.rs index f771cfd7..23a3e4d2 100644 --- a/src-tauri/src/services/skill.rs +++ b/src-tauri/src/services/skill.rs @@ -214,56 +214,8 @@ impl SkillService { temp_dir.clone() }; - // 遍历目标目录 - for entry in fs::read_dir(&scan_dir)? { - let entry = entry?; - let path = entry.path(); - - if !path.is_dir() { - continue; - } - - let skill_md = path.join("SKILL.md"); - if !skill_md.exists() { - continue; - } - - // 解析技能元数据 - match self.parse_skill_metadata(&skill_md) { - Ok(meta) => { - // 安全地获取目录名 - let Some(dir_name) = path.file_name() else { - log::warn!("Failed to get directory name from path: {path:?}"); - continue; - }; - let directory = dir_name.to_string_lossy().to_string(); - - // 构建 README URL(考虑 skillsPath) - let readme_path = if let Some(ref skills_path) = repo.skills_path { - format!("{}/{}", skills_path.trim_matches('/'), directory) - } else { - directory.clone() - }; - - skills.push(Skill { - key: format!("{}/{}:{}", repo.owner, repo.name, directory), - name: meta.name.unwrap_or_else(|| directory.clone()), - description: meta.description.unwrap_or_default(), - directory, - readme_url: Some(format!( - "https://github.com/{}/{}/tree/{}/{}", - repo.owner, repo.name, repo.branch, readme_path - )), - installed: false, - repo_owner: Some(repo.owner.clone()), - repo_name: Some(repo.name.clone()), - repo_branch: Some(repo.branch.clone()), - skills_path: repo.skills_path.clone(), - }); - } - Err(e) => log::warn!("解析 {} 元数据失败: {}", skill_md.display(), e), - } - } + // 递归扫描目录查找所有技能 + self.scan_dir_recursive(&scan_dir, &scan_dir, repo, &mut skills)?; // 清理临时目录 let _ = fs::remove_dir_all(&temp_dir); @@ -271,6 +223,90 @@ impl SkillService { Ok(skills) } + /// 递归扫描目录查找 SKILL.md + /// + /// 规则: + /// 1. 如果当前目录存在 SKILL.md,则识别为技能,停止扫描其子目录(子目录视为功能文件夹) + /// 2. 如果当前目录不存在 SKILL.md,则递归扫描所有子目录 + fn scan_dir_recursive( + &self, + current_dir: &Path, + base_dir: &Path, + repo: &SkillRepo, + skills: &mut Vec, + ) -> Result<()> { + // 检查当前目录是否包含 SKILL.md + let skill_md = current_dir.join("SKILL.md"); + + if skill_md.exists() { + // 发现技能!获取相对路径作为目录名 + let directory = if current_dir == base_dir { + // 根目录的 SKILL.md,使用仓库名 + repo.name.clone() + } else { + // 子目录的 SKILL.md,使用相对路径 + current_dir + .strip_prefix(base_dir) + .unwrap_or(current_dir) + .to_string_lossy() + .to_string() + }; + + if let Ok(skill) = self.build_skill_from_metadata(&skill_md, &directory, repo) { + skills.push(skill); + } + + // 停止扫描此目录的子目录(同级目录都是功能文件夹) + return Ok(()); + } + + // 未发现 SKILL.md,继续递归扫描所有子目录 + for entry in fs::read_dir(current_dir)? { + let entry = entry?; + let path = entry.path(); + + // 只处理目录 + if path.is_dir() { + self.scan_dir_recursive(&path, base_dir, repo, skills)?; + } + } + + Ok(()) + } + + /// 从 SKILL.md 构建技能对象 + fn build_skill_from_metadata( + &self, + skill_md: &Path, + directory: &str, + repo: &SkillRepo, + ) -> Result { + let meta = self.parse_skill_metadata(skill_md)?; + + // 构建 README URL + let readme_path = if let Some(ref skills_path) = repo.skills_path { + format!("{}/{}", skills_path.trim_matches('/'), directory) + } else { + directory.to_string() + }; + + Ok(Skill { + key: format!("{}/{}:{}", repo.owner, repo.name, directory), + name: meta.name.unwrap_or_else(|| directory.to_string()), + description: meta.description.unwrap_or_default(), + directory: directory.to_string(), + readme_url: Some(format!( + "https://github.com/{}/{}/tree/{}/{}", + repo.owner, repo.name, repo.branch, readme_path + )), + installed: false, + repo_owner: Some(repo.owner.clone()), + repo_name: Some(repo.name.clone()), + repo_branch: Some(repo.branch.clone()), + skills_path: repo.skills_path.clone(), + }) + } + /// 解析技能元数据 fn parse_skill_metadata(&self, path: &Path) -> Result { let content = fs::read_to_string(path)?; @@ -302,25 +338,18 @@ impl SkillService { return Ok(()); } - for entry in fs::read_dir(&self.install_dir)? { - let entry = entry?; - let path = entry.path(); + // 收集所有本地技能 + let mut local_skills = Vec::new(); + self.scan_local_dir_recursive(&self.install_dir, &self.install_dir, &mut local_skills)?; - if !path.is_dir() { - continue; - } + // 处理找到的本地技能 + for local_skill in local_skills { + let directory = &local_skill.directory; - // 安全地获取目录名 - let Some(dir_name) = path.file_name() else { - log::warn!("Failed to get directory name from path: {path:?}"); - continue; - }; - let directory = dir_name.to_string_lossy().to_string(); - - // 更新已安装状态 + // 更新已安装状态(匹配远程技能) let mut found = false; for skill in skills.iter_mut() { - if skill.directory.eq_ignore_ascii_case(&directory) { + if skill.directory.eq_ignore_ascii_case(directory) { skill.installed = true; found = true; break; @@ -329,23 +358,69 @@ impl SkillService { // 添加本地独有的技能(仅当在仓库中未找到时) if !found { - let skill_md = path.join("SKILL.md"); - if skill_md.exists() { - if let Ok(meta) = self.parse_skill_metadata(&skill_md) { - skills.push(Skill { - key: format!("local:{directory}"), - name: meta.name.unwrap_or_else(|| directory.clone()), - description: meta.description.unwrap_or_default(), - directory: directory.clone(), - readme_url: None, - installed: true, - repo_owner: None, - repo_name: None, - repo_branch: None, - skills_path: None, - }); - } - } + skills.push(local_skill); + } + } + + Ok(()) + } + + /// 递归扫描本地目录查找 SKILL.md + fn scan_local_dir_recursive( + &self, + current_dir: &Path, + base_dir: &Path, + skills: &mut Vec, + ) -> Result<()> { + // 检查当前目录是否包含 SKILL.md + let skill_md = current_dir.join("SKILL.md"); + + if skill_md.exists() { + // 发现技能!获取相对路径作为目录名 + let directory = if current_dir == base_dir { + // 如果是 install_dir 本身,使用最后一段路径名 + current_dir + .file_name() + .unwrap_or_default() + .to_string_lossy() + .to_string() + } else { + // 使用相对于 install_dir 的路径 + current_dir + .strip_prefix(base_dir) + .unwrap_or(current_dir) + .to_string_lossy() + .to_string() + }; + + // 解析元数据并创建本地技能对象 + if let Ok(meta) = self.parse_skill_metadata(&skill_md) { + skills.push(Skill { + key: format!("local:{directory}"), + name: meta.name.unwrap_or_else(|| directory.clone()), + description: meta.description.unwrap_or_default(), + directory: directory.clone(), + readme_url: None, + installed: true, + repo_owner: None, + repo_name: None, + repo_branch: None, + skills_path: None, + }); + } + + // 停止扫描此目录的子目录(同级目录都是功能文件夹) + return Ok(()); + } + + // 未发现 SKILL.md,继续递归扫描所有子目录 + for entry in fs::read_dir(current_dir)? { + let entry = entry?; + let path = entry.path(); + + // 只处理目录 + if path.is_dir() { + self.scan_local_dir_recursive(&path, base_dir, skills)?; } } diff --git a/src-tauri/tests/import_export_sync.rs b/src-tauri/tests/import_export_sync.rs index 7edf7378..9a853aeb 100644 --- a/src-tauri/tests/import_export_sync.rs +++ b/src-tauri/tests/import_export_sync.rs @@ -3,13 +3,15 @@ use std::fs; use std::path::PathBuf; use cc_switch_lib::{ - get_claude_settings_path, read_json_file, AppError, AppType, ConfigService, - MultiAppConfig, Provider, ProviderMeta, + get_claude_settings_path, read_json_file, AppError, AppType, ConfigService, MultiAppConfig, + Provider, ProviderMeta, }; #[path = "support.rs"] mod support; -use support::{create_test_state, create_test_state_with_config, ensure_test_home, reset_test_fs, test_mutex}; +use support::{ + create_test_state, create_test_state_with_config, ensure_test_home, reset_test_fs, test_mutex, +}; #[test] fn sync_claude_provider_writes_live_settings() { diff --git a/src-tauri/tests/mcp_commands.rs b/src-tauri/tests/mcp_commands.rs index 52158362..2474d03e 100644 --- a/src-tauri/tests/mcp_commands.rs +++ b/src-tauri/tests/mcp_commands.rs @@ -42,9 +42,13 @@ fn import_default_config_claude_persists_provider() { .expect("import default config succeeds"); // 验证内存状态 - let providers = state.db.get_all_providers(AppType::Claude.as_str()) + let providers = state + .db + .get_all_providers(AppType::Claude.as_str()) .expect("get all providers"); - let current_id = state.db.get_current_provider(AppType::Claude.as_str()) + let current_id = state + .db + .get_current_provider(AppType::Claude.as_str()) .expect("get current provider"); assert_eq!(current_id.as_deref(), Some("default")); let default_provider = providers.get("default").expect("default provider"); @@ -87,7 +91,9 @@ fn import_default_config_without_live_file_returns_error() { // 使用数据库架构,不再检查 config.json // 失败的导入不应该向数据库写入任何供应商 - let providers = state.db.get_all_providers(AppType::Claude.as_str()) + let providers = state + .db + .get_all_providers(AppType::Claude.as_str()) .expect("get all providers"); assert!( providers.is_empty(), @@ -125,8 +131,7 @@ fn import_mcp_from_claude_creates_config_and_enables_servers() { "import should report inserted or normalized entries" ); - let servers = state.db.get_all_mcp_servers() - .expect("get all mcp servers"); + let servers = state.db.get_all_mcp_servers().expect("get all mcp servers"); let entry = servers .get("echo") .expect("server imported into unified structure"); @@ -168,8 +173,7 @@ fn import_mcp_from_claude_invalid_json_preserves_state() { } // 使用数据库架构,检查 MCP 服务器未被写入 - let servers = state.db.get_all_mcp_servers() - .expect("get all mcp servers"); + let servers = state.db.get_all_mcp_servers().expect("get all mcp servers"); assert!( servers.is_empty(), "failed import should not persist any MCP servers to database" @@ -224,11 +228,8 @@ fn set_mcp_enabled_for_codex_writes_live_config() { McpService::toggle_app(&state, "codex-server", AppType::Codex, true) .expect("toggle_app should succeed"); - let servers = state.db.get_all_mcp_servers() - .expect("get all mcp servers"); - let entry = servers - .get("codex-server") - .expect("codex server exists"); + let servers = state.db.get_all_mcp_servers().expect("get all mcp servers"); + let entry = servers.get("codex-server").expect("codex server exists"); assert!( entry.apps.codex, "server should have Codex app enabled after toggle" diff --git a/src-tauri/tests/provider_commands.rs b/src-tauri/tests/provider_commands.rs index dec36c81..c0828f03 100644 --- a/src-tauri/tests/provider_commands.rs +++ b/src-tauri/tests/provider_commands.rs @@ -7,8 +7,8 @@ use cc_switch_lib::{ #[path = "support.rs"] mod support; -use support::{create_test_state_with_config, ensure_test_home, reset_test_fs, test_mutex}; use std::collections::HashMap; +use support::{create_test_state_with_config, ensure_test_home, reset_test_fs, test_mutex}; #[test] fn switch_provider_updates_codex_live_and_state() { @@ -104,16 +104,22 @@ command = "say" "config.toml should contain synced MCP servers" ); - let current_id = app_state.db.get_current_provider(AppType::Codex.as_str()) + let current_id = app_state + .db + .get_current_provider(AppType::Codex.as_str()) .expect("get current provider"); - assert_eq!(current_id.as_deref(), Some("new-provider"), "current provider updated"); + assert_eq!( + current_id.as_deref(), + Some("new-provider"), + "current provider updated" + ); - let providers = app_state.db.get_all_providers(AppType::Codex.as_str()) + let providers = app_state + .db + .get_all_providers(AppType::Codex.as_str()) .expect("get all providers"); - let new_provider = providers - .get("new-provider") - .expect("new provider exists"); + let new_provider = providers.get("new-provider").expect("new provider exists"); let new_config_text = new_provider .settings_config .get("config") @@ -165,7 +171,9 @@ fn switch_provider_missing_provider_returns_error() { let err_str = err.to_string(); assert!( - err_str.contains("供应商不存在") || err_str.contains("Provider not found") || err_str.contains("missing-provider"), + err_str.contains("供应商不存在") + || err_str.contains("Provider not found") + || err_str.contains("missing-provider"), "error message should mention missing provider, got: {err_str}" ); } @@ -241,11 +249,19 @@ fn switch_provider_updates_claude_live_and_state() { "live settings.json should reflect new provider auth" ); - let current_id = app_state.db.get_current_provider(AppType::Claude.as_str()) + let current_id = app_state + .db + .get_current_provider(AppType::Claude.as_str()) .expect("get current provider"); - assert_eq!(current_id.as_deref(), Some("new-provider"), "current provider updated"); + assert_eq!( + current_id.as_deref(), + Some("new-provider"), + "current provider updated" + ); - let providers = app_state.db.get_all_providers(AppType::Claude.as_str()) + let providers = app_state + .db + .get_all_providers(AppType::Claude.as_str()) .expect("get all providers"); let legacy_provider = providers @@ -258,9 +274,7 @@ fn switch_provider_updates_claude_live_and_state() { "previous provider should be backfilled with live config" ); - let new_provider = providers - .get("new-provider") - .expect("new provider exists"); + let new_provider = providers.get("new-provider").expect("new provider exists"); assert_eq!( new_provider .settings_config @@ -283,7 +297,9 @@ fn switch_provider_updates_claude_live_and_state() { ); // 验证当前供应商已更新 - let current_id = app_state.db.get_current_provider(AppType::Claude.as_str()) + let current_id = app_state + .db + .get_current_provider(AppType::Claude.as_str()) .expect("get current provider"); assert_eq!( current_id.as_deref(), @@ -328,7 +344,9 @@ fn switch_provider_codex_missing_auth_returns_error_and_keeps_state() { other => panic!("expected config error, got {other:?}"), } - let current_id = app_state.db.get_current_provider(AppType::Codex.as_str()) + let current_id = app_state + .db + .get_current_provider(AppType::Codex.as_str()) .expect("get current provider"); // 切换失败后,由于数据库操作是先设置再验证,current 可能已被设为 "invalid" // 但由于 live 配置写入失败,状态应该回滚 diff --git a/src-tauri/tests/provider_service.rs b/src-tauri/tests/provider_service.rs index 74534e17..54f0700a 100644 --- a/src-tauri/tests/provider_service.rs +++ b/src-tauri/tests/provider_service.rs @@ -1,13 +1,15 @@ use serde_json::json; use cc_switch_lib::{ - get_claude_settings_path, read_json_file, write_codex_live_atomic, AppError, AppType, - McpApps, McpServer, MultiAppConfig, Provider, ProviderMeta, ProviderService, + get_claude_settings_path, read_json_file, write_codex_live_atomic, AppError, AppType, McpApps, + McpServer, MultiAppConfig, Provider, ProviderMeta, ProviderService, }; #[path = "support.rs"] mod support; -use support::{create_test_state, create_test_state_with_config, ensure_test_home, reset_test_fs, test_mutex}; +use support::{ + create_test_state, create_test_state_with_config, ensure_test_home, reset_test_fs, test_mutex, +}; fn sanitize_provider_name(name: &str) -> String { name.chars() @@ -69,7 +71,10 @@ command = "say" } // 使用新的统一 MCP 结构(v3.7.0+) - let servers = initial_config.mcp.servers.get_or_insert_with(Default::default); + let servers = initial_config + .mcp + .servers + .get_or_insert_with(Default::default); servers.insert( "echo-server".into(), McpServer { @@ -111,16 +116,22 @@ command = "say" "config.toml should contain synced MCP servers" ); - let current_id = state.db.get_current_provider(AppType::Codex.as_str()) + let current_id = state + .db + .get_current_provider(AppType::Codex.as_str()) .expect("read current provider after switch"); - assert_eq!(current_id.as_deref(), Some("new-provider"), "current provider updated"); + assert_eq!( + current_id.as_deref(), + Some("new-provider"), + "current provider updated" + ); - let providers = state.db.get_all_providers(AppType::Codex.as_str()) + let providers = state + .db + .get_all_providers(AppType::Codex.as_str()) .expect("read providers after switch"); - let new_provider = providers - .get("new-provider") - .expect("new provider exists"); + let new_provider = providers.get("new-provider").expect("new provider exists"); let new_config_text = new_provider .settings_config .get("config") @@ -385,11 +396,19 @@ fn provider_service_switch_claude_updates_live_and_state() { "live settings.json should reflect new provider auth" ); - let providers = state.db.get_all_providers(AppType::Claude.as_str()) + let providers = state + .db + .get_all_providers(AppType::Claude.as_str()) .expect("get all providers"); - let current_id = state.db.get_current_provider(AppType::Claude.as_str()) + let current_id = state + .db + .get_current_provider(AppType::Claude.as_str()) .expect("get current provider"); - assert_eq!(current_id.as_deref(), Some("new-provider"), "current provider updated"); + assert_eq!( + current_id.as_deref(), + Some("new-provider"), + "current provider updated" + ); let legacy_provider = providers .get("old-provider") @@ -509,7 +528,9 @@ fn provider_service_delete_codex_removes_provider_and_files() { ProviderService::delete(&app_state, AppType::Codex, "to-delete") .expect("delete provider should succeed"); - let providers = app_state.db.get_all_providers(AppType::Codex.as_str()) + let providers = app_state + .db + .get_all_providers(AppType::Codex.as_str()) .expect("get all providers"); assert!( !providers.contains_key("to-delete"), @@ -567,7 +588,9 @@ fn provider_service_delete_claude_removes_provider_files() { ProviderService::delete(&app_state, AppType::Claude, "delete").expect("delete claude provider"); - let providers = app_state.db.get_all_providers(AppType::Claude.as_str()) + let providers = app_state + .db + .get_all_providers(AppType::Claude.as_str()) .expect("get all providers"); assert!( !providers.contains_key("delete"), @@ -608,15 +631,18 @@ fn provider_service_delete_current_provider_returns_error() { .expect_err("deleting current provider should fail"); match err { AppError::Localized { zh, .. } => assert!( - zh.contains("不能删除当前正在使用的供应商") || zh.contains("无法删除当前正在使用的供应商"), + zh.contains("不能删除当前正在使用的供应商") + || zh.contains("无法删除当前正在使用的供应商"), "unexpected message: {zh}" ), AppError::Config(msg) => assert!( - msg.contains("不能删除当前正在使用的供应商") || msg.contains("无法删除当前正在使用的供应商"), + msg.contains("不能删除当前正在使用的供应商") + || msg.contains("无法删除当前正在使用的供应商"), "unexpected message: {msg}" ), AppError::Message(msg) => assert!( - msg.contains("不能删除当前正在使用的供应商") || msg.contains("无法删除当前正在使用的供应商"), + msg.contains("不能删除当前正在使用的供应商") + || msg.contains("无法删除当前正在使用的供应商"), "unexpected message: {msg}" ), other => panic!("expected Config/Message error, got {other:?}"),