refactor llm config

This commit is contained in:
bridge
2025-12-30 22:20:30 +08:00
parent f539b21801
commit d55ada7d66
10 changed files with 235 additions and 69 deletions

View File

@@ -210,8 +210,57 @@ def serialize_phenomenon(phenomenon) -> Optional[dict]:
"effect_desc": effect_desc
}
def check_llm_connectivity() -> tuple[bool, str]:
"""
检查 LLM 连通性
Returns:
(是否成功, 错误信息)
"""
try:
from src.utils.llm.config import LLMMode, LLMConfig
normal_config = LLMConfig.from_mode(LLMMode.NORMAL)
fast_config = LLMConfig.from_mode(LLMMode.FAST)
# 检查配置是否完整
if not normal_config.api_key or not normal_config.base_url:
return False, "LLM 配置不完整:请填写 API Key 和 Base URL"
if not normal_config.model_name:
return False, "LLM 配置不完整:请填写智能模型名称"
# 判断是否需要测试两次
same_model = (normal_config.model_name == fast_config.model_name and
normal_config.base_url == fast_config.base_url and
normal_config.api_key == fast_config.api_key)
if same_model:
# 只测试一次
print(f"检测 LLM 连通性(单模型): {normal_config.model_name}")
success, error = test_connectivity(LLMMode.NORMAL, normal_config)
if not success:
return False, f"连接失败:{error}"
else:
# 测试两次
print(f"检测智能模型连通性: {normal_config.model_name}")
success, error = test_connectivity(LLMMode.NORMAL, normal_config)
if not success:
return False, f"智能模型连接失败:{error}"
print(f"检测快速模型连通性: {fast_config.model_name}")
success, error = test_connectivity(LLMMode.FAST, fast_config)
if not success:
return False, f"快速模型连接失败:{error}"
return True, ""
except Exception as e:
return False, f"连通性检测异常:{str(e)}"
def init_game():
"""初始化游戏世界,逻辑复用自 src/run/run.py"""
print("正在初始化游戏世界...")
game_map = load_cultivation_world_map()
world = World(map=game_map, month_stamp=create_month_stamp(Year(100), Month.JANUARY))
@@ -267,6 +316,24 @@ def init_game():
game_instance["world"] = world
game_instance["sim"] = sim
print("游戏世界初始化完成!")
# ===== LLM 连通性检测(在 simulator 运行前)=====
print("正在检测 LLM 连通性...")
success, error_msg = check_llm_connectivity()
if not success:
print(f"[警告] LLM 连通性检测失败: {error_msg}")
print("[警告] Simulator 已暂停,等待配置 LLM...")
game_instance["llm_check_failed"] = True
game_instance["llm_error_message"] = error_msg
game_instance["is_paused"] = True
print("等待前端连接并配置 LLM...")
else:
print("LLM 连通性检测通过 ✓")
game_instance["llm_check_failed"] = False
game_instance["llm_error_message"] = ""
game_instance["is_paused"] = False
# ===== LLM 检测结束 =====
async def game_loop():
"""后台自动运行游戏循环"""
@@ -455,6 +522,17 @@ print(f"Web dist path: {WEB_DIST_PATH}")
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
# ===== 检查 LLM 状态并通知前端 =====
if game_instance.get("llm_check_failed", False):
error_msg = game_instance.get("llm_error_message", "LLM 连接失败")
await websocket.send_json({
"type": "llm_config_required",
"error": error_msg
})
print(f"已向客户端发送 LLM 配置要求: {error_msg}")
# ===== 检测结束 =====
try:
while True:
# 保持连接活跃,接收客户端指令(目前暂不处理复杂指令)
@@ -1024,17 +1102,22 @@ def test_llm_connection(req: TestConnectionRequest):
model_name=req.model_name
)
success = test_connectivity(config=config)
success, error_msg = test_connectivity(config=config)
if success:
return {"status": "ok", "message": "连接成功"}
else:
raise HTTPException(status_code=400, detail="连接失败")
# 返回 400 错误并附带详细的错误信息
raise HTTPException(status_code=400, detail=error_msg)
except HTTPException:
# 重新抛出 HTTPException
raise
except Exception as e:
# 其他未预期的错误
raise HTTPException(status_code=500, detail=f"测试出错: {str(e)}")
@app.post("/api/config/llm/save")
def save_llm_config(req: LLMConfigDTO):
async def save_llm_config(req: LLMConfigDTO):
"""保存 LLM 配置"""
try:
# 1. Update In-Memory Config (Partial update)
@@ -1075,6 +1158,24 @@ def save_llm_config(req: LLMConfigDTO):
OmegaConf.save(conf, local_config_path)
# ===== 如果之前 LLM 连接失败,现在恢复运行 =====
if game_instance.get("llm_check_failed", False):
print("检测到之前 LLM 连接失败,正在恢复 Simulator 运行...")
# 清除失败标志并恢复运行
game_instance["llm_check_failed"] = False
game_instance["llm_error_message"] = ""
game_instance["is_paused"] = False
print("Simulator 已恢复运行 ✓")
# 通知所有客户端刷新
await manager.broadcast({
"type": "game_reinitialized",
"message": "LLM 配置成功,游戏已恢复运行"
})
# ===== 恢复运行结束 =====
return {"status": "ok", "message": "配置已保存"}
except Exception as e:
import traceback
@@ -1113,22 +1214,6 @@ def api_save_game(req: SaveGameRequest):
if not world or not sim:
raise HTTPException(status_code=503, detail="Game not initialized")
# 这里的 existed_sects 需要从 world 或者 sim 中获取,目前简单起见,
# 我们可以遍历地图上的宗门总部,或者如果全局有保存最好。
# 由于 init_game 只有一次,我们需要从 world 中反推 active sects
# 但 save_game 签名里的 existed_sects 主要是为了记录 id。
# 实际上 world.map.regions 中包含了宗门总部信息。
# 或者更简单的:直接从 sects_by_id 取所有? 不太对。
# 让我们看看 save_game 实现:它主要是存 id。
# 我们可以传入空列表,如果在 load 时能容忍的话。
# 实际上 load_game 里existed_sects = [sects_by_id[sid] for sid in existed_sect_ids]
# 所以 save 时如果不传load 时就拿不到。
# 临时方案:遍历所有宗门,如果它有领地或者有人,就算存在。
# 或者更粗暴CONFIG.game.sect_num 如果没变,可以不管。
# 最好是 world 对象上能挂载 existed_sects。
# 暂时方案:传入所有宗门作为 existed_sects (全集),虽然有点浪费,但不丢数据。
# 更好的方案:修改 init_game把 existed_sects 挂载到 world 上。
# 尝试从 world 属性获取(如果以后添加了)
existed_sects = getattr(world, "existed_sects", [])
if not existed_sects:

