refactor llm config
This commit is contained in:
@@ -1,9 +1,7 @@
|
||||
"""
|
||||
配置管理模块
|
||||
使用OmegaConf读取config.yml和local_config.yml
|
||||
local_config.yml的优先级更高
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
@@ -34,10 +32,11 @@ def load_config():
|
||||
config = OmegaConf.merge(base_config, local_config)
|
||||
|
||||
# 把paths下的所有值pathlib化
|
||||
for key, value in config.paths.items():
|
||||
config.paths[key] = Path(value)
|
||||
if hasattr(config, "paths"):
|
||||
for key, value in config.paths.items():
|
||||
config.paths[key] = Path(value)
|
||||
|
||||
return config
|
||||
|
||||
# 导出配置对象
|
||||
CONFIG = load_config()
|
||||
CONFIG = load_config()
|
||||
|
||||
@@ -173,7 +173,7 @@ async def call_llm_with_task_name(
|
||||
return await call_llm_with_template(template_path, infos, mode, max_retries)
|
||||
|
||||
|
||||
def test_connectivity(mode: LLMMode = LLMMode.NORMAL, config: Optional[LLMConfig] = None) -> bool:
|
||||
def test_connectivity(mode: LLMMode = LLMMode.NORMAL, config: Optional[LLMConfig] = None) -> tuple[bool, str]:
|
||||
"""
|
||||
测试 LLM 服务连通性 (同步版本)
|
||||
|
||||
@@ -182,7 +182,7 @@ def test_connectivity(mode: LLMMode = LLMMode.NORMAL, config: Optional[LLMConfig
|
||||
config: 直接使用该配置进行测试
|
||||
|
||||
Returns:
|
||||
bool: 连接成功返回 True,失败返回 False
|
||||
tuple[bool, str]: (是否成功, 错误信息),成功时错误信息为空字符串
|
||||
"""
|
||||
try:
|
||||
if config is None:
|
||||
@@ -199,7 +199,22 @@ def test_connectivity(mode: LLMMode = LLMMode.NORMAL, config: Optional[LLMConfig
|
||||
else:
|
||||
# 直接调用 requests 实现
|
||||
_call_with_requests(config, "test")
|
||||
return True
|
||||
return True, ""
|
||||
except Exception as e:
|
||||
print(f"Connectivity test failed: {e}")
|
||||
return False
|
||||
error_msg = str(e)
|
||||
print(f"Connectivity test failed: {error_msg}")
|
||||
|
||||
# 解析常见错误并提供友好提示
|
||||
if "401" in error_msg or "invalid_api_key" in error_msg or "Incorrect API key" in error_msg:
|
||||
return False, "API Key 无效,请检查您的密钥是否正确"
|
||||
elif "403" in error_msg or "Forbidden" in error_msg:
|
||||
return False, "访问被拒绝,请检查您的权限或配额"
|
||||
elif "404" in error_msg:
|
||||
return False, "服务地址不存在,请检查 Base URL 是否正确"
|
||||
elif "timeout" in error_msg.lower():
|
||||
return False, "连接超时,请检查网络连接或服务地址"
|
||||
elif "Connection" in error_msg or "connect" in error_msg.lower():
|
||||
return False, "无法连接到服务器,请检查 Base URL 和网络"
|
||||
else:
|
||||
# 返回原始错误信息
|
||||
return False, error_msg
|
||||
|
||||
@@ -2,13 +2,13 @@
|
||||
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
|
||||
from src.utils.config import CONFIG
|
||||
|
||||
class LLMMode(str, Enum):
|
||||
"""LLM 调用模式"""
|
||||
NORMAL = "normal"
|
||||
FAST = "fast"
|
||||
DEFAULT = "default"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -21,7 +21,7 @@ class LLMConfig:
|
||||
@classmethod
|
||||
def from_mode(cls, mode: LLMMode) -> 'LLMConfig':
|
||||
"""
|
||||
根据模式创建配置
|
||||
根据模式创建配置,从 CONFIG 读取
|
||||
|
||||
Args:
|
||||
mode: LLM 调用模式
|
||||
@@ -29,36 +29,46 @@ class LLMConfig:
|
||||
Returns:
|
||||
LLMConfig: 配置对象
|
||||
"""
|
||||
from src.utils.config import CONFIG
|
||||
# 从 CONFIG 读取配置
|
||||
api_key = getattr(CONFIG.llm, "key", "")
|
||||
base_url = getattr(CONFIG.llm, "base_url", "")
|
||||
|
||||
# 根据模式选择模型
|
||||
model_name = (
|
||||
CONFIG.llm.model_name if mode == LLMMode.NORMAL
|
||||
else CONFIG.llm.fast_model_name
|
||||
)
|
||||
|
||||
# API Key 优先从环境变量读取
|
||||
api_key = CONFIG.llm.key
|
||||
model_name = ""
|
||||
if mode == LLMMode.FAST:
|
||||
model_name = getattr(CONFIG.llm, "fast_model_name", "")
|
||||
else:
|
||||
# NORMAL or DEFAULT fallback
|
||||
model_name = getattr(CONFIG.llm, "model_name", "")
|
||||
|
||||
return cls(
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
base_url=CONFIG.llm.base_url
|
||||
base_url=base_url
|
||||
)
|
||||
|
||||
|
||||
def get_task_mode(task_name: str) -> LLMMode:
|
||||
"""
|
||||
获取指定任务的 LLM 调用模式
|
||||
|
||||
Args:
|
||||
task_name: 任务名称 (配置在 llm.default_modes 下的 key)
|
||||
|
||||
Returns:
|
||||
LLMMode: 对应的模式,如果未配置则默认返回 NORMAL
|
||||
根据任务名称获取 LLM 模式
|
||||
"""
|
||||
from src.utils.config import CONFIG
|
||||
# 从 CONFIG 读取全局模式
|
||||
global_mode = getattr(CONFIG.llm, "mode", "default").lower()
|
||||
|
||||
# 获取配置的模式字符串,默认 normal
|
||||
mode_str = getattr(CONFIG.llm.default_modes, task_name, "normal")
|
||||
return LLMMode(mode_str)
|
||||
if global_mode == "normal":
|
||||
return LLMMode.NORMAL
|
||||
elif global_mode == "fast":
|
||||
return LLMMode.FAST
|
||||
|
||||
# Default 模式:根据 task_name 从细粒度配置中获取
|
||||
# 如果配置了 default_modes,则根据任务名称返回对应模式
|
||||
default_modes = getattr(CONFIG.llm, "default_modes", {})
|
||||
if default_modes and task_name in default_modes:
|
||||
task_mode = default_modes[task_name].lower()
|
||||
if task_mode == "fast":
|
||||
return LLMMode.FAST
|
||||
else:
|
||||
return LLMMode.NORMAL
|
||||
|
||||
# 如果没有配置,默认返回 NORMAL
|
||||
return LLMMode.NORMAL
|
||||
|
||||
Reference in New Issue
Block a user