From 666e19215d70419d8008bccdb1400936d16b5881 Mon Sep 17 00:00:00 2001 From: Zihao Xu Date: Tue, 20 Jan 2026 23:27:04 -0800 Subject: [PATCH] test: add comprehensive LLM failure scenario tests (#79) * test: add comprehensive LLM failure scenario tests Add 29 tests covering: - HTTP errors: 401, 403, 404, 500, timeout, connection refused - Parse errors: invalid JSON, empty response, array instead of object - Retry logic: retry on parse failure, max retries exceeded - Connectivity test: friendly error messages for common failures - Configuration validation: missing base URL, URL normalization - Async call_llm: success and error propagation - Exception classes: LLMError, ParseError Closes #71 * refactor: add comment explaining unreachable code for type checker --- src/utils/llm/client.py | 5 +- tests/test_llm_failures.py | 693 +++++++++++++++++++++++++++++++++++++ 2 files changed, 696 insertions(+), 2 deletions(-) create mode 100644 tests/test_llm_failures.py diff --git a/src/utils/llm/client.py b/src/utils/llm/client.py index 831c135..ef3399b 100644 --- a/src/utils/llm/client.py +++ b/src/utils/llm/client.py @@ -90,7 +90,7 @@ async def call_llm_json( if max_retries is None: max_retries = int(getattr(CONFIG.ai, "max_parse_retries", 0)) - last_error = None + last_error: ParseError | None = None for attempt in range(max_retries + 1): response = await call_llm(prompt, mode) try: @@ -100,7 +100,8 @@ async def call_llm_json( if attempt < max_retries: continue raise LLMError(f"解析失败(重试 {max_retries} 次后)", cause=last_error) from last_error - + + # This should never be reached, but satisfies type checker. raise LLMError("未知错误") diff --git a/tests/test_llm_failures.py b/tests/test_llm_failures.py new file mode 100644 index 0000000..bcc9fa3 --- /dev/null +++ b/tests/test_llm_failures.py @@ -0,0 +1,693 @@ +""" +Tests for LLM failure scenarios and error handling. + +## What's Tested + +- HTTP error handling (401, 403, 404, 500, timeout) +- Parse error handling (invalid JSON, empty response) +- Retry logic (retry on parse failure, max retries exceeded) +- Connectivity test (llm_test_connectivity returns friendly error messages) +- Configuration validation (missing API key, missing base URL) + +## Why These Tests Matter + +LLM calls are critical to the game's AI decision making. When LLM fails: +1. The game should handle errors gracefully, not crash. +2. Users should see friendly error messages, not raw exceptions. +3. Retry logic should work correctly for transient failures. +""" + +import pytest +import json +import urllib.error +from unittest.mock import patch, MagicMock, AsyncMock +from io import BytesIO + +from src.utils.llm.client import ( + call_llm, + call_llm_json, + call_llm_with_template, + call_llm_with_task_name, + test_connectivity as llm_test_connectivity, + _call_with_requests, + LLMMode, +) +from src.utils.llm.config import LLMConfig +from src.utils.llm.parser import parse_json +from src.utils.llm.exceptions import LLMError, ParseError + + +def make_http_error(url: str, code: int, msg: str, body: bytes) -> urllib.error.HTTPError: + """Create an HTTPError for testing. The hdrs param type is incorrectly typed in stubs.""" + return urllib.error.HTTPError( + url=url, + code=code, + msg=msg, + hdrs=None, # type: ignore[arg-type] + fp=BytesIO(body) + ) + + +class TestHTTPErrors: + """Tests for HTTP error handling in LLM client.""" + + def test_401_unauthorized(self): + """Test handling of 401 Unauthorized (invalid API key).""" + config = LLMConfig( + model_name="test-model", + api_key="invalid-key", + base_url="http://test.api/v1" + ) + + # Create a mock HTTPError for 401. + http_error = make_http_error( + url="http://test.api/v1/chat/completions", + code=401, + msg="Unauthorized", + body=b'{"error": {"message": "Invalid API key"}}' + ) + + with patch("urllib.request.urlopen", side_effect=http_error): + with pytest.raises(Exception) as exc_info: + _call_with_requests(config, "test prompt") + + assert "401" in str(exc_info.value) + assert "Invalid API key" in str(exc_info.value) + + def test_403_forbidden(self): + """Test handling of 403 Forbidden (access denied).""" + config = LLMConfig( + model_name="test-model", + api_key="test-key", + base_url="http://test.api/v1" + ) + + http_error = make_http_error( + url="http://test.api/v1/chat/completions", + code=403, + msg="Forbidden", + body=b'{"error": {"message": "Access denied"}}' + ) + + with patch("urllib.request.urlopen", side_effect=http_error): + with pytest.raises(Exception) as exc_info: + _call_with_requests(config, "test prompt") + + assert "403" in str(exc_info.value) + + def test_404_not_found(self): + """Test handling of 404 Not Found (wrong URL).""" + config = LLMConfig( + model_name="test-model", + api_key="test-key", + base_url="http://test.api/wrong-path" + ) + + http_error = make_http_error( + url="http://test.api/wrong-path/chat/completions", + code=404, + msg="Not Found", + body=b'{"error": {"message": "Not found"}}' + ) + + with patch("urllib.request.urlopen", side_effect=http_error): + with pytest.raises(Exception) as exc_info: + _call_with_requests(config, "test prompt") + + assert "404" in str(exc_info.value) + + def test_500_server_error(self): + """Test handling of 500 Internal Server Error.""" + config = LLMConfig( + model_name="test-model", + api_key="test-key", + base_url="http://test.api/v1" + ) + + http_error = make_http_error( + url="http://test.api/v1/chat/completions", + code=500, + msg="Internal Server Error", + body=b'{"error": {"message": "Internal server error"}}' + ) + + with patch("urllib.request.urlopen", side_effect=http_error): + with pytest.raises(Exception) as exc_info: + _call_with_requests(config, "test prompt") + + assert "500" in str(exc_info.value) + + def test_timeout_error(self): + """Test handling of connection timeout.""" + config = LLMConfig( + model_name="test-model", + api_key="test-key", + base_url="http://test.api/v1" + ) + + timeout_error = TimeoutError("Connection timed out") + + with patch("urllib.request.urlopen", side_effect=timeout_error): + with pytest.raises(Exception) as exc_info: + _call_with_requests(config, "test prompt") + + assert "timed out" in str(exc_info.value).lower() + + def test_connection_refused(self): + """Test handling of connection refused error.""" + config = LLMConfig( + model_name="test-model", + api_key="test-key", + base_url="http://localhost:9999" + ) + + connection_error = ConnectionRefusedError("Connection refused") + + with patch("urllib.request.urlopen", side_effect=connection_error): + with pytest.raises(Exception) as exc_info: + _call_with_requests(config, "test prompt") + + assert "Connection refused" in str(exc_info.value) + + +class TestParseErrors: + """Tests for JSON parse error handling.""" + + def test_invalid_json_response(self): + """Test handling of invalid JSON in LLM response.""" + text = "This is not valid JSON at all" + + with pytest.raises(ParseError) as exc_info: + parse_json(text) + + assert "无法解析 JSON" in str(exc_info.value) + + def test_empty_response(self): + """Test handling of empty response.""" + result = parse_json("") + assert result == {} + + result = parse_json(" ") + assert result == {} + + def test_json_array_instead_of_object(self): + """Test handling of JSON array when object expected.""" + text = '[1, 2, 3]' + + with pytest.raises(ParseError): + parse_json(text) + + def test_partial_json(self): + """Test handling of incomplete/truncated JSON.""" + text = '{"key": "value", "incomplete' + + with pytest.raises(ParseError): + parse_json(text) + + def test_json_with_markdown_but_invalid_content(self): + """Test handling of markdown code block with invalid JSON.""" + text = """ + Here is the response: + ```json + {not valid json} + ``` + """ + + with pytest.raises(ParseError): + parse_json(text) + + +class TestRetryLogic: + """Tests for LLM retry logic.""" + + @pytest.mark.asyncio + async def test_retry_on_parse_failure_then_success(self): + """Test that retry works when first response is unparseable.""" + with patch("src.utils.llm.client.call_llm", new_callable=AsyncMock) as mock_call: + # First call returns invalid JSON, second returns valid JSON. + mock_call.side_effect = [ + "Invalid JSON response", + '{"success": true}' + ] + + result = await call_llm_json("prompt", max_retries=1) + + assert result == {"success": True} + assert mock_call.call_count == 2 + + @pytest.mark.asyncio + async def test_max_retries_exceeded(self): + """Test that LLMError is raised after max retries.""" + with patch("src.utils.llm.client.call_llm", new_callable=AsyncMock) as mock_call: + # All calls return invalid JSON. + mock_call.return_value = "Always invalid" + + with pytest.raises(LLMError) as exc_info: + await call_llm_json("prompt", max_retries=2) + + # Should have tried 3 times (initial + 2 retries). + assert mock_call.call_count == 3 + assert "重试" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_no_retry_when_max_retries_zero(self): + """Test that no retry happens when max_retries=0.""" + with patch("src.utils.llm.client.call_llm", new_callable=AsyncMock) as mock_call: + mock_call.return_value = "Invalid JSON" + + with pytest.raises(LLMError): + await call_llm_json("prompt", max_retries=0) + + # Should only try once. + assert mock_call.call_count == 1 + + @pytest.mark.asyncio + async def test_retry_preserves_mode(self): + """Test that retry uses the same LLM mode.""" + with patch("src.utils.llm.client.call_llm", new_callable=AsyncMock) as mock_call: + mock_call.side_effect = [ + "Invalid", + '{"ok": true}' + ] + + await call_llm_json("prompt", mode=LLMMode.FAST, max_retries=1) + + # Both calls should use FAST mode. + for call in mock_call.call_args_list: + assert call.kwargs.get("mode") == LLMMode.FAST or call.args[1] == LLMMode.FAST + + +class TestConnectivityTest: + """Tests for llm_test_connectivity function.""" + + def test_connectivity_success(self): + """Test successful connectivity check.""" + mock_response_content = json.dumps({ + "choices": [{"message": {"content": "OK"}}] + }).encode('utf-8') + + mock_response = MagicMock() + mock_response.read.return_value = mock_response_content + mock_response.__enter__.return_value = mock_response + + mock_config = LLMConfig( + model_name="test-model", + api_key="valid-key", + base_url="http://test.api/v1" + ) + + with patch("urllib.request.urlopen", return_value=mock_response): + success, error = llm_test_connectivity(config=mock_config) + + assert success is True + assert error == "" + + def test_connectivity_invalid_api_key(self): + """Test connectivity check with invalid API key.""" + http_error = make_http_error( + url="http://test.api/v1/chat/completions", + code=401, + msg="Unauthorized", + body=b'{"error": {"message": "Incorrect API key"}}' + ) + + mock_config = LLMConfig( + model_name="test-model", + api_key="invalid-key", + base_url="http://test.api/v1" + ) + + with patch("urllib.request.urlopen", side_effect=http_error): + success, error = llm_test_connectivity(config=mock_config) + + assert success is False + assert "API Key 无效" in error + + def test_connectivity_forbidden(self): + """Test connectivity check with 403 Forbidden.""" + http_error = make_http_error( + url="http://test.api/v1/chat/completions", + code=403, + msg="Forbidden", + body=b'{"error": {"message": "Forbidden"}}' + ) + + mock_config = LLMConfig( + model_name="test-model", + api_key="test-key", + base_url="http://test.api/v1" + ) + + with patch("urllib.request.urlopen", side_effect=http_error): + success, error = llm_test_connectivity(config=mock_config) + + assert success is False + assert "访问被拒绝" in error + + def test_connectivity_not_found(self): + """Test connectivity check with 404 Not Found.""" + http_error = make_http_error( + url="http://test.api/wrong/chat/completions", + code=404, + msg="Not Found", + body=b'{"error": {"message": "Not found"}}' + ) + + mock_config = LLMConfig( + model_name="test-model", + api_key="test-key", + base_url="http://test.api/wrong" + ) + + with patch("urllib.request.urlopen", side_effect=http_error): + success, error = llm_test_connectivity(config=mock_config) + + assert success is False + assert "服务地址不存在" in error + + def test_connectivity_timeout(self): + """Test connectivity check with timeout.""" + mock_config = LLMConfig( + model_name="test-model", + api_key="test-key", + base_url="http://test.api/v1" + ) + + with patch("urllib.request.urlopen", side_effect=TimeoutError("timeout")): + success, error = llm_test_connectivity(config=mock_config) + + assert success is False + assert "超时" in error + + def test_connectivity_connection_error(self): + """Test connectivity check with connection error.""" + mock_config = LLMConfig( + model_name="test-model", + api_key="test-key", + base_url="http://localhost:9999" + ) + + with patch("urllib.request.urlopen", side_effect=ConnectionError("Connection refused")): + success, error = llm_test_connectivity(config=mock_config) + + assert success is False + assert "无法连接" in error + + def test_connectivity_with_mode_instead_of_config(self): + """Test connectivity using mode parameter (config=None path).""" + mock_config = LLMConfig( + model_name="test-model", + api_key="test-key", + base_url="http://test.api/v1" + ) + + mock_response_content = json.dumps({ + "choices": [{"message": {"content": "OK"}}] + }).encode('utf-8') + + mock_response = MagicMock() + mock_response.read.return_value = mock_response_content + mock_response.__enter__.return_value = mock_response + + with patch("src.utils.llm.client.LLMConfig.from_mode", return_value=mock_config) as mock_from_mode, \ + patch("urllib.request.urlopen", return_value=mock_response): + # Pass mode, not config - this exercises line 161. + success, error = llm_test_connectivity(mode=LLMMode.FAST) + + assert success is True + mock_from_mode.assert_called_once_with(LLMMode.FAST) + + def test_connectivity_unknown_error(self): + """Test connectivity with unknown error returns raw message.""" + mock_config = LLMConfig( + model_name="test-model", + api_key="test-key", + base_url="http://test.api/v1" + ) + + # An error that doesn't match any known pattern. + unknown_error = Exception("Some weird error xyz123") + + with patch("urllib.request.urlopen", side_effect=unknown_error): + success, error = llm_test_connectivity(config=mock_config) + + assert success is False + # Should return the raw error message. + assert "Some weird error xyz123" in error + + +class TestConfigurationValidation: + """Tests for configuration validation.""" + + def test_missing_base_url(self): + """Test error when base URL is missing.""" + config = LLMConfig( + model_name="test-model", + api_key="test-key", + base_url="" + ) + + with pytest.raises(ValueError) as exc_info: + _call_with_requests(config, "test prompt") + + assert "Base URL is required" in str(exc_info.value) + + def test_url_normalization_adds_chat_completions(self): + """Test that URL is normalized to include chat/completions.""" + config = LLMConfig( + model_name="test-model", + api_key="test-key", + base_url="http://test.api/v1" + ) + + mock_response_content = json.dumps({ + "choices": [{"message": {"content": "OK"}}] + }).encode('utf-8') + + mock_response = MagicMock() + mock_response.read.return_value = mock_response_content + mock_response.__enter__.return_value = mock_response + + with patch("urllib.request.urlopen", return_value=mock_response) as mock_urlopen: + _call_with_requests(config, "test") + + # Verify URL was normalized. + args, _ = mock_urlopen.call_args + request_obj = args[0] + assert request_obj.full_url == "http://test.api/v1/chat/completions" + + def test_url_normalization_preserves_existing_path(self): + """Test that URL already containing chat/completions is not modified.""" + config = LLMConfig( + model_name="test-model", + api_key="test-key", + base_url="http://test.api/v1/chat/completions" + ) + + mock_response_content = json.dumps({ + "choices": [{"message": {"content": "OK"}}] + }).encode('utf-8') + + mock_response = MagicMock() + mock_response.read.return_value = mock_response_content + mock_response.__enter__.return_value = mock_response + + with patch("urllib.request.urlopen", return_value=mock_response) as mock_urlopen: + _call_with_requests(config, "test") + + args, _ = mock_urlopen.call_args + request_obj = args[0] + # Should not double the path. + assert request_obj.full_url == "http://test.api/v1/chat/completions" + + +class TestAsyncCallLLM: + """Tests for async call_llm function.""" + + @pytest.mark.asyncio + async def test_call_llm_success(self): + """Test successful async LLM call.""" + mock_response_content = json.dumps({ + "choices": [{"message": {"content": "Hello from LLM"}}] + }).encode('utf-8') + + mock_response = MagicMock() + mock_response.read.return_value = mock_response_content + mock_response.__enter__.return_value = mock_response + + mock_config = LLMConfig( + model_name="test-model", + api_key="test-key", + base_url="http://test.api/v1" + ) + + with patch("src.utils.llm.client.LLMConfig.from_mode", return_value=mock_config), \ + patch("urllib.request.urlopen", return_value=mock_response): + result = await call_llm("test prompt") + + assert result == "Hello from LLM" + + @pytest.mark.asyncio + async def test_call_llm_propagates_error(self): + """Test that errors from _call_with_requests propagate.""" + mock_config = LLMConfig( + model_name="test-model", + api_key="test-key", + base_url="http://test.api/v1" + ) + + http_error = make_http_error( + url="http://test.api/v1/chat/completions", + code=500, + msg="Internal Server Error", + body=b'{"error": "Server error"}' + ) + + with patch("src.utils.llm.client.LLMConfig.from_mode", return_value=mock_config), \ + patch("urllib.request.urlopen", side_effect=http_error): + with pytest.raises(Exception) as exc_info: + await call_llm("test prompt") + + assert "500" in str(exc_info.value) + + +class TestLLMErrorException: + """Tests for LLMError exception class.""" + + def test_llm_error_with_cause(self): + """Test LLMError preserves cause exception.""" + cause = ParseError("Parse failed", raw_text="bad json") + error = LLMError("LLM call failed", cause=cause) + + assert error.cause is cause + assert "LLM call failed" in str(error) + + def test_llm_error_with_context(self): + """Test LLMError stores context.""" + error = LLMError("Failed", prompt="test", retries=3) + + assert error.context["prompt"] == "test" + assert error.context["retries"] == 3 + + def test_parse_error_stores_raw_text(self): + """Test ParseError stores raw text.""" + error = ParseError("Invalid JSON", raw_text="not json") + + assert error.raw_text == "not json" + + +class TestCallLLMWithTemplate: + """Tests for call_llm_with_template function.""" + + @pytest.mark.asyncio + async def test_call_with_template_success(self): + """Test successful call with template.""" + with patch("src.utils.llm.client.load_template", return_value="Hello {name}!"), \ + patch("src.utils.llm.client.call_llm_json", new_callable=AsyncMock) as mock_call: + mock_call.return_value = {"greeting": "Hello World!"} + + result = await call_llm_with_template( + template_path="test.txt", + infos={"name": "World"} + ) + + assert result == {"greeting": "Hello World!"} + # Verify prompt was built correctly. + mock_call.assert_called_once() + call_args = mock_call.call_args + assert "Hello World!" in call_args[0][0] + + @pytest.mark.asyncio + async def test_call_with_template_passes_mode(self): + """Test that mode is passed through.""" + with patch("src.utils.llm.client.load_template", return_value="Test"), \ + patch("src.utils.llm.client.call_llm_json", new_callable=AsyncMock) as mock_call: + mock_call.return_value = {} + + await call_llm_with_template( + template_path="test.txt", + infos={}, + mode=LLMMode.FAST + ) + + call_args = mock_call.call_args + assert call_args[0][1] == LLMMode.FAST + + @pytest.mark.asyncio + async def test_call_with_template_passes_max_retries(self): + """Test that max_retries is passed through.""" + with patch("src.utils.llm.client.load_template", return_value="Test"), \ + patch("src.utils.llm.client.call_llm_json", new_callable=AsyncMock) as mock_call: + mock_call.return_value = {} + + await call_llm_with_template( + template_path="test.txt", + infos={}, + max_retries=5 + ) + + call_args = mock_call.call_args + assert call_args[0][2] == 5 + + +class TestCallLLMWithTaskName: + """Tests for call_llm_with_task_name function.""" + + @pytest.mark.asyncio + async def test_call_with_task_name_uses_task_mode(self): + """Test that task mode is determined from task name.""" + with patch("src.utils.llm.client.get_task_mode", return_value=LLMMode.FAST) as mock_get_mode, \ + patch("src.utils.llm.client.call_llm_with_template", new_callable=AsyncMock) as mock_call, \ + patch("src.utils.llm.client.CONFIG") as mock_config: + mock_config.llm.mode = "default" + mock_call.return_value = {} + + await call_llm_with_task_name( + task_name="test_task", + template_path="test.txt", + infos={} + ) + + mock_get_mode.assert_called_once_with("test_task") + # call_llm_with_template(template_path, infos, mode, max_retries) + call_args = mock_call.call_args[0] + assert call_args[2] == LLMMode.FAST + + @pytest.mark.asyncio + async def test_call_with_task_name_global_mode_override(self): + """Test that global mode overrides task mode.""" + with patch("src.utils.llm.client.get_task_mode", return_value=LLMMode.FAST), \ + patch("src.utils.llm.client.call_llm_with_template", new_callable=AsyncMock) as mock_call, \ + patch("src.utils.llm.client.CONFIG") as mock_config: + # Global mode is "normal", should override task mode "fast". + mock_config.llm.mode = "normal" + mock_call.return_value = {} + + await call_llm_with_task_name( + task_name="test_task", + template_path="test.txt", + infos={} + ) + + call_args = mock_call.call_args + # Should use NORMAL due to global override. + assert call_args[0][2] == LLMMode.NORMAL + + @pytest.mark.asyncio + async def test_call_with_task_name_passes_max_retries(self): + """Test that max_retries is passed through.""" + with patch("src.utils.llm.client.get_task_mode", return_value=LLMMode.NORMAL), \ + patch("src.utils.llm.client.call_llm_with_template", new_callable=AsyncMock) as mock_call, \ + patch("src.utils.llm.client.CONFIG") as mock_config: + mock_config.llm.mode = "default" + mock_call.return_value = {} + + await call_llm_with_task_name( + task_name="test_task", + template_path="test.txt", + infos={}, + max_retries=3 + ) + + call_args = mock_call.call_args + assert call_args[0][3] == 3