View File

@@ -1,9 +1,7 @@
"""
配置管理模块
使用OmegaConf读取config.yml和local_config.yml
local_config.yml的优先级更高
"""
from pathlib import Path
from omegaconf import OmegaConf
@@ -34,10 +32,11 @@ def load_config():
config = OmegaConf.merge(base_config, local_config)
# 把paths下的所有值pathlib化
for key, value in config.paths.items():
config.paths[key] = Path(value)
if hasattr(config, "paths"):
for key, value in config.paths.items():
config.paths[key] = Path(value)
return config
# 导出配置对象
CONFIG = load_config()
CONFIG = load_config()

View File

@@ -173,7 +173,7 @@ async def call_llm_with_task_name(
return await call_llm_with_template(template_path, infos, mode, max_retries)
def test_connectivity(mode: LLMMode = LLMMode.NORMAL, config: Optional[LLMConfig] = None) -> bool:
def test_connectivity(mode: LLMMode = LLMMode.NORMAL, config: Optional[LLMConfig] = None) -> tuple[bool, str]:
"""
测试 LLM 服务连通性 (同步版本)
@@ -182,7 +182,7 @@ def test_connectivity(mode: LLMMode = LLMMode.NORMAL, config: Optional[LLMConfig
config: 直接使用该配置进行测试
Returns:
bool: 连接成功返回 True失败返回 False
tuple[bool, str]: (是否成功, 错误信息),成功时错误信息为空字符串
"""
try:
if config is None:
@@ -199,7 +199,22 @@ def test_connectivity(mode: LLMMode = LLMMode.NORMAL, config: Optional[LLMConfig
else:
# 直接调用 requests 实现
_call_with_requests(config, "test")
return True
return True, ""
except Exception as e:
print(f"Connectivity test failed: {e}")
return False
error_msg = str(e)
print(f"Connectivity test failed: {error_msg}")
# 解析常见错误并提供友好提示
if "401" in error_msg or "invalid_api_key" in error_msg or "Incorrect API key" in error_msg:
return False, "API Key 无效,请检查您的密钥是否正确"
elif "403" in error_msg or "Forbidden" in error_msg:
return False, "访问被拒绝,请检查您的权限或配额"
elif "404" in error_msg:
return False, "服务地址不存在,请检查 Base URL 是否正确"
elif "timeout" in error_msg.lower():
return False, "连接超时,请检查网络连接或服务地址"
elif "Connection" in error_msg or "connect" in error_msg.lower():
return False, "无法连接到服务器,请检查 Base URL 和网络"
else:
# 返回原始错误信息
return False, error_msg

View File

@@ -2,13 +2,13 @@
from enum import Enum
from dataclasses import dataclass
import os
from src.utils.config import CONFIG
class LLMMode(str, Enum):
"""LLM 调用模式"""
NORMAL = "normal"
FAST = "fast"
DEFAULT = "default"
@dataclass(frozen=True)
@@ -21,7 +21,7 @@ class LLMConfig:
@classmethod
def from_mode(cls, mode: LLMMode) -> 'LLMConfig':
"""
根据模式创建配置
根据模式创建配置,从 CONFIG 读取
Args:
mode: LLM 调用模式
@@ -29,36 +29,46 @@ class LLMConfig:
Returns:
LLMConfig: 配置对象
"""
from src.utils.config import CONFIG
# 从 CONFIG 读取配置
api_key = getattr(CONFIG.llm, "key", "")
base_url = getattr(CONFIG.llm, "base_url", "")
# 根据模式选择模型
model_name = (
CONFIG.llm.model_name if mode == LLMMode.NORMAL
else CONFIG.llm.fast_model_name
)
# API Key 优先从环境变量读取
api_key = CONFIG.llm.key
model_name = ""
if mode == LLMMode.FAST:
model_name = getattr(CONFIG.llm, "fast_model_name", "")
else:
# NORMAL or DEFAULT fallback
model_name = getattr(CONFIG.llm, "model_name", "")
return cls(
model_name=model_name,
api_key=api_key,
base_url=CONFIG.llm.base_url
base_url=base_url
)
def get_task_mode(task_name: str) -> LLMMode:
"""
获取指定任务的 LLM 调用模式
Args:
task_name: 任务名称 (配置在 llm.default_modes 下的 key)
Returns:
LLMMode: 对应的模式,如果未配置则默认返回 NORMAL
根据任务名称获取 LLM 模式
"""
from src.utils.config import CONFIG
# 从 CONFIG 读取全局模式
global_mode = getattr(CONFIG.llm, "mode", "default").lower()
# 获取配置的模式字符串,默认 normal
mode_str = getattr(CONFIG.llm.default_modes, task_name, "normal")
return LLMMode(mode_str)
if global_mode == "normal":
return LLMMode.NORMAL
elif global_mode == "fast":
return LLMMode.FAST
# Default 模式:根据 task_name 从细粒度配置中获取
# 如果配置了 default_modes则根据任务名称返回对应模式
default_modes = getattr(CONFIG.llm, "default_modes", {})
if default_modes and task_name in default_modes:
task_mode = default_modes[task_name].lower()
if task_mode == "fast":
return LLMMode.FAST
else:
return LLMMode.NORMAL
# 如果没有配置,默认返回 NORMAL
return LLMMode.NORMAL