diff --git a/src/classes/ai.py b/src/classes/ai.py index 61ab9ce..8a0e1c3 100644 --- a/src/classes/ai.py +++ b/src/classes/ai.py @@ -11,7 +11,6 @@ from src.classes.world import World from src.classes.event import Event, NULL_EVENT from src.utils.llm import call_ai_action from src.classes.typings import ACTION_NAME_PARAMS_PAIRS -from src.utils.config import CONFIG from src.classes.actions import ACTION_INFOS_STR if TYPE_CHECKING: @@ -20,8 +19,6 @@ if TYPE_CHECKING: class AI(ABC): """ 抽象AI:统一采用批量接口。 - 原先的 GroupAI(多个角色的AI)语义被保留并上移到此基类。 - 子类需实现 _decide(world, avatars) 返回每个 Avatar 的 (action_name, action_params, thinking)。 """ @abstractmethod @@ -31,24 +28,14 @@ class AI(ABC): async def decide(self, world: World, avatars_to_decide: list[Avatar]) -> dict[Avatar, tuple[ACTION_NAME_PARAMS_PAIRS, str, str, Event]]: """ 决定做什么,同时生成对应的事件。 - 一个 AI 支持批量生成多个 avatar 的动作。 - 这对 LLM AI 节省时间和 token 非常有意义。 + 由于底层 LLM 调用已接入全局任务池,此处直接并发执行所有任务即可。 """ - results = {} - max_decide_num = CONFIG.ai.max_decide_num - - # 使用 asyncio.gather 并行执行多个批次的决策 - tasks = [] - for i in range(0, len(avatars_to_decide), max_decide_num): - tasks.append(self._decide(world, avatars_to_decide[i:i+max_decide_num])) - - if tasks: - batch_results_list = await asyncio.gather(*tasks) - for batch_result in batch_results_list: - results.update(batch_result) + # 调用具体的决策逻辑 + results = await self._decide(world, avatars_to_decide) - for avatar, result in list(results.items()): - action_name_params_pairs, avatar_thinking, short_term_objective = result # type: ignore + # 补全 Event 字段 + for avatar in list(results.keys()): + action_name_params_pairs, avatar_thinking, short_term_objective = results[avatar] # type: ignore # 不在决策阶段生成开始事件,提交阶段统一触发 results[avatar] = (action_name_params_pairs, avatar_thinking, short_term_objective, NULL_EVENT) @@ -57,18 +44,14 @@ class AI(ABC): class LLMAI(AI): """ LLM AI - 一些思考: - AI动作应该分两类: - 1. 长期动作,比如要持续很长一段时间的行为 - 2. 突发应对动作,比如突然有人要攻击NPC,这个时候的反应 """ async def _decide(self, world: World, avatars_to_decide: list[Avatar]) -> dict[Avatar, tuple[ACTION_NAME_PARAMS_PAIRS, str, str]]: """ 异步决策逻辑:通过LLM决定执行什么动作和参数 - 改动:支持每个角色仅获取其已知区域的世界信息,并发调用 LLM。 """ general_action_infos = ACTION_INFOS_STR + async def decide_one(avatar: Avatar): # 获取基于该角色已知区域的世界信息(包含距离计算) world_info = world.get_info(avatar=avatar, detailed=True) @@ -86,6 +69,7 @@ class LLMAI(AI): res = await call_ai_action(info) return avatar, res + # 直接并发所有任务 tasks = [decide_one(avatar) for avatar in avatars_to_decide] results_list = await asyncio.gather(*tasks) @@ -96,20 +80,20 @@ class LLMAI(AI): r = res[avatar.name] # 仅接受 action_name_params_pairs,不再支持单个 action_name/action_params - raw_pairs = r["action_name_params_pairs"] + raw_pairs = r.get("action_name_params_pairs", []) pairs: ACTION_NAME_PARAMS_PAIRS = [] + for p in raw_pairs: if isinstance(p, list) and len(p) == 2: pairs.append((p[0], p[1])) elif isinstance(p, dict) and "action_name" in p and "action_params" in p: pairs.append((p["action_name"], p["action_params"])) else: - # 跳过无法解析的项 continue # 至少有一个 if not pairs: - raise ValueError(f"LLM未返回有效的action_name_params_pairs: {r}") + continue # Skip if no valid actions found avatar_thinking = r.get("avatar_thinking", r.get("thinking", "")) short_term_objective = r.get("short_term_objective", "") @@ -117,4 +101,4 @@ class LLMAI(AI): return results -llm_ai = LLMAI() \ No newline at end of file +llm_ai = LLMAI() diff --git a/src/classes/avatar/__init__.py b/src/classes/avatar/__init__.py index 9392bde..3a021d4 100644 --- a/src/classes/avatar/__init__.py +++ b/src/classes/avatar/__init__.py @@ -7,7 +7,6 @@ from src.classes.avatar.core import ( Avatar, Gender, gender_strs, - MAX_HISTORY_EVENTS, ) from src.classes.avatar.info_presenter import ( @@ -23,7 +22,6 @@ __all__ = [ "Avatar", "Gender", "gender_strs", - "MAX_HISTORY_EVENTS", # 信息展示函数 "get_avatar_info", "get_avatar_structured_info", diff --git a/src/classes/avatar/action_mixin.py b/src/classes/avatar/action_mixin.py index 5e3e51a..96321e2 100644 --- a/src/classes/avatar/action_mixin.py +++ b/src/classes/avatar/action_mixin.py @@ -106,11 +106,6 @@ class ActionMixin: return start_event return None - def peek_next_plan(self: "Avatar") -> Optional[ActionPlan]: - if not self.planned_actions: - return None - return self.planned_actions[0] - async def tick_action(self: "Avatar") -> List[Event]: """ 推进当前动作一步;返回过程中由动作内部产生的事件(通过 add_event 收集)。 diff --git a/src/classes/avatar/core.py b/src/classes/avatar/core.py index 01e2eef..ecf5799 100644 --- a/src/classes/avatar/core.py +++ b/src/classes/avatar/core.py @@ -60,9 +60,6 @@ gender_strs = { Gender.FEMALE: "女", } -# 历史事件的最大数量 -MAX_HISTORY_EVENTS = 10 - @dataclass class Avatar( @@ -90,7 +87,6 @@ class Avatar( root: Root = field(default_factory=lambda: random.choice(list(Root))) personas: List[Persona] = field(default_factory=list) technique: Technique | None = None - history_events: List[Event] = field(default_factory=list) _pending_events: List[Event] = field(default_factory=list) current_action: Optional[ActionInstance] = None planned_actions: List[ActionPlan] = field(default_factory=list) @@ -206,25 +202,6 @@ class Avatar( # ========== 区域与位置 ========== - def is_in_region(self, region: Region | None) -> bool: - current_region = self.tile.region - if current_region is None: - tile = self.world.map.get_tile(self.pos_x, self.pos_y) - current_region = tile.region - return current_region == region - - def get_co_region_avatars(self, avatars: List["Avatar"]) -> List["Avatar"]: - """返回与自己处于同一区域的角色列表(不含自己)。""" - if self.tile is None: - return [] - same_region: list[Avatar] = [] - for other in avatars: - if other is self or other.tile is None: - continue - if other.tile.region == self.tile.region: - same_region.append(other) - return same_region - def _init_known_regions(self): """初始化已知区域:当前位置 + 宗门驻地""" if self.tile and self.tile.region: diff --git a/src/classes/avatar/inventory_mixin.py b/src/classes/avatar/inventory_mixin.py index ef1abd8..03a79dc 100644 --- a/src/classes/avatar/inventory_mixin.py +++ b/src/classes/avatar/inventory_mixin.py @@ -59,19 +59,6 @@ class InventoryMixin: return True - def has_item(self: "Avatar", item: "Item", quantity: int = 1) -> bool: - """ - 检查是否拥有足够数量的物品 - - Args: - item: 要检查的物品 - quantity: 需要的数量,默认为1 - - Returns: - bool: 是否拥有足够数量的物品 - """ - return item in self.items and self.items[item] >= quantity - def get_item_quantity(self: "Avatar", item: "Item") -> int: """ 获取指定物品的数量 diff --git a/src/classes/relation_resolver.py b/src/classes/relation_resolver.py index 2f77637..9ceafb4 100644 --- a/src/classes/relation_resolver.py +++ b/src/classes/relation_resolver.py @@ -1,5 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING, List, Tuple, Optional +import asyncio from src.classes.relation import ( Relation, @@ -18,8 +19,6 @@ from src.utils.config import CONFIG if TYPE_CHECKING: from src.classes.avatar import Avatar -from src.utils.ai_batch import AITaskBatch - class RelationResolver: TEMPLATE_PATH = CONFIG.paths.templates / "relation_update.txt" @@ -137,25 +136,10 @@ class RelationResolver: """ if not pairs: return [] - - events = [] - - # 使用 asyncio.gather 而不是 AITaskBatch.gather,因为 AITaskBatch 没有 gather 方法 - import asyncio - tasks = [] - for a, b in pairs: - # 创建协程任务但不立即 await - tasks.append(RelationResolver.resolve_pair(a, b)) - - if not tasks: - return [] - + # 并发执行所有任务 + tasks = [RelationResolver.resolve_pair(a, b) for a, b in pairs] results = await asyncio.gather(*tasks) - # 收集结果 - for res in results: - if res and isinstance(res, Event): - events.append(res) - - return events + # 过滤掉 None 结果 (resolve_pair 失败或无变化时返回 None) + return [res for res in results if res] diff --git a/src/sim/simulator.py b/src/sim/simulator.py index d68bb97..3bbd417 100644 --- a/src/sim/simulator.py +++ b/src/sim/simulator.py @@ -207,9 +207,8 @@ class Simulator: # 使用 gather 并行触发奇遇 tasks = [try_trigger_fortune(avatar) for avatar in living_avatars] results = await asyncio.gather(*tasks) - for res in results: - if res: - events.extend(res) + + events.extend([e for res in results if res for e in res]) return events diff --git a/src/utils/ai_batch.py b/src/utils/ai_batch.py deleted file mode 100644 index e8b72f1..0000000 --- a/src/utils/ai_batch.py +++ /dev/null @@ -1,53 +0,0 @@ -""" -通用 AI 任务批处理器。 -用于将串行的异步任务收集起来并行执行,优化 LLM 密集型场景的性能。 -""" -import asyncio -from typing import Coroutine, Any, List - -class AITaskBatch: - """ - AI 任务批处理器。 - - 使用示例: - ```python - async with AITaskBatch() as batch: - for item in items: - batch.add(process_item(item)) - # with 块结束时,所有任务已并发执行完毕 - ``` - """ - def __init__(self): - self.tasks: List[Coroutine[Any, Any, Any]] = [] - - def add(self, coro: Coroutine[Any, Any, Any]) -> None: - """ - 添加一个协程任务到池中(不立即执行)。 - 注意:传入的协程应该自行处理结果(如修改对象状态),或者通过外部变量收集结果。 - """ - self.tasks.append(coro) - - async def run(self) -> List[Any]: - """ - 并行执行池中所有任务,并等待全部完成。 - 返回所有任务的结果列表(顺序与添加顺序一致)。 - """ - if not self.tasks: - return [] - - # 使用 gather 并发执行 - results = await asyncio.gather(*self.tasks) - - # 清空任务队列 - self.tasks = [] - return list(results) - - async def __aenter__(self) -> "AITaskBatch": - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: - # 如果 with 块内部发生异常,不执行任务,直接抛出 - if exc_type: - return - await self.run() - diff --git a/src/utils/llm/client.py b/src/utils/llm/client.py index 95ba180..54b4c0f 100644 --- a/src/utils/llm/client.py +++ b/src/utils/llm/client.py @@ -1,35 +1,35 @@ """LLM 客户端核心调用逻辑""" -from pathlib import Path import json import urllib.request import urllib.error +import asyncio +from pathlib import Path +from typing import Optional +from src.run.log import log_llm_call +from src.utils.config import CONFIG from .config import LLMMode, LLMConfig from .parser import parse_json from .prompt import build_prompt, load_template from .exceptions import LLMError, ParseError -from src.run.log import log_llm_call try: - # 使用动态导入,避免 PyInstaller 静态分析将其作为依赖打包 - import importlib - importlib.import_module("litellm") - has_litellm = True + import litellm + HAS_LITELLM = True except ImportError: - has_litellm = False + HAS_LITELLM = False -def _call_with_litellm(config: LLMConfig, prompt: str) -> str: - """使用 litellm 调用""" - import importlib - litellm = importlib.import_module("litellm") - response = litellm.completion( - model=config.model_name, - messages=[{"role": "user", "content": prompt}], - api_key=config.api_key, - base_url=config.base_url, - ) - return response.choices[0].message.content +# 模块级信号量,懒加载 +_SEMAPHORE: Optional[asyncio.Semaphore] = None + + +def _get_semaphore() -> asyncio.Semaphore: + global _SEMAPHORE + if _SEMAPHORE is None: + limit = getattr(CONFIG.ai, "max_concurrent_requests", 10) + _SEMAPHORE = asyncio.Semaphore(limit) + return _SEMAPHORE def _call_with_requests(config: LLMConfig, prompt: str) -> str: @@ -39,17 +39,14 @@ def _call_with_requests(config: LLMConfig, prompt: str) -> str: "Authorization": f"Bearer {config.api_key}" } - # 处理模型名称:去除 'openai/' 前缀(针对 litellm 的兼容性配置) - model_name = config.model_name - if model_name.startswith("openai/"): - model_name = model_name.replace("openai/", "") + # 兼容 litellm 的 openai/ 前缀处理 + model_name = config.model_name.replace("openai/", "") data = { "model": model_name, "messages": [{"role": "user", "content": prompt}] } - # 处理 URL url = config.base_url if not url: raise ValueError("Base URL is required for requests mode") @@ -57,9 +54,7 @@ def _call_with_requests(config: LLMConfig, prompt: str) -> str: if "chat/completions" not in url: url = url.rstrip("/") if not url.endswith("/v1"): - # 尝试智能追加 v1,如果用户没写 - # 但有些服务可能不需要 v1,这里保守起见,如果没 v1 且没 chat/completions,直接加 /chat/completions - # 假设用户配置的是类似 https://api.openai.com/v1 + # 简单启发式:如果不是显式 v1 结尾,也加上 pass url = f"{url}/chat/completions" @@ -75,53 +70,37 @@ def _call_with_requests(config: LLMConfig, prompt: str) -> str: result = json.loads(response.read().decode('utf-8')) return result['choices'][0]['message']['content'] except urllib.error.HTTPError as e: - error_content = e.read().decode('utf-8') - raise Exception(f"LLM Request failed {e.code}: {error_content}") + raise Exception(f"LLM Request failed {e.code}: {e.read().decode('utf-8')}") except Exception as e: raise Exception(f"LLM Request failed: {str(e)}") async def call_llm(prompt: str, mode: LLMMode = LLMMode.NORMAL) -> str: """ - 最基础的 LLM 调用,返回原始文本 - - Args: - prompt: 输入提示词 - mode: 调用模式 - - Returns: - str: LLM 返回的原始文本 + 基础 LLM 调用,自动控制并发 """ - import asyncio - - # 获取配置 config = LLMConfig.from_mode(mode) + semaphore = _get_semaphore() - # 调用逻辑 - def _call(): - # try: - # return _call_with_litellm(config, prompt) - # except ImportError: - # # 如果没有 litellm,降级使用 requests - # return _call_with_requests(config, prompt) - try: - if has_litellm: - return _call_with_litellm(config, prompt) - else: - return _call_with_requests(config, prompt) - except Exception as e: - # litellm 可能抛出其他错误,如果仅仅是导入错误我们降级 - # 如果是 litellm 内部错误(如 api key 错误),应该抛出 - # 但为了稳健,如果 litellm 失败,是否尝试 request? - # 用户只说了 "没有的话(if no litellm)",通常指安装。 - # 所以 catch ImportError 是对的。 - raise e + async with semaphore: + if HAS_LITELLM: + try: + # 使用 litellm 原生异步接口 + response = await litellm.acompletion( + model=config.model_name, + messages=[{"role": "user", "content": prompt}], + api_key=config.api_key, + base_url=config.base_url, + ) + result = response.choices[0].message.content + except Exception as e: + # 再次抛出以便上层处理,或者记录日志 + raise Exception(f"LiteLLM call failed: {str(e)}") from e + else: + # 降级到 requests (在线程池中运行) + result = await asyncio.to_thread(_call_with_requests, config, prompt) - result = await asyncio.to_thread(_call) - - # 记录日志 log_llm_call(config.model_name, prompt, result) - return result @@ -130,22 +109,8 @@ async def call_llm_json( mode: LLMMode = LLMMode.NORMAL, max_retries: int | None = None ) -> dict: - """ - 调用 LLM 并解析为 JSON,内置重试机制 - - Args: - prompt: 输入提示词 - mode: 调用模式 - max_retries: 最大重试次数,None 则从配置读取 - - Returns: - dict: 解析后的 JSON 对象 - - Raises: - LLMError: 解析失败且重试次数用尽时抛出 - """ + """调用 LLM 并解析为 JSON,带重试""" if max_retries is None: - from src.utils.config import CONFIG max_retries = int(getattr(CONFIG.ai, "max_parse_retries", 0)) last_error = None @@ -156,14 +121,9 @@ async def call_llm_json( except ParseError as e: last_error = e if attempt < max_retries: - continue # 继续重试 - # 最后一次失败,抛出详细错误 - raise LLMError( - f"解析失败(重试 {max_retries} 次后)", - cause=last_error - ) from last_error - - # 不应该到这里,但为了类型检查 + continue + raise LLMError(f"解析失败(重试 {max_retries} 次后)", cause=last_error) from last_error + raise LLMError("未知错误") @@ -173,37 +133,13 @@ async def call_llm_with_template( mode: LLMMode = LLMMode.NORMAL, max_retries: int | None = None ) -> dict: - """ - 使用模板调用 LLM(最常用的高级接口) - - Args: - template_path: 模板文件路径 - infos: 要填充的信息字典 - mode: 调用模式 - max_retries: 最大重试次数,None 则从配置读取 - - Returns: - dict: 解析后的 JSON 对象 - """ + """使用模板调用 LLM""" template = load_template(template_path) prompt = build_prompt(template, infos) return await call_llm_json(prompt, mode, max_retries) -async def call_ai_action( - infos: dict, - mode: LLMMode = LLMMode.NORMAL -) -> dict: - """ - AI 行动决策专用接口 - - Args: - infos: 行动决策所需信息 - mode: 调用模式 - - Returns: - dict: AI 行动决策结果 - """ - from src.utils.config import CONFIG +async def call_ai_action(infos: dict, mode: LLMMode = LLMMode.NORMAL) -> dict: + """AI 行动决策专用接口""" template_path = CONFIG.paths.templates / "ai.txt" return await call_llm_with_template(template_path, infos, mode) diff --git a/src/utils/llm/parser.py b/src/utils/llm/parser.py index 56542bb..2e7c4db 100644 --- a/src/utils/llm/parser.py +++ b/src/utils/llm/parser.py @@ -7,107 +7,26 @@ from .exceptions import ParseError def parse_json(text: str) -> dict: """ - 主解析入口,依次尝试多种策略 - - Args: - text: 待解析的文本 - - Returns: - dict: 解析结果 - - Raises: - ParseError: 所有策略均失败时抛出 + 解析 JSON,支持从 markdown 代码块提取或直接解析 """ text = (text or '').strip() if not text: return {} - strategies = [ - try_parse_code_blocks, - try_parse_balanced_json, - try_parse_whole_text, - ] - - errors = [] - for strategy in strategies: - result = strategy(text) - if result is not None: - return result - errors.append(f"{strategy.__name__}") - - # 抛出详细错误 - raise ParseError( - f"所有解析策略均失败: {', '.join(errors)}", - raw_text=text[:500] # 只保留前 500 字符避免日志过长 - ) - - -def try_parse_code_blocks(text: str) -> dict | None: - """ - 尝试从代码块解析 JSON - - Args: - text: 待解析的文本 - - Returns: - dict | None: 解析成功返回字典,失败返回 None - """ + # 策略1: 尝试从 Markdown 代码块提取 + # 优先匹配 json/json5 块,如果没有指定语言的块也尝试 blocks = _extract_code_blocks(text) - - # 只处理 json/json5 或未标注语言的代码块 - for lang, block in blocks: - if lang and lang not in ("json", "json5"): - continue - - # 先在块内找平衡对象 - span = _find_balanced_json_object(block) - candidates = [span] if span else [block] - - for cand in candidates: - if not cand: - continue + for lang, content in blocks: + if not lang or lang in ("json", "json5"): try: - obj = json5.loads(cand) + obj = json5.loads(content) if isinstance(obj, dict): return obj except Exception: continue - return None - - -def try_parse_balanced_json(text: str) -> dict | None: - """ - 尝试提取并解析平衡的 JSON 对象 - - Args: - text: 待解析的文本 - - Returns: - dict | None: 解析成功返回字典,失败返回 None - """ - json_span = _find_balanced_json_object(text) - if json_span: - try: - obj = json5.loads(json_span) - if isinstance(obj, dict): - return obj - except Exception: - pass - - return None - - -def try_parse_whole_text(text: str) -> dict | None: - """ - 尝试整体解析为 JSON - - Args: - text: 待解析的文本 - - Returns: - dict | None: 解析成功返回字典,失败返回 None - """ + # 策略2: 尝试整体解析 + # 有时候 LLM 不会输出 markdown,直接输出 json try: obj = json5.loads(text) if isinstance(obj, dict): @@ -115,71 +34,17 @@ def try_parse_whole_text(text: str) -> dict | None: except Exception: pass - return None + # 失败 + raise ParseError( + "无法解析 JSON: 未找到有效的 JSON 对象或代码块", + raw_text=text[:500] + ) def _extract_code_blocks(text: str) -> list[tuple[str, str]]: - """ - 提取所有 markdown 代码块 - - Args: - text: 待提取的文本 - - Returns: - list[tuple[str, str]]: (语言, 内容) 元组列表 - """ + """提取 markdown 代码块""" pattern = re.compile(r"```([^\n`]*)\n([\s\S]*?)```", re.DOTALL) blocks = [] for lang, content in pattern.findall(text): blocks.append((lang.strip().lower(), content.strip())) return blocks - - -def _find_balanced_json_object(text: str) -> str | None: - """ - 在文本中扫描并返回首个平衡的花括号 {...} 片段 - 忽略字符串中的括号 - - Args: - text: 待扫描的文本 - - Returns: - str | None: 找到则返回子串,否则返回 None - """ - depth = 0 - start_index = None - in_string = False - string_char = '' - escape = False - - for idx, ch in enumerate(text): - if in_string: - if escape: - escape = False - continue - if ch == '\\': - escape = True - continue - if ch == string_char: - in_string = False - continue - - if ch in ('"', "'"): - in_string = True - string_char = ch - continue - - if ch == '{': - if depth == 0: - start_index = idx - depth += 1 - continue - - if ch == '}': - if depth > 0: - depth -= 1 - if depth == 0 and start_index is not None: - return text[start_index:idx + 1] - - return None - diff --git a/static/config.yml b/static/config.yml index 16dc675..7cec36b 100644 --- a/static/config.yml +++ b/static/config.yml @@ -14,7 +14,7 @@ paths: saves: assets/saves/ ai: - max_decide_num: 3 + max_concurrent_requests: 10 max_parse_retries: 3 game: diff --git a/tests/test_llm_mock.py b/tests/test_llm_mock.py index bc3d0f6..6f81f6e 100644 --- a/tests/test_llm_mock.py +++ b/tests/test_llm_mock.py @@ -1,9 +1,10 @@ import pytest +import json from unittest.mock import MagicMock, patch, AsyncMock from pathlib import Path from src.utils.llm.prompt import build_prompt -from src.utils.llm.parser import parse_json, try_parse_code_blocks, try_parse_balanced_json -from src.utils.llm.client import call_llm_json, LLMMode +from src.utils.llm.parser import parse_json +from src.utils.llm.client import call_llm_json, call_llm, LLMMode from src.utils.llm.exceptions import ParseError, LLMError # ================= Prompt Tests ================= @@ -58,21 +59,11 @@ def test_parse_code_block(): result = parse_json(text) assert result == {"foo": "bar"} -def test_parse_nested_braces(): - text = 'some text {"a": {"b": 1}} some more text' - result = parse_json(text) - assert result == {"a": {"b": 1}} - def test_parse_fail(): text = "Not a json" with pytest.raises(ParseError): parse_json(text) -def test_extract_from_text_with_noise(): - text = "Sure! Here is the JSON you requested: {\"a\": 1} Hope this helps." - result = parse_json(text) - assert result == {"a": 1} - # ================= Client Mock Tests ================= @pytest.mark.asyncio async def test_call_llm_json_success(): @@ -107,3 +98,40 @@ async def test_call_llm_json_all_fail(): assert mock_call.call_count == 2 # Initial + 1 retry +@pytest.mark.asyncio +async def test_call_llm_fallback_requests(): + """测试没有 litellm 时降级到 requests""" + + # 模拟 HTTP 响应内容 + mock_response_content = json.dumps({ + "choices": [{"message": {"content": "Response from requests"}}] + }).encode('utf-8') + + # Mock response object + mock_response = MagicMock() + mock_response.read.return_value = mock_response_content + mock_response.__enter__.return_value = mock_response + + # Mock Config + mock_config = MagicMock() + mock_config.api_key = "test_key" + mock_config.base_url = "http://test.api/v1" + mock_config.model_name = "test-model" + + # Patch 多个对象 + with patch("src.utils.llm.client.HAS_LITELLM", False), \ + patch("src.utils.llm.client.LLMConfig.from_mode", return_value=mock_config), \ + patch("urllib.request.urlopen", return_value=mock_response) as mock_urlopen: + + result = await call_llm("hello", mode=LLMMode.NORMAL) + + assert result == "Response from requests" + + # 验证 urlopen 被调用 + mock_urlopen.assert_called_once() + + # 验证请求参数 + args, _ = mock_urlopen.call_args + request_obj = args[0] + # client.py 逻辑会把 http://test.api/v1 变成 http://test.api/v1/chat/completions + assert request_obj.full_url == "http://test.api/v1/chat/completions"