diff --git a/src/classes/ai.py b/src/classes/ai.py index 8a0e1c3..4fa7e80 100644 --- a/src/classes/ai.py +++ b/src/classes/ai.py @@ -9,9 +9,10 @@ import asyncio from src.classes.world import World from src.classes.event import Event, NULL_EVENT -from src.utils.llm import call_ai_action +from src.utils.llm import call_llm_with_task_name from src.classes.typings import ACTION_NAME_PARAMS_PAIRS from src.classes.actions import ACTION_INFOS_STR +from src.utils.config import CONFIG if TYPE_CHECKING: from src.classes.avatar import Avatar @@ -66,7 +67,8 @@ class LLMAI(AI): "world_info": world_info, "general_action_infos": general_action_infos, } - res = await call_ai_action(info) + template_path = CONFIG.paths.templates / "ai.txt" + res = await call_llm_with_task_name("action_decision", template_path, info) return avatar, res # 直接并发所有任务 diff --git a/src/classes/long_term_objective.py b/src/classes/long_term_objective.py index 2742d12..102d009 100644 --- a/src/classes/long_term_objective.py +++ b/src/classes/long_term_objective.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from src.classes.event import Event from src.utils.config import CONFIG -from src.utils.llm import call_llm_with_template, LLMMode +from src.utils.llm import call_llm_with_task_name from src.run.log import get_logger from src.classes.actions import ACTION_INFOS_STR @@ -92,8 +92,8 @@ async def generate_long_term_objective(avatar: "Avatar") -> Optional[LongTermObj "general_action_infos": ACTION_INFOS_STR, } - # 调用LLM并自动解析JSON(使用fast模型) - response_data = await call_llm_with_template(template_path, infos, LLMMode.NORMAL) + # 调用LLM并自动解析JSON(使用配置的模型模式) + response_data = await call_llm_with_task_name("long_term_objective", template_path, infos) content = response_data.get("long_term_objective", "").strip() diff --git a/src/classes/mutual_action/mutual_action.py b/src/classes/mutual_action/mutual_action.py index ec46f97..2c69492 100644 --- a/src/classes/mutual_action/mutual_action.py +++ b/src/classes/mutual_action/mutual_action.py @@ -5,9 +5,8 @@ from typing import TYPE_CHECKING import asyncio from src.classes.action.action import DefineAction, ActualActionMixin, LLMAction -from src.classes.tile import get_avatar_distance from src.classes.event import Event -from src.utils.llm import call_llm_with_template, LLMMode +from src.utils.llm import call_llm_with_task_name from src.utils.config import CONFIG from src.classes.relation import relation_display_names, Relation from src.classes.relations import get_possible_new_relations @@ -86,7 +85,7 @@ class MutualAction(DefineAction, LLMAction, ActualActionMixin, TargetingMixin): async def _call_llm_feedback(self, infos: dict) -> dict: """异步调用 LLM 获取反馈""" template_path = self._get_template_path() - return await call_llm_with_template(template_path, infos, LLMMode.FAST) + return await call_llm_with_task_name("interaction_feedback", template_path, infos) def _set_target_immediate_action(self, target_avatar: "Avatar", action_name: str, action_params: dict) -> None: """ diff --git a/src/classes/nickname.py b/src/classes/nickname.py index 85b7209..f227127 100644 --- a/src/classes/nickname.py +++ b/src/classes/nickname.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from src.classes.event import Event from src.utils.config import CONFIG -from src.utils.llm import call_llm_with_template, LLMMode +from src.utils.llm import call_llm_with_task_name from src.run.log import get_logger from src.classes.nickname_data import Nickname @@ -75,7 +75,7 @@ async def generate_nickname(avatar: "Avatar") -> Optional[dict]: } # 调用LLM并自动解析JSON - response_data = await call_llm_with_template(template_path, infos, LLMMode.NORMAL) + response_data = await call_llm_with_task_name("nickname", template_path, infos) nickname = response_data.get("nickname", "").strip() thinking = response_data.get("thinking", "") diff --git a/src/classes/relation_resolver.py b/src/classes/relation_resolver.py index 9ceafb4..65a14d8 100644 --- a/src/classes/relation_resolver.py +++ b/src/classes/relation_resolver.py @@ -13,7 +13,7 @@ from src.classes.relations import ( ) from src.classes.calendar import get_date_str from src.classes.event import Event -from src.utils.llm import call_llm_with_template, LLMMode +from src.utils.llm import call_llm_with_task_name from src.utils.config import CONFIG if TYPE_CHECKING: @@ -64,7 +64,7 @@ class RelationResolver: """ infos = RelationResolver._build_prompt_data(avatar_a, avatar_b) - result = await call_llm_with_template(RelationResolver.TEMPLATE_PATH, infos, mode=LLMMode.FAST) + result = await call_llm_with_task_name("relation_resolver", RelationResolver.TEMPLATE_PATH, infos) changed = result.get("changed", False) if not changed: diff --git a/src/classes/single_choice.py b/src/classes/single_choice.py index 6004dca..2390442 100644 --- a/src/classes/single_choice.py +++ b/src/classes/single_choice.py @@ -1,5 +1,5 @@ from typing import Any, Dict, List, TYPE_CHECKING -from src.utils.llm import call_llm_with_template +from src.utils.llm import call_llm_with_task_name from src.utils.config import CONFIG import json @@ -33,7 +33,8 @@ async def make_decision( # 3. 调用 AI template_path = CONFIG.paths.templates / "single_choice.txt" - result = await call_llm_with_template( + result = await call_llm_with_task_name( + "single_choice", template_path, infos={ "avatar_infos": avatar_infos, diff --git a/src/classes/story_teller.py b/src/classes/story_teller.py index d50ebac..1b8f216 100644 --- a/src/classes/story_teller.py +++ b/src/classes/story_teller.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: from src.classes.avatar import Avatar from src.utils.config import CONFIG -from src.utils.llm import call_llm_with_template, LLMMode +from src.utils.llm import call_llm_with_task_name from src.classes.relations import ( process_relation_changes, get_relation_change_context @@ -118,7 +118,7 @@ class StoryTeller: infos = StoryTeller._build_template_data(event, res, avatar_infos, prompt, *actors) # 移除了 try-except 块,允许异常向上冒泡,以便 Fail Fast - data = await call_llm_with_template(template_path, infos, LLMMode.FAST) + data = await call_llm_with_task_name("story_teller", template_path, infos) story = data.get("story", "").strip() if story: diff --git a/src/utils/llm/__init__.py b/src/utils/llm/__init__.py index a4087b0..da00015 100644 --- a/src/utils/llm/__init__.py +++ b/src/utils/llm/__init__.py @@ -7,16 +7,17 @@ LLM 调用模块 - call_llm_with_template: 使用模板调用(最常用) """ -from .client import call_llm, call_llm_json, call_llm_with_template, call_ai_action -from .config import LLMMode +from .client import call_llm, call_llm_json, call_llm_with_template, call_llm_with_task_name +from .config import LLMMode, get_task_mode from .exceptions import LLMError, ParseError, ConfigError __all__ = [ "call_llm", "call_llm_json", "call_llm_with_template", - "call_ai_action", + "call_llm_with_task_name", "LLMMode", + "get_task_mode", "LLMError", "ParseError", "ConfigError", diff --git a/src/utils/llm/client.py b/src/utils/llm/client.py index 54b4c0f..15fb917 100644 --- a/src/utils/llm/client.py +++ b/src/utils/llm/client.py @@ -9,7 +9,7 @@ from typing import Optional from src.run.log import log_llm_call from src.utils.config import CONFIG -from .config import LLMMode, LLMConfig +from .config import LLMMode, LLMConfig, get_task_mode from .parser import parse_json from .prompt import build_prompt, load_template from .exceptions import LLMError, ParseError @@ -139,7 +139,24 @@ async def call_llm_with_template( return await call_llm_json(prompt, mode, max_retries) -async def call_ai_action(infos: dict, mode: LLMMode = LLMMode.NORMAL) -> dict: - """AI 行动决策专用接口""" - template_path = CONFIG.paths.templates / "ai.txt" - return await call_llm_with_template(template_path, infos, mode) +async def call_llm_with_task_name( + task_name: str, + template_path: Path | str, + infos: dict, + max_retries: int | None = None +) -> dict: + """ + 根据任务名称自动选择 LLM 模式并调用 + + Args: + task_name: 任务名称,用于在 config.yml 中查找对应的模式 + template_path: 模板路径 + infos: 模板参数 + max_retries: 最大重试次数 + + Returns: + dict: LLM 返回的 JSON 数据 + """ + mode = get_task_mode(task_name) + return await call_llm_with_template(template_path, infos, mode, max_retries) + diff --git a/src/utils/llm/config.py b/src/utils/llm/config.py index 6b406b1..eff2499 100644 --- a/src/utils/llm/config.py +++ b/src/utils/llm/config.py @@ -46,3 +46,19 @@ class LLMConfig: base_url=CONFIG.llm.base_url ) + +def get_task_mode(task_name: str) -> LLMMode: + """ + 获取指定任务的 LLM 调用模式 + + Args: + task_name: 任务名称 (配置在 llm.default_modes 下的 key) + + Returns: + LLMMode: 对应的模式,如果未配置则默认返回 NORMAL + """ + from src.utils.config import CONFIG + + # 获取配置的模式字符串,默认 normal + mode_str = getattr(CONFIG.llm.default_modes, task_name, "normal") + return LLMMode(mode_str) diff --git a/static/config.yml b/static/config.yml index 7dab267..7a65a59 100644 --- a/static/config.yml +++ b/static/config.yml @@ -7,6 +7,14 @@ llm: model_name: "openai/qwen-plus" fast_model_name: "openai/qwen-flash" base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1" + default_modes: + action_decision: "normal" + long_term_objective: "normal" + nickname: "normal" + single_choice: "normal" + relation_resolver: "fast" + story_teller: "fast" + interaction_feedback: "fast" paths: templates: static/templates/