Compare commits

...

1 Commits

Author SHA1 Message Date
YoVinchen
38bac36960 feat(skill): implement recursive scanning for skill repositories
Add recursive directory scanning to discover SKILL.md files in nested
directories. When a SKILL.md is found, treat sibling directories as
functional folders rather than separate skills.
2025-11-28 11:43:31 +08:00
12 changed files with 300 additions and 190 deletions

View File

@@ -153,12 +153,14 @@ impl Database {
tx: &rusqlite::Transaction<'_>,
config: &MultiAppConfig,
) -> Result<(), AppError> {
let migrate_app_prompts =
|prompts_map: &std::collections::HashMap<String, crate::prompt::Prompt>,
app_type: &str|
-> Result<(), AppError> {
for (id, prompt) in prompts_map {
tx.execute(
let migrate_app_prompts = |prompts_map: &std::collections::HashMap<
String,
crate::prompt::Prompt,
>,
app_type: &str|
-> Result<(), AppError> {
for (id, prompt) in prompts_map {
tx.execute(
"INSERT OR REPLACE INTO prompts (
id, app_type, name, content, description, enabled, created_at, updated_at
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
@@ -174,9 +176,9 @@ impl Database {
],
)
.map_err(|e| AppError::Database(format!("Migrate prompt failed: {e}")))?;
}
Ok(())
};
}
Ok(())
};
migrate_app_prompts(&config.prompts.claude.prompts, "claude")?;
migrate_app_prompts(&config.prompts.codex.prompts, "codex")?;

View File

@@ -226,13 +226,13 @@ impl Database {
Self::add_column_if_missing(conn, "skills", "installed_at", "INTEGER NOT NULL DEFAULT 0")?;
// skill_repos 表
Self::add_column_if_missing(conn, "skill_repos", "branch", "TEXT NOT NULL DEFAULT 'main'")?;
Self::add_column_if_missing(
conn,
"skill_repos",
"enabled",
"BOOLEAN NOT NULL DEFAULT 1",
"branch",
"TEXT NOT NULL DEFAULT 'main'",
)?;
Self::add_column_if_missing(conn, "skill_repos", "enabled", "BOOLEAN NOT NULL DEFAULT 1")?;
Self::add_column_if_missing(conn, "skill_repos", "skills_path", "TEXT")?;
Ok(())
@@ -247,9 +247,7 @@ impl Database {
pub(crate) fn set_user_version(conn: &Connection, version: i32) -> Result<(), AppError> {
if version < 0 {
return Err(AppError::Database(
"user_version 不能为负数".to_string(),
));
return Err(AppError::Database("user_version 不能为负数".to_string()));
}
let sql = format!("PRAGMA user_version = {version};");
conn.execute(&sql, [])
@@ -261,10 +259,7 @@ impl Database {
if s.is_empty() {
return Err(AppError::Database(format!("{kind} 不能为空")));
}
if !s
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_')
{
if !s.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
return Err(AppError::Database(format!(
"非法{kind}: {s},仅允许字母、数字和下划线"
)));
@@ -292,7 +287,11 @@ impl Database {
Ok(false)
}
pub(crate) fn has_column(conn: &Connection, table: &str, column: &str) -> Result<bool, AppError> {
pub(crate) fn has_column(
conn: &Connection,
table: &str,
column: &str,
) -> Result<bool, AppError> {
Self::validate_identifier(table, "表名")?;
Self::validate_identifier(column, "列名")?;

View File

@@ -108,8 +108,8 @@ fn migration_rejects_future_version() {
Database::create_tables_on_conn(&conn).expect("create tables");
Database::set_user_version(&conn, SCHEMA_VERSION + 1).expect("set future version");
let err = Database::apply_schema_migrations_on_conn(&conn)
.expect_err("should reject higher version");
let err =
Database::apply_schema_migrations_on_conn(&conn).expect_err("should reject higher version");
assert!(
err.to_string().contains("数据库版本过新"),
"unexpected error: {err}"
@@ -168,10 +168,7 @@ fn migration_aligns_column_defaults_and_types() {
let is_current = get_column_info(&conn, "providers", "is_current");
assert_eq!(is_current.r#type, "BOOLEAN");
assert_eq!(is_current.notnull, 1);
assert_eq!(
normalize_default(&is_current.default).as_deref(),
Some("0")
);
assert_eq!(normalize_default(&is_current.default).as_deref(), Some("0"));
let tags = get_column_info(&conn, "mcp_servers", "tags");
assert_eq!(tags.r#type, "TEXT");
@@ -181,10 +178,7 @@ fn migration_aligns_column_defaults_and_types() {
let enabled = get_column_info(&conn, "prompts", "enabled");
assert_eq!(enabled.r#type, "BOOLEAN");
assert_eq!(enabled.notnull, 1);
assert_eq!(
normalize_default(&enabled.default).as_deref(),
Some("1")
);
assert_eq!(normalize_default(&enabled.default).as_deref(), Some("1"));
let installed_at = get_column_info(&conn, "skills", "installed_at");
assert_eq!(installed_at.r#type, "INTEGER");

View File

@@ -307,8 +307,7 @@ fn test_import_prompt_allows_space_in_base64_content() {
let db = Arc::new(Database::memory().expect("create memory db"));
let state = AppState::new(db.clone());
let prompt_id =
import_prompt_from_deeplink(&state, request.clone()).expect("import prompt");
let prompt_id = import_prompt_from_deeplink(&state, request.clone()).expect("import prompt");
let prompts = state.db.get_prompts("codex").expect("get prompts");
let prompt = prompts.get(&prompt_id).expect("prompt saved");

View File

@@ -223,11 +223,9 @@ fn create_tray_menu(
let providers = app_state.db.get_all_providers(app_type_str)?;
// 使用有效的当前供应商 ID验证存在性自动清理失效 ID
let current_id = crate::settings::get_effective_current_provider(
&app_state.db,
&section.app_type,
)?
.unwrap_or_default();
let current_id =
crate::settings::get_effective_current_provider(&app_state.db, &section.app_type)?
.unwrap_or_default();
let manager = crate::provider::ProviderManager {
providers,
@@ -1033,21 +1031,19 @@ fn show_migration_error_dialog(app: &tauri::AppHandle, error: &str) -> bool {
let message = if is_chinese_locale() {
format!(
"从旧版本迁移配置时发生错误:\n\n{}\n\n\
"从旧版本迁移配置时发生错误:\n\n{error}\n\n\
您的数据尚未丢失,旧配置文件仍然保留。\n\
建议回退到旧版本 CC Switch 以保护数据。\n\n\
点击「重试」重新尝试迁移\n\
点击「退出」关闭程序(可回退版本后重新打开)",
error
点击「退出」关闭程序(可回退版本后重新打开)"
)
} else {
format!(
"An error occurred while migrating configuration:\n\n{}\n\n\
"An error occurred while migrating configuration:\n\n{error}\n\n\
Your data is NOT lost - the old config file is still preserved.\n\
Consider rolling back to an older CC Switch version.\n\n\
Click 'Retry' to attempt migration again\n\
Click 'Exit' to close the program",
error
Click 'Exit' to close the program"
)
};

View File

@@ -100,12 +100,13 @@ pub(crate) fn write_live_snapshot(app_type: &AppType, provider: &Provider) -> Re
write_json_file(&path, &provider.settings_config)?;
}
AppType::Codex => {
let obj = provider.settings_config.as_object().ok_or_else(|| {
AppError::Config("Codex 供应商配置必须是 JSON 对象".to_string())
})?;
let auth = obj.get("auth").ok_or_else(|| {
AppError::Config("Codex 供应商配置缺少 'auth' 字段".to_string())
})?;
let obj = provider
.settings_config
.as_object()
.ok_or_else(|| AppError::Config("Codex 供应商配置必须是 JSON 对象".to_string()))?;
let auth = obj
.get("auth")
.ok_or_else(|| AppError::Config("Codex 供应商配置缺少 'auth' 字段".to_string()))?;
let config_str = obj.get("config").and_then(|v| v.as_str()).ok_or_else(|| {
AppError::Config("Codex 供应商配置缺少 'config' 字段或不是字符串".to_string())
})?;
@@ -113,8 +114,7 @@ pub(crate) fn write_live_snapshot(app_type: &AppType, provider: &Provider) -> Re
let auth_path = get_codex_auth_path();
write_json_file(&auth_path, auth)?;
let config_path = get_codex_config_path();
std::fs::write(&config_path, config_str)
.map_err(|e| AppError::io(&config_path, e))?;
std::fs::write(&config_path, config_str).map_err(|e| AppError::io(&config_path, e))?;
}
AppType::Gemini => {
// Delegate to write_gemini_live which handles env file writing correctly
@@ -132,11 +132,11 @@ pub(crate) fn write_live_snapshot(app_type: &AppType, provider: &Provider) -> Re
pub fn sync_current_to_live(state: &AppState) -> Result<(), AppError> {
for app_type in [AppType::Claude, AppType::Codex, AppType::Gemini] {
// Use validated effective current provider
let current_id = match crate::settings::get_effective_current_provider(&state.db, &app_type)?
{
Some(id) => id,
None => continue,
};
let current_id =
match crate::settings::get_effective_current_provider(&state.db, &app_type)? {
Some(id) => id,
None => continue,
};
let providers = state.db.get_all_providers(app_type.as_str())?;
if let Some(provider) = providers.get(&current_id) {

View File

@@ -168,9 +168,7 @@ pub(crate) fn validate_usage_script(script: &UsageScript) -> Result<(), AppError
if interval > 1440 {
return Err(AppError::localized(
"usage_script.interval_too_large",
format!(
"自动查询间隔不能超过 1440 分钟24小时当前值: {interval}"
),
format!("自动查询间隔不能超过 1440 分钟24小时当前值: {interval}"),
format!(
"Auto query interval cannot exceed 1440 minutes (24 hours), current: {interval}"
),

View File

@@ -214,56 +214,8 @@ impl SkillService {
temp_dir.clone()
};
// 遍历目标目录
for entry in fs::read_dir(&scan_dir)? {
let entry = entry?;
let path = entry.path();
if !path.is_dir() {
continue;
}
let skill_md = path.join("SKILL.md");
if !skill_md.exists() {
continue;
}
// 解析技能元数据
match self.parse_skill_metadata(&skill_md) {
Ok(meta) => {
// 安全地获取目录名
let Some(dir_name) = path.file_name() else {
log::warn!("Failed to get directory name from path: {path:?}");
continue;
};
let directory = dir_name.to_string_lossy().to_string();
// 构建 README URL考虑 skillsPath
let readme_path = if let Some(ref skills_path) = repo.skills_path {
format!("{}/{}", skills_path.trim_matches('/'), directory)
} else {
directory.clone()
};
skills.push(Skill {
key: format!("{}/{}:{}", repo.owner, repo.name, directory),
name: meta.name.unwrap_or_else(|| directory.clone()),
description: meta.description.unwrap_or_default(),
directory,
readme_url: Some(format!(
"https://github.com/{}/{}/tree/{}/{}",
repo.owner, repo.name, repo.branch, readme_path
)),
installed: false,
repo_owner: Some(repo.owner.clone()),
repo_name: Some(repo.name.clone()),
repo_branch: Some(repo.branch.clone()),
skills_path: repo.skills_path.clone(),
});
}
Err(e) => log::warn!("解析 {} 元数据失败: {}", skill_md.display(), e),
}
}
// 递归扫描目录查找所有技能
self.scan_dir_recursive(&scan_dir, &scan_dir, repo, &mut skills)?;
// 清理临时目录
let _ = fs::remove_dir_all(&temp_dir);
@@ -271,6 +223,90 @@ impl SkillService {
Ok(skills)
}
/// 递归扫描目录查找 SKILL.md
///
/// 规则:
/// 1. 如果当前目录存在 SKILL.md则识别为技能停止扫描其子目录子目录视为功能文件夹
/// 2. 如果当前目录不存在 SKILL.md则递归扫描所有子目录
fn scan_dir_recursive(
&self,
current_dir: &Path,
base_dir: &Path,
repo: &SkillRepo,
skills: &mut Vec<Skill>,
) -> Result<()> {
// 检查当前目录是否包含 SKILL.md
let skill_md = current_dir.join("SKILL.md");
if skill_md.exists() {
// 发现技能!获取相对路径作为目录名
let directory = if current_dir == base_dir {
// 根目录的 SKILL.md使用仓库名
repo.name.clone()
} else {
// 子目录的 SKILL.md使用相对路径
current_dir
.strip_prefix(base_dir)
.unwrap_or(current_dir)
.to_string_lossy()
.to_string()
};
if let Ok(skill) = self.build_skill_from_metadata(&skill_md, &directory, repo) {
skills.push(skill);
}
// 停止扫描此目录的子目录(同级目录都是功能文件夹)
return Ok(());
}
// 未发现 SKILL.md继续递归扫描所有子目录
for entry in fs::read_dir(current_dir)? {
let entry = entry?;
let path = entry.path();
// 只处理目录
if path.is_dir() {
self.scan_dir_recursive(&path, base_dir, repo, skills)?;
}
}
Ok(())
}
/// 从 SKILL.md 构建技能对象
fn build_skill_from_metadata(
&self,
skill_md: &Path,
directory: &str,
repo: &SkillRepo,
) -> Result<Skill> {
let meta = self.parse_skill_metadata(skill_md)?;
// 构建 README URL
let readme_path = if let Some(ref skills_path) = repo.skills_path {
format!("{}/{}", skills_path.trim_matches('/'), directory)
} else {
directory.to_string()
};
Ok(Skill {
key: format!("{}/{}:{}", repo.owner, repo.name, directory),
name: meta.name.unwrap_or_else(|| directory.to_string()),
description: meta.description.unwrap_or_default(),
directory: directory.to_string(),
readme_url: Some(format!(
"https://github.com/{}/{}/tree/{}/{}",
repo.owner, repo.name, repo.branch, readme_path
)),
installed: false,
repo_owner: Some(repo.owner.clone()),
repo_name: Some(repo.name.clone()),
repo_branch: Some(repo.branch.clone()),
skills_path: repo.skills_path.clone(),
})
}
/// 解析技能元数据
fn parse_skill_metadata(&self, path: &Path) -> Result<SkillMetadata> {
let content = fs::read_to_string(path)?;
@@ -302,25 +338,18 @@ impl SkillService {
return Ok(());
}
for entry in fs::read_dir(&self.install_dir)? {
let entry = entry?;
let path = entry.path();
// 收集所有本地技能
let mut local_skills = Vec::new();
self.scan_local_dir_recursive(&self.install_dir, &self.install_dir, &mut local_skills)?;
if !path.is_dir() {
continue;
}
// 处理找到的本地技能
for local_skill in local_skills {
let directory = &local_skill.directory;
// 安全地获取目录名
let Some(dir_name) = path.file_name() else {
log::warn!("Failed to get directory name from path: {path:?}");
continue;
};
let directory = dir_name.to_string_lossy().to_string();
// 更新已安装状态
// 更新已安装状态(匹配远程技能)
let mut found = false;
for skill in skills.iter_mut() {
if skill.directory.eq_ignore_ascii_case(&directory) {
if skill.directory.eq_ignore_ascii_case(directory) {
skill.installed = true;
found = true;
break;
@@ -329,23 +358,69 @@ impl SkillService {
// 添加本地独有的技能(仅当在仓库中未找到时)
if !found {
let skill_md = path.join("SKILL.md");
if skill_md.exists() {
if let Ok(meta) = self.parse_skill_metadata(&skill_md) {
skills.push(Skill {
key: format!("local:{directory}"),
name: meta.name.unwrap_or_else(|| directory.clone()),
description: meta.description.unwrap_or_default(),
directory: directory.clone(),
readme_url: None,
installed: true,
repo_owner: None,
repo_name: None,
repo_branch: None,
skills_path: None,
});
}
}
skills.push(local_skill);
}
}
Ok(())
}
/// 递归扫描本地目录查找 SKILL.md
fn scan_local_dir_recursive(
&self,
current_dir: &Path,
base_dir: &Path,
skills: &mut Vec<Skill>,
) -> Result<()> {
// 检查当前目录是否包含 SKILL.md
let skill_md = current_dir.join("SKILL.md");
if skill_md.exists() {
// 发现技能!获取相对路径作为目录名
let directory = if current_dir == base_dir {
// 如果是 install_dir 本身,使用最后一段路径名
current_dir
.file_name()
.unwrap_or_default()
.to_string_lossy()
.to_string()
} else {
// 使用相对于 install_dir 的路径
current_dir
.strip_prefix(base_dir)
.unwrap_or(current_dir)
.to_string_lossy()
.to_string()
};
// 解析元数据并创建本地技能对象
if let Ok(meta) = self.parse_skill_metadata(&skill_md) {
skills.push(Skill {
key: format!("local:{directory}"),
name: meta.name.unwrap_or_else(|| directory.clone()),
description: meta.description.unwrap_or_default(),
directory: directory.clone(),
readme_url: None,
installed: true,
repo_owner: None,
repo_name: None,
repo_branch: None,
skills_path: None,
});
}
// 停止扫描此目录的子目录(同级目录都是功能文件夹)
return Ok(());
}
// 未发现 SKILL.md继续递归扫描所有子目录
for entry in fs::read_dir(current_dir)? {
let entry = entry?;
let path = entry.path();
// 只处理目录
if path.is_dir() {
self.scan_local_dir_recursive(&path, base_dir, skills)?;
}
}

View File

@@ -3,13 +3,15 @@ use std::fs;
use std::path::PathBuf;
use cc_switch_lib::{
get_claude_settings_path, read_json_file, AppError, AppType, ConfigService,
MultiAppConfig, Provider, ProviderMeta,
get_claude_settings_path, read_json_file, AppError, AppType, ConfigService, MultiAppConfig,
Provider, ProviderMeta,
};
#[path = "support.rs"]
mod support;
use support::{create_test_state, create_test_state_with_config, ensure_test_home, reset_test_fs, test_mutex};
use support::{
create_test_state, create_test_state_with_config, ensure_test_home, reset_test_fs, test_mutex,
};
#[test]
fn sync_claude_provider_writes_live_settings() {

View File

@@ -42,9 +42,13 @@ fn import_default_config_claude_persists_provider() {
.expect("import default config succeeds");
// 验证内存状态
let providers = state.db.get_all_providers(AppType::Claude.as_str())
let providers = state
.db
.get_all_providers(AppType::Claude.as_str())
.expect("get all providers");
let current_id = state.db.get_current_provider(AppType::Claude.as_str())
let current_id = state
.db
.get_current_provider(AppType::Claude.as_str())
.expect("get current provider");
assert_eq!(current_id.as_deref(), Some("default"));
let default_provider = providers.get("default").expect("default provider");
@@ -87,7 +91,9 @@ fn import_default_config_without_live_file_returns_error() {
// 使用数据库架构,不再检查 config.json
// 失败的导入不应该向数据库写入任何供应商
let providers = state.db.get_all_providers(AppType::Claude.as_str())
let providers = state
.db
.get_all_providers(AppType::Claude.as_str())
.expect("get all providers");
assert!(
providers.is_empty(),
@@ -125,8 +131,7 @@ fn import_mcp_from_claude_creates_config_and_enables_servers() {
"import should report inserted or normalized entries"
);
let servers = state.db.get_all_mcp_servers()
.expect("get all mcp servers");
let servers = state.db.get_all_mcp_servers().expect("get all mcp servers");
let entry = servers
.get("echo")
.expect("server imported into unified structure");
@@ -168,8 +173,7 @@ fn import_mcp_from_claude_invalid_json_preserves_state() {
}
// 使用数据库架构,检查 MCP 服务器未被写入
let servers = state.db.get_all_mcp_servers()
.expect("get all mcp servers");
let servers = state.db.get_all_mcp_servers().expect("get all mcp servers");
assert!(
servers.is_empty(),
"failed import should not persist any MCP servers to database"
@@ -224,11 +228,8 @@ fn set_mcp_enabled_for_codex_writes_live_config() {
McpService::toggle_app(&state, "codex-server", AppType::Codex, true)
.expect("toggle_app should succeed");
let servers = state.db.get_all_mcp_servers()
.expect("get all mcp servers");
let entry = servers
.get("codex-server")
.expect("codex server exists");
let servers = state.db.get_all_mcp_servers().expect("get all mcp servers");
let entry = servers.get("codex-server").expect("codex server exists");
assert!(
entry.apps.codex,
"server should have Codex app enabled after toggle"

View File

@@ -7,8 +7,8 @@ use cc_switch_lib::{
#[path = "support.rs"]
mod support;
use support::{create_test_state_with_config, ensure_test_home, reset_test_fs, test_mutex};
use std::collections::HashMap;
use support::{create_test_state_with_config, ensure_test_home, reset_test_fs, test_mutex};
#[test]
fn switch_provider_updates_codex_live_and_state() {
@@ -104,16 +104,22 @@ command = "say"
"config.toml should contain synced MCP servers"
);
let current_id = app_state.db.get_current_provider(AppType::Codex.as_str())
let current_id = app_state
.db
.get_current_provider(AppType::Codex.as_str())
.expect("get current provider");
assert_eq!(current_id.as_deref(), Some("new-provider"), "current provider updated");
assert_eq!(
current_id.as_deref(),
Some("new-provider"),
"current provider updated"
);
let providers = app_state.db.get_all_providers(AppType::Codex.as_str())
let providers = app_state
.db
.get_all_providers(AppType::Codex.as_str())
.expect("get all providers");
let new_provider = providers
.get("new-provider")
.expect("new provider exists");
let new_provider = providers.get("new-provider").expect("new provider exists");
let new_config_text = new_provider
.settings_config
.get("config")
@@ -165,7 +171,9 @@ fn switch_provider_missing_provider_returns_error() {
let err_str = err.to_string();
assert!(
err_str.contains("供应商不存在") || err_str.contains("Provider not found") || err_str.contains("missing-provider"),
err_str.contains("供应商不存在")
|| err_str.contains("Provider not found")
|| err_str.contains("missing-provider"),
"error message should mention missing provider, got: {err_str}"
);
}
@@ -241,11 +249,19 @@ fn switch_provider_updates_claude_live_and_state() {
"live settings.json should reflect new provider auth"
);
let current_id = app_state.db.get_current_provider(AppType::Claude.as_str())
let current_id = app_state
.db
.get_current_provider(AppType::Claude.as_str())
.expect("get current provider");
assert_eq!(current_id.as_deref(), Some("new-provider"), "current provider updated");
assert_eq!(
current_id.as_deref(),
Some("new-provider"),
"current provider updated"
);
let providers = app_state.db.get_all_providers(AppType::Claude.as_str())
let providers = app_state
.db
.get_all_providers(AppType::Claude.as_str())
.expect("get all providers");
let legacy_provider = providers
@@ -258,9 +274,7 @@ fn switch_provider_updates_claude_live_and_state() {
"previous provider should be backfilled with live config"
);
let new_provider = providers
.get("new-provider")
.expect("new provider exists");
let new_provider = providers.get("new-provider").expect("new provider exists");
assert_eq!(
new_provider
.settings_config
@@ -283,7 +297,9 @@ fn switch_provider_updates_claude_live_and_state() {
);
// 验证当前供应商已更新
let current_id = app_state.db.get_current_provider(AppType::Claude.as_str())
let current_id = app_state
.db
.get_current_provider(AppType::Claude.as_str())
.expect("get current provider");
assert_eq!(
current_id.as_deref(),
@@ -328,7 +344,9 @@ fn switch_provider_codex_missing_auth_returns_error_and_keeps_state() {
other => panic!("expected config error, got {other:?}"),
}
let current_id = app_state.db.get_current_provider(AppType::Codex.as_str())
let current_id = app_state
.db
.get_current_provider(AppType::Codex.as_str())
.expect("get current provider");
// 切换失败后由于数据库操作是先设置再验证current 可能已被设为 "invalid"
// 但由于 live 配置写入失败,状态应该回滚

View File

@@ -1,13 +1,15 @@
use serde_json::json;
use cc_switch_lib::{
get_claude_settings_path, read_json_file, write_codex_live_atomic, AppError, AppType,
McpApps, McpServer, MultiAppConfig, Provider, ProviderMeta, ProviderService,
get_claude_settings_path, read_json_file, write_codex_live_atomic, AppError, AppType, McpApps,
McpServer, MultiAppConfig, Provider, ProviderMeta, ProviderService,
};
#[path = "support.rs"]
mod support;
use support::{create_test_state, create_test_state_with_config, ensure_test_home, reset_test_fs, test_mutex};
use support::{
create_test_state, create_test_state_with_config, ensure_test_home, reset_test_fs, test_mutex,
};
fn sanitize_provider_name(name: &str) -> String {
name.chars()
@@ -69,7 +71,10 @@ command = "say"
}
// 使用新的统一 MCP 结构v3.7.0+
let servers = initial_config.mcp.servers.get_or_insert_with(Default::default);
let servers = initial_config
.mcp
.servers
.get_or_insert_with(Default::default);
servers.insert(
"echo-server".into(),
McpServer {
@@ -111,16 +116,22 @@ command = "say"
"config.toml should contain synced MCP servers"
);
let current_id = state.db.get_current_provider(AppType::Codex.as_str())
let current_id = state
.db
.get_current_provider(AppType::Codex.as_str())
.expect("read current provider after switch");
assert_eq!(current_id.as_deref(), Some("new-provider"), "current provider updated");
assert_eq!(
current_id.as_deref(),
Some("new-provider"),
"current provider updated"
);
let providers = state.db.get_all_providers(AppType::Codex.as_str())
let providers = state
.db
.get_all_providers(AppType::Codex.as_str())
.expect("read providers after switch");
let new_provider = providers
.get("new-provider")
.expect("new provider exists");
let new_provider = providers.get("new-provider").expect("new provider exists");
let new_config_text = new_provider
.settings_config
.get("config")
@@ -385,11 +396,19 @@ fn provider_service_switch_claude_updates_live_and_state() {
"live settings.json should reflect new provider auth"
);
let providers = state.db.get_all_providers(AppType::Claude.as_str())
let providers = state
.db
.get_all_providers(AppType::Claude.as_str())
.expect("get all providers");
let current_id = state.db.get_current_provider(AppType::Claude.as_str())
let current_id = state
.db
.get_current_provider(AppType::Claude.as_str())
.expect("get current provider");
assert_eq!(current_id.as_deref(), Some("new-provider"), "current provider updated");
assert_eq!(
current_id.as_deref(),
Some("new-provider"),
"current provider updated"
);
let legacy_provider = providers
.get("old-provider")
@@ -509,7 +528,9 @@ fn provider_service_delete_codex_removes_provider_and_files() {
ProviderService::delete(&app_state, AppType::Codex, "to-delete")
.expect("delete provider should succeed");
let providers = app_state.db.get_all_providers(AppType::Codex.as_str())
let providers = app_state
.db
.get_all_providers(AppType::Codex.as_str())
.expect("get all providers");
assert!(
!providers.contains_key("to-delete"),
@@ -567,7 +588,9 @@ fn provider_service_delete_claude_removes_provider_files() {
ProviderService::delete(&app_state, AppType::Claude, "delete").expect("delete claude provider");
let providers = app_state.db.get_all_providers(AppType::Claude.as_str())
let providers = app_state
.db
.get_all_providers(AppType::Claude.as_str())
.expect("get all providers");
assert!(
!providers.contains_key("delete"),
@@ -608,15 +631,18 @@ fn provider_service_delete_current_provider_returns_error() {
.expect_err("deleting current provider should fail");
match err {
AppError::Localized { zh, .. } => assert!(
zh.contains("不能删除当前正在使用的供应商") || zh.contains("无法删除当前正在使用的供应商"),
zh.contains("不能删除当前正在使用的供应商")
|| zh.contains("无法删除当前正在使用的供应商"),
"unexpected message: {zh}"
),
AppError::Config(msg) => assert!(
msg.contains("不能删除当前正在使用的供应商") || msg.contains("无法删除当前正在使用的供应商"),
msg.contains("不能删除当前正在使用的供应商")
|| msg.contains("无法删除当前正在使用的供应商"),
"unexpected message: {msg}"
),
AppError::Message(msg) => assert!(
msg.contains("不能删除当前正在使用的供应商") || msg.contains("无法删除当前正在使用的供应商"),
msg.contains("不能删除当前正在使用的供应商")
|| msg.contains("无法删除当前正在使用的供应商"),
"unexpected message: {msg}"
),
other => panic!("expected Config/Message error, got {other:?}"),