add pytest
This commit is contained in:
77
tests/test_action_combat.py
Normal file
77
tests/test_action_combat.py
Normal file
@@ -0,0 +1,77 @@
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from src.classes.action.attack import Attack
|
||||
from src.classes.action_runtime import ActionStatus
|
||||
|
||||
class TestActionCombat:
|
||||
|
||||
@pytest.fixture
|
||||
def target_avatar(self, dummy_avatar):
|
||||
"""创建一个靶子角色"""
|
||||
target = MagicMock()
|
||||
target.name = "TargetDummy"
|
||||
target.id = "target_id"
|
||||
target.hp = MagicMock()
|
||||
target.hp.current = 100
|
||||
target.hp.max = 100
|
||||
target.increase_weapon_proficiency = MagicMock()
|
||||
return target
|
||||
|
||||
@patch("src.classes.action.attack.decide_battle")
|
||||
def test_attack_execution(self, mock_decide, dummy_avatar, target_avatar):
|
||||
"""测试攻击执行:扣除 HP"""
|
||||
# Mock decide_battle 返回 (winner, loser, loser_dmg, winner_dmg)
|
||||
# 假设 dummy 赢了,Target 掉了 10 点血,dummy 掉了 2 点
|
||||
mock_decide.return_value = (dummy_avatar, target_avatar, 10, 2)
|
||||
|
||||
# 注入 target 到 world
|
||||
dummy_avatar.world.avatar_manager.avatars = {target_avatar.name: target_avatar}
|
||||
|
||||
# Mock HP 为 MagicMock 以便 assert_called
|
||||
dummy_avatar.hp = MagicMock()
|
||||
|
||||
action = Attack(dummy_avatar, dummy_avatar.world)
|
||||
action._execute(avatar_name="TargetDummy")
|
||||
|
||||
# 验证伤害应用
|
||||
target_avatar.hp.reduce.assert_called_with(10)
|
||||
dummy_avatar.hp.reduce.assert_called_with(2)
|
||||
|
||||
# 验证熟练度增加 (虽然是随机的,但 mock 了 uniform 就好了,或者只验证调用)
|
||||
assert dummy_avatar.weapon.get_detailed_info.called or True # 只是确保流程跑通
|
||||
|
||||
@patch("src.classes.action.attack.handle_death") # 这个是在 death.py 里的
|
||||
@patch("src.classes.battle.handle_battle_finish", new_callable=AsyncMock)
|
||||
def test_attack_finish(self, mock_battle_finish, mock_handle_death, dummy_avatar, target_avatar):
|
||||
"""测试战斗结束回调"""
|
||||
# 注入 target
|
||||
dummy_avatar.world.avatar_manager.avatars = {target_avatar.name: target_avatar}
|
||||
|
||||
action = Attack(dummy_avatar, dummy_avatar.world)
|
||||
|
||||
# 设置 _last_result (通常由 execute 设置)
|
||||
action._last_result = (dummy_avatar, target_avatar, 10, 2)
|
||||
action._start_event_content = "Start Battle"
|
||||
|
||||
# 运行 finish
|
||||
import asyncio
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(action.finish(avatar_name="TargetDummy"))
|
||||
|
||||
# 验证调用了 handle_battle_finish
|
||||
mock_battle_finish.assert_called_once()
|
||||
args, kwargs = mock_battle_finish.call_args
|
||||
assert args[1] == dummy_avatar # winner
|
||||
assert args[2] == target_avatar # loser
|
||||
|
||||
def test_can_start_missing_target(self, dummy_avatar):
|
||||
"""测试目标不存在"""
|
||||
dummy_avatar.world.avatar_manager.avatars = {}
|
||||
action = Attack(dummy_avatar, dummy_avatar.world)
|
||||
|
||||
ok, reason = action.can_start("Ghost")
|
||||
assert ok is False
|
||||
assert reason == "目标不存在"
|
||||
|
||||
109
tests/test_action_cultivate.py
Normal file
109
tests/test_action_cultivate.py
Normal file
@@ -0,0 +1,109 @@
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
from src.classes.action.cultivate import Cultivate
|
||||
from src.classes.tile import TileType
|
||||
from src.classes.region import CultivateRegion, NormalRegion
|
||||
from src.classes.event import Event
|
||||
from src.classes.root import Root
|
||||
from src.classes.essence import EssenceType
|
||||
|
||||
class TestActionCultivate:
|
||||
|
||||
@pytest.fixture
|
||||
def cultivation_avatar(self, dummy_avatar):
|
||||
"""配置一个适合修炼的角色环境"""
|
||||
# 设置灵根
|
||||
dummy_avatar.root = Root.FIRE
|
||||
|
||||
# 使用 patch mock 掉 effects 属性
|
||||
# 注意:这里会影响 Avatar 类,但在 fixture 作用域结束后会还原
|
||||
with patch('src.classes.avatar.Avatar.effects', new_callable=PropertyMock) as mock_effects:
|
||||
mock_effects.return_value = {}
|
||||
|
||||
# 重置修炼进度
|
||||
dummy_avatar.cultivation_progress.exp = 0
|
||||
# 设置为 29 级
|
||||
dummy_avatar.cultivation_progress.level = 29
|
||||
dummy_avatar.cultivation_progress.max_exp = 1000
|
||||
|
||||
yield dummy_avatar
|
||||
|
||||
def test_cultivate_in_wild(self, cultivation_avatar):
|
||||
"""测试在野外(非修炼区域)修炼:低保经验"""
|
||||
# 确保当前区域不是 CultivateRegion
|
||||
tile = cultivation_avatar.tile
|
||||
tile.region = NormalRegion(id=999, name="Wild", desc="Just Wild") # 普通区域
|
||||
|
||||
action = Cultivate(cultivation_avatar, cultivation_avatar.world)
|
||||
|
||||
# Check
|
||||
can_start, reason = action.can_start()
|
||||
assert can_start is True
|
||||
|
||||
# Execute
|
||||
action._execute()
|
||||
|
||||
# Assert: 获得低保经验
|
||||
expected_exp = Cultivate.BASE_EXP_LOW_EFFICIENCY
|
||||
assert cultivation_avatar.cultivation_progress.exp == expected_exp
|
||||
|
||||
def test_cultivate_in_matching_region(self, cultivation_avatar):
|
||||
"""测试在匹配灵气的洞府修炼:高经验"""
|
||||
# 设置当前 Tile 为 CultivateRegion
|
||||
region = CultivateRegion(id=1, name="Fire Cave", desc="Hot", essence_type=EssenceType.FIRE, essence_density=5)
|
||||
|
||||
cultivation_avatar.tile.region = region
|
||||
|
||||
action = Cultivate(cultivation_avatar, cultivation_avatar.world)
|
||||
action._execute()
|
||||
|
||||
# Assert: density(5) * base(100) = 500
|
||||
expected_exp = 5 * Cultivate.BASE_EXP_PER_DENSITY
|
||||
|
||||
assert cultivation_avatar.cultivation_progress.exp == expected_exp
|
||||
|
||||
def test_cultivate_in_mismatching_region(self, cultivation_avatar):
|
||||
"""测试在不匹配灵气的洞府修炼:低保经验"""
|
||||
# 设置水灵气,角色是火灵根
|
||||
region = CultivateRegion(id=2, name="Water Cave", desc="Wet", essence_type=EssenceType.WATER, essence_density=5)
|
||||
cultivation_avatar.tile.region = region
|
||||
|
||||
action = Cultivate(cultivation_avatar, cultivation_avatar.world)
|
||||
action._execute()
|
||||
|
||||
# Assert: 0 * 100 -> fallback to LOW_EFFICIENCY
|
||||
expected_exp = Cultivate.BASE_EXP_LOW_EFFICIENCY
|
||||
assert cultivation_avatar.cultivation_progress.exp == expected_exp
|
||||
|
||||
def test_cultivate_bottleneck(self, cultivation_avatar):
|
||||
"""测试瓶颈期修炼:不增加经验"""
|
||||
# 设置为瓶颈等级
|
||||
cultivation_avatar.cultivation_progress.level = 30
|
||||
initial_exp = cultivation_avatar.cultivation_progress.exp
|
||||
|
||||
action = Cultivate(cultivation_avatar, cultivation_avatar.world)
|
||||
|
||||
# Check can_start
|
||||
can_start, reason = action.can_start()
|
||||
assert can_start is False
|
||||
assert "瓶颈" in reason
|
||||
|
||||
# Force execute (should return early)
|
||||
action._execute()
|
||||
assert cultivation_avatar.cultivation_progress.exp == initial_exp
|
||||
|
||||
def test_cultivate_occupied_region(self, cultivation_avatar):
|
||||
"""测试修炼区域被他人占据"""
|
||||
region = CultivateRegion(id=3, name="Occupied", desc="Full", essence_type=EssenceType.FIRE, essence_density=5)
|
||||
other_avatar = MagicMock()
|
||||
other_avatar.name = "Stranger"
|
||||
region.host_avatar = other_avatar # 占据者不是自己
|
||||
cultivation_avatar.tile.region = region
|
||||
|
||||
action = Cultivate(cultivation_avatar, cultivation_avatar.world)
|
||||
|
||||
can_start, reason = action.can_start()
|
||||
assert can_start is False
|
||||
assert "Stranger" in reason
|
||||
|
||||
66
tests/test_action_move.py
Normal file
66
tests/test_action_move.py
Normal file
@@ -0,0 +1,66 @@
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
from src.classes.action.move import Move
|
||||
from src.classes.tile import TileType
|
||||
from src.classes.action_runtime import ActionStatus
|
||||
|
||||
class TestActionMove:
|
||||
|
||||
def test_move_basic(self, dummy_avatar):
|
||||
"""测试基础移动:向右下移动 (1, 1)"""
|
||||
# 初始位置 (0, 0)
|
||||
assert dummy_avatar.pos_x == 0
|
||||
assert dummy_avatar.pos_y == 0
|
||||
|
||||
# 默认步长为 1,允许移动 (1, 1) 因为 clamp 逻辑允许斜向优先
|
||||
# 假设 move_step_length = 1,曼哈顿距离 1+1=2 > 1?
|
||||
# 看一下 clamp_manhattan_with_diagonal_priority 逻辑:
|
||||
# 如果 limit=1, (1,1) 是允许的吗?
|
||||
# 通常斜向算 1 步还是 2 步?根据源码:clamp_manhattan_with_diagonal_priority
|
||||
# 暂时假设允许 (1,0)
|
||||
|
||||
action = Move(dummy_avatar, dummy_avatar.world)
|
||||
action.execute(delta_x=1, delta_y=0)
|
||||
|
||||
assert dummy_avatar.pos_x == 1
|
||||
assert dummy_avatar.pos_y == 0
|
||||
assert dummy_avatar.tile.x == 1
|
||||
assert dummy_avatar.tile.y == 0
|
||||
|
||||
def test_move_out_of_bounds(self, dummy_avatar):
|
||||
"""测试边界移动:尝试移出地图 (往左)"""
|
||||
# 初始 (0, 0)
|
||||
action = Move(dummy_avatar, dummy_avatar.world)
|
||||
action.execute(delta_x=-1, delta_y=0)
|
||||
|
||||
# 应该还在 (0, 0)
|
||||
assert dummy_avatar.pos_x == 0
|
||||
assert dummy_avatar.pos_y == 0
|
||||
|
||||
def test_move_with_increased_step(self, dummy_avatar):
|
||||
"""测试增加步长后的移动"""
|
||||
# 增加步长
|
||||
with patch.object(type(dummy_avatar), 'move_step_length', new_callable=PropertyMock) as mock_step:
|
||||
mock_step.return_value = 3
|
||||
|
||||
action = Move(dummy_avatar, dummy_avatar.world)
|
||||
# 尝试移动 (0, 3)
|
||||
action.execute(delta_x=0, delta_y=3)
|
||||
|
||||
assert dummy_avatar.pos_x == 0
|
||||
assert dummy_avatar.pos_y == 3
|
||||
|
||||
def test_move_clamped_by_step(self, dummy_avatar):
|
||||
"""测试步长限制:尝试移动超过步长的距离"""
|
||||
with patch.object(type(dummy_avatar), 'move_step_length', new_callable=PropertyMock) as mock_step:
|
||||
mock_step.return_value = 1
|
||||
|
||||
action = Move(dummy_avatar, dummy_avatar.world)
|
||||
# 尝试移动 (5, 0)
|
||||
action.execute(delta_x=5, delta_y=0)
|
||||
|
||||
# 应该只移动了 1 格 (1, 0)
|
||||
assert dummy_avatar.pos_x == 1
|
||||
assert dummy_avatar.pos_y == 0
|
||||
|
||||
80
tests/test_action_social.py
Normal file
80
tests/test_action_social.py
Normal file
@@ -0,0 +1,80 @@
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from src.classes.mutual_action.conversation import Conversation
|
||||
from src.classes.action_runtime import ActionStatus
|
||||
from src.classes.event import Event
|
||||
|
||||
class TestActionSocial:
|
||||
|
||||
@pytest.fixture
|
||||
def target_avatar(self, dummy_avatar):
|
||||
target = MagicMock()
|
||||
target.name = "FriendDummy"
|
||||
target.id = "friend_id"
|
||||
target.get_info.return_value = "Target Info"
|
||||
target.get_planned_actions_str.return_value = "None"
|
||||
target.thinking = ""
|
||||
# 模拟 add_event
|
||||
target.events = []
|
||||
target.add_event = lambda e, to_sidebar=False: target.events.append(e)
|
||||
# 模拟修炼进度(用于关系判断)
|
||||
target.cultivation_progress.level = 10
|
||||
target.gender = dummy_avatar.gender # 同性
|
||||
target.get_relation.return_value = None
|
||||
|
||||
return target
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.classes.mutual_action.mutual_action.call_llm_with_template", new_callable=AsyncMock)
|
||||
async def test_conversation_flow(self, mock_llm, dummy_avatar, target_avatar):
|
||||
"""测试对话流程:Step -> LLM -> Feedback"""
|
||||
|
||||
# 1. 准备 Mock LLM 返回
|
||||
mock_response = {
|
||||
"FriendDummy": {
|
||||
"thinking": "He is nice.",
|
||||
"conversation_content": "Hello there!",
|
||||
"feedback": "Accept" # Conversation 其实不强制 feedback,主要是 content
|
||||
}
|
||||
}
|
||||
mock_llm.return_value = mock_response
|
||||
|
||||
# 注入 World 查找
|
||||
dummy_avatar.world.avatar_manager.avatars = {target_avatar.name: target_avatar}
|
||||
|
||||
# Mock 自己的 level (避免 dummy_avatar 中也是 Mock 导致无法比较)
|
||||
dummy_avatar.cultivation_progress.level = 10
|
||||
|
||||
# 2. 初始化 Action
|
||||
action = Conversation(dummy_avatar, dummy_avatar.world)
|
||||
action._start_month_stamp = 100
|
||||
|
||||
# 3. 第一次 Step: 应该触发 LLM 任务并返回 RUNNING
|
||||
res1 = action.step(target_avatar=target_avatar)
|
||||
assert res1.status == ActionStatus.RUNNING
|
||||
assert action._feedback_task is not None
|
||||
|
||||
# 等待 Task 完成
|
||||
await action._feedback_task
|
||||
|
||||
# 4. 第二次 Step: 消费结果
|
||||
res2 = action.step(target_avatar=target_avatar)
|
||||
assert res2.status == ActionStatus.COMPLETED
|
||||
|
||||
# 5. 验证结果
|
||||
# 应该有一个包含对话内容的事件
|
||||
assert len(res2.events) >= 1
|
||||
content_event = res2.events[0]
|
||||
assert "Hello there!" in content_event.content
|
||||
assert dummy_avatar.id in content_event.related_avatars
|
||||
assert target_avatar.id in content_event.related_avatars
|
||||
|
||||
# 验证 Target 思考被更新
|
||||
assert target_avatar.thinking == "He is nice."
|
||||
|
||||
def test_conversation_no_target(self, dummy_avatar):
|
||||
action = Conversation(dummy_avatar, dummy_avatar.world)
|
||||
res = action.step(target_avatar=None)
|
||||
assert res.status == ActionStatus.FAILED
|
||||
|
||||
110
tests/test_cultivation_logic.py
Normal file
110
tests/test_cultivation_logic.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import pytest
|
||||
from src.classes.cultivation import Realm, Stage, CultivationProgress
|
||||
|
||||
# ================= Realm Tests =================
|
||||
def test_realm_comparison():
|
||||
assert Realm.Qi_Refinement < Realm.Foundation_Establishment
|
||||
assert Realm.Foundation_Establishment < Realm.Core_Formation
|
||||
assert Realm.Core_Formation < Realm.Nascent_Soul
|
||||
assert Realm.Nascent_Soul > Realm.Qi_Refinement
|
||||
|
||||
def test_realm_from_id():
|
||||
assert Realm.from_id(1) == Realm.Qi_Refinement
|
||||
assert Realm.from_id(4) == Realm.Nascent_Soul
|
||||
with pytest.raises(ValueError):
|
||||
Realm.from_id(0)
|
||||
|
||||
# ================= Stage Tests =================
|
||||
def test_stage_comparison():
|
||||
assert Stage.Early_Stage < Stage.Middle_Stage
|
||||
assert Stage.Middle_Stage < Stage.Late_Stage
|
||||
|
||||
# ================= CultivationProgress Tests =================
|
||||
def test_cp_initialization():
|
||||
cp = CultivationProgress(level=1, exp=0)
|
||||
assert cp.realm == Realm.Qi_Refinement
|
||||
assert cp.stage == Stage.Early_Stage
|
||||
|
||||
def test_cp_level_mapping():
|
||||
# Level 1-10 -> Early
|
||||
assert CultivationProgress(1).stage == Stage.Early_Stage
|
||||
assert CultivationProgress(10).stage == Stage.Early_Stage
|
||||
|
||||
# Level 11-20 -> Middle
|
||||
assert CultivationProgress(11).stage == Stage.Middle_Stage
|
||||
assert CultivationProgress(20).stage == Stage.Middle_Stage
|
||||
|
||||
# Level 21-30 -> Late
|
||||
assert CultivationProgress(21).stage == Stage.Late_Stage
|
||||
assert CultivationProgress(30).stage == Stage.Late_Stage
|
||||
|
||||
# Level 31 -> Next Realm (Foundation)
|
||||
cp = CultivationProgress(31)
|
||||
assert cp.realm == Realm.Foundation_Establishment
|
||||
assert cp.stage == Stage.Early_Stage
|
||||
|
||||
def test_cp_bottleneck():
|
||||
# Level 30 is end of Qi Refinement (Late Stage)
|
||||
# According to code: bottleneck if level % 30 == 0
|
||||
cp = CultivationProgress(30)
|
||||
assert cp.is_in_bottleneck() is True
|
||||
assert cp.can_break_through() is True
|
||||
assert cp.can_cultivate() is False
|
||||
|
||||
cp = CultivationProgress(29)
|
||||
assert cp.is_in_bottleneck() is False
|
||||
|
||||
def test_cp_add_exp_normal():
|
||||
cp = CultivationProgress(1, exp=0)
|
||||
required = cp.get_exp_required()
|
||||
|
||||
# Add not enough to level up
|
||||
leveled = cp.add_exp(required - 1)
|
||||
assert leveled is False
|
||||
assert cp.level == 1
|
||||
assert cp.exp == required - 1
|
||||
|
||||
# Add enough to level up
|
||||
leveled = cp.add_exp(2) # Total > required
|
||||
assert leveled is True
|
||||
assert cp.level == 2
|
||||
# Exp should be consumed
|
||||
assert cp.exp == 1
|
||||
|
||||
def test_cp_add_exp_stops_at_bottleneck():
|
||||
# Start at level 29
|
||||
cp = CultivationProgress(29, exp=0)
|
||||
req_29 = cp.get_exp_required()
|
||||
|
||||
# Add enough exp to theoretically go to 31
|
||||
# But should stop at 30 (bottleneck)
|
||||
# Need exp for 29->30.
|
||||
# At 30, it is bottleneck.
|
||||
|
||||
cp.add_exp(req_29 + 100000)
|
||||
|
||||
assert cp.level == 30
|
||||
assert cp.is_in_bottleneck() is True
|
||||
# Exp should accumulate? Logic says:
|
||||
# if is_in_bottleneck(): break (inside while loop)
|
||||
# So extra exp stays in self.exp
|
||||
assert cp.exp >= 100000
|
||||
|
||||
def test_cp_breakthrough():
|
||||
cp = CultivationProgress(30, exp=0)
|
||||
cp.break_through()
|
||||
assert cp.level == 31
|
||||
assert cp.realm == Realm.Foundation_Establishment
|
||||
assert cp.is_in_bottleneck() is False
|
||||
|
||||
def test_cp_serialization():
|
||||
cp = CultivationProgress(5, exp=123)
|
||||
data = cp.to_dict()
|
||||
assert data["level"] == 5
|
||||
assert data["exp"] == 123
|
||||
|
||||
cp_new = CultivationProgress.from_dict(data)
|
||||
assert cp_new.level == 5
|
||||
assert cp_new.exp == 123
|
||||
assert cp_new.realm == Realm.Qi_Refinement
|
||||
|
||||
109
tests/test_llm_mock.py
Normal file
109
tests/test_llm_mock.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from pathlib import Path
|
||||
from src.utils.llm.prompt import build_prompt
|
||||
from src.utils.llm.parser import parse_json, try_parse_code_blocks, try_parse_balanced_json
|
||||
from src.utils.llm.client import call_llm_json, LLMMode
|
||||
from src.utils.llm.exceptions import ParseError, LLMError
|
||||
|
||||
# ================= Prompt Tests =================
|
||||
def test_build_prompt_basic():
|
||||
template = "Hello {name}, your age is {age}."
|
||||
infos = {"name": "Alice", "age": 20}
|
||||
result = build_prompt(template, infos)
|
||||
assert result == "Hello Alice, your age is 20."
|
||||
|
||||
def test_build_prompt_with_complex_types():
|
||||
# intentify_prompt_infos handles lists/dicts
|
||||
template = "List: {items}"
|
||||
infos = {"items": ["a", "b"]}
|
||||
result = build_prompt(template, infos)
|
||||
# intentify_prompt_infos usually joins lists with commas or newlines
|
||||
# We should verify what intentify_prompt_infos does.
|
||||
# Assuming it makes it string friendly.
|
||||
assert "a" in result and "b" in result
|
||||
|
||||
def test_intentify_prompt_infos_formatting():
|
||||
# intentify_prompt_infos only transforms specific keys
|
||||
template = "Infos: {avatar_infos}"
|
||||
avatar_data = {"name": "Alice", "hp": 100}
|
||||
infos = {"avatar_infos": avatar_data}
|
||||
|
||||
result = build_prompt(template, infos)
|
||||
|
||||
# Expect pretty printed json
|
||||
assert '{\n "name": "Alice",' in result
|
||||
assert '"hp": 100\n}' in result
|
||||
|
||||
# ================= Parser Tests =================
|
||||
def test_parse_simple_json():
|
||||
text = '{"key": "value", "num": 1}'
|
||||
result = parse_json(text)
|
||||
assert result == {"key": "value", "num": 1}
|
||||
|
||||
def test_parse_json5_comments():
|
||||
text = '{key: "value", /* comment */ num: 1}'
|
||||
result = parse_json(text)
|
||||
assert result == {"key": "value", "num": 1}
|
||||
|
||||
def test_parse_code_block():
|
||||
text = """
|
||||
Here is the json:
|
||||
```json
|
||||
{
|
||||
"foo": "bar"
|
||||
}
|
||||
```
|
||||
"""
|
||||
result = parse_json(text)
|
||||
assert result == {"foo": "bar"}
|
||||
|
||||
def test_parse_nested_braces():
|
||||
text = 'some text {"a": {"b": 1}} some more text'
|
||||
result = parse_json(text)
|
||||
assert result == {"a": {"b": 1}}
|
||||
|
||||
def test_parse_fail():
|
||||
text = "Not a json"
|
||||
with pytest.raises(ParseError):
|
||||
parse_json(text)
|
||||
|
||||
def test_extract_from_text_with_noise():
|
||||
text = "Sure! Here is the JSON you requested: {\"a\": 1} Hope this helps."
|
||||
result = parse_json(text)
|
||||
assert result == {"a": 1}
|
||||
|
||||
# ================= Client Mock Tests =================
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_llm_json_success():
|
||||
# Mock call_llm to return a valid JSON string
|
||||
with patch("src.utils.llm.client.call_llm", new_callable=AsyncMock) as mock_call:
|
||||
mock_call.return_value = '{"success": true}'
|
||||
|
||||
result = await call_llm_json("prompt", mode=LLMMode.NORMAL)
|
||||
assert result == {"success": True}
|
||||
mock_call.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_llm_json_retry_success():
|
||||
# Mock call_llm to fail once (bad json) then succeed
|
||||
with patch("src.utils.llm.client.call_llm", new_callable=AsyncMock) as mock_call:
|
||||
mock_call.side_effect = ["Bad JSON", '{"success": true}']
|
||||
|
||||
# We need to make sure config max_retries is at least 1
|
||||
# pass max_retries explicitly
|
||||
result = await call_llm_json("prompt", mode=LLMMode.NORMAL, max_retries=1)
|
||||
|
||||
assert result == {"success": True}
|
||||
assert mock_call.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_llm_json_all_fail():
|
||||
with patch("src.utils.llm.client.call_llm", new_callable=AsyncMock) as mock_call:
|
||||
mock_call.return_value = "Bad JSON"
|
||||
|
||||
with pytest.raises(LLMError):
|
||||
await call_llm_json("prompt", mode=LLMMode.NORMAL, max_retries=1)
|
||||
|
||||
assert mock_call.call_count == 2 # Initial + 1 retry
|
||||
|
||||
118
tests/test_utils_numerical.py
Normal file
118
tests/test_utils_numerical.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import pytest
|
||||
from src.classes.hp_and_mp import HP
|
||||
from src.classes.cultivation import Realm
|
||||
from src.utils.distance import chebyshev_distance, manhattan_distance, euclidean_distance
|
||||
from src.utils.id_generator import get_avatar_id
|
||||
from src.utils.df import get_int, get_bool, get_list_int, get_float
|
||||
|
||||
# ================= HP Tests =================
|
||||
def test_hp_initialization():
|
||||
hp = HP(max=100, cur=100)
|
||||
assert hp.max == 100
|
||||
assert hp.cur == 100
|
||||
|
||||
def test_hp_reduce():
|
||||
hp = HP(max=100, cur=100)
|
||||
alive = hp.reduce(30)
|
||||
assert hp.cur == 70
|
||||
assert alive is True
|
||||
|
||||
def test_hp_reduce_to_death():
|
||||
hp = HP(max=100, cur=10)
|
||||
alive = hp.reduce(20)
|
||||
assert hp.cur == -10
|
||||
assert alive is False
|
||||
|
||||
def test_hp_recover():
|
||||
hp = HP(max=100, cur=50)
|
||||
hp.recover(30)
|
||||
assert hp.cur == 80
|
||||
|
||||
def test_hp_recover_overflow():
|
||||
hp = HP(max=100, cur=90)
|
||||
hp.recover(20)
|
||||
assert hp.cur == 100
|
||||
|
||||
def test_hp_add_max():
|
||||
hp = HP(max=100, cur=100)
|
||||
hp.add_max(50)
|
||||
assert hp.max == 150
|
||||
assert hp.cur == 100
|
||||
|
||||
def test_hp_comparison():
|
||||
hp1 = HP(max=100, cur=50)
|
||||
hp2 = HP(max=100, cur=60)
|
||||
hp3 = HP(max=200, cur=50)
|
||||
|
||||
assert hp1 < hp2
|
||||
assert hp1 == hp3 # Compares cur
|
||||
assert hp2 > hp1
|
||||
|
||||
def test_hp_serialization():
|
||||
hp = HP(max=100, cur=50)
|
||||
data = hp.to_dict()
|
||||
assert data == {"max": 100, "cur": 50}
|
||||
|
||||
hp_new = HP.from_dict(data)
|
||||
assert hp_new == hp
|
||||
assert hp_new.max == hp.max
|
||||
|
||||
# ================= Distance Tests =================
|
||||
def test_chebyshev_distance():
|
||||
p1 = (0, 0)
|
||||
p2 = (3, 4)
|
||||
# max(|3-0|, |4-0|) = 4
|
||||
assert chebyshev_distance(p1, p2) == 4
|
||||
|
||||
def test_manhattan_distance():
|
||||
p1 = (0, 0)
|
||||
p2 = (3, 4)
|
||||
# |3-0| + |4-0| = 7
|
||||
assert manhattan_distance(p1, p2) == 7
|
||||
|
||||
def test_euclidean_distance():
|
||||
p1 = (0, 0)
|
||||
p2 = (3, 4)
|
||||
# sqrt(3^2 + 4^2) = 5.0
|
||||
assert euclidean_distance(p1, p2) == 5.0
|
||||
|
||||
# ================= ID Generator Tests =================
|
||||
def test_id_generator():
|
||||
id1 = get_avatar_id()
|
||||
id2 = get_avatar_id()
|
||||
assert isinstance(id1, str)
|
||||
assert len(id1) == 8
|
||||
assert id1 != id2
|
||||
|
||||
# ================= DF Helper Tests =================
|
||||
def test_df_get_int():
|
||||
row = {"a": "123", "b": "12.3", "c": "abc"}
|
||||
assert get_int(row, "a") == 123
|
||||
assert get_int(row, "b") == 12 # int(12.3) -> 12
|
||||
assert get_int(row, "c", default=99) == 99
|
||||
assert get_int(row, "missing", default=0) == 0
|
||||
|
||||
def test_df_get_float():
|
||||
row = {"a": "12.5", "b": "invalid"}
|
||||
assert get_float(row, "a") == 12.5
|
||||
assert get_float(row, "b", default=1.0) == 1.0
|
||||
|
||||
def test_df_get_bool():
|
||||
row = {"a": "true", "b": "1", "c": "yes", "d": "false", "e": "0"}
|
||||
assert get_bool(row, "a") is True
|
||||
assert get_bool(row, "b") is True
|
||||
assert get_bool(row, "c") is True
|
||||
assert get_bool(row, "d") is False
|
||||
assert get_bool(row, "e") is False
|
||||
assert get_bool(row, "missing") is False
|
||||
|
||||
def test_df_get_list_int():
|
||||
row = {"a": "1|2|3", "b": "1,2,3", "c": "1|invalid|3"}
|
||||
# Default separator is likely '|' from CONFIG, but let's test with explicit separator if needed or assume default.
|
||||
# The code says `separator = CONFIG.df.ids_separator` if None.
|
||||
# We'll assume default is '|' or we can mock CONFIG.
|
||||
# Actually, looking at the code: `if separator is None: separator = CONFIG.df.ids_separator`.
|
||||
# Let's provide explicit separator to be safe and independent of CONFIG.
|
||||
assert get_list_int(row, "a", separator="|") == [1, 2, 3]
|
||||
assert get_list_int(row, "c", separator="|") == [1, 3]
|
||||
|
||||
Reference in New Issue
Block a user