refactor llm config

This commit is contained in:
bridge
2025-12-30 22:27:27 +08:00
parent d55ada7d66
commit f14ea0b92e
5 changed files with 20 additions and 53 deletions

View File

@@ -185,9 +185,12 @@ You can also join the QQ group for discussion: 1071821688. Verification answer i
Edit `static/config.yml`:
```yaml
llm:
key: "your-api-key-here" # your API key
key: "your-api-key-here" # your API key
base_url: "https://api.xxx.com" # API base URL
model_name: "model-name" # main model name
fast_model_name: "fast-model" # fast model name
```
For supported models, refer to [litellm documentation](https://docs.litellm.ai/docs/providers)
Supports all API providers compatible with OpenAI interface format (e.g., Qwen, DeepSeek, SiliconFlow, OpenRouter, etc.)
4. Run:
Need to start both backend and frontend.

View File

@@ -189,9 +189,12 @@
在 `static/config.yml` 中配置LLM参数
```yaml
llm:
key: "your-api-key-here" # 你的API密钥
key: "your-api-key-here" # 你的API密钥,如"sk-xxx"
base_url: "https://api.xxx.com" # API地址如"https://dashscope.aliyuncs.com/compatible-mode/v1"
model_name: "model-name" # 智能模型名称,如"qwen-plus"
fast_model_name: "fast-model" # 快速模型名称,如"qwen-fast"
```
具体支持的模型请参考 [litellm文档](https://docs.litellm.ai/docs/providers)
支持所有兼容 OpenAI 接口格式的 API 提供商如通义千问、DeepSeek、硅基流动、OpenRouter 等)
4. 运行:
需要同时启动后端和前端。

View File

@@ -3,4 +3,4 @@ omegaconf>=2.3.0
json5>=0.9.0
fastapi>=0.100.0
uvicorn>=0.20.0
websockets>=11.0
websockets>=11.0

View File

@@ -14,12 +14,6 @@ from .parser import parse_json
from .prompt import build_prompt, load_template
from .exceptions import LLMError, ParseError
try:
import litellm
HAS_LITELLM = True
except ImportError:
HAS_LITELLM = False
# 模块级信号量,懒加载
_SEMAPHORE: Optional[asyncio.Semaphore] = None
@@ -38,14 +32,7 @@ def _call_with_requests(config: LLMConfig, prompt: str) -> str:
"Content-Type": "application/json",
"Authorization": f"Bearer {config.api_key}"
}
# 兼容 litellm 的 openai/ 前缀处理,以及其他常见前缀清理
model_name = config.model_name
for prefix in ["openai/", "azure/", "bedrock/"]:
if model_name.startswith(prefix):
model_name = model_name[len(prefix):]
break
data = {
"model": model_name,
"messages": [{"role": "user", "content": prompt}]
@@ -82,28 +69,13 @@ def _call_with_requests(config: LLMConfig, prompt: str) -> str:
async def call_llm(prompt: str, mode: LLMMode = LLMMode.NORMAL) -> str:
"""
基础 LLM 调用,自动控制并发
使用 urllib 直接调用 OpenAI 兼容接口
"""
config = LLMConfig.from_mode(mode)
semaphore = _get_semaphore()
async with semaphore:
if HAS_LITELLM:
try:
# 使用 litellm 原生异步接口
response = await litellm.acompletion(
model=config.model_name,
messages=[{"role": "user", "content": prompt}],
api_key=config.api_key,
base_url=config.base_url,
)
result = response.choices[0].message.content
except Exception as e:
# 再次抛出以便上层处理,或者记录日志
raise Exception(f"LiteLLM call failed: {str(e)}") from e
else:
# 降级到 requests (在线程池中运行),实现 OpenAI 兼容接口
# 这样即使没有 litellm只要模型服务提供商支持 OpenAI 格式(如 Qwen, DeepSeek, LocalAI 等)均可工作
result = await asyncio.to_thread(_call_with_requests, config, prompt)
result = await asyncio.to_thread(_call_with_requests, config, prompt)
log_llm_call(config.model_name, prompt, result)
return result
@@ -188,17 +160,7 @@ def test_connectivity(mode: LLMMode = LLMMode.NORMAL, config: Optional[LLMConfig
if config is None:
config = LLMConfig.from_mode(mode)
if HAS_LITELLM:
# 使用 litellm 同步接口
litellm.completion(
model=config.model_name,
messages=[{"role": "user", "content": "你好"}],
api_key=config.api_key,
base_url=config.base_url,
)
else:
# 直接调用 requests 实现
_call_with_requests(config, "test")
_call_with_requests(config, "test")
return True, ""
except Exception as e:
error_msg = str(e)

View File

@@ -99,12 +99,12 @@ async def test_call_llm_json_all_fail():
assert mock_call.call_count == 2 # Initial + 1 retry
@pytest.mark.asyncio
async def test_call_llm_fallback_requests():
"""测试没有 litellm 时降级到 requests"""
async def test_call_llm_with_urllib():
"""测试使用 urllib 调用 OpenAI 兼容接口"""
# 模拟 HTTP 响应内容
mock_response_content = json.dumps({
"choices": [{"message": {"content": "Response from requests"}}]
"choices": [{"message": {"content": "Response from API"}}]
}).encode('utf-8')
# Mock response object
@@ -119,13 +119,12 @@ async def test_call_llm_fallback_requests():
mock_config.model_name = "test-model"
# Patch 多个对象
with patch("src.utils.llm.client.HAS_LITELLM", False), \
patch("src.utils.llm.client.LLMConfig.from_mode", return_value=mock_config), \
with patch("src.utils.llm.client.LLMConfig.from_mode", return_value=mock_config), \
patch("urllib.request.urlopen", return_value=mock_response) as mock_urlopen:
result = await call_llm("hello", mode=LLMMode.NORMAL)
assert result == "Response from requests"
assert result == "Response from API"
# 验证 urlopen 被调用
mock_urlopen.assert_called_once()
@@ -134,4 +133,4 @@ async def test_call_llm_fallback_requests():
args, _ = mock_urlopen.call_args
request_obj = args[0]
# client.py 逻辑会把 http://test.api/v1 变成 http://test.api/v1/chat/completions
assert request_obj.full_url == "http://test.api/v1/chat/completions"
assert request_obj.full_url == "http://test.api/v1/chat/completions"