update llm
This commit is contained in:
@@ -7,16 +7,17 @@ LLM 调用模块
|
||||
- call_llm_with_template: 使用模板调用(最常用)
|
||||
"""
|
||||
|
||||
from .client import call_llm, call_llm_json, call_llm_with_template, call_ai_action
|
||||
from .config import LLMMode
|
||||
from .client import call_llm, call_llm_json, call_llm_with_template, call_llm_with_task_name
|
||||
from .config import LLMMode, get_task_mode
|
||||
from .exceptions import LLMError, ParseError, ConfigError
|
||||
|
||||
__all__ = [
|
||||
"call_llm",
|
||||
"call_llm_json",
|
||||
"call_llm_with_template",
|
||||
"call_ai_action",
|
||||
"call_llm_with_task_name",
|
||||
"LLMMode",
|
||||
"get_task_mode",
|
||||
"LLMError",
|
||||
"ParseError",
|
||||
"ConfigError",
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Optional
|
||||
|
||||
from src.run.log import log_llm_call
|
||||
from src.utils.config import CONFIG
|
||||
from .config import LLMMode, LLMConfig
|
||||
from .config import LLMMode, LLMConfig, get_task_mode
|
||||
from .parser import parse_json
|
||||
from .prompt import build_prompt, load_template
|
||||
from .exceptions import LLMError, ParseError
|
||||
@@ -139,7 +139,24 @@ async def call_llm_with_template(
|
||||
return await call_llm_json(prompt, mode, max_retries)
|
||||
|
||||
|
||||
async def call_ai_action(infos: dict, mode: LLMMode = LLMMode.NORMAL) -> dict:
|
||||
"""AI 行动决策专用接口"""
|
||||
template_path = CONFIG.paths.templates / "ai.txt"
|
||||
return await call_llm_with_template(template_path, infos, mode)
|
||||
async def call_llm_with_task_name(
|
||||
task_name: str,
|
||||
template_path: Path | str,
|
||||
infos: dict,
|
||||
max_retries: int | None = None
|
||||
) -> dict:
|
||||
"""
|
||||
根据任务名称自动选择 LLM 模式并调用
|
||||
|
||||
Args:
|
||||
task_name: 任务名称,用于在 config.yml 中查找对应的模式
|
||||
template_path: 模板路径
|
||||
infos: 模板参数
|
||||
max_retries: 最大重试次数
|
||||
|
||||
Returns:
|
||||
dict: LLM 返回的 JSON 数据
|
||||
"""
|
||||
mode = get_task_mode(task_name)
|
||||
return await call_llm_with_template(template_path, infos, mode, max_retries)
|
||||
|
||||
|
||||
@@ -46,3 +46,19 @@ class LLMConfig:
|
||||
base_url=CONFIG.llm.base_url
|
||||
)
|
||||
|
||||
|
||||
def get_task_mode(task_name: str) -> LLMMode:
|
||||
"""
|
||||
获取指定任务的 LLM 调用模式
|
||||
|
||||
Args:
|
||||
task_name: 任务名称 (配置在 llm.default_modes 下的 key)
|
||||
|
||||
Returns:
|
||||
LLMMode: 对应的模式,如果未配置则默认返回 NORMAL
|
||||
"""
|
||||
from src.utils.config import CONFIG
|
||||
|
||||
# 获取配置的模式字符串,默认 normal
|
||||
mode_str = getattr(CONFIG.llm.default_modes, task_name, "normal")
|
||||
return LLMMode(mode_str)
|
||||
|
||||
Reference in New Issue
Block a user