refactor llm

This commit is contained in:
bridge
2025-12-20 22:13:26 +08:00
parent e8489fcc25
commit 162ea8efe2
12 changed files with 122 additions and 422 deletions

View File

@@ -11,7 +11,6 @@ from src.classes.world import World
from src.classes.event import Event, NULL_EVENT
from src.utils.llm import call_ai_action
from src.classes.typings import ACTION_NAME_PARAMS_PAIRS
from src.utils.config import CONFIG
from src.classes.actions import ACTION_INFOS_STR
if TYPE_CHECKING:
@@ -20,8 +19,6 @@ if TYPE_CHECKING:
class AI(ABC):
"""
抽象AI统一采用批量接口。
原先的 GroupAI多个角色的AI语义被保留并上移到此基类。
子类需实现 _decide(world, avatars) 返回每个 Avatar 的 (action_name, action_params, thinking)。
"""
@abstractmethod
@@ -31,24 +28,14 @@ class AI(ABC):
async def decide(self, world: World, avatars_to_decide: list[Avatar]) -> dict[Avatar, tuple[ACTION_NAME_PARAMS_PAIRS, str, str, Event]]:
"""
决定做什么,同时生成对应的事件。
一个 AI 支持批量生成多个 avatar 的动作
这对 LLM AI 节省时间和 token 非常有意义。
由于底层 LLM 调用已接入全局任务池,此处直接并发执行所有任务即可
"""
results = {}
max_decide_num = CONFIG.ai.max_decide_num
# 使用 asyncio.gather 并行执行多个批次的决策
tasks = []
for i in range(0, len(avatars_to_decide), max_decide_num):
tasks.append(self._decide(world, avatars_to_decide[i:i+max_decide_num]))
if tasks:
batch_results_list = await asyncio.gather(*tasks)
for batch_result in batch_results_list:
results.update(batch_result)
# 调用具体的决策逻辑
results = await self._decide(world, avatars_to_decide)
for avatar, result in list(results.items()):
action_name_params_pairs, avatar_thinking, short_term_objective = result # type: ignore
# 补全 Event 字段
for avatar in list(results.keys()):
action_name_params_pairs, avatar_thinking, short_term_objective = results[avatar] # type: ignore
# 不在决策阶段生成开始事件,提交阶段统一触发
results[avatar] = (action_name_params_pairs, avatar_thinking, short_term_objective, NULL_EVENT)
@@ -57,18 +44,14 @@ class AI(ABC):
class LLMAI(AI):
"""
LLM AI
一些思考:
AI动作应该分两类
1. 长期动作,比如要持续很长一段时间的行为
2. 突发应对动作比如突然有人要攻击NPC这个时候的反应
"""
async def _decide(self, world: World, avatars_to_decide: list[Avatar]) -> dict[Avatar, tuple[ACTION_NAME_PARAMS_PAIRS, str, str]]:
"""
异步决策逻辑通过LLM决定执行什么动作和参数
改动:支持每个角色仅获取其已知区域的世界信息,并发调用 LLM。
"""
general_action_infos = ACTION_INFOS_STR
async def decide_one(avatar: Avatar):
# 获取基于该角色已知区域的世界信息(包含距离计算)
world_info = world.get_info(avatar=avatar, detailed=True)
@@ -86,6 +69,7 @@ class LLMAI(AI):
res = await call_ai_action(info)
return avatar, res
# 直接并发所有任务
tasks = [decide_one(avatar) for avatar in avatars_to_decide]
results_list = await asyncio.gather(*tasks)
@@ -96,20 +80,20 @@ class LLMAI(AI):
r = res[avatar.name]
# 仅接受 action_name_params_pairs不再支持单个 action_name/action_params
raw_pairs = r["action_name_params_pairs"]
raw_pairs = r.get("action_name_params_pairs", [])
pairs: ACTION_NAME_PARAMS_PAIRS = []
for p in raw_pairs:
if isinstance(p, list) and len(p) == 2:
pairs.append((p[0], p[1]))
elif isinstance(p, dict) and "action_name" in p and "action_params" in p:
pairs.append((p["action_name"], p["action_params"]))
else:
# 跳过无法解析的项
continue
# 至少有一个
if not pairs:
raise ValueError(f"LLM未返回有效的action_name_params_pairs: {r}")
continue # Skip if no valid actions found
avatar_thinking = r.get("avatar_thinking", r.get("thinking", ""))
short_term_objective = r.get("short_term_objective", "")
@@ -117,4 +101,4 @@ class LLMAI(AI):
return results
llm_ai = LLMAI()
llm_ai = LLMAI()

