fix action bugs

This commit is contained in:
bridge
2025-10-04 21:27:17 +08:00
parent f109abbb08
commit 9b0be5b4c2
10 changed files with 128 additions and 31 deletions

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
from src.classes.action import InstantAction
from src.classes.event import Event
from src.classes.battle import decide_battle
from src.classes.story_teller import StoryTeller
class Battle(InstantAction):
@@ -22,7 +23,7 @@ class Battle(InstantAction):
return
winner, loser, damage = decide_battle(self.avatar, target)
loser.hp.reduce(damage)
self._last_result = (winner.name, loser.name)
self._last_result = (winner.name, loser.name, damage)
def can_start(self, avatar_name: str | None = None) -> bool:
if avatar_name is None:
@@ -32,15 +33,31 @@ class Battle(InstantAction):
def start(self, avatar_name: str) -> Event:
target = self._get_target(avatar_name)
target_name = target.name if target is not None else avatar_name
return Event(self.world.month_stamp, f"{self.avatar.name}{target_name} 发起战斗")
event = Event(self.world.month_stamp, f"{self.avatar.name}{target_name} 发起战斗")
# 记录开始事件内容,供故事生成使用
self._start_event_content = event.content
return event
# InstantAction 已实现 step 完成
def finish(self, avatar_name: str) -> list[Event]:
res = self._last_result
if isinstance(res, tuple) and len(res) == 2:
winner, loser = res
return [Event(self.world.month_stamp, f"{winner} 战胜了 {loser}")]
if isinstance(res, tuple) and len(res) in (2, 3):
winner, loser = res[0], res[1]
damage = res[2] if len(res) == 3 else None
if damage is not None:
result_text = f"{winner} 战胜了 {loser},造成{damage}点伤害"
else:
result_text = f"{winner} 战胜了 {loser}"
result_event = Event(self.world.month_stamp, result_text)
# 生成战斗小故事:直接复用已生成的事件文本
target = self._get_target(avatar_name)
avatar_infos = StoryTeller.build_avatar_infos(self.avatar, target)
start_text = getattr(self, "_start_event_content", "") or result_event.content
story = StoryTeller.tell_story(avatar_infos, start_text, result_event.content)
story_event = Event(self.world.month_stamp, story)
return [result_event, story_event]
return []

View File

@@ -38,9 +38,6 @@ class MoveToAvatar(DefineAction, ActualActionMixin):
Move(self.avatar, self.world).execute(delta_x, delta_y)
def can_start(self, avatar_name: str | None = None) -> bool:
target = self._get_target(avatar_name)
if target is None:
return False
return True
def start(self, avatar_name: str) -> Event:

View File

@@ -18,10 +18,8 @@ class ActionRegistry:
def register(cls, action_cls: type, *, actual: bool) -> None:
name = action_cls.__name__
cls._name_to_cls[name] = action_cls
cls._name_to_cls[name.lower()] = action_cls # 大小写别名
if actual:
cls._actual_name_to_cls[name] = action_cls
cls._actual_name_to_cls[name.lower()] = action_cls # 大小写别名
@classmethod
def get(cls, name: str) -> type:
@@ -29,11 +27,25 @@ class ActionRegistry:
@classmethod
def all(cls) -> Iterable[type]:
return cls._name_to_cls.values()
# 去重保持稳定顺序
seen = set()
ordered: list[type] = []
for t in cls._name_to_cls.values():
if t not in seen:
seen.add(t)
ordered.append(t)
return ordered
@classmethod
def all_actual(cls) -> Iterable[type]:
return cls._actual_name_to_cls.values()
# 去重保持稳定顺序
seen = set()
ordered: list[type] = []
for t in cls._actual_name_to_cls.values():
if t not in seen:
seen.add(t)
ordered.append(t)
return ordered
def register_action(*, actual: bool = True) -> Callable[[type], type]:

View File

@@ -59,7 +59,8 @@ class MutualAction(DefineAction, LLMAction, TargetingMixin):
def _call_llm_feedback(self, infos: dict) -> dict:
template_path = self._get_template_path()
res = get_prompt_and_call_llm(template_path, infos)
# mutual用快速llm不需要复杂决策
res = get_prompt_and_call_llm(template_path, infos, mode="fast")
return res
def _set_target_immediate_action(self, target_avatar: "Avatar", action_name: str, action_params: dict) -> None:

View File

