Files
cultivation-world-simulator/tests/test_llm_mock.py
2025-12-20 22:13:26 +08:00

138 lines
4.8 KiB
Python

import pytest
import json
from unittest.mock import MagicMock, patch, AsyncMock
from pathlib import Path
from src.utils.llm.prompt import build_prompt
from src.utils.llm.parser import parse_json
from src.utils.llm.client import call_llm_json, call_llm, LLMMode
from src.utils.llm.exceptions import ParseError, LLMError
# ================= Prompt Tests =================
def test_build_prompt_basic():
template = "Hello {name}, your age is {age}."
infos = {"name": "Alice", "age": 20}
result = build_prompt(template, infos)
assert result == "Hello Alice, your age is 20."
def test_build_prompt_with_complex_types():
# intentify_prompt_infos handles lists/dicts
template = "List: {items}"
infos = {"items": ["a", "b"]}
result = build_prompt(template, infos)
# intentify_prompt_infos usually joins lists with commas or newlines
# We should verify what intentify_prompt_infos does.
# Assuming it makes it string friendly.
assert "a" in result and "b" in result
def test_intentify_prompt_infos_formatting():
# intentify_prompt_infos only transforms specific keys
template = "Infos: {avatar_infos}"
avatar_data = {"name": "Alice", "hp": 100}
infos = {"avatar_infos": avatar_data}
result = build_prompt(template, infos)
# Expect pretty printed json
assert '{\n "name": "Alice",' in result
assert '"hp": 100\n}' in result
# ================= Parser Tests =================
def test_parse_simple_json():
text = '{"key": "value", "num": 1}'
result = parse_json(text)
assert result == {"key": "value", "num": 1}
def test_parse_json5_comments():
text = '{key: "value", /* comment */ num: 1}'
result = parse_json(text)
assert result == {"key": "value", "num": 1}
def test_parse_code_block():
text = """
Here is the json:
```json
{
"foo": "bar"
}
```
"""
result = parse_json(text)
assert result == {"foo": "bar"}
def test_parse_fail():
text = "Not a json"
with pytest.raises(ParseError):
parse_json(text)
# ================= Client Mock Tests =================
@pytest.mark.asyncio
async def test_call_llm_json_success():
# Mock call_llm to return a valid JSON string
with patch("src.utils.llm.client.call_llm", new_callable=AsyncMock) as mock_call:
mock_call.return_value = '{"success": true}'
result = await call_llm_json("prompt", mode=LLMMode.NORMAL)
assert result == {"success": True}
mock_call.assert_called_once()
@pytest.mark.asyncio
async def test_call_llm_json_retry_success():
# Mock call_llm to fail once (bad json) then succeed
with patch("src.utils.llm.client.call_llm", new_callable=AsyncMock) as mock_call:
mock_call.side_effect = ["Bad JSON", '{"success": true}']
# We need to make sure config max_retries is at least 1
# pass max_retries explicitly
result = await call_llm_json("prompt", mode=LLMMode.NORMAL, max_retries=1)
assert result == {"success": True}
assert mock_call.call_count == 2
@pytest.mark.asyncio
async def test_call_llm_json_all_fail():
with patch("src.utils.llm.client.call_llm", new_callable=AsyncMock) as mock_call:
mock_call.return_value = "Bad JSON"
with pytest.raises(LLMError):
await call_llm_json("prompt", mode=LLMMode.NORMAL, max_retries=1)
assert mock_call.call_count == 2 # Initial + 1 retry
@pytest.mark.asyncio
async def test_call_llm_fallback_requests():
"""测试没有 litellm 时降级到 requests"""
# 模拟 HTTP 响应内容
mock_response_content = json.dumps({
"choices": [{"message": {"content": "Response from requests"}}]
}).encode('utf-8')
# Mock response object
mock_response = MagicMock()
mock_response.read.return_value = mock_response_content
mock_response.__enter__.return_value = mock_response
# Mock Config
mock_config = MagicMock()
mock_config.api_key = "test_key"
mock_config.base_url = "http://test.api/v1"
mock_config.model_name = "test-model"
# Patch 多个对象
with patch("src.utils.llm.client.HAS_LITELLM", False), \
patch("src.utils.llm.client.LLMConfig.from_mode", return_value=mock_config), \
patch("urllib.request.urlopen", return_value=mock_response) as mock_urlopen:
result = await call_llm("hello", mode=LLMMode.NORMAL)
assert result == "Response from requests"
# 验证 urlopen 被调用
mock_urlopen.assert_called_once()
# 验证请求参数
args, _ = mock_urlopen.call_args
request_obj = args[0]
# client.py 逻辑会把 http://test.api/v1 变成 http://test.api/v1/chat/completions
assert request_obj.full_url == "http://test.api/v1/chat/completions"