refactor llm config

This commit is contained in:
bridge
2025-12-30 22:20:30 +08:00
parent f539b21801
commit d55ada7d66
10 changed files with 235 additions and 69 deletions

View File

@@ -1,9 +1,7 @@
"""
配置管理模块
使用OmegaConf读取config.yml和local_config.yml
local_config.yml的优先级更高
"""
from pathlib import Path
from omegaconf import OmegaConf
@@ -34,10 +32,11 @@ def load_config():
config = OmegaConf.merge(base_config, local_config)
# 把paths下的所有值pathlib化
for key, value in config.paths.items():
config.paths[key] = Path(value)
if hasattr(config, "paths"):
for key, value in config.paths.items():
config.paths[key] = Path(value)
return config
# 导出配置对象
CONFIG = load_config()
CONFIG = load_config()

View File

@@ -173,7 +173,7 @@ async def call_llm_with_task_name(
return await call_llm_with_template(template_path, infos, mode, max_retries)
def test_connectivity(mode: LLMMode = LLMMode.NORMAL, config: Optional[LLMConfig] = None) -> bool:
def test_connectivity(mode: LLMMode = LLMMode.NORMAL, config: Optional[LLMConfig] = None) -> tuple[bool, str]:
"""
测试 LLM 服务连通性 (同步版本)
@@ -182,7 +182,7 @@ def test_connectivity(mode: LLMMode = LLMMode.NORMAL, config: Optional[LLMConfig
config: 直接使用该配置进行测试
Returns:
bool: 连接成功返回 True失败返回 False
tuple[bool, str]: (是否成功, 错误信息),成功时错误信息为空字符串
"""
try:
if config is None:
@@ -199,7 +199,22 @@ def test_connectivity(mode: LLMMode = LLMMode.NORMAL, config: Optional[LLMConfig
else:
# 直接调用 requests 实现
_call_with_requests(config, "test")
return True
return True, ""
except Exception as e:
print(f"Connectivity test failed: {e}")
return False
error_msg = str(e)
print(f"Connectivity test failed: {error_msg}")
# 解析常见错误并提供友好提示
if "401" in error_msg or "invalid_api_key" in error_msg or "Incorrect API key" in error_msg:
return False, "API Key 无效,请检查您的密钥是否正确"
elif "403" in error_msg or "Forbidden" in error_msg:
return False, "访问被拒绝,请检查您的权限或配额"
elif "404" in error_msg:
return False, "服务地址不存在,请检查 Base URL 是否正确"
elif "timeout" in error_msg.lower():
return False, "连接超时,请检查网络连接或服务地址"
elif "Connection" in error_msg or "connect" in error_msg.lower():
return False, "无法连接到服务器,请检查 Base URL 和网络"
else:
# 返回原始错误信息
return False, error_msg

View File

@@ -2,13 +2,13 @@
from enum import Enum
from dataclasses import dataclass
import os
from src.utils.config import CONFIG
class LLMMode(str, Enum):
"""LLM 调用模式"""
NORMAL = "normal"
FAST = "fast"
DEFAULT = "default"
@dataclass(frozen=True)
@@ -21,7 +21,7 @@ class LLMConfig:
@classmethod
def from_mode(cls, mode: LLMMode) -> 'LLMConfig':
"""
根据模式创建配置
根据模式创建配置,从 CONFIG 读取
Args:
mode: LLM 调用模式
@@ -29,36 +29,46 @@ class LLMConfig:
Returns:
LLMConfig: 配置对象
"""
from src.utils.config import CONFIG
# 从 CONFIG 读取配置
api_key = getattr(CONFIG.llm, "key", "")
base_url = getattr(CONFIG.llm, "base_url", "")
# 根据模式选择模型
model_name = (
CONFIG.llm.model_name if mode == LLMMode.NORMAL
else CONFIG.llm.fast_model_name
)
# API Key 优先从环境变量读取
api_key = CONFIG.llm.key
model_name = ""
if mode == LLMMode.FAST:
model_name = getattr(CONFIG.llm, "fast_model_name", "")
else:
# NORMAL or DEFAULT fallback
model_name = getattr(CONFIG.llm, "model_name", "")
return cls(
model_name=model_name,
api_key=api_key,
base_url=CONFIG.llm.base_url
base_url=base_url
)
def get_task_mode(task_name: str) -> LLMMode:
"""
获取指定任务的 LLM 调用模式
Args:
task_name: 任务名称 (配置在 llm.default_modes 下的 key)
Returns:
LLMMode: 对应的模式,如果未配置则默认返回 NORMAL
根据任务名称获取 LLM 模式
"""
from src.utils.config import CONFIG
# 从 CONFIG 读取全局模式
global_mode = getattr(CONFIG.llm, "mode", "default").lower()
# 获取配置的模式字符串,默认 normal
mode_str = getattr(CONFIG.llm.default_modes, task_name, "normal")
return LLMMode(mode_str)
if global_mode == "normal":
return LLMMode.NORMAL
elif global_mode == "fast":
return LLMMode.FAST
# Default 模式:根据 task_name 从细粒度配置中获取
# 如果配置了 default_modes则根据任务名称返回对应模式
default_modes = getattr(CONFIG.llm, "default_modes", {})
if default_modes and task_name in default_modes:
task_mode = default_modes[task_name].lower()
if task_mode == "fast":
return LLMMode.FAST
else:
return LLMMode.NORMAL
# 如果没有配置,默认返回 NORMAL
return LLMMode.NORMAL