View File

@@ -7,7 +7,6 @@ from src.classes.avatar.core import (
Avatar,
Gender,
gender_strs,
MAX_HISTORY_EVENTS,
)
from src.classes.avatar.info_presenter import (
@@ -23,7 +22,6 @@ __all__ = [
"Avatar",
"Gender",
"gender_strs",
"MAX_HISTORY_EVENTS",
# 信息展示函数
"get_avatar_info",
"get_avatar_structured_info",

View File

@@ -106,11 +106,6 @@ class ActionMixin:
return start_event
return None
def peek_next_plan(self: "Avatar") -> Optional[ActionPlan]:
if not self.planned_actions:
return None
return self.planned_actions[0]
async def tick_action(self: "Avatar") -> List[Event]:
"""
推进当前动作一步;返回过程中由动作内部产生的事件(通过 add_event 收集)。

View File

@@ -60,9 +60,6 @@ gender_strs = {
Gender.FEMALE: "",
}
# 历史事件的最大数量
MAX_HISTORY_EVENTS = 10
@dataclass
class Avatar(
@@ -90,7 +87,6 @@ class Avatar(
root: Root = field(default_factory=lambda: random.choice(list(Root)))
personas: List[Persona] = field(default_factory=list)
technique: Technique | None = None
history_events: List[Event] = field(default_factory=list)
_pending_events: List[Event] = field(default_factory=list)
current_action: Optional[ActionInstance] = None
planned_actions: List[ActionPlan] = field(default_factory=list)
@@ -206,25 +202,6 @@ class Avatar(
# ========== 区域与位置 ==========
def is_in_region(self, region: Region | None) -> bool:
current_region = self.tile.region
if current_region is None:
tile = self.world.map.get_tile(self.pos_x, self.pos_y)
current_region = tile.region
return current_region == region
def get_co_region_avatars(self, avatars: List["Avatar"]) -> List["Avatar"]:
"""返回与自己处于同一区域的角色列表(不含自己)。"""
if self.tile is None:
return []
same_region: list[Avatar] = []
for other in avatars:
if other is self or other.tile is None:
continue
if other.tile.region == self.tile.region:
same_region.append(other)
return same_region
def _init_known_regions(self):
"""初始化已知区域:当前位置 + 宗门驻地"""
if self.tile and self.tile.region:

View File

@@ -59,19 +59,6 @@ class InventoryMixin:
return True
def has_item(self: "Avatar", item: "Item", quantity: int = 1) -> bool:
"""
检查是否拥有足够数量的物品
Args:
item: 要检查的物品
quantity: 需要的数量默认为1
Returns:
bool: 是否拥有足够数量的物品
"""
return item in self.items and self.items[item] >= quantity
def get_item_quantity(self: "Avatar", item: "Item") -> int:
"""
获取指定物品的数量

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List, Tuple, Optional
import asyncio
from src.classes.relation import (
Relation,
@@ -18,8 +19,6 @@ from src.utils.config import CONFIG
if TYPE_CHECKING:
from src.classes.avatar import Avatar
from src.utils.ai_batch import AITaskBatch
class RelationResolver:
TEMPLATE_PATH = CONFIG.paths.templates / "relation_update.txt"
@@ -137,25 +136,10 @@ class RelationResolver:
"""
if not pairs:
return []
events = []
# 使用 asyncio.gather 而不是 AITaskBatch.gather因为 AITaskBatch 没有 gather 方法
import asyncio
tasks = []
for a, b in pairs:
# 创建协程任务但不立即 await
tasks.append(RelationResolver.resolve_pair(a, b))
if not tasks:
return []
# 并发执行所有任务
tasks = [RelationResolver.resolve_pair(a, b) for a, b in pairs]
results = await asyncio.gather(*tasks)
# 收集结果
for res in results:
if res and isinstance(res, Event):
events.append(res)
return events
# 过滤掉 None 结果 (resolve_pair 失败或无变化时返回 None)
return [res for res in results if res]

View File

@@ -207,9 +207,8 @@ class Simulator:
# 使用 gather 并行触发奇遇
tasks = [try_trigger_fortune(avatar) for avatar in living_avatars]
results = await asyncio.gather(*tasks)
for res in results:
if res:
events.extend(res)
events.extend([e for res in results if res for e in res])
return events

View File

@@ -1,53 +0,0 @@
"""
通用 AI 任务批处理器。
用于将串行的异步任务收集起来并行执行,优化 LLM 密集型场景的性能。
"""
import asyncio
from typing import Coroutine, Any, List
class AITaskBatch:
"""
AI 任务批处理器。
使用示例:
```python
async with AITaskBatch() as batch:
for item in items:
batch.add(process_item(item))
# with 块结束时,所有任务已并发执行完毕
```
"""
def __init__(self):
self.tasks: List[Coroutine[Any, Any, Any]] = []
def add(self, coro: Coroutine[Any, Any, Any]) -> None:
"""
添加一个协程任务到池中(不立即执行)。
注意:传入的协程应该自行处理结果(如修改对象状态),或者通过外部变量收集结果。
"""
self.tasks.append(coro)
async def run(self) -> List[Any]:
"""
并行执行池中所有任务,并等待全部完成。
返回所有任务的结果列表(顺序与添加顺序一致)。
"""
if not self.tasks:
return []
# 使用 gather 并发执行
results = await asyncio.gather(*self.tasks)
# 清空任务队列
self.tasks = []
return list(results)
async def __aenter__(self) -> "AITaskBatch":
return self
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
# 如果 with 块内部发生异常,不执行任务,直接抛出
if exc_type:
return
await self.run()

View File

@@ -1,35 +1,35 @@
"""LLM 客户端核心调用逻辑"""
from pathlib import Path
import json
import urllib.request
import urllib.error
import asyncio
from pathlib import Path
from typing import Optional
from src.run.log import log_llm_call
from src.utils.config import CONFIG
from .config import LLMMode, LLMConfig
from .parser import parse_json
from .prompt import build_prompt, load_template
from .exceptions import LLMError, ParseError
from src.run.log import log_llm_call
try:
# 使用动态导入,避免 PyInstaller 静态分析将其作为依赖打包
import importlib
importlib.import_module("litellm")
has_litellm = True
import litellm
HAS_LITELLM = True
except ImportError:
has_litellm = False
HAS_LITELLM = False
def _call_with_litellm(config: LLMConfig, prompt: str) -> str:
"""使用 litellm 调用"""
import importlib
litellm = importlib.import_module("litellm")
response = litellm.completion(
model=config.model_name,
messages=[{"role": "user", "content": prompt}],
api_key=config.api_key,
base_url=config.base_url,
)
return response.choices[0].message.content
# 模块级信号量,懒加载
_SEMAPHORE: Optional[asyncio.Semaphore] = None
def _get_semaphore() -> asyncio.Semaphore:
global _SEMAPHORE
if _SEMAPHORE is None:
limit = getattr(CONFIG.ai, "max_concurrent_requests", 10)
_SEMAPHORE = asyncio.Semaphore(limit)
return _SEMAPHORE
def _call_with_requests(config: LLMConfig, prompt: str) -> str:
@@ -39,17 +39,14 @@ def _call_with_requests(config: LLMConfig, prompt: str) -> str:
"Authorization": f"Bearer {config.api_key}"
}
# 处理模型名称:去除 'openai/' 前缀(针对 litellm 的兼容性配置)
model_name = config.model_name
if model_name.startswith("openai/"):
model_name = model_name.replace("openai/", "")
# 兼容 litellm 的 openai/ 前缀处理
model_name = config.model_name.replace("openai/", "")
data = {
"model": model_name,
"messages": [{"role": "user", "content": prompt}]
}
# 处理 URL
url = config.base_url
if not url:
raise ValueError("Base URL is required for requests mode")
@@ -57,9 +54,7 @@ def _call_with_requests(config: LLMConfig, prompt: str) -> str:
if "chat/completions" not in url:
url = url.rstrip("/")
if not url.endswith("/v1"):
# 尝试智能追加 v1如果用户没写
# 但有些服务可能不需要 v1这里保守起见如果没 v1 且没 chat/completions直接加 /chat/completions
# 假设用户配置的是类似 https://api.openai.com/v1
# 简单启发式:如果不是显式 v1 结尾,也加上
pass
url = f"{url}/chat/completions"
@@ -75,53 +70,37 @@ def _call_with_requests(config: LLMConfig, prompt: str) -> str:
result = json.loads(response.read().decode('utf-8'))
return result['choices'][0]['message']['content']
except urllib.error.HTTPError as e:
error_content = e.read().decode('utf-8')
raise Exception(f"LLM Request failed {e.code}: {error_content}")
raise Exception(f"LLM Request failed {e.code}: {e.read().decode('utf-8')}")
except Exception as e:
raise Exception(f"LLM Request failed: {str(e)}")
async def call_llm(prompt: str, mode: LLMMode = LLMMode.NORMAL) -> str:
"""
基础 LLM 调用,返回原始文本
Args:
prompt: 输入提示词
mode: 调用模式
Returns:
str: LLM 返回的原始文本
基础 LLM 调用,自动控制并发
"""
import asyncio
# 获取配置
config = LLMConfig.from_mode(mode)
semaphore = _get_semaphore()
# 调用逻辑
def _call():
# try:
# return _call_with_litellm(config, prompt)
# except ImportError:
# # 如果没有 litellm降级使用 requests
# return _call_with_requests(config, prompt)
try:
if has_litellm:
return _call_with_litellm(config, prompt)
else:
return _call_with_requests(config, prompt)
except Exception as e:
# litellm 可能抛出其他错误,如果仅仅是导入错误我们降级
# 如果是 litellm 内部错误(如 api key 错误),应该抛出
# 但为了稳健,如果 litellm 失败,是否尝试 request?
# 用户只说了 "没有的话(if no litellm)",通常指安装。
# 所以 catch ImportError 是对的。
raise e
async with semaphore:
if HAS_LITELLM:
try:
# 使用 litellm 原生异步接口
response = await litellm.acompletion(
model=config.model_name,
messages=[{"role": "user", "content": prompt}],
api_key=config.api_key,
base_url=config.base_url,
)
result = response.choices[0].message.content
except Exception as e:
# 再次抛出以便上层处理,或者记录日志
raise Exception(f"LiteLLM call failed: {str(e)}") from e
else:
# 降级到 requests (在线程池中运行)
result = await asyncio.to_thread(_call_with_requests, config, prompt)
result = await asyncio.to_thread(_call)
# 记录日志
log_llm_call(config.model_name, prompt, result)
return result
@@ -130,22 +109,8 @@ async def call_llm_json(
mode: LLMMode = LLMMode.NORMAL,
max_retries: int | None = None
) -> dict:
"""
调用 LLM 并解析为 JSON内置重试机制
Args:
prompt: 输入提示词
mode: 调用模式
max_retries: 最大重试次数None 则从配置读取
Returns:
dict: 解析后的 JSON 对象
Raises:
LLMError: 解析失败且重试次数用尽时抛出
"""
"""调用 LLM 并解析为 JSON带重试"""
if max_retries is None:
from src.utils.config import CONFIG
max_retries = int(getattr(CONFIG.ai, "max_parse_retries", 0))
last_error = None
@@ -156,14 +121,9 @@ async def call_llm_json(
except ParseError as e:
last_error = e
if attempt < max_retries:
continue # 继续重试
# 最后一次失败,抛出详细错误
raise LLMError(
f"解析失败(重试 {max_retries} 次后)",
cause=last_error
) from last_error
# 不应该到这里,但为了类型检查
continue
raise LLMError(f"解析失败(重试 {max_retries} 次后)", cause=last_error) from last_error
raise LLMError("未知错误")
@@ -173,37 +133,13 @@ async def call_llm_with_template(
mode: LLMMode = LLMMode.NORMAL,
max_retries: int | None = None
) -> dict:
"""
使用模板调用 LLM最常用的高级接口
Args:
template_path: 模板文件路径
infos: 要填充的信息字典
mode: 调用模式
max_retries: 最大重试次数None 则从配置读取
Returns:
dict: 解析后的 JSON 对象
"""
"""使用模板调用 LLM"""
template = load_template(template_path)
prompt = build_prompt(template, infos)
return await call_llm_json(prompt, mode, max_retries)
async def call_ai_action(
infos: dict,
mode: LLMMode = LLMMode.NORMAL
) -> dict:
"""
AI 行动决策专用接口
Args:
infos: 行动决策所需信息
mode: 调用模式
Returns:
dict: AI 行动决策结果
"""
from src.utils.config import CONFIG
async def call_ai_action(infos: dict, mode: LLMMode = LLMMode.NORMAL) -> dict:
"""AI 行动决策专用接口"""
template_path = CONFIG.paths.templates / "ai.txt"
return await call_llm_with_template(template_path, infos, mode)

View File

@@ -7,107 +7,26 @@ from .exceptions import ParseError
def parse_json(text: str) -> dict:
"""
解析入口,依次尝试多种策略
Args:
text: 待解析的文本
Returns:
dict: 解析结果
Raises:
ParseError: 所有策略均失败时抛出
解析 JSON支持从 markdown 代码块提取或直接解析
"""
text = (text or '').strip()
if not text:
return {}
strategies = [
try_parse_code_blocks,
try_parse_balanced_json,
try_parse_whole_text,
]
errors = []
for strategy in strategies:
result = strategy(text)
if result is not None:
return result
errors.append(f"{strategy.__name__}")
# 抛出详细错误
raise ParseError(
f"所有解析策略均失败: {', '.join(errors)}",
raw_text=text[:500] # 只保留前 500 字符避免日志过长
)
def try_parse_code_blocks(text: str) -> dict | None:
"""
尝试从代码块解析 JSON
Args:
text: 待解析的文本
Returns:
dict | None: 解析成功返回字典,失败返回 None
"""
# 策略1: 尝试从 Markdown 代码块提取
# 优先匹配 json/json5 块,如果没有指定语言的块也尝试
blocks = _extract_code_blocks(text)
# 只处理 json/json5 或未标注语言的代码块
for lang, block in blocks:
if lang and lang not in ("json", "json5"):
continue
# 先在块内找平衡对象
span = _find_balanced_json_object(block)
candidates = [span] if span else [block]
for cand in candidates:
if not cand:
continue
for lang, content in blocks:
if not lang or lang in ("json", "json5"):
try:
obj = json5.loads(cand)
obj = json5.loads(content)
if isinstance(obj, dict):
return obj
except Exception:
continue
return None
def try_parse_balanced_json(text: str) -> dict | None:
"""
尝试提取并解析平衡的 JSON 对象
Args:
text: 待解析的文本
Returns:
dict | None: 解析成功返回字典,失败返回 None
"""
json_span = _find_balanced_json_object(text)
if json_span:
try:
obj = json5.loads(json_span)
if isinstance(obj, dict):
return obj
except Exception:
pass
return None
def try_parse_whole_text(text: str) -> dict | None:
"""
尝试整体解析为 JSON
Args:
text: 待解析的文本
Returns:
dict | None: 解析成功返回字典,失败返回 None
"""
# 策略2: 尝试整体解析
# 有时候 LLM 不会输出 markdown直接输出 json
try:
obj = json5.loads(text)
if isinstance(obj, dict):
@@ -115,71 +34,17 @@ def try_parse_whole_text(text: str) -> dict | None:
except Exception:
pass
return None
# 失败
raise ParseError(
"无法解析 JSON: 未找到有效的 JSON 对象或代码块",
raw_text=text[:500]
)
def _extract_code_blocks(text: str) -> list[tuple[str, str]]:
"""
提取所有 markdown 代码块
Args:
text: 待提取的文本
Returns:
list[tuple[str, str]]: (语言, 内容) 元组列表
"""
"""提取 markdown 代码块"""
pattern = re.compile(r"```([^\n`]*)\n([\s\S]*?)```", re.DOTALL)
blocks = []
for lang, content in pattern.findall(text):
blocks.append((lang.strip().lower(), content.strip()))
return blocks
def _find_balanced_json_object(text: str) -> str | None:
"""
在文本中扫描并返回首个平衡的花括号 {...} 片段
忽略字符串中的括号
Args:
text: 待扫描的文本
Returns:
str | None: 找到则返回子串,否则返回 None
"""
depth = 0
start_index = None
in_string = False
string_char = ''
escape = False
for idx, ch in enumerate(text):
if in_string:
if escape:
escape = False
continue
if ch == '\\':
escape = True
continue
if ch == string_char:
in_string = False
continue
if ch in ('"', "'"):
in_string = True
string_char = ch
continue
if ch == '{':
if depth == 0:
start_index = idx
depth += 1
continue
if ch == '}':
if depth > 0:
depth -= 1
if depth == 0 and start_index is not None:
return text[start_index:idx + 1]
return None

View File

@@ -14,7 +14,7 @@ paths:
saves: assets/saves/
ai:
max_decide_num: 3
max_concurrent_requests: 10
max_parse_retries: 3
game:

View File

@@ -1,9 +1,10 @@
import pytest
import json
from unittest.mock import MagicMock, patch, AsyncMock
from pathlib import Path
from src.utils.llm.prompt import build_prompt
from src.utils.llm.parser import parse_json, try_parse_code_blocks, try_parse_balanced_json
from src.utils.llm.client import call_llm_json, LLMMode
from src.utils.llm.parser import parse_json
from src.utils.llm.client import call_llm_json, call_llm, LLMMode
from src.utils.llm.exceptions import ParseError, LLMError
# ================= Prompt Tests =================
@@ -58,21 +59,11 @@ def test_parse_code_block():
result = parse_json(text)
assert result == {"foo": "bar"}
def test_parse_nested_braces():
text = 'some text {"a": {"b": 1}} some more text'
result = parse_json(text)
assert result == {"a": {"b": 1}}
def test_parse_fail():
text = "Not a json"
with pytest.raises(ParseError):
parse_json(text)
def test_extract_from_text_with_noise():
text = "Sure! Here is the JSON you requested: {\"a\": 1} Hope this helps."
result = parse_json(text)
assert result == {"a": 1}
# ================= Client Mock Tests =================
@pytest.mark.asyncio
async def test_call_llm_json_success():
@@ -107,3 +98,40 @@ async def test_call_llm_json_all_fail():
assert mock_call.call_count == 2 # Initial + 1 retry
@pytest.mark.asyncio
async def test_call_llm_fallback_requests():
"""测试没有 litellm 时降级到 requests"""
# 模拟 HTTP 响应内容
mock_response_content = json.dumps({
"choices": [{"message": {"content": "Response from requests"}}]
}).encode('utf-8')
# Mock response object
mock_response = MagicMock()
mock_response.read.return_value = mock_response_content
mock_response.__enter__.return_value = mock_response
# Mock Config
mock_config = MagicMock()
mock_config.api_key = "test_key"
mock_config.base_url = "http://test.api/v1"
mock_config.model_name = "test-model"
# Patch 多个对象
with patch("src.utils.llm.client.HAS_LITELLM", False), \
patch("src.utils.llm.client.LLMConfig.from_mode", return_value=mock_config), \
patch("urllib.request.urlopen", return_value=mock_response) as mock_urlopen:
result = await call_llm("hello", mode=LLMMode.NORMAL)
assert result == "Response from requests"
# 验证 urlopen 被调用
mock_urlopen.assert_called_once()
# 验证请求参数
args, _ = mock_urlopen.call_args
request_obj = args[0]
# client.py 逻辑会把 http://test.api/v1 变成 http://test.api/v1/chat/completions
assert request_obj.full_url == "http://test.api/v1/chat/completions"