mirror of
https://github.com/farion1231/cc-switch.git
synced 2026-04-24 10:12:46 +08:00
Compare commits
1 Commits
fix/openco
...
fix/recurs
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
38bac36960 |
@@ -153,12 +153,14 @@ impl Database {
|
||||
tx: &rusqlite::Transaction<'_>,
|
||||
config: &MultiAppConfig,
|
||||
) -> Result<(), AppError> {
|
||||
let migrate_app_prompts =
|
||||
|prompts_map: &std::collections::HashMap<String, crate::prompt::Prompt>,
|
||||
app_type: &str|
|
||||
-> Result<(), AppError> {
|
||||
for (id, prompt) in prompts_map {
|
||||
tx.execute(
|
||||
let migrate_app_prompts = |prompts_map: &std::collections::HashMap<
|
||||
String,
|
||||
crate::prompt::Prompt,
|
||||
>,
|
||||
app_type: &str|
|
||||
-> Result<(), AppError> {
|
||||
for (id, prompt) in prompts_map {
|
||||
tx.execute(
|
||||
"INSERT OR REPLACE INTO prompts (
|
||||
id, app_type, name, content, description, enabled, created_at, updated_at
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
|
||||
@@ -174,9 +176,9 @@ impl Database {
|
||||
],
|
||||
)
|
||||
.map_err(|e| AppError::Database(format!("Migrate prompt failed: {e}")))?;
|
||||
}
|
||||
Ok(())
|
||||
};
|
||||
}
|
||||
Ok(())
|
||||
};
|
||||
|
||||
migrate_app_prompts(&config.prompts.claude.prompts, "claude")?;
|
||||
migrate_app_prompts(&config.prompts.codex.prompts, "codex")?;
|
||||
|
||||
@@ -226,13 +226,13 @@ impl Database {
|
||||
Self::add_column_if_missing(conn, "skills", "installed_at", "INTEGER NOT NULL DEFAULT 0")?;
|
||||
|
||||
// skill_repos 表
|
||||
Self::add_column_if_missing(conn, "skill_repos", "branch", "TEXT NOT NULL DEFAULT 'main'")?;
|
||||
Self::add_column_if_missing(
|
||||
conn,
|
||||
"skill_repos",
|
||||
"enabled",
|
||||
"BOOLEAN NOT NULL DEFAULT 1",
|
||||
"branch",
|
||||
"TEXT NOT NULL DEFAULT 'main'",
|
||||
)?;
|
||||
Self::add_column_if_missing(conn, "skill_repos", "enabled", "BOOLEAN NOT NULL DEFAULT 1")?;
|
||||
Self::add_column_if_missing(conn, "skill_repos", "skills_path", "TEXT")?;
|
||||
|
||||
Ok(())
|
||||
@@ -247,9 +247,7 @@ impl Database {
|
||||
|
||||
pub(crate) fn set_user_version(conn: &Connection, version: i32) -> Result<(), AppError> {
|
||||
if version < 0 {
|
||||
return Err(AppError::Database(
|
||||
"user_version 不能为负数".to_string(),
|
||||
));
|
||||
return Err(AppError::Database("user_version 不能为负数".to_string()));
|
||||
}
|
||||
let sql = format!("PRAGMA user_version = {version};");
|
||||
conn.execute(&sql, [])
|
||||
@@ -261,10 +259,7 @@ impl Database {
|
||||
if s.is_empty() {
|
||||
return Err(AppError::Database(format!("{kind} 不能为空")));
|
||||
}
|
||||
if !s
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '_')
|
||||
{
|
||||
if !s.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
|
||||
return Err(AppError::Database(format!(
|
||||
"非法{kind}: {s},仅允许字母、数字和下划线"
|
||||
)));
|
||||
@@ -292,7 +287,11 @@ impl Database {
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
pub(crate) fn has_column(conn: &Connection, table: &str, column: &str) -> Result<bool, AppError> {
|
||||
pub(crate) fn has_column(
|
||||
conn: &Connection,
|
||||
table: &str,
|
||||
column: &str,
|
||||
) -> Result<bool, AppError> {
|
||||
Self::validate_identifier(table, "表名")?;
|
||||
Self::validate_identifier(column, "列名")?;
|
||||
|
||||
|
||||
@@ -108,8 +108,8 @@ fn migration_rejects_future_version() {
|
||||
Database::create_tables_on_conn(&conn).expect("create tables");
|
||||
Database::set_user_version(&conn, SCHEMA_VERSION + 1).expect("set future version");
|
||||
|
||||
let err = Database::apply_schema_migrations_on_conn(&conn)
|
||||
.expect_err("should reject higher version");
|
||||
let err =
|
||||
Database::apply_schema_migrations_on_conn(&conn).expect_err("should reject higher version");
|
||||
assert!(
|
||||
err.to_string().contains("数据库版本过新"),
|
||||
"unexpected error: {err}"
|
||||
@@ -168,10 +168,7 @@ fn migration_aligns_column_defaults_and_types() {
|
||||
let is_current = get_column_info(&conn, "providers", "is_current");
|
||||
assert_eq!(is_current.r#type, "BOOLEAN");
|
||||
assert_eq!(is_current.notnull, 1);
|
||||
assert_eq!(
|
||||
normalize_default(&is_current.default).as_deref(),
|
||||
Some("0")
|
||||
);
|
||||
assert_eq!(normalize_default(&is_current.default).as_deref(), Some("0"));
|
||||
|
||||
let tags = get_column_info(&conn, "mcp_servers", "tags");
|
||||
assert_eq!(tags.r#type, "TEXT");
|
||||
@@ -181,10 +178,7 @@ fn migration_aligns_column_defaults_and_types() {
|
||||
let enabled = get_column_info(&conn, "prompts", "enabled");
|
||||
assert_eq!(enabled.r#type, "BOOLEAN");
|
||||
assert_eq!(enabled.notnull, 1);
|
||||
assert_eq!(
|
||||
normalize_default(&enabled.default).as_deref(),
|
||||
Some("1")
|
||||
);
|
||||
assert_eq!(normalize_default(&enabled.default).as_deref(), Some("1"));
|
||||
|
||||
let installed_at = get_column_info(&conn, "skills", "installed_at");
|
||||
assert_eq!(installed_at.r#type, "INTEGER");
|
||||
|
||||
@@ -307,8 +307,7 @@ fn test_import_prompt_allows_space_in_base64_content() {
|
||||
let db = Arc::new(Database::memory().expect("create memory db"));
|
||||
let state = AppState::new(db.clone());
|
||||
|
||||
let prompt_id =
|
||||
import_prompt_from_deeplink(&state, request.clone()).expect("import prompt");
|
||||
let prompt_id = import_prompt_from_deeplink(&state, request.clone()).expect("import prompt");
|
||||
|
||||
let prompts = state.db.get_prompts("codex").expect("get prompts");
|
||||
let prompt = prompts.get(&prompt_id).expect("prompt saved");
|
||||
|
||||
@@ -223,11 +223,9 @@ fn create_tray_menu(
|
||||
let providers = app_state.db.get_all_providers(app_type_str)?;
|
||||
|
||||
// 使用有效的当前供应商 ID(验证存在性,自动清理失效 ID)
|
||||
let current_id = crate::settings::get_effective_current_provider(
|
||||
&app_state.db,
|
||||
§ion.app_type,
|
||||
)?
|
||||
.unwrap_or_default();
|
||||
let current_id =
|
||||
crate::settings::get_effective_current_provider(&app_state.db, §ion.app_type)?
|
||||
.unwrap_or_default();
|
||||
|
||||
let manager = crate::provider::ProviderManager {
|
||||
providers,
|
||||
@@ -1033,21 +1031,19 @@ fn show_migration_error_dialog(app: &tauri::AppHandle, error: &str) -> bool {
|
||||
|
||||
let message = if is_chinese_locale() {
|
||||
format!(
|
||||
"从旧版本迁移配置时发生错误:\n\n{}\n\n\
|
||||
"从旧版本迁移配置时发生错误:\n\n{error}\n\n\
|
||||
您的数据尚未丢失,旧配置文件仍然保留。\n\
|
||||
建议回退到旧版本 CC Switch 以保护数据。\n\n\
|
||||
点击「重试」重新尝试迁移\n\
|
||||
点击「退出」关闭程序(可回退版本后重新打开)",
|
||||
error
|
||||
点击「退出」关闭程序(可回退版本后重新打开)"
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"An error occurred while migrating configuration:\n\n{}\n\n\
|
||||
"An error occurred while migrating configuration:\n\n{error}\n\n\
|
||||
Your data is NOT lost - the old config file is still preserved.\n\
|
||||
Consider rolling back to an older CC Switch version.\n\n\
|
||||
Click 'Retry' to attempt migration again\n\
|
||||
Click 'Exit' to close the program",
|
||||
error
|
||||
Click 'Exit' to close the program"
|
||||
)
|
||||
};
|
||||
|
||||
|
||||
@@ -100,12 +100,13 @@ pub(crate) fn write_live_snapshot(app_type: &AppType, provider: &Provider) -> Re
|
||||
write_json_file(&path, &provider.settings_config)?;
|
||||
}
|
||||
AppType::Codex => {
|
||||
let obj = provider.settings_config.as_object().ok_or_else(|| {
|
||||
AppError::Config("Codex 供应商配置必须是 JSON 对象".to_string())
|
||||
})?;
|
||||
let auth = obj.get("auth").ok_or_else(|| {
|
||||
AppError::Config("Codex 供应商配置缺少 'auth' 字段".to_string())
|
||||
})?;
|
||||
let obj = provider
|
||||
.settings_config
|
||||
.as_object()
|
||||
.ok_or_else(|| AppError::Config("Codex 供应商配置必须是 JSON 对象".to_string()))?;
|
||||
let auth = obj
|
||||
.get("auth")
|
||||
.ok_or_else(|| AppError::Config("Codex 供应商配置缺少 'auth' 字段".to_string()))?;
|
||||
let config_str = obj.get("config").and_then(|v| v.as_str()).ok_or_else(|| {
|
||||
AppError::Config("Codex 供应商配置缺少 'config' 字段或不是字符串".to_string())
|
||||
})?;
|
||||
@@ -113,8 +114,7 @@ pub(crate) fn write_live_snapshot(app_type: &AppType, provider: &Provider) -> Re
|
||||
let auth_path = get_codex_auth_path();
|
||||
write_json_file(&auth_path, auth)?;
|
||||
let config_path = get_codex_config_path();
|
||||
std::fs::write(&config_path, config_str)
|
||||
.map_err(|e| AppError::io(&config_path, e))?;
|
||||
std::fs::write(&config_path, config_str).map_err(|e| AppError::io(&config_path, e))?;
|
||||
}
|
||||
AppType::Gemini => {
|
||||
// Delegate to write_gemini_live which handles env file writing correctly
|
||||
@@ -132,11 +132,11 @@ pub(crate) fn write_live_snapshot(app_type: &AppType, provider: &Provider) -> Re
|
||||
pub fn sync_current_to_live(state: &AppState) -> Result<(), AppError> {
|
||||
for app_type in [AppType::Claude, AppType::Codex, AppType::Gemini] {
|
||||
// Use validated effective current provider
|
||||
let current_id = match crate::settings::get_effective_current_provider(&state.db, &app_type)?
|
||||
{
|
||||
Some(id) => id,
|
||||
None => continue,
|
||||
};
|
||||
let current_id =
|
||||
match crate::settings::get_effective_current_provider(&state.db, &app_type)? {
|
||||
Some(id) => id,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let providers = state.db.get_all_providers(app_type.as_str())?;
|
||||
if let Some(provider) = providers.get(¤t_id) {
|
||||
|
||||
@@ -168,9 +168,7 @@ pub(crate) fn validate_usage_script(script: &UsageScript) -> Result<(), AppError
|
||||
if interval > 1440 {
|
||||
return Err(AppError::localized(
|
||||
"usage_script.interval_too_large",
|
||||
format!(
|
||||
"自动查询间隔不能超过 1440 分钟(24小时),当前值: {interval}"
|
||||
),
|
||||
format!("自动查询间隔不能超过 1440 分钟(24小时),当前值: {interval}"),
|
||||
format!(
|
||||
"Auto query interval cannot exceed 1440 minutes (24 hours), current: {interval}"
|
||||
),
|
||||
|
||||
@@ -214,56 +214,8 @@ impl SkillService {
|
||||
temp_dir.clone()
|
||||
};
|
||||
|
||||
// 遍历目标目录
|
||||
for entry in fs::read_dir(&scan_dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
|
||||
if !path.is_dir() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let skill_md = path.join("SKILL.md");
|
||||
if !skill_md.exists() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// 解析技能元数据
|
||||
match self.parse_skill_metadata(&skill_md) {
|
||||
Ok(meta) => {
|
||||
// 安全地获取目录名
|
||||
let Some(dir_name) = path.file_name() else {
|
||||
log::warn!("Failed to get directory name from path: {path:?}");
|
||||
continue;
|
||||
};
|
||||
let directory = dir_name.to_string_lossy().to_string();
|
||||
|
||||
// 构建 README URL(考虑 skillsPath)
|
||||
let readme_path = if let Some(ref skills_path) = repo.skills_path {
|
||||
format!("{}/{}", skills_path.trim_matches('/'), directory)
|
||||
} else {
|
||||
directory.clone()
|
||||
};
|
||||
|
||||
skills.push(Skill {
|
||||
key: format!("{}/{}:{}", repo.owner, repo.name, directory),
|
||||
name: meta.name.unwrap_or_else(|| directory.clone()),
|
||||
description: meta.description.unwrap_or_default(),
|
||||
directory,
|
||||
readme_url: Some(format!(
|
||||
"https://github.com/{}/{}/tree/{}/{}",
|
||||
repo.owner, repo.name, repo.branch, readme_path
|
||||
)),
|
||||
installed: false,
|
||||
repo_owner: Some(repo.owner.clone()),
|
||||
repo_name: Some(repo.name.clone()),
|
||||
repo_branch: Some(repo.branch.clone()),
|
||||
skills_path: repo.skills_path.clone(),
|
||||
});
|
||||
}
|
||||
Err(e) => log::warn!("解析 {} 元数据失败: {}", skill_md.display(), e),
|
||||
}
|
||||
}
|
||||
// 递归扫描目录查找所有技能
|
||||
self.scan_dir_recursive(&scan_dir, &scan_dir, repo, &mut skills)?;
|
||||
|
||||
// 清理临时目录
|
||||
let _ = fs::remove_dir_all(&temp_dir);
|
||||
@@ -271,6 +223,90 @@ impl SkillService {
|
||||
Ok(skills)
|
||||
}
|
||||
|
||||
/// 递归扫描目录查找 SKILL.md
|
||||
///
|
||||
/// 规则:
|
||||
/// 1. 如果当前目录存在 SKILL.md,则识别为技能,停止扫描其子目录(子目录视为功能文件夹)
|
||||
/// 2. 如果当前目录不存在 SKILL.md,则递归扫描所有子目录
|
||||
fn scan_dir_recursive(
|
||||
&self,
|
||||
current_dir: &Path,
|
||||
base_dir: &Path,
|
||||
repo: &SkillRepo,
|
||||
skills: &mut Vec<Skill>,
|
||||
) -> Result<()> {
|
||||
// 检查当前目录是否包含 SKILL.md
|
||||
let skill_md = current_dir.join("SKILL.md");
|
||||
|
||||
if skill_md.exists() {
|
||||
// 发现技能!获取相对路径作为目录名
|
||||
let directory = if current_dir == base_dir {
|
||||
// 根目录的 SKILL.md,使用仓库名
|
||||
repo.name.clone()
|
||||
} else {
|
||||
// 子目录的 SKILL.md,使用相对路径
|
||||
current_dir
|
||||
.strip_prefix(base_dir)
|
||||
.unwrap_or(current_dir)
|
||||
.to_string_lossy()
|
||||
.to_string()
|
||||
};
|
||||
|
||||
if let Ok(skill) = self.build_skill_from_metadata(&skill_md, &directory, repo) {
|
||||
skills.push(skill);
|
||||
}
|
||||
|
||||
// 停止扫描此目录的子目录(同级目录都是功能文件夹)
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// 未发现 SKILL.md,继续递归扫描所有子目录
|
||||
for entry in fs::read_dir(current_dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
|
||||
// 只处理目录
|
||||
if path.is_dir() {
|
||||
self.scan_dir_recursive(&path, base_dir, repo, skills)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 从 SKILL.md 构建技能对象
|
||||
fn build_skill_from_metadata(
|
||||
&self,
|
||||
skill_md: &Path,
|
||||
directory: &str,
|
||||
repo: &SkillRepo,
|
||||
) -> Result<Skill> {
|
||||
let meta = self.parse_skill_metadata(skill_md)?;
|
||||
|
||||
// 构建 README URL
|
||||
let readme_path = if let Some(ref skills_path) = repo.skills_path {
|
||||
format!("{}/{}", skills_path.trim_matches('/'), directory)
|
||||
} else {
|
||||
directory.to_string()
|
||||
};
|
||||
|
||||
Ok(Skill {
|
||||
key: format!("{}/{}:{}", repo.owner, repo.name, directory),
|
||||
name: meta.name.unwrap_or_else(|| directory.to_string()),
|
||||
description: meta.description.unwrap_or_default(),
|
||||
directory: directory.to_string(),
|
||||
readme_url: Some(format!(
|
||||
"https://github.com/{}/{}/tree/{}/{}",
|
||||
repo.owner, repo.name, repo.branch, readme_path
|
||||
)),
|
||||
installed: false,
|
||||
repo_owner: Some(repo.owner.clone()),
|
||||
repo_name: Some(repo.name.clone()),
|
||||
repo_branch: Some(repo.branch.clone()),
|
||||
skills_path: repo.skills_path.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// 解析技能元数据
|
||||
fn parse_skill_metadata(&self, path: &Path) -> Result<SkillMetadata> {
|
||||
let content = fs::read_to_string(path)?;
|
||||
@@ -302,25 +338,18 @@ impl SkillService {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for entry in fs::read_dir(&self.install_dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
// 收集所有本地技能
|
||||
let mut local_skills = Vec::new();
|
||||
self.scan_local_dir_recursive(&self.install_dir, &self.install_dir, &mut local_skills)?;
|
||||
|
||||
if !path.is_dir() {
|
||||
continue;
|
||||
}
|
||||
// 处理找到的本地技能
|
||||
for local_skill in local_skills {
|
||||
let directory = &local_skill.directory;
|
||||
|
||||
// 安全地获取目录名
|
||||
let Some(dir_name) = path.file_name() else {
|
||||
log::warn!("Failed to get directory name from path: {path:?}");
|
||||
continue;
|
||||
};
|
||||
let directory = dir_name.to_string_lossy().to_string();
|
||||
|
||||
// 更新已安装状态
|
||||
// 更新已安装状态(匹配远程技能)
|
||||
let mut found = false;
|
||||
for skill in skills.iter_mut() {
|
||||
if skill.directory.eq_ignore_ascii_case(&directory) {
|
||||
if skill.directory.eq_ignore_ascii_case(directory) {
|
||||
skill.installed = true;
|
||||
found = true;
|
||||
break;
|
||||
@@ -329,23 +358,69 @@ impl SkillService {
|
||||
|
||||
// 添加本地独有的技能(仅当在仓库中未找到时)
|
||||
if !found {
|
||||
let skill_md = path.join("SKILL.md");
|
||||
if skill_md.exists() {
|
||||
if let Ok(meta) = self.parse_skill_metadata(&skill_md) {
|
||||
skills.push(Skill {
|
||||
key: format!("local:{directory}"),
|
||||
name: meta.name.unwrap_or_else(|| directory.clone()),
|
||||
description: meta.description.unwrap_or_default(),
|
||||
directory: directory.clone(),
|
||||
readme_url: None,
|
||||
installed: true,
|
||||
repo_owner: None,
|
||||
repo_name: None,
|
||||
repo_branch: None,
|
||||
skills_path: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
skills.push(local_skill);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 递归扫描本地目录查找 SKILL.md
|
||||
fn scan_local_dir_recursive(
|
||||
&self,
|
||||
current_dir: &Path,
|
||||
base_dir: &Path,
|
||||
skills: &mut Vec<Skill>,
|
||||
) -> Result<()> {
|
||||
// 检查当前目录是否包含 SKILL.md
|
||||
let skill_md = current_dir.join("SKILL.md");
|
||||
|
||||
if skill_md.exists() {
|
||||
// 发现技能!获取相对路径作为目录名
|
||||
let directory = if current_dir == base_dir {
|
||||
// 如果是 install_dir 本身,使用最后一段路径名
|
||||
current_dir
|
||||
.file_name()
|
||||
.unwrap_or_default()
|
||||
.to_string_lossy()
|
||||
.to_string()
|
||||
} else {
|
||||
// 使用相对于 install_dir 的路径
|
||||
current_dir
|
||||
.strip_prefix(base_dir)
|
||||
.unwrap_or(current_dir)
|
||||
.to_string_lossy()
|
||||
.to_string()
|
||||
};
|
||||
|
||||
// 解析元数据并创建本地技能对象
|
||||
if let Ok(meta) = self.parse_skill_metadata(&skill_md) {
|
||||
skills.push(Skill {
|
||||
key: format!("local:{directory}"),
|
||||
name: meta.name.unwrap_or_else(|| directory.clone()),
|
||||
description: meta.description.unwrap_or_default(),
|
||||
directory: directory.clone(),
|
||||
readme_url: None,
|
||||
installed: true,
|
||||
repo_owner: None,
|
||||
repo_name: None,
|
||||
repo_branch: None,
|
||||
skills_path: None,
|
||||
});
|
||||
}
|
||||
|
||||
// 停止扫描此目录的子目录(同级目录都是功能文件夹)
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// 未发现 SKILL.md,继续递归扫描所有子目录
|
||||
for entry in fs::read_dir(current_dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
|
||||
// 只处理目录
|
||||
if path.is_dir() {
|
||||
self.scan_local_dir_recursive(&path, base_dir, skills)?;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,13 +3,15 @@ use std::fs;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use cc_switch_lib::{
|
||||
get_claude_settings_path, read_json_file, AppError, AppType, ConfigService,
|
||||
MultiAppConfig, Provider, ProviderMeta,
|
||||
get_claude_settings_path, read_json_file, AppError, AppType, ConfigService, MultiAppConfig,
|
||||
Provider, ProviderMeta,
|
||||
};
|
||||
|
||||
#[path = "support.rs"]
|
||||
mod support;
|
||||
use support::{create_test_state, create_test_state_with_config, ensure_test_home, reset_test_fs, test_mutex};
|
||||
use support::{
|
||||
create_test_state, create_test_state_with_config, ensure_test_home, reset_test_fs, test_mutex,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn sync_claude_provider_writes_live_settings() {
|
||||
|
||||
@@ -42,9 +42,13 @@ fn import_default_config_claude_persists_provider() {
|
||||
.expect("import default config succeeds");
|
||||
|
||||
// 验证内存状态
|
||||
let providers = state.db.get_all_providers(AppType::Claude.as_str())
|
||||
let providers = state
|
||||
.db
|
||||
.get_all_providers(AppType::Claude.as_str())
|
||||
.expect("get all providers");
|
||||
let current_id = state.db.get_current_provider(AppType::Claude.as_str())
|
||||
let current_id = state
|
||||
.db
|
||||
.get_current_provider(AppType::Claude.as_str())
|
||||
.expect("get current provider");
|
||||
assert_eq!(current_id.as_deref(), Some("default"));
|
||||
let default_provider = providers.get("default").expect("default provider");
|
||||
@@ -87,7 +91,9 @@ fn import_default_config_without_live_file_returns_error() {
|
||||
|
||||
// 使用数据库架构,不再检查 config.json
|
||||
// 失败的导入不应该向数据库写入任何供应商
|
||||
let providers = state.db.get_all_providers(AppType::Claude.as_str())
|
||||
let providers = state
|
||||
.db
|
||||
.get_all_providers(AppType::Claude.as_str())
|
||||
.expect("get all providers");
|
||||
assert!(
|
||||
providers.is_empty(),
|
||||
@@ -125,8 +131,7 @@ fn import_mcp_from_claude_creates_config_and_enables_servers() {
|
||||
"import should report inserted or normalized entries"
|
||||
);
|
||||
|
||||
let servers = state.db.get_all_mcp_servers()
|
||||
.expect("get all mcp servers");
|
||||
let servers = state.db.get_all_mcp_servers().expect("get all mcp servers");
|
||||
let entry = servers
|
||||
.get("echo")
|
||||
.expect("server imported into unified structure");
|
||||
@@ -168,8 +173,7 @@ fn import_mcp_from_claude_invalid_json_preserves_state() {
|
||||
}
|
||||
|
||||
// 使用数据库架构,检查 MCP 服务器未被写入
|
||||
let servers = state.db.get_all_mcp_servers()
|
||||
.expect("get all mcp servers");
|
||||
let servers = state.db.get_all_mcp_servers().expect("get all mcp servers");
|
||||
assert!(
|
||||
servers.is_empty(),
|
||||
"failed import should not persist any MCP servers to database"
|
||||
@@ -224,11 +228,8 @@ fn set_mcp_enabled_for_codex_writes_live_config() {
|
||||
McpService::toggle_app(&state, "codex-server", AppType::Codex, true)
|
||||
.expect("toggle_app should succeed");
|
||||
|
||||
let servers = state.db.get_all_mcp_servers()
|
||||
.expect("get all mcp servers");
|
||||
let entry = servers
|
||||
.get("codex-server")
|
||||
.expect("codex server exists");
|
||||
let servers = state.db.get_all_mcp_servers().expect("get all mcp servers");
|
||||
let entry = servers.get("codex-server").expect("codex server exists");
|
||||
assert!(
|
||||
entry.apps.codex,
|
||||
"server should have Codex app enabled after toggle"
|
||||
|
||||
@@ -7,8 +7,8 @@ use cc_switch_lib::{
|
||||
|
||||
#[path = "support.rs"]
|
||||
mod support;
|
||||
use support::{create_test_state_with_config, ensure_test_home, reset_test_fs, test_mutex};
|
||||
use std::collections::HashMap;
|
||||
use support::{create_test_state_with_config, ensure_test_home, reset_test_fs, test_mutex};
|
||||
|
||||
#[test]
|
||||
fn switch_provider_updates_codex_live_and_state() {
|
||||
@@ -104,16 +104,22 @@ command = "say"
|
||||
"config.toml should contain synced MCP servers"
|
||||
);
|
||||
|
||||
let current_id = app_state.db.get_current_provider(AppType::Codex.as_str())
|
||||
let current_id = app_state
|
||||
.db
|
||||
.get_current_provider(AppType::Codex.as_str())
|
||||
.expect("get current provider");
|
||||
assert_eq!(current_id.as_deref(), Some("new-provider"), "current provider updated");
|
||||
assert_eq!(
|
||||
current_id.as_deref(),
|
||||
Some("new-provider"),
|
||||
"current provider updated"
|
||||
);
|
||||
|
||||
let providers = app_state.db.get_all_providers(AppType::Codex.as_str())
|
||||
let providers = app_state
|
||||
.db
|
||||
.get_all_providers(AppType::Codex.as_str())
|
||||
.expect("get all providers");
|
||||
|
||||
let new_provider = providers
|
||||
.get("new-provider")
|
||||
.expect("new provider exists");
|
||||
let new_provider = providers.get("new-provider").expect("new provider exists");
|
||||
let new_config_text = new_provider
|
||||
.settings_config
|
||||
.get("config")
|
||||
@@ -165,7 +171,9 @@ fn switch_provider_missing_provider_returns_error() {
|
||||
|
||||
let err_str = err.to_string();
|
||||
assert!(
|
||||
err_str.contains("供应商不存在") || err_str.contains("Provider not found") || err_str.contains("missing-provider"),
|
||||
err_str.contains("供应商不存在")
|
||||
|| err_str.contains("Provider not found")
|
||||
|| err_str.contains("missing-provider"),
|
||||
"error message should mention missing provider, got: {err_str}"
|
||||
);
|
||||
}
|
||||
@@ -241,11 +249,19 @@ fn switch_provider_updates_claude_live_and_state() {
|
||||
"live settings.json should reflect new provider auth"
|
||||
);
|
||||
|
||||
let current_id = app_state.db.get_current_provider(AppType::Claude.as_str())
|
||||
let current_id = app_state
|
||||
.db
|
||||
.get_current_provider(AppType::Claude.as_str())
|
||||
.expect("get current provider");
|
||||
assert_eq!(current_id.as_deref(), Some("new-provider"), "current provider updated");
|
||||
assert_eq!(
|
||||
current_id.as_deref(),
|
||||
Some("new-provider"),
|
||||
"current provider updated"
|
||||
);
|
||||
|
||||
let providers = app_state.db.get_all_providers(AppType::Claude.as_str())
|
||||
let providers = app_state
|
||||
.db
|
||||
.get_all_providers(AppType::Claude.as_str())
|
||||
.expect("get all providers");
|
||||
|
||||
let legacy_provider = providers
|
||||
@@ -258,9 +274,7 @@ fn switch_provider_updates_claude_live_and_state() {
|
||||
"previous provider should be backfilled with live config"
|
||||
);
|
||||
|
||||
let new_provider = providers
|
||||
.get("new-provider")
|
||||
.expect("new provider exists");
|
||||
let new_provider = providers.get("new-provider").expect("new provider exists");
|
||||
assert_eq!(
|
||||
new_provider
|
||||
.settings_config
|
||||
@@ -283,7 +297,9 @@ fn switch_provider_updates_claude_live_and_state() {
|
||||
);
|
||||
|
||||
// 验证当前供应商已更新
|
||||
let current_id = app_state.db.get_current_provider(AppType::Claude.as_str())
|
||||
let current_id = app_state
|
||||
.db
|
||||
.get_current_provider(AppType::Claude.as_str())
|
||||
.expect("get current provider");
|
||||
assert_eq!(
|
||||
current_id.as_deref(),
|
||||
@@ -328,7 +344,9 @@ fn switch_provider_codex_missing_auth_returns_error_and_keeps_state() {
|
||||
other => panic!("expected config error, got {other:?}"),
|
||||
}
|
||||
|
||||
let current_id = app_state.db.get_current_provider(AppType::Codex.as_str())
|
||||
let current_id = app_state
|
||||
.db
|
||||
.get_current_provider(AppType::Codex.as_str())
|
||||
.expect("get current provider");
|
||||
// 切换失败后,由于数据库操作是先设置再验证,current 可能已被设为 "invalid"
|
||||
// 但由于 live 配置写入失败,状态应该回滚
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
use serde_json::json;
|
||||
|
||||
use cc_switch_lib::{
|
||||
get_claude_settings_path, read_json_file, write_codex_live_atomic, AppError, AppType,
|
||||
McpApps, McpServer, MultiAppConfig, Provider, ProviderMeta, ProviderService,
|
||||
get_claude_settings_path, read_json_file, write_codex_live_atomic, AppError, AppType, McpApps,
|
||||
McpServer, MultiAppConfig, Provider, ProviderMeta, ProviderService,
|
||||
};
|
||||
|
||||
#[path = "support.rs"]
|
||||
mod support;
|
||||
use support::{create_test_state, create_test_state_with_config, ensure_test_home, reset_test_fs, test_mutex};
|
||||
use support::{
|
||||
create_test_state, create_test_state_with_config, ensure_test_home, reset_test_fs, test_mutex,
|
||||
};
|
||||
|
||||
fn sanitize_provider_name(name: &str) -> String {
|
||||
name.chars()
|
||||
@@ -69,7 +71,10 @@ command = "say"
|
||||
}
|
||||
|
||||
// 使用新的统一 MCP 结构(v3.7.0+)
|
||||
let servers = initial_config.mcp.servers.get_or_insert_with(Default::default);
|
||||
let servers = initial_config
|
||||
.mcp
|
||||
.servers
|
||||
.get_or_insert_with(Default::default);
|
||||
servers.insert(
|
||||
"echo-server".into(),
|
||||
McpServer {
|
||||
@@ -111,16 +116,22 @@ command = "say"
|
||||
"config.toml should contain synced MCP servers"
|
||||
);
|
||||
|
||||
let current_id = state.db.get_current_provider(AppType::Codex.as_str())
|
||||
let current_id = state
|
||||
.db
|
||||
.get_current_provider(AppType::Codex.as_str())
|
||||
.expect("read current provider after switch");
|
||||
assert_eq!(current_id.as_deref(), Some("new-provider"), "current provider updated");
|
||||
assert_eq!(
|
||||
current_id.as_deref(),
|
||||
Some("new-provider"),
|
||||
"current provider updated"
|
||||
);
|
||||
|
||||
let providers = state.db.get_all_providers(AppType::Codex.as_str())
|
||||
let providers = state
|
||||
.db
|
||||
.get_all_providers(AppType::Codex.as_str())
|
||||
.expect("read providers after switch");
|
||||
|
||||
let new_provider = providers
|
||||
.get("new-provider")
|
||||
.expect("new provider exists");
|
||||
let new_provider = providers.get("new-provider").expect("new provider exists");
|
||||
let new_config_text = new_provider
|
||||
.settings_config
|
||||
.get("config")
|
||||
@@ -385,11 +396,19 @@ fn provider_service_switch_claude_updates_live_and_state() {
|
||||
"live settings.json should reflect new provider auth"
|
||||
);
|
||||
|
||||
let providers = state.db.get_all_providers(AppType::Claude.as_str())
|
||||
let providers = state
|
||||
.db
|
||||
.get_all_providers(AppType::Claude.as_str())
|
||||
.expect("get all providers");
|
||||
let current_id = state.db.get_current_provider(AppType::Claude.as_str())
|
||||
let current_id = state
|
||||
.db
|
||||
.get_current_provider(AppType::Claude.as_str())
|
||||
.expect("get current provider");
|
||||
assert_eq!(current_id.as_deref(), Some("new-provider"), "current provider updated");
|
||||
assert_eq!(
|
||||
current_id.as_deref(),
|
||||
Some("new-provider"),
|
||||
"current provider updated"
|
||||
);
|
||||
|
||||
let legacy_provider = providers
|
||||
.get("old-provider")
|
||||
@@ -509,7 +528,9 @@ fn provider_service_delete_codex_removes_provider_and_files() {
|
||||
ProviderService::delete(&app_state, AppType::Codex, "to-delete")
|
||||
.expect("delete provider should succeed");
|
||||
|
||||
let providers = app_state.db.get_all_providers(AppType::Codex.as_str())
|
||||
let providers = app_state
|
||||
.db
|
||||
.get_all_providers(AppType::Codex.as_str())
|
||||
.expect("get all providers");
|
||||
assert!(
|
||||
!providers.contains_key("to-delete"),
|
||||
@@ -567,7 +588,9 @@ fn provider_service_delete_claude_removes_provider_files() {
|
||||
|
||||
ProviderService::delete(&app_state, AppType::Claude, "delete").expect("delete claude provider");
|
||||
|
||||
let providers = app_state.db.get_all_providers(AppType::Claude.as_str())
|
||||
let providers = app_state
|
||||
.db
|
||||
.get_all_providers(AppType::Claude.as_str())
|
||||
.expect("get all providers");
|
||||
assert!(
|
||||
!providers.contains_key("delete"),
|
||||
@@ -608,15 +631,18 @@ fn provider_service_delete_current_provider_returns_error() {
|
||||
.expect_err("deleting current provider should fail");
|
||||
match err {
|
||||
AppError::Localized { zh, .. } => assert!(
|
||||
zh.contains("不能删除当前正在使用的供应商") || zh.contains("无法删除当前正在使用的供应商"),
|
||||
zh.contains("不能删除当前正在使用的供应商")
|
||||
|| zh.contains("无法删除当前正在使用的供应商"),
|
||||
"unexpected message: {zh}"
|
||||
),
|
||||
AppError::Config(msg) => assert!(
|
||||
msg.contains("不能删除当前正在使用的供应商") || msg.contains("无法删除当前正在使用的供应商"),
|
||||
msg.contains("不能删除当前正在使用的供应商")
|
||||
|| msg.contains("无法删除当前正在使用的供应商"),
|
||||
"unexpected message: {msg}"
|
||||
),
|
||||
AppError::Message(msg) => assert!(
|
||||
msg.contains("不能删除当前正在使用的供应商") || msg.contains("无法删除当前正在使用的供应商"),
|
||||
msg.contains("不能删除当前正在使用的供应商")
|
||||
|| msg.contains("无法删除当前正在使用的供应商"),
|
||||
"unexpected message: {msg}"
|
||||
),
|
||||
other => panic!("expected Config/Message error, got {other:?}"),
|
||||
|
||||
Reference in New Issue
Block a user