@@ -0,0 +1,53 @@
from __future__ import annotations
from typing import Dict
from src.utils.config import CONFIG
from src.utils.llm import get_prompt_and_call_llm
class StoryTeller:
"""
故事生成器:基于模板与 LLM将给定事件扩展为简短的小故事。
"""
@staticmethod
def build_avatar_infos(*avatars: "Avatar") -> Dict[str, str]:
"""
将若干角色信息组织为 {name: info} 映射,供故事模板使用。
优先使用 `get_prompt_info([])`,失败时退化为 `get_info()`。
"""
infos: Dict[str, str] = {}
for av in avatars:
try:
infos[av.name] = av.get_prompt_info([])
except Exception:
infos[av.name] = getattr(av, "name", "未知角色")
return infos
@staticmethod
def tell_story(avatar_infos: Dict[str, str], event: str, res: str) -> str:
"""
基于 `static/templates/story.txt` 模板生成小故事。
始终使用 fast 模式以提升速度。
失败时返回降级版文案,避免中断流程。
"""
template_path = CONFIG.paths.templates / "story.txt"
infos = {
"avatar_infos": avatar_infos,
"event": event,
"res": res,
}
try:
data = get_prompt_and_call_llm(template_path, infos, mode="fast")
story = str(data.get("story", "")).strip()
if story:
return story
except Exception:
return (res or event or "")
return (res or event or "")
__all__ = ["StoryTeller"]

View File

@@ -17,7 +17,7 @@ def get_prompt(template: str, infos: dict) -> str:
return prompt_template.format(**infos)
def call_llm(prompt: str) -> str:
def call_llm(prompt: str, mode="normal") -> str:
"""
调用LLM
@@ -27,7 +27,12 @@ def call_llm(prompt: str) -> str:
str: LLM返回的结果
"""
# 从配置中获取模型信息
model_name = CONFIG.llm.model_name
if mode == "normal":
model_name = CONFIG.llm.model_name
elif mode == "fast":
model_name = CONFIG.llm.fast_model_name
else:
raise ValueError(f"Invalid mode: {mode}")
api_key = CONFIG.llm.key
base_url = CONFIG.llm.base_url
# 调用litellm的completion函数
@@ -43,7 +48,7 @@ def call_llm(prompt: str) -> str:
log_llm_call(model_name, prompt, result) # 记录日志
return result
async def call_llm_async(prompt: str) -> str:
async def call_llm_async(prompt: str, mode="normal") -> str:
"""
异步调用LLM
@@ -53,7 +58,7 @@ async def call_llm_async(prompt: str) -> str:
str: LLM返回的结果
"""
# 使用asyncio.to_thread包装同步调用
result = await asyncio.to_thread(call_llm, prompt)
result = await asyncio.to_thread(call_llm, prompt, mode)
return result
def parse_llm_response(res: str) -> dict:
@@ -69,38 +74,36 @@ def parse_llm_response(res: str) -> dict:
return json5.loads(res)
def get_prompt_and_call_llm(template_path: Path, infos: dict) -> str:
def get_prompt_and_call_llm(template_path: Path, infos: dict, mode="normal") -> str:
"""
根据模板获取提示词并调用LLM
"""
template = read_txt(template_path)
prompt = get_prompt(template, infos)
res = call_llm(prompt)
res = call_llm(prompt, mode)
json_res = parse_llm_response(res)
return json_res
async def get_prompt_and_call_llm_async(template_path: Path, infos: dict) -> str:
async def get_prompt_and_call_llm_async(template_path: Path, infos: dict, mode="normal") -> str:
"""
异步版本根据模板获取提示词并调用LLM
"""
template = read_txt(template_path)
prompt = get_prompt(template, infos)
res = await call_llm_async(prompt)
res = await call_llm_async(prompt, mode)
json_res = parse_llm_response(res)
# print(f"prompt = {prompt}")
# print(f"json_res = {json_res}")
return json_res
def get_ai_prompt_and_call_llm(infos: dict) -> dict:
def get_ai_prompt_and_call_llm(infos: dict, mode="normal") -> dict:
"""
根据模板获取提示词并调用LLM
"""
template_path = CONFIG.paths.templates / "ai.txt"
return get_prompt_and_call_llm(template_path, infos)
return get_prompt_and_call_llm(template_path, infos, mode)
async def get_ai_prompt_and_call_llm_async(infos: dict) -> dict:
async def get_ai_prompt_and_call_llm_async(infos: dict, mode="normal") -> dict:
"""
异步版本根据模板获取提示词并调用LLM
"""
template_path = CONFIG.paths.templates / "ai.txt"
return await get_prompt_and_call_llm_async(template_path, infos)
return await get_prompt_and_call_llm_async(template_path, infos, mode)