refactor llm
This commit is contained in:
@@ -11,7 +11,6 @@ from src.classes.world import World
|
||||
from src.classes.event import Event, NULL_EVENT
|
||||
from src.utils.llm import call_ai_action
|
||||
from src.classes.typings import ACTION_NAME_PARAMS_PAIRS
|
||||
from src.utils.config import CONFIG
|
||||
from src.classes.actions import ACTION_INFOS_STR
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -20,8 +19,6 @@ if TYPE_CHECKING:
|
||||
class AI(ABC):
|
||||
"""
|
||||
抽象AI:统一采用批量接口。
|
||||
原先的 GroupAI(多个角色的AI)语义被保留并上移到此基类。
|
||||
子类需实现 _decide(world, avatars) 返回每个 Avatar 的 (action_name, action_params, thinking)。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@@ -31,24 +28,14 @@ class AI(ABC):
|
||||
async def decide(self, world: World, avatars_to_decide: list[Avatar]) -> dict[Avatar, tuple[ACTION_NAME_PARAMS_PAIRS, str, str, Event]]:
|
||||
"""
|
||||
决定做什么,同时生成对应的事件。
|
||||
一个 AI 支持批量生成多个 avatar 的动作。
|
||||
这对 LLM AI 节省时间和 token 非常有意义。
|
||||
由于底层 LLM 调用已接入全局任务池,此处直接并发执行所有任务即可。
|
||||
"""
|
||||
results = {}
|
||||
max_decide_num = CONFIG.ai.max_decide_num
|
||||
|
||||
# 使用 asyncio.gather 并行执行多个批次的决策
|
||||
tasks = []
|
||||
for i in range(0, len(avatars_to_decide), max_decide_num):
|
||||
tasks.append(self._decide(world, avatars_to_decide[i:i+max_decide_num]))
|
||||
|
||||
if tasks:
|
||||
batch_results_list = await asyncio.gather(*tasks)
|
||||
for batch_result in batch_results_list:
|
||||
results.update(batch_result)
|
||||
# 调用具体的决策逻辑
|
||||
results = await self._decide(world, avatars_to_decide)
|
||||
|
||||
for avatar, result in list(results.items()):
|
||||
action_name_params_pairs, avatar_thinking, short_term_objective = result # type: ignore
|
||||
# 补全 Event 字段
|
||||
for avatar in list(results.keys()):
|
||||
action_name_params_pairs, avatar_thinking, short_term_objective = results[avatar] # type: ignore
|
||||
# 不在决策阶段生成开始事件,提交阶段统一触发
|
||||
results[avatar] = (action_name_params_pairs, avatar_thinking, short_term_objective, NULL_EVENT)
|
||||
|
||||
@@ -57,18 +44,14 @@ class AI(ABC):
|
||||
class LLMAI(AI):
|
||||
"""
|
||||
LLM AI
|
||||
一些思考:
|
||||
AI动作应该分两类:
|
||||
1. 长期动作,比如要持续很长一段时间的行为
|
||||
2. 突发应对动作,比如突然有人要攻击NPC,这个时候的反应
|
||||
"""
|
||||
|
||||
async def _decide(self, world: World, avatars_to_decide: list[Avatar]) -> dict[Avatar, tuple[ACTION_NAME_PARAMS_PAIRS, str, str]]:
|
||||
"""
|
||||
异步决策逻辑:通过LLM决定执行什么动作和参数
|
||||
改动:支持每个角色仅获取其已知区域的世界信息,并发调用 LLM。
|
||||
"""
|
||||
general_action_infos = ACTION_INFOS_STR
|
||||
|
||||
async def decide_one(avatar: Avatar):
|
||||
# 获取基于该角色已知区域的世界信息(包含距离计算)
|
||||
world_info = world.get_info(avatar=avatar, detailed=True)
|
||||
@@ -86,6 +69,7 @@ class LLMAI(AI):
|
||||
res = await call_ai_action(info)
|
||||
return avatar, res
|
||||
|
||||
# 直接并发所有任务
|
||||
tasks = [decide_one(avatar) for avatar in avatars_to_decide]
|
||||
results_list = await asyncio.gather(*tasks)
|
||||
|
||||
@@ -96,20 +80,20 @@ class LLMAI(AI):
|
||||
|
||||
r = res[avatar.name]
|
||||
# 仅接受 action_name_params_pairs,不再支持单个 action_name/action_params
|
||||
raw_pairs = r["action_name_params_pairs"]
|
||||
raw_pairs = r.get("action_name_params_pairs", [])
|
||||
pairs: ACTION_NAME_PARAMS_PAIRS = []
|
||||
|
||||
for p in raw_pairs:
|
||||
if isinstance(p, list) and len(p) == 2:
|
||||
pairs.append((p[0], p[1]))
|
||||
elif isinstance(p, dict) and "action_name" in p and "action_params" in p:
|
||||
pairs.append((p["action_name"], p["action_params"]))
|
||||
else:
|
||||
# 跳过无法解析的项
|
||||
continue
|
||||
|
||||
# 至少有一个
|
||||
if not pairs:
|
||||
raise ValueError(f"LLM未返回有效的action_name_params_pairs: {r}")
|
||||
continue # Skip if no valid actions found
|
||||
|
||||
avatar_thinking = r.get("avatar_thinking", r.get("thinking", ""))
|
||||
short_term_objective = r.get("short_term_objective", "")
|
||||
@@ -117,4 +101,4 @@ class LLMAI(AI):
|
||||
|
||||
return results
|
||||
|
||||
llm_ai = LLMAI()
|
||||
llm_ai = LLMAI()
|
||||
|
||||
@@ -7,7 +7,6 @@ from src.classes.avatar.core import (
|
||||
Avatar,
|
||||
Gender,
|
||||
gender_strs,
|
||||
MAX_HISTORY_EVENTS,
|
||||
)
|
||||
|
||||
from src.classes.avatar.info_presenter import (
|
||||
@@ -23,7 +22,6 @@ __all__ = [
|
||||
"Avatar",
|
||||
"Gender",
|
||||
"gender_strs",
|
||||
"MAX_HISTORY_EVENTS",
|
||||
# 信息展示函数
|
||||
"get_avatar_info",
|
||||
"get_avatar_structured_info",
|
||||
|
||||
@@ -106,11 +106,6 @@ class ActionMixin:
|
||||
return start_event
|
||||
return None
|
||||
|
||||
def peek_next_plan(self: "Avatar") -> Optional[ActionPlan]:
|
||||
if not self.planned_actions:
|
||||
return None
|
||||
return self.planned_actions[0]
|
||||
|
||||
async def tick_action(self: "Avatar") -> List[Event]:
|
||||
"""
|
||||
推进当前动作一步;返回过程中由动作内部产生的事件(通过 add_event 收集)。
|
||||
|
||||
@@ -60,9 +60,6 @@ gender_strs = {
|
||||
Gender.FEMALE: "女",
|
||||
}
|
||||
|
||||
# 历史事件的最大数量
|
||||
MAX_HISTORY_EVENTS = 10
|
||||
|
||||
|
||||
@dataclass
|
||||
class Avatar(
|
||||
@@ -90,7 +87,6 @@ class Avatar(
|
||||
root: Root = field(default_factory=lambda: random.choice(list(Root)))
|
||||
personas: List[Persona] = field(default_factory=list)
|
||||
technique: Technique | None = None
|
||||
history_events: List[Event] = field(default_factory=list)
|
||||
_pending_events: List[Event] = field(default_factory=list)
|
||||
current_action: Optional[ActionInstance] = None
|
||||
planned_actions: List[ActionPlan] = field(default_factory=list)
|
||||
@@ -206,25 +202,6 @@ class Avatar(
|
||||
|
||||
# ========== 区域与位置 ==========
|
||||
|
||||
def is_in_region(self, region: Region | None) -> bool:
|
||||
current_region = self.tile.region
|
||||
if current_region is None:
|
||||
tile = self.world.map.get_tile(self.pos_x, self.pos_y)
|
||||
current_region = tile.region
|
||||
return current_region == region
|
||||
|
||||
def get_co_region_avatars(self, avatars: List["Avatar"]) -> List["Avatar"]:
|
||||
"""返回与自己处于同一区域的角色列表(不含自己)。"""
|
||||
if self.tile is None:
|
||||
return []
|
||||
same_region: list[Avatar] = []
|
||||
for other in avatars:
|
||||
if other is self or other.tile is None:
|
||||
continue
|
||||
if other.tile.region == self.tile.region:
|
||||
same_region.append(other)
|
||||
return same_region
|
||||
|
||||
def _init_known_regions(self):
|
||||
"""初始化已知区域:当前位置 + 宗门驻地"""
|
||||
if self.tile and self.tile.region:
|
||||
|
||||
@@ -59,19 +59,6 @@ class InventoryMixin:
|
||||
|
||||
return True
|
||||
|
||||
def has_item(self: "Avatar", item: "Item", quantity: int = 1) -> bool:
|
||||
"""
|
||||
检查是否拥有足够数量的物品
|
||||
|
||||
Args:
|
||||
item: 要检查的物品
|
||||
quantity: 需要的数量,默认为1
|
||||
|
||||
Returns:
|
||||
bool: 是否拥有足够数量的物品
|
||||
"""
|
||||
return item in self.items and self.items[item] >= quantity
|
||||
|
||||
def get_item_quantity(self: "Avatar", item: "Item") -> int:
|
||||
"""
|
||||
获取指定物品的数量
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, List, Tuple, Optional
|
||||
import asyncio
|
||||
|
||||
from src.classes.relation import (
|
||||
Relation,
|
||||
@@ -18,8 +19,6 @@ from src.utils.config import CONFIG
|
||||
if TYPE_CHECKING:
|
||||
from src.classes.avatar import Avatar
|
||||
|
||||
from src.utils.ai_batch import AITaskBatch
|
||||
|
||||
class RelationResolver:
|
||||
TEMPLATE_PATH = CONFIG.paths.templates / "relation_update.txt"
|
||||
|
||||
@@ -137,25 +136,10 @@ class RelationResolver:
|
||||
"""
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
events = []
|
||||
|
||||
# 使用 asyncio.gather 而不是 AITaskBatch.gather,因为 AITaskBatch 没有 gather 方法
|
||||
import asyncio
|
||||
tasks = []
|
||||
for a, b in pairs:
|
||||
# 创建协程任务但不立即 await
|
||||
tasks.append(RelationResolver.resolve_pair(a, b))
|
||||
|
||||
if not tasks:
|
||||
return []
|
||||
|
||||
|
||||
# 并发执行所有任务
|
||||
tasks = [RelationResolver.resolve_pair(a, b) for a, b in pairs]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# 收集结果
|
||||
for res in results:
|
||||
if res and isinstance(res, Event):
|
||||
events.append(res)
|
||||
|
||||
return events
|
||||
# 过滤掉 None 结果 (resolve_pair 失败或无变化时返回 None)
|
||||
return [res for res in results if res]
|
||||
|
||||
Reference in New Issue
Block a user