update llm

This commit is contained in:
bridge
2025-12-29 22:02:44 +08:00
parent 6a4059280a
commit c2cb8098ee
11 changed files with 68 additions and 24 deletions

View File

@@ -7,16 +7,17 @@ LLM 调用模块
- call_llm_with_template: 使用模板调用(最常用)
"""
from .client import call_llm, call_llm_json, call_llm_with_template, call_ai_action
from .config import LLMMode
from .client import call_llm, call_llm_json, call_llm_with_template, call_llm_with_task_name
from .config import LLMMode, get_task_mode
from .exceptions import LLMError, ParseError, ConfigError
__all__ = [
"call_llm",
"call_llm_json",
"call_llm_with_template",
"call_ai_action",
"call_llm_with_task_name",
"LLMMode",
"get_task_mode",
"LLMError",
"ParseError",
"ConfigError",

View File

@@ -9,7 +9,7 @@ from typing import Optional
from src.run.log import log_llm_call
from src.utils.config import CONFIG
from .config import LLMMode, LLMConfig
from .config import LLMMode, LLMConfig, get_task_mode
from .parser import parse_json
from .prompt import build_prompt, load_template
from .exceptions import LLMError, ParseError
@@ -139,7 +139,24 @@ async def call_llm_with_template(
return await call_llm_json(prompt, mode, max_retries)
async def call_ai_action(infos: dict, mode: LLMMode = LLMMode.NORMAL) -> dict:
"""AI 行动决策专用接口"""
template_path = CONFIG.paths.templates / "ai.txt"
return await call_llm_with_template(template_path, infos, mode)
async def call_llm_with_task_name(
task_name: str,
template_path: Path | str,
infos: dict,
max_retries: int | None = None
) -> dict:
"""
根据任务名称自动选择 LLM 模式并调用
Args:
task_name: 任务名称,用于在 config.yml 中查找对应的模式
template_path: 模板路径
infos: 模板参数
max_retries: 最大重试次数
Returns:
dict: LLM 返回的 JSON 数据
"""
mode = get_task_mode(task_name)
return await call_llm_with_template(template_path, infos, mode, max_retries)

View File

@@ -46,3 +46,19 @@ class LLMConfig:
base_url=CONFIG.llm.base_url
)
def get_task_mode(task_name: str) -> LLMMode:
"""
获取指定任务的 LLM 调用模式
Args:
task_name: 任务名称 (配置在 llm.default_modes 下的 key)
Returns:
LLMMode: 对应的模式,如果未配置则默认返回 NORMAL
"""
from src.utils.config import CONFIG
# 获取配置的模式字符串,默认 normal
mode_str = getattr(CONFIG.llm.default_modes, task_name, "normal")
return LLMMode(mode_str)