diff --git a/src/classes/action/action.py b/src/classes/action/action.py index 90a462b..3016ef1 100644 --- a/src/classes/action/action.py +++ b/src/classes/action/action.py @@ -168,7 +168,7 @@ class ActualActionMixin(): ... @abstractmethod - def finish(self, **params) -> list[Event]: + async def finish(self, **params) -> list[Event]: return [] diff --git a/src/classes/action/battle.py b/src/classes/action/battle.py index e7415e4..2799ae7 100644 --- a/src/classes/action/battle.py +++ b/src/classes/action/battle.py @@ -71,7 +71,7 @@ class Battle(InstantAction): # InstantAction 已实现 step 完成 - def finish(self, avatar_name: str) -> list[Event]: + async def finish(self, avatar_name: str) -> list[Event]: res = self._last_result if not (isinstance(res, tuple) and len(res) == 4): return [] @@ -87,10 +87,10 @@ class Battle(InstantAction): pass result_event = Event(self.world.month_stamp, result_text, related_avatars=rel_ids, is_major=True) - # 生成战斗小故事(同步调用,与其他动作保持一致) + # 生成战斗小故事 target = self._get_target(avatar_name) start_text = self._start_event_content if hasattr(self, '_start_event_content') else result_event.content - story = StoryTeller.tell_story(start_text, result_event.content, self.avatar, target, prompt=self.STORY_PROMPT) + story = await StoryTeller.tell_story(start_text, result_event.content, self.avatar, target, prompt=self.STORY_PROMPT) story_event = Event(self.world.month_stamp, story, related_avatars=rel_ids, is_story=True) return [result_event, story_event] diff --git a/src/classes/action/breakthrough.py b/src/classes/action/breakthrough.py index 5830a59..3235697 100644 --- a/src/classes/action/breakthrough.py +++ b/src/classes/action/breakthrough.py @@ -118,7 +118,7 @@ class Breakthrough(TimedAction): # TimedAction 已统一 step 逻辑 - def finish(self) -> list[Event]: + async def finish(self) -> list[Event]: if not self._last_result: return [] result_ok = self._last_result[0] == "success" @@ -139,7 +139,7 @@ class Breakthrough(TimedAction): # 故事参与者:本体 +(可选)相关角色 prompt = TribulationSelector.get_story_prompt(str(calamity)) - story = StoryTeller.tell_story(core_text, ("突破成功" if result_ok else "突破失败"), self.avatar, self._calamity_other, prompt=prompt) + story = await StoryTeller.tell_story(core_text, ("突破成功" if result_ok else "突破失败"), self.avatar, self._calamity_other, prompt=prompt) events.append(Event(self.world.month_stamp, story, related_avatars=rel_ids, is_story=True)) return events diff --git a/src/classes/action/catch.py b/src/classes/action/catch.py index adc9784..2da2340 100644 --- a/src/classes/action/catch.py +++ b/src/classes/action/catch.py @@ -85,7 +85,7 @@ class Catch(TimedAction): region = self.avatar.tile.region return Event(self.world.month_stamp, f"{self.avatar.name} 在 {region.name} 尝试御兽", related_avatars=[self.avatar.id]) - def finish(self) -> list[Event]: + async def finish(self) -> list[Event]: res = self._caught_result if not (isinstance(res, tuple) and len(res) == 3): return [] diff --git a/src/classes/action/cultivate.py b/src/classes/action/cultivate.py index 0679b21..45dd931 100644 --- a/src/classes/action/cultivate.py +++ b/src/classes/action/cultivate.py @@ -71,7 +71,7 @@ class Cultivate(TimedAction): # TimedAction 已统一 step 逻辑 - def finish(self) -> list[Event]: + async def finish(self) -> list[Event]: return [] diff --git a/src/classes/action/devour_mortals.py b/src/classes/action/devour_mortals.py index eb26dc4..a87f7d1 100644 --- a/src/classes/action/devour_mortals.py +++ b/src/classes/action/devour_mortals.py @@ -32,7 +32,7 @@ class DevourMortals(TimedAction): def start(self) -> Event: return Event(self.world.month_stamp, f"{self.avatar.name} 在城镇开始吞噬凡人", related_avatars=[self.avatar.id]) - def finish(self) -> list[Event]: + async def finish(self) -> list[Event]: return [] diff --git a/src/classes/action/escape.py b/src/classes/action/escape.py index 9582124..4ff11e8 100644 --- a/src/classes/action/escape.py +++ b/src/classes/action/escape.py @@ -75,7 +75,7 @@ class Escape(InstantAction): # InstantAction 已实现 step 完成 - def finish(self, avatar_name: str) -> list[Event]: + async def finish(self, avatar_name: str) -> list[Event]: return [] diff --git a/src/classes/action/harvest.py b/src/classes/action/harvest.py index a062d61..50762be 100644 --- a/src/classes/action/harvest.py +++ b/src/classes/action/harvest.py @@ -65,7 +65,7 @@ class Harvest(TimedAction): # TimedAction 已统一 step 逻辑 - def finish(self) -> list[Event]: + async def finish(self) -> list[Event]: return [] diff --git a/src/classes/action/help_mortals.py b/src/classes/action/help_mortals.py index a4e4e1d..f10dcf6 100644 --- a/src/classes/action/help_mortals.py +++ b/src/classes/action/help_mortals.py @@ -43,7 +43,7 @@ class HelpMortals(TimedAction): # TimedAction 已统一 step 逻辑 - def finish(self) -> list[Event]: + async def finish(self) -> list[Event]: return [] diff --git a/src/classes/action/hunt.py b/src/classes/action/hunt.py index ae01780..3adac99 100644 --- a/src/classes/action/hunt.py +++ b/src/classes/action/hunt.py @@ -65,7 +65,7 @@ class Hunt(TimedAction): # TimedAction 已统一 step 逻辑 - def finish(self) -> list[Event]: + async def finish(self) -> list[Event]: return [] diff --git a/src/classes/action/move_away_from_avatar.py b/src/classes/action/move_away_from_avatar.py index adab989..55fab74 100644 --- a/src/classes/action/move_away_from_avatar.py +++ b/src/classes/action/move_away_from_avatar.py @@ -67,7 +67,7 @@ class MoveAwayFromAvatar(TimedAction): # TimedAction 已统一 step 逻辑 - def finish(self, avatar_name: str) -> list[Event]: + async def finish(self, avatar_name: str) -> list[Event]: return [] diff --git a/src/classes/action/move_away_from_region.py b/src/classes/action/move_away_from_region.py index a093e3f..b13b1f8 100644 --- a/src/classes/action/move_away_from_region.py +++ b/src/classes/action/move_away_from_region.py @@ -47,7 +47,7 @@ class MoveAwayFromRegion(InstantAction): # InstantAction 已实现 step 完成 - def finish(self, region: str) -> list[Event]: + async def finish(self, region: str) -> list[Event]: return [] diff --git a/src/classes/action/move_to_avatar.py b/src/classes/action/move_to_avatar.py index a04119b..056f4a6 100644 --- a/src/classes/action/move_to_avatar.py +++ b/src/classes/action/move_to_avatar.py @@ -59,7 +59,7 @@ class MoveToAvatar(DefineAction, ActualActionMixin): done = self.avatar.tile == target.tile return ActionResult(status=(ActionStatus.COMPLETED if done else ActionStatus.RUNNING), events=[]) - def finish(self, avatar_name: str) -> list[Event]: + async def finish(self, avatar_name: str) -> list[Event]: return [] diff --git a/src/classes/action/move_to_region.py b/src/classes/action/move_to_region.py index 7871617..dd8db89 100644 --- a/src/classes/action/move_to_region.py +++ b/src/classes/action/move_to_region.py @@ -52,7 +52,7 @@ class MoveToRegion(DefineAction, ActualActionMixin): done = self.avatar.is_in_region(r) or ((self.avatar.pos_x, self.avatar.pos_y) in getattr(r, "cors", ())) return ActionResult(status=(ActionStatus.COMPLETED if done else ActionStatus.RUNNING), events=[]) - def finish(self, region: Region | str) -> list[Event]: + async def finish(self, region: Region | str) -> list[Event]: return [] diff --git a/src/classes/action/nurture_weapon.py b/src/classes/action/nurture_weapon.py index 315b802..33cb01c 100644 --- a/src/classes/action/nurture_weapon.py +++ b/src/classes/action/nurture_weapon.py @@ -63,7 +63,7 @@ class NurtureWeapon(TimedAction): related_avatars=[self.avatar.id] ) - def finish(self) -> list[Event]: + async def finish(self) -> list[Event]: weapon_name = self.avatar.weapon.name if self.avatar.weapon else "兵器" proficiency = self.avatar.weapon_proficiency # 注意:升华事件已经在_execute中添加,这里只添加完成事件 diff --git a/src/classes/action/play.py b/src/classes/action/play.py index 8d94966..f057715 100644 --- a/src/classes/action/play.py +++ b/src/classes/action/play.py @@ -31,7 +31,7 @@ class Play(TimedAction): # TimedAction 已统一 step 逻辑 - def finish(self) -> list[Event]: + async def finish(self) -> list[Event]: return [] diff --git a/src/classes/action/plunder_mortals.py b/src/classes/action/plunder_mortals.py index 38c9a01..191a519 100644 --- a/src/classes/action/plunder_mortals.py +++ b/src/classes/action/plunder_mortals.py @@ -48,7 +48,7 @@ class PlunderMortals(TimedAction): # TimedAction 已统一 step 逻辑 - def finish(self) -> list[Event]: + async def finish(self) -> list[Event]: return [] diff --git a/src/classes/action/self_heal.py b/src/classes/action/self_heal.py index 0668d86..ddabece 100644 --- a/src/classes/action/self_heal.py +++ b/src/classes/action/self_heal.py @@ -59,7 +59,7 @@ class SelfHeal(TimedAction): # TimedAction 已统一 step 逻辑 - def finish(self) -> list[Event]: + async def finish(self) -> list[Event]: healed_total = int(getattr(self, "_healed_total", 0)) # 统一用一次事件简要反馈 return [Event(self.world.month_stamp, f"{self.avatar.name} 疗伤完成,HP已回满(本次恢复{healed_total}点,当前HP {self.avatar.hp})", related_avatars=[self.avatar.id])] diff --git a/src/classes/action/sold.py b/src/classes/action/sold.py index 36506b8..f1f9e85 100644 --- a/src/classes/action/sold.py +++ b/src/classes/action/sold.py @@ -79,7 +79,7 @@ class SellItems(InstantAction): # InstantAction 已实现 step 完成 - def finish(self, item_name: str) -> list[Event]: + async def finish(self, item_name: str) -> list[Event]: return [] diff --git a/src/classes/action/switch_weapon.py b/src/classes/action/switch_weapon.py index 6472a8e..cf2dc50 100644 --- a/src/classes/action/switch_weapon.py +++ b/src/classes/action/switch_weapon.py @@ -75,6 +75,6 @@ class SwitchWeapon(InstantAction): related_avatars=[self.avatar.id] ) - def finish(self, weapon_type_name: str) -> list[Event]: + async def finish(self, weapon_type_name: str) -> list[Event]: return [] diff --git a/src/classes/ai.py b/src/classes/ai.py index c9c8ccd..4a55041 100644 --- a/src/classes/ai.py +++ b/src/classes/ai.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING from src.classes.world import World from src.classes.event import Event, NULL_EVENT -from src.utils.llm import get_ai_prompt_and_call_llm_async +from src.utils.llm import call_ai_action from src.classes.typings import ACTION_NAME_PARAMS_PAIRS from src.utils.config import CONFIG from src.classes.actions import ACTION_INFOS_STR @@ -70,7 +70,7 @@ class LLMAI(AI): "global_info": global_info, "general_action_infos": general_action_infos, } - res = await get_ai_prompt_and_call_llm_async(info) + res = await call_ai_action(info) results: dict[Avatar, tuple[ACTION_NAME_PARAMS_PAIRS, str, str]] = {} for avatar in avatars_to_decide: r = res[avatar.name] diff --git a/src/classes/avatar.py b/src/classes/avatar.py index 31fa72c..3bfc23e 100644 --- a/src/classes/avatar.py +++ b/src/classes/avatar.py @@ -365,7 +365,7 @@ class Avatar(AvatarSaveMixin, AvatarLoadMixin): result: ActionResult = action.step(**params_for_step) if result.status == ActionStatus.COMPLETED: params_for_finish = filter_kwargs_for_callable(action.finish, params) - finish_events = action.finish(**params_for_finish) + finish_events = await action.finish(**params_for_finish) # 仅当当前动作仍然是刚才执行的那个实例时才清空 # 若在 step() 内部通过"抢占"机制切换了动作(如 Escape 失败立即切到 Battle),不要清空新动作 if self.current_action is action_instance_before: diff --git a/src/classes/fortune.py b/src/classes/fortune.py index e88a4ef..0b40ff8 100644 --- a/src/classes/fortune.py +++ b/src/classes/fortune.py @@ -478,7 +478,7 @@ async def try_trigger_fortune(avatar: Avatar) -> list[Event]: base_event = Event(month_at_finish, event_text, related_avatars=related_avatars, is_major=True) # 生成故事事件 - story = await StoryTeller.tell_story_async(event_text, res_text, *actors_for_story, prompt=story_prompt) + story = await StoryTeller.tell_story(event_text, res_text, *actors_for_story, prompt=story_prompt) story_event = Event(month_at_finish, story, related_avatars=related_avatars, is_story=True) # 返回基础事件和故事事件 diff --git a/src/classes/long_term_objective.py b/src/classes/long_term_objective.py index ca07b2a..f5781a3 100644 --- a/src/classes/long_term_objective.py +++ b/src/classes/long_term_objective.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from src.classes.event import Event from src.utils.config import CONFIG -from src.utils.llm import get_prompt_and_call_llm_async +from src.utils.llm import call_llm_with_template, LLMMode from src.run.log import get_logger logger = get_logger().logger @@ -91,7 +91,7 @@ async def generate_long_term_objective(avatar: "Avatar") -> Optional[LongTermObj } # 调用LLM并自动解析JSON(使用fast模型) - response_data = await get_prompt_and_call_llm_async(template_path, infos, mode="fast") + response_data = await call_llm_with_template(template_path, infos, LLMMode.FAST) content = response_data.get("long_term_objective", "").strip() diff --git a/src/classes/mutual_action/dual_cultivation.py b/src/classes/mutual_action/dual_cultivation.py index 2d50876..142d91b 100644 --- a/src/classes/mutual_action/dual_cultivation.py +++ b/src/classes/mutual_action/dual_cultivation.py @@ -96,7 +96,7 @@ class DualCultivation(MutualAction): initiator.cultivation_progress.add_exp(exp_gain) self._dual_exp_gain = exp_gain - def finish(self, target_avatar: "Avatar|str") -> list[Event]: + async def finish(self, target_avatar: "Avatar|str") -> list[Event]: target = self._get_target_avatar(target_avatar) events: list[Event] = [] success = self._dual_cultivation_success @@ -111,7 +111,7 @@ class DualCultivation(MutualAction): # 生成恋爱/双修小故事 start_text = self._start_event_content or result_event.content - story = StoryTeller.tell_story(start_text, result_event.content, self.avatar, target, prompt=self.STORY_PROMPT) + story = await StoryTeller.tell_story(start_text, result_event.content, self.avatar, target, prompt=self.STORY_PROMPT) story_event = Event(self.world.month_stamp, story, related_avatars=[self.avatar.id, target.id], is_story=True) events.append(story_event) else: diff --git a/src/classes/mutual_action/gift_spirit_stone.py b/src/classes/mutual_action/gift_spirit_stone.py index 2080ad4..a011ef1 100644 --- a/src/classes/mutual_action/gift_spirit_stone.py +++ b/src/classes/mutual_action/gift_spirit_stone.py @@ -79,7 +79,7 @@ class GiftSpiritStone(MutualAction): # 目标获得灵石 target.magic_stone += self.GIFT_AMOUNT - def finish(self, target_avatar: "Avatar|str") -> list[Event]: + async def finish(self, target_avatar: "Avatar|str") -> list[Event]: target = self._get_target_avatar(target_avatar) events: list[Event] = [] success = self._gift_success @@ -98,7 +98,7 @@ class GiftSpiritStone(MutualAction): # 生成赠送小故事 from src.classes.story_teller import StoryTeller start_text = self._start_event_content or result_event.content - story = StoryTeller.tell_story( + story = await StoryTeller.tell_story( start_text, result_text, self.avatar, diff --git a/src/classes/mutual_action/impart.py b/src/classes/mutual_action/impart.py index 439eaf3..8a6f330 100644 --- a/src/classes/mutual_action/impart.py +++ b/src/classes/mutual_action/impart.py @@ -90,7 +90,7 @@ class Impart(MutualAction): target.cultivation_progress.add_exp(exp_gain) self._impart_exp_gain = exp_gain - def finish(self, target_avatar: "Avatar|str") -> list[Event]: + async def finish(self, target_avatar: "Avatar|str") -> list[Event]: target = self._get_target_avatar(target_avatar) events: list[Event] = [] success = self._impart_success @@ -110,7 +110,7 @@ class Impart(MutualAction): # 生成师徒传道小故事 from src.classes.story_teller import StoryTeller start_text = self._start_event_content or result_event.content - story = StoryTeller.tell_story( + story = await StoryTeller.tell_story( start_text, result_text, self.avatar, diff --git a/src/classes/mutual_action/mutual_action.py b/src/classes/mutual_action/mutual_action.py index ada3a4f..3030c9b 100644 --- a/src/classes/mutual_action/mutual_action.py +++ b/src/classes/mutual_action/mutual_action.py @@ -7,7 +7,7 @@ import asyncio from src.classes.action.action import DefineAction, ActualActionMixin, LLMAction from src.classes.tile import get_avatar_distance from src.classes.event import Event -from src.utils.llm import get_prompt_and_call_llm, get_prompt_and_call_llm_async +from src.utils.llm import call_llm_with_template, LLMMode from src.utils.config import CONFIG from src.classes.relation import relation_display_names, Relation from src.classes.relations import get_possible_new_relations @@ -81,17 +81,10 @@ class MutualAction(DefineAction, LLMAction, TargetingMixin): "feedback_actions": feedback_actions, } - def _call_llm_feedback(self, infos: dict) -> dict: - """ - 兼容保留:同步调用(不在事件循环内使用)。 - """ + async def _call_llm_feedback(self, infos: dict) -> dict: + """异步调用 LLM 获取反馈""" template_path = self._get_template_path() - res = get_prompt_and_call_llm(template_path, infos, mode="fast") - return res - - async def _call_llm_feedback_async(self, infos: dict) -> dict: - template_path = self._get_template_path() - return await get_prompt_and_call_llm_async(template_path, infos, mode="fast") + return await call_llm_with_template(template_path, infos, LLMMode.FAST) def _set_target_immediate_action(self, target_avatar: "Avatar", action_name: str, action_params: dict) -> None: """ @@ -132,16 +125,14 @@ class MutualAction(DefineAction, LLMAction, TargetingMixin): return self.find_avatar_by_name(target_avatar) return target_avatar - def _execute(self, target_avatar: "Avatar|str") -> None: - """ - 保留同步实现(不在事件循环内使用)。 - """ + async def _execute(self, target_avatar: "Avatar|str") -> None: + """异步执行互动动作""" target_avatar = self._get_target_avatar(target_avatar) if target_avatar is None: return infos = self._build_prompt_infos(target_avatar) - res = self._call_llm_feedback(infos) + res = await self._call_llm_feedback(infos) r = res.get(infos["avatar_name_2"], {}) thinking = r.get("thinking", "") feedback = r.get("feedback", "") @@ -209,12 +200,8 @@ class MutualAction(DefineAction, LLMAction, TargetingMixin): # 若无任务,创建异步任务 if self._feedback_task is None and self._feedback_cached is None: infos = self._build_prompt_infos(target) - try: - loop = asyncio.get_running_loop() - self._feedback_task = loop.create_task(self._call_llm_feedback_async(infos)) - except RuntimeError: - # 无运行中的事件循环时,退化为同步调用(如离线批处理) - self._feedback_cached = self._call_llm_feedback(infos) + loop = asyncio.get_running_loop() + self._feedback_task = loop.create_task(self._call_llm_feedback(infos)) # 若任务已完成,消费结果 if self._feedback_task is not None and self._feedback_task.done(): @@ -238,7 +225,7 @@ class MutualAction(DefineAction, LLMAction, TargetingMixin): return ActionResult(status=ActionStatus.RUNNING, events=[]) - def finish(self, target_avatar: "Avatar|str") -> list[Event]: + async def finish(self, target_avatar: "Avatar|str") -> list[Event]: """ 完成互动动作,事件已在 step 中处理,无需额外事件 """ diff --git a/src/classes/nickname.py b/src/classes/nickname.py index df072d5..21ebb3b 100644 --- a/src/classes/nickname.py +++ b/src/classes/nickname.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from src.classes.event import Event from src.utils.config import CONFIG -from src.utils.llm import get_prompt_and_call_llm_async +from src.utils.llm import call_llm_with_template, LLMMode from src.run.log import get_logger logger = get_logger().logger @@ -73,7 +73,7 @@ async def generate_nickname(avatar: "Avatar") -> Optional[str]: } # 调用LLM并自动解析JSON - response_data = await get_prompt_and_call_llm_async(template_path, infos, mode="fast") + response_data = await call_llm_with_template(template_path, infos, LLMMode.FAST) nickname = response_data.get("nickname", "").strip() thinking = response_data.get("thinking", "") diff --git a/src/classes/story_teller.py b/src/classes/story_teller.py index 06b9f64..8dffc13 100644 --- a/src/classes/story_teller.py +++ b/src/classes/story_teller.py @@ -3,8 +3,11 @@ from __future__ import annotations from typing import Dict, TYPE_CHECKING import random +if TYPE_CHECKING: + from src.classes.avatar import Avatar + from src.utils.config import CONFIG -from src.utils.llm import get_prompt_and_call_llm, get_prompt_and_call_llm_async +from src.utils.llm import call_llm_with_template, LLMMode story_styles = [ "平淡叙述:语句克制、少修饰、像旁观者记录。", @@ -67,32 +70,7 @@ class StoryTeller: return f"{event}。{res}。{style}" @staticmethod - def tell_story(event: str, res: str, *actors: "Avatar", prompt: str = "") -> str: - """ - 生成小故事(同步版本)。 - 基于 `static/templates/story.txt` 模板,失败时返回降级文案。 - - Args: - event: 事件描述 - res: 结果描述 - *actors: 参与的角色(1-2个) - prompt: 可选的故事提示词 - """ - avatar_infos = StoryTeller._build_avatar_infos(*actors) - infos = StoryTeller._build_template_data(event, res, avatar_infos, prompt) - - try: - data = get_prompt_and_call_llm(StoryTeller.TEMPLATE_PATH, infos, mode="fast") - story = data.get("story", "").strip() - if story: - return story - except Exception: - pass - - return StoryTeller._make_fallback_story(event, res, infos["style"]) - - @staticmethod - async def tell_story_async(event: str, res: str, *actors: "Avatar", prompt: str = "") -> str: + async def tell_story(event: str, res: str, *actors: "Avatar", prompt: str = "") -> str: """ 生成小故事(异步版本)。 基于 `static/templates/story.txt` 模板,失败时返回降级文案。 @@ -107,8 +85,8 @@ class StoryTeller: infos = StoryTeller._build_template_data(event, res, avatar_infos, prompt) try: - data = await get_prompt_and_call_llm_async(StoryTeller.TEMPLATE_PATH, infos, mode="fast") - story = str(data.get("story", "")).strip() + data = await call_llm_with_template(StoryTeller.TEMPLATE_PATH, infos, LLMMode.FAST) + story = data.get("story", "").strip() if story: return story except Exception: @@ -116,4 +94,5 @@ class StoryTeller: return StoryTeller._make_fallback_story(event, res, infos["style"]) + __all__ = ["StoryTeller"] \ No newline at end of file diff --git a/src/utils/llm.py b/src/utils/llm.py deleted file mode 100644 index f1e9fef..0000000 --- a/src/utils/llm.py +++ /dev/null @@ -1,226 +0,0 @@ -from litellm import completion -from pathlib import Path -import asyncio -import re -import json5 -import os - -from src.utils.config import CONFIG -from src.utils.io import read_txt -from src.run.log import log_llm_call -from src.utils.strings import intentify_prompt_infos - -def get_prompt(template: str, infos: dict) -> str: - """ - 根据模板,获取提示词 - """ - # 将 dict/list 等结构化对象转为 JSON 字符串 - # 策略: - # - avatar_infos: 不包装 intent(模板里已经说明是 dict[Name, info]) - # - general_action_infos: 强制包装 intent 以凸显语义 - # - 其他容器类型:默认包装 intent - processed_infos = intentify_prompt_infos(infos) - return template.format(**processed_infos) - - -def call_llm(prompt: str, mode="normal") -> str: - """ - 调用LLM - - Args: - prompt: 输入的提示词 - Returns: - str: LLM返回的结果 - """ - # 从配置中获取模型信息 - if mode == "normal": - model_name = CONFIG.llm.model_name - elif mode == "fast": - model_name = CONFIG.llm.fast_model_name - else: - raise ValueError(f"Invalid mode: {mode}") - # API Key 优先从环境变量读取,其次 fallback 到配置文件 - api_key = os.getenv("QWEN_API_KEY") or CONFIG.llm.key - base_url = CONFIG.llm.base_url - # 调用litellm的completion函数 - response = completion( - model=model_name, - messages=[{"role": "user", "content": prompt}], - api_key=api_key, - base_url=base_url, - ) - - # 返回生成的内容 - result = response.choices[0].message.content - log_llm_call(model_name, prompt, result) # 记录日志 - return result - -async def call_llm_async(prompt: str, mode="normal") -> str: - """ - 异步调用LLM - - Args: - prompt: 输入的提示词 - Returns: - str: LLM返回的结果 - """ - # 使用asyncio.to_thread包装同步调用 - result = await asyncio.to_thread(call_llm, prompt, mode) - return result - -def _extract_code_blocks(text: str): - """ - 提取所有markdown代码块,返回 (lang, content) 列表。 - """ - pattern = re.compile(r"```([^\n`]*)\n([\s\S]*?)```", re.DOTALL) - blocks = [] - for lang, content in pattern.findall(text): - blocks.append((lang.strip().lower(), content.strip())) - return blocks - - -def _find_first_balanced_json_object(text: str): - """ - 在整段文本中扫描并返回首个平衡的花括号 {...} 片段(忽略字符串中的括号)。 - 找到则返回子串,否则返回None。 - """ - depth = 0 - start_index = None - in_string = False - string_char = '' - escape = False - for idx, ch in enumerate(text): - if in_string: - if escape: - escape = False - continue - if ch == '\\': - escape = True - continue - if ch == string_char: - in_string = False - continue - if ch in ('"', "'"): - in_string = True - string_char = ch - continue - if ch == '{': - if depth == 0: - start_index = idx - depth += 1 - continue - if ch == '}': - if depth > 0: - depth -= 1 - if depth == 0 and start_index is not None: - return text[start_index:idx + 1] - return None - - -def parse_llm_response(res: str) -> dict: - """ - 仅针对 JSON 的稳健解析: - 1) 优先解析 ```json/json5``` 或未标注语言的代码块 - 2) 自由文本中定位首个平衡的 {...} - 3) 整体 json5 兜底 - 最终返回字典;否则抛错。 - """ - res = (res or '').strip() - if not res: - return {} - - # 1) 优先解析代码块(仅 json/json5/未标注语言) - for lang, block in _extract_code_blocks(res): - if lang and lang not in ("json", "json5"): - continue - # 先在块内找平衡对象 - span = _find_first_balanced_json_object(block) - candidates = [span] if span else [block] - for cand in candidates: - if not cand: - continue - try: - obj = json5.loads(cand) - if isinstance(obj, dict): - return obj - except Exception: - continue - - # 2) 扫描全文首个平衡的JSON对象 - json_span = _find_first_balanced_json_object(res) - if json_span: - try: - obj = json5.loads(json_span) - if isinstance(obj, dict): - return obj - except Exception: - pass - - # 3) 整体 json5 兜底 - obj = json5.loads(res) - return obj - - -def call_and_parse_llm(prompt: str, mode: str = "normal") -> dict: - """ - 将 LLM 调用与解析合并,并在解析失败时按配置重试。 - 成功返回 dict,超过重试次数仍失败则抛错。 - """ - max_retries = int(getattr(CONFIG.ai, "max_parse_retries", 0)) - last_err: Exception | None = None - for _ in range(1 + max_retries): - res = call_llm(prompt, mode) - try: - return parse_llm_response(res) - except Exception as e: - last_err = e - continue - raise ValueError(f"LLM响应解析失败,已重试 {max_retries} 次") from last_err - - -async def call_and_parse_llm_async(prompt: str, mode: str = "normal") -> dict: - """ - 异步版本:将 LLM 调用与解析合并,并在解析失败时按配置重试。 - 成功返回 dict,超过重试次数仍失败则抛错。 - """ - max_retries = int(getattr(CONFIG.ai, "max_parse_retries", 0)) - last_err: Exception | None = None - for _ in range(1 + max_retries): - res = await call_llm_async(prompt, mode) - try: - return parse_llm_response(res) - except Exception as e: - last_err = e - continue - raise ValueError(f"LLM响应解析失败,已重试 {max_retries} 次") from last_err - - -def get_prompt_and_call_llm(template_path: Path, infos: dict, mode="normal") -> dict: - """ - 根据模板,获取提示词,并调用LLM - """ - template = read_txt(template_path) - prompt = get_prompt(template, infos) - return call_and_parse_llm(prompt, mode) - -async def get_prompt_and_call_llm_async(template_path: Path, infos: dict, mode="normal") -> dict: - """ - 异步版本:根据模板,获取提示词,并调用LLM - """ - template = read_txt(template_path) - prompt = get_prompt(template, infos) - return await call_and_parse_llm_async(prompt, mode) - -def get_ai_prompt_and_call_llm(infos: dict, mode="normal") -> dict: - """ - 根据模板,获取提示词,并调用LLM - """ - template_path = CONFIG.paths.templates / "ai.txt" - return get_prompt_and_call_llm(template_path, infos, mode) - -async def get_ai_prompt_and_call_llm_async(infos: dict, mode="normal") -> dict: - """ - 异步版本:根据模板,获取提示词,并调用LLM - """ - template_path = CONFIG.paths.templates / "ai.txt" - return await get_prompt_and_call_llm_async(template_path, infos, mode) \ No newline at end of file diff --git a/src/utils/llm/__init__.py b/src/utils/llm/__init__.py new file mode 100644 index 0000000..a4087b0 --- /dev/null +++ b/src/utils/llm/__init__.py @@ -0,0 +1,24 @@ +""" +LLM 调用模块 + +提供三个核心 API: +- call_llm: 基础调用,返回原始文本 +- call_llm_json: 调用并解析为 JSON +- call_llm_with_template: 使用模板调用(最常用) +""" + +from .client import call_llm, call_llm_json, call_llm_with_template, call_ai_action +from .config import LLMMode +from .exceptions import LLMError, ParseError, ConfigError + +__all__ = [ + "call_llm", + "call_llm_json", + "call_llm_with_template", + "call_ai_action", + "LLMMode", + "LLMError", + "ParseError", + "ConfigError", +] + diff --git a/src/utils/llm/client.py b/src/utils/llm/client.py new file mode 100644 index 0000000..6e5b350 --- /dev/null +++ b/src/utils/llm/client.py @@ -0,0 +1,129 @@ +"""LLM 客户端核心调用逻辑""" + +from pathlib import Path +from litellm import completion + +from .config import LLMMode, LLMConfig +from .parser import parse_json +from .prompt import build_prompt, load_template +from .exceptions import LLMError, ParseError +from src.run.log import log_llm_call + + +async def call_llm(prompt: str, mode: LLMMode = LLMMode.NORMAL) -> str: + """ + 最基础的 LLM 调用,返回原始文本 + + Args: + prompt: 输入提示词 + mode: 调用模式 + + Returns: + str: LLM 返回的原始文本 + """ + import asyncio + + # 获取配置 + config = LLMConfig.from_mode(mode) + + # 调用 litellm(包装为异步) + def _call(): + response = completion( + model=config.model_name, + messages=[{"role": "user", "content": prompt}], + api_key=config.api_key, + base_url=config.base_url, + ) + return response.choices[0].message.content + + result = await asyncio.to_thread(_call) + + # 记录日志 + log_llm_call(config.model_name, prompt, result) + + return result + + +async def call_llm_json( + prompt: str, + mode: LLMMode = LLMMode.NORMAL, + max_retries: int | None = None +) -> dict: + """ + 调用 LLM 并解析为 JSON,内置重试机制 + + Args: + prompt: 输入提示词 + mode: 调用模式 + max_retries: 最大重试次数,None 则从配置读取 + + Returns: + dict: 解析后的 JSON 对象 + + Raises: + LLMError: 解析失败且重试次数用尽时抛出 + """ + if max_retries is None: + from src.utils.config import CONFIG + max_retries = int(getattr(CONFIG.ai, "max_parse_retries", 0)) + + last_error = None + for attempt in range(max_retries + 1): + response = await call_llm(prompt, mode) + try: + return parse_json(response) + except ParseError as e: + last_error = e + if attempt < max_retries: + continue # 继续重试 + # 最后一次失败,抛出详细错误 + raise LLMError( + f"解析失败(重试 {max_retries} 次后)", + cause=last_error + ) from last_error + + # 不应该到这里,但为了类型检查 + raise LLMError("未知错误") + + +async def call_llm_with_template( + template_path: Path | str, + infos: dict, + mode: LLMMode = LLMMode.NORMAL, + max_retries: int | None = None +) -> dict: + """ + 使用模板调用 LLM(最常用的高级接口) + + Args: + template_path: 模板文件路径 + infos: 要填充的信息字典 + mode: 调用模式 + max_retries: 最大重试次数,None 则从配置读取 + + Returns: + dict: 解析后的 JSON 对象 + """ + template = load_template(template_path) + prompt = build_prompt(template, infos) + return await call_llm_json(prompt, mode, max_retries) + + +async def call_ai_action( + infos: dict, + mode: LLMMode = LLMMode.NORMAL +) -> dict: + """ + AI 行动决策专用接口 + + Args: + infos: 行动决策所需信息 + mode: 调用模式 + + Returns: + dict: AI 行动决策结果 + """ + from src.utils.config import CONFIG + template_path = CONFIG.paths.templates / "ai.txt" + return await call_llm_with_template(template_path, infos, mode) + diff --git a/src/utils/llm/config.py b/src/utils/llm/config.py new file mode 100644 index 0000000..6b406b1 --- /dev/null +++ b/src/utils/llm/config.py @@ -0,0 +1,48 @@ +"""LLM 配置管理""" + +from enum import Enum +from dataclasses import dataclass +import os + + +class LLMMode(str, Enum): + """LLM 调用模式""" + NORMAL = "normal" + FAST = "fast" + + +@dataclass(frozen=True) +class LLMConfig: + """LLM 配置数据类""" + model_name: str + api_key: str + base_url: str + + @classmethod + def from_mode(cls, mode: LLMMode) -> 'LLMConfig': + """ + 根据模式创建配置 + + Args: + mode: LLM 调用模式 + + Returns: + LLMConfig: 配置对象 + """ + from src.utils.config import CONFIG + + # 根据模式选择模型 + model_name = ( + CONFIG.llm.model_name if mode == LLMMode.NORMAL + else CONFIG.llm.fast_model_name + ) + + # API Key 优先从环境变量读取 + api_key = os.getenv("QWEN_API_KEY") or CONFIG.llm.key + + return cls( + model_name=model_name, + api_key=api_key, + base_url=CONFIG.llm.base_url + ) + diff --git a/src/utils/llm/exceptions.py b/src/utils/llm/exceptions.py new file mode 100644 index 0000000..3d4f468 --- /dev/null +++ b/src/utils/llm/exceptions.py @@ -0,0 +1,24 @@ +"""LLM 相关异常定义""" + + +class LLMError(Exception): + """LLM 相关错误的基类""" + + def __init__(self, message: str, *, cause: Exception | None = None, **context): + super().__init__(message) + self.cause = cause + self.context = context + + +class ParseError(LLMError): + """JSON 解析失败""" + + def __init__(self, message: str, *, raw_text: str = ""): + super().__init__(message, raw_text=raw_text) + self.raw_text = raw_text + + +class ConfigError(LLMError): + """配置错误""" + pass + diff --git a/src/utils/llm/parser.py b/src/utils/llm/parser.py new file mode 100644 index 0000000..56542bb --- /dev/null +++ b/src/utils/llm/parser.py @@ -0,0 +1,185 @@ +"""JSON 解析逻辑""" + +import re +import json5 +from .exceptions import ParseError + + +def parse_json(text: str) -> dict: + """ + 主解析入口,依次尝试多种策略 + + Args: + text: 待解析的文本 + + Returns: + dict: 解析结果 + + Raises: + ParseError: 所有策略均失败时抛出 + """ + text = (text or '').strip() + if not text: + return {} + + strategies = [ + try_parse_code_blocks, + try_parse_balanced_json, + try_parse_whole_text, + ] + + errors = [] + for strategy in strategies: + result = strategy(text) + if result is not None: + return result + errors.append(f"{strategy.__name__}") + + # 抛出详细错误 + raise ParseError( + f"所有解析策略均失败: {', '.join(errors)}", + raw_text=text[:500] # 只保留前 500 字符避免日志过长 + ) + + +def try_parse_code_blocks(text: str) -> dict | None: + """ + 尝试从代码块解析 JSON + + Args: + text: 待解析的文本 + + Returns: + dict | None: 解析成功返回字典,失败返回 None + """ + blocks = _extract_code_blocks(text) + + # 只处理 json/json5 或未标注语言的代码块 + for lang, block in blocks: + if lang and lang not in ("json", "json5"): + continue + + # 先在块内找平衡对象 + span = _find_balanced_json_object(block) + candidates = [span] if span else [block] + + for cand in candidates: + if not cand: + continue + try: + obj = json5.loads(cand) + if isinstance(obj, dict): + return obj + except Exception: + continue + + return None + + +def try_parse_balanced_json(text: str) -> dict | None: + """ + 尝试提取并解析平衡的 JSON 对象 + + Args: + text: 待解析的文本 + + Returns: + dict | None: 解析成功返回字典,失败返回 None + """ + json_span = _find_balanced_json_object(text) + if json_span: + try: + obj = json5.loads(json_span) + if isinstance(obj, dict): + return obj + except Exception: + pass + + return None + + +def try_parse_whole_text(text: str) -> dict | None: + """ + 尝试整体解析为 JSON + + Args: + text: 待解析的文本 + + Returns: + dict | None: 解析成功返回字典,失败返回 None + """ + try: + obj = json5.loads(text) + if isinstance(obj, dict): + return obj + except Exception: + pass + + return None + + +def _extract_code_blocks(text: str) -> list[tuple[str, str]]: + """ + 提取所有 markdown 代码块 + + Args: + text: 待提取的文本 + + Returns: + list[tuple[str, str]]: (语言, 内容) 元组列表 + """ + pattern = re.compile(r"```([^\n`]*)\n([\s\S]*?)```", re.DOTALL) + blocks = [] + for lang, content in pattern.findall(text): + blocks.append((lang.strip().lower(), content.strip())) + return blocks + + +def _find_balanced_json_object(text: str) -> str | None: + """ + 在文本中扫描并返回首个平衡的花括号 {...} 片段 + 忽略字符串中的括号 + + Args: + text: 待扫描的文本 + + Returns: + str | None: 找到则返回子串,否则返回 None + """ + depth = 0 + start_index = None + in_string = False + string_char = '' + escape = False + + for idx, ch in enumerate(text): + if in_string: + if escape: + escape = False + continue + if ch == '\\': + escape = True + continue + if ch == string_char: + in_string = False + continue + + if ch in ('"', "'"): + in_string = True + string_char = ch + continue + + if ch == '{': + if depth == 0: + start_index = idx + depth += 1 + continue + + if ch == '}': + if depth > 0: + depth -= 1 + if depth == 0 and start_index is not None: + return text[start_index:idx + 1] + + return None + diff --git a/src/utils/llm/prompt.py b/src/utils/llm/prompt.py new file mode 100644 index 0000000..c50450e --- /dev/null +++ b/src/utils/llm/prompt.py @@ -0,0 +1,34 @@ +"""提示词处理""" + +from pathlib import Path +from src.utils.strings import intentify_prompt_infos + + +def build_prompt(template: str, infos: dict) -> str: + """ + 根据模板构建提示词 + + Args: + template: 提示词模板 + infos: 要填充的信息字典 + + Returns: + str: 构建好的提示词 + """ + processed = intentify_prompt_infos(infos) + return template.format(**processed) + + +def load_template(path: Path | str) -> str: + """ + 加载模板文件 + + Args: + path: 模板文件路径 + + Returns: + str: 模板内容 + """ + path = Path(path) + return path.read_text(encoding="utf-8") +