refactor llm
This commit is contained in:
@@ -1,53 +0,0 @@
|
||||
"""
|
||||
通用 AI 任务批处理器。
|
||||
用于将串行的异步任务收集起来并行执行,优化 LLM 密集型场景的性能。
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Coroutine, Any, List
|
||||
|
||||
class AITaskBatch:
|
||||
"""
|
||||
AI 任务批处理器。
|
||||
|
||||
使用示例:
|
||||
```python
|
||||
async with AITaskBatch() as batch:
|
||||
for item in items:
|
||||
batch.add(process_item(item))
|
||||
# with 块结束时,所有任务已并发执行完毕
|
||||
```
|
||||
"""
|
||||
def __init__(self):
|
||||
self.tasks: List[Coroutine[Any, Any, Any]] = []
|
||||
|
||||
def add(self, coro: Coroutine[Any, Any, Any]) -> None:
|
||||
"""
|
||||
添加一个协程任务到池中(不立即执行)。
|
||||
注意:传入的协程应该自行处理结果(如修改对象状态),或者通过外部变量收集结果。
|
||||
"""
|
||||
self.tasks.append(coro)
|
||||
|
||||
async def run(self) -> List[Any]:
|
||||
"""
|
||||
并行执行池中所有任务,并等待全部完成。
|
||||
返回所有任务的结果列表(顺序与添加顺序一致)。
|
||||
"""
|
||||
if not self.tasks:
|
||||
return []
|
||||
|
||||
# 使用 gather 并发执行
|
||||
results = await asyncio.gather(*self.tasks)
|
||||
|
||||
# 清空任务队列
|
||||
self.tasks = []
|
||||
return list(results)
|
||||
|
||||
async def __aenter__(self) -> "AITaskBatch":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
# 如果 with 块内部发生异常,不执行任务,直接抛出
|
||||
if exc_type:
|
||||
return
|
||||
await self.run()
|
||||
|
||||
@@ -1,35 +1,35 @@
|
||||
"""LLM 客户端核心调用逻辑"""
|
||||
|
||||
from pathlib import Path
|
||||
import json
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from src.run.log import log_llm_call
|
||||
from src.utils.config import CONFIG
|
||||
from .config import LLMMode, LLMConfig
|
||||
from .parser import parse_json
|
||||
from .prompt import build_prompt, load_template
|
||||
from .exceptions import LLMError, ParseError
|
||||
from src.run.log import log_llm_call
|
||||
|
||||
try:
|
||||
# 使用动态导入,避免 PyInstaller 静态分析将其作为依赖打包
|
||||
import importlib
|
||||
importlib.import_module("litellm")
|
||||
has_litellm = True
|
||||
import litellm
|
||||
HAS_LITELLM = True
|
||||
except ImportError:
|
||||
has_litellm = False
|
||||
HAS_LITELLM = False
|
||||
|
||||
def _call_with_litellm(config: LLMConfig, prompt: str) -> str:
|
||||
"""使用 litellm 调用"""
|
||||
import importlib
|
||||
litellm = importlib.import_module("litellm")
|
||||
response = litellm.completion(
|
||||
model=config.model_name,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
# 模块级信号量,懒加载
|
||||
_SEMAPHORE: Optional[asyncio.Semaphore] = None
|
||||
|
||||
|
||||
def _get_semaphore() -> asyncio.Semaphore:
|
||||
global _SEMAPHORE
|
||||
if _SEMAPHORE is None:
|
||||
limit = getattr(CONFIG.ai, "max_concurrent_requests", 10)
|
||||
_SEMAPHORE = asyncio.Semaphore(limit)
|
||||
return _SEMAPHORE
|
||||
|
||||
|
||||
def _call_with_requests(config: LLMConfig, prompt: str) -> str:
|
||||
@@ -39,17 +39,14 @@ def _call_with_requests(config: LLMConfig, prompt: str) -> str:
|
||||
"Authorization": f"Bearer {config.api_key}"
|
||||
}
|
||||
|
||||
# 处理模型名称:去除 'openai/' 前缀(针对 litellm 的兼容性配置)
|
||||
model_name = config.model_name
|
||||
if model_name.startswith("openai/"):
|
||||
model_name = model_name.replace("openai/", "")
|
||||
# 兼容 litellm 的 openai/ 前缀处理
|
||||
model_name = config.model_name.replace("openai/", "")
|
||||
|
||||
data = {
|
||||
"model": model_name,
|
||||
"messages": [{"role": "user", "content": prompt}]
|
||||
}
|
||||
|
||||
# 处理 URL
|
||||
url = config.base_url
|
||||
if not url:
|
||||
raise ValueError("Base URL is required for requests mode")
|
||||
@@ -57,9 +54,7 @@ def _call_with_requests(config: LLMConfig, prompt: str) -> str:
|
||||
if "chat/completions" not in url:
|
||||
url = url.rstrip("/")
|
||||
if not url.endswith("/v1"):
|
||||
# 尝试智能追加 v1,如果用户没写
|
||||
# 但有些服务可能不需要 v1,这里保守起见,如果没 v1 且没 chat/completions,直接加 /chat/completions
|
||||
# 假设用户配置的是类似 https://api.openai.com/v1
|
||||
# 简单启发式:如果不是显式 v1 结尾,也加上
|
||||
pass
|
||||
url = f"{url}/chat/completions"
|
||||
|
||||
@@ -75,53 +70,37 @@ def _call_with_requests(config: LLMConfig, prompt: str) -> str:
|
||||
result = json.loads(response.read().decode('utf-8'))
|
||||
return result['choices'][0]['message']['content']
|
||||
except urllib.error.HTTPError as e:
|
||||
error_content = e.read().decode('utf-8')
|
||||
raise Exception(f"LLM Request failed {e.code}: {error_content}")
|
||||
raise Exception(f"LLM Request failed {e.code}: {e.read().decode('utf-8')}")
|
||||
except Exception as e:
|
||||
raise Exception(f"LLM Request failed: {str(e)}")
|
||||
|
||||
|
||||
async def call_llm(prompt: str, mode: LLMMode = LLMMode.NORMAL) -> str:
|
||||
"""
|
||||
最基础的 LLM 调用,返回原始文本
|
||||
|
||||
Args:
|
||||
prompt: 输入提示词
|
||||
mode: 调用模式
|
||||
|
||||
Returns:
|
||||
str: LLM 返回的原始文本
|
||||
基础 LLM 调用,自动控制并发
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# 获取配置
|
||||
config = LLMConfig.from_mode(mode)
|
||||
semaphore = _get_semaphore()
|
||||
|
||||
# 调用逻辑
|
||||
def _call():
|
||||
# try:
|
||||
# return _call_with_litellm(config, prompt)
|
||||
# except ImportError:
|
||||
# # 如果没有 litellm,降级使用 requests
|
||||
# return _call_with_requests(config, prompt)
|
||||
try:
|
||||
if has_litellm:
|
||||
return _call_with_litellm(config, prompt)
|
||||
else:
|
||||
return _call_with_requests(config, prompt)
|
||||
except Exception as e:
|
||||
# litellm 可能抛出其他错误,如果仅仅是导入错误我们降级
|
||||
# 如果是 litellm 内部错误(如 api key 错误),应该抛出
|
||||
# 但为了稳健,如果 litellm 失败,是否尝试 request?
|
||||
# 用户只说了 "没有的话(if no litellm)",通常指安装。
|
||||
# 所以 catch ImportError 是对的。
|
||||
raise e
|
||||
async with semaphore:
|
||||
if HAS_LITELLM:
|
||||
try:
|
||||
# 使用 litellm 原生异步接口
|
||||
response = await litellm.acompletion(
|
||||
model=config.model_name,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
)
|
||||
result = response.choices[0].message.content
|
||||
except Exception as e:
|
||||
# 再次抛出以便上层处理,或者记录日志
|
||||
raise Exception(f"LiteLLM call failed: {str(e)}") from e
|
||||
else:
|
||||
# 降级到 requests (在线程池中运行)
|
||||
result = await asyncio.to_thread(_call_with_requests, config, prompt)
|
||||
|
||||
result = await asyncio.to_thread(_call)
|
||||
|
||||
# 记录日志
|
||||
log_llm_call(config.model_name, prompt, result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -130,22 +109,8 @@ async def call_llm_json(
|
||||
mode: LLMMode = LLMMode.NORMAL,
|
||||
max_retries: int | None = None
|
||||
) -> dict:
|
||||
"""
|
||||
调用 LLM 并解析为 JSON,内置重试机制
|
||||
|
||||
Args:
|
||||
prompt: 输入提示词
|
||||
mode: 调用模式
|
||||
max_retries: 最大重试次数,None 则从配置读取
|
||||
|
||||
Returns:
|
||||
dict: 解析后的 JSON 对象
|
||||
|
||||
Raises:
|
||||
LLMError: 解析失败且重试次数用尽时抛出
|
||||
"""
|
||||
"""调用 LLM 并解析为 JSON,带重试"""
|
||||
if max_retries is None:
|
||||
from src.utils.config import CONFIG
|
||||
max_retries = int(getattr(CONFIG.ai, "max_parse_retries", 0))
|
||||
|
||||
last_error = None
|
||||
@@ -156,14 +121,9 @@ async def call_llm_json(
|
||||
except ParseError as e:
|
||||
last_error = e
|
||||
if attempt < max_retries:
|
||||
continue # 继续重试
|
||||
# 最后一次失败,抛出详细错误
|
||||
raise LLMError(
|
||||
f"解析失败(重试 {max_retries} 次后)",
|
||||
cause=last_error
|
||||
) from last_error
|
||||
|
||||
# 不应该到这里,但为了类型检查
|
||||
continue
|
||||
raise LLMError(f"解析失败(重试 {max_retries} 次后)", cause=last_error) from last_error
|
||||
|
||||
raise LLMError("未知错误")
|
||||
|
||||
|
||||
@@ -173,37 +133,13 @@ async def call_llm_with_template(
|
||||
mode: LLMMode = LLMMode.NORMAL,
|
||||
max_retries: int | None = None
|
||||
) -> dict:
|
||||
"""
|
||||
使用模板调用 LLM(最常用的高级接口)
|
||||
|
||||
Args:
|
||||
template_path: 模板文件路径
|
||||
infos: 要填充的信息字典
|
||||
mode: 调用模式
|
||||
max_retries: 最大重试次数,None 则从配置读取
|
||||
|
||||
Returns:
|
||||
dict: 解析后的 JSON 对象
|
||||
"""
|
||||
"""使用模板调用 LLM"""
|
||||
template = load_template(template_path)
|
||||
prompt = build_prompt(template, infos)
|
||||
return await call_llm_json(prompt, mode, max_retries)
|
||||
|
||||
|
||||
async def call_ai_action(
|
||||
infos: dict,
|
||||
mode: LLMMode = LLMMode.NORMAL
|
||||
) -> dict:
|
||||
"""
|
||||
AI 行动决策专用接口
|
||||
|
||||
Args:
|
||||
infos: 行动决策所需信息
|
||||
mode: 调用模式
|
||||
|
||||
Returns:
|
||||
dict: AI 行动决策结果
|
||||
"""
|
||||
from src.utils.config import CONFIG
|
||||
async def call_ai_action(infos: dict, mode: LLMMode = LLMMode.NORMAL) -> dict:
|
||||
"""AI 行动决策专用接口"""
|
||||
template_path = CONFIG.paths.templates / "ai.txt"
|
||||
return await call_llm_with_template(template_path, infos, mode)
|
||||
|
||||
@@ -7,107 +7,26 @@ from .exceptions import ParseError
|
||||
|
||||
def parse_json(text: str) -> dict:
|
||||
"""
|
||||
主解析入口,依次尝试多种策略
|
||||
|
||||
Args:
|
||||
text: 待解析的文本
|
||||
|
||||
Returns:
|
||||
dict: 解析结果
|
||||
|
||||
Raises:
|
||||
ParseError: 所有策略均失败时抛出
|
||||
解析 JSON,支持从 markdown 代码块提取或直接解析
|
||||
"""
|
||||
text = (text or '').strip()
|
||||
if not text:
|
||||
return {}
|
||||
|
||||
strategies = [
|
||||
try_parse_code_blocks,
|
||||
try_parse_balanced_json,
|
||||
try_parse_whole_text,
|
||||
]
|
||||
|
||||
errors = []
|
||||
for strategy in strategies:
|
||||
result = strategy(text)
|
||||
if result is not None:
|
||||
return result
|
||||
errors.append(f"{strategy.__name__}")
|
||||
|
||||
# 抛出详细错误
|
||||
raise ParseError(
|
||||
f"所有解析策略均失败: {', '.join(errors)}",
|
||||
raw_text=text[:500] # 只保留前 500 字符避免日志过长
|
||||
)
|
||||
|
||||
|
||||
def try_parse_code_blocks(text: str) -> dict | None:
|
||||
"""
|
||||
尝试从代码块解析 JSON
|
||||
|
||||
Args:
|
||||
text: 待解析的文本
|
||||
|
||||
Returns:
|
||||
dict | None: 解析成功返回字典,失败返回 None
|
||||
"""
|
||||
# 策略1: 尝试从 Markdown 代码块提取
|
||||
# 优先匹配 json/json5 块,如果没有指定语言的块也尝试
|
||||
blocks = _extract_code_blocks(text)
|
||||
|
||||
# 只处理 json/json5 或未标注语言的代码块
|
||||
for lang, block in blocks:
|
||||
if lang and lang not in ("json", "json5"):
|
||||
continue
|
||||
|
||||
# 先在块内找平衡对象
|
||||
span = _find_balanced_json_object(block)
|
||||
candidates = [span] if span else [block]
|
||||
|
||||
for cand in candidates:
|
||||
if not cand:
|
||||
continue
|
||||
for lang, content in blocks:
|
||||
if not lang or lang in ("json", "json5"):
|
||||
try:
|
||||
obj = json5.loads(cand)
|
||||
obj = json5.loads(content)
|
||||
if isinstance(obj, dict):
|
||||
return obj
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def try_parse_balanced_json(text: str) -> dict | None:
|
||||
"""
|
||||
尝试提取并解析平衡的 JSON 对象
|
||||
|
||||
Args:
|
||||
text: 待解析的文本
|
||||
|
||||
Returns:
|
||||
dict | None: 解析成功返回字典,失败返回 None
|
||||
"""
|
||||
json_span = _find_balanced_json_object(text)
|
||||
if json_span:
|
||||
try:
|
||||
obj = json5.loads(json_span)
|
||||
if isinstance(obj, dict):
|
||||
return obj
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def try_parse_whole_text(text: str) -> dict | None:
|
||||
"""
|
||||
尝试整体解析为 JSON
|
||||
|
||||
Args:
|
||||
text: 待解析的文本
|
||||
|
||||
Returns:
|
||||
dict | None: 解析成功返回字典,失败返回 None
|
||||
"""
|
||||
# 策略2: 尝试整体解析
|
||||
# 有时候 LLM 不会输出 markdown,直接输出 json
|
||||
try:
|
||||
obj = json5.loads(text)
|
||||
if isinstance(obj, dict):
|
||||
@@ -115,71 +34,17 @@ def try_parse_whole_text(text: str) -> dict | None:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
# 失败
|
||||
raise ParseError(
|
||||
"无法解析 JSON: 未找到有效的 JSON 对象或代码块",
|
||||
raw_text=text[:500]
|
||||
)
|
||||
|
||||
|
||||
def _extract_code_blocks(text: str) -> list[tuple[str, str]]:
|
||||
"""
|
||||
提取所有 markdown 代码块
|
||||
|
||||
Args:
|
||||
text: 待提取的文本
|
||||
|
||||
Returns:
|
||||
list[tuple[str, str]]: (语言, 内容) 元组列表
|
||||
"""
|
||||
"""提取 markdown 代码块"""
|
||||
pattern = re.compile(r"```([^\n`]*)\n([\s\S]*?)```", re.DOTALL)
|
||||
blocks = []
|
||||
for lang, content in pattern.findall(text):
|
||||
blocks.append((lang.strip().lower(), content.strip()))
|
||||
return blocks
|
||||
|
||||
|
||||
def _find_balanced_json_object(text: str) -> str | None:
|
||||
"""
|
||||
在文本中扫描并返回首个平衡的花括号 {...} 片段
|
||||
忽略字符串中的括号
|
||||
|
||||
Args:
|
||||
text: 待扫描的文本
|
||||
|
||||
Returns:
|
||||
str | None: 找到则返回子串,否则返回 None
|
||||
"""
|
||||
depth = 0
|
||||
start_index = None
|
||||
in_string = False
|
||||
string_char = ''
|
||||
escape = False
|
||||
|
||||
for idx, ch in enumerate(text):
|
||||
if in_string:
|
||||
if escape:
|
||||
escape = False
|
||||
continue
|
||||
if ch == '\\':
|
||||
escape = True
|
||||
continue
|
||||
if ch == string_char:
|
||||
in_string = False
|
||||
continue
|
||||
|
||||
if ch in ('"', "'"):
|
||||
in_string = True
|
||||
string_char = ch
|
||||
continue
|
||||
|
||||
if ch == '{':
|
||||
if depth == 0:
|
||||
start_index = idx
|
||||
depth += 1
|
||||
continue
|
||||
|
||||
if ch == '}':
|
||||
if depth > 0:
|
||||
depth -= 1
|
||||
if depth == 0 and start_index is not None:
|
||||
return text[start_index:idx + 1]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user