refactor llm
This commit is contained in:
226
src/utils/llm.py
226
src/utils/llm.py
@@ -1,226 +0,0 @@
|
||||
from litellm import completion
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
import re
|
||||
import json5
|
||||
import os
|
||||
|
||||
from src.utils.config import CONFIG
|
||||
from src.utils.io import read_txt
|
||||
from src.run.log import log_llm_call
|
||||
from src.utils.strings import intentify_prompt_infos
|
||||
|
||||
def get_prompt(template: str, infos: dict) -> str:
|
||||
"""
|
||||
根据模板,获取提示词
|
||||
"""
|
||||
# 将 dict/list 等结构化对象转为 JSON 字符串
|
||||
# 策略:
|
||||
# - avatar_infos: 不包装 intent(模板里已经说明是 dict[Name, info])
|
||||
# - general_action_infos: 强制包装 intent 以凸显语义
|
||||
# - 其他容器类型:默认包装 intent
|
||||
processed_infos = intentify_prompt_infos(infos)
|
||||
return template.format(**processed_infos)
|
||||
|
||||
|
||||
def call_llm(prompt: str, mode="normal") -> str:
|
||||
"""
|
||||
调用LLM
|
||||
|
||||
Args:
|
||||
prompt: 输入的提示词
|
||||
Returns:
|
||||
str: LLM返回的结果
|
||||
"""
|
||||
# 从配置中获取模型信息
|
||||
if mode == "normal":
|
||||
model_name = CONFIG.llm.model_name
|
||||
elif mode == "fast":
|
||||
model_name = CONFIG.llm.fast_model_name
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
# API Key 优先从环境变量读取,其次 fallback 到配置文件
|
||||
api_key = os.getenv("QWEN_API_KEY") or CONFIG.llm.key
|
||||
base_url = CONFIG.llm.base_url
|
||||
# 调用litellm的completion函数
|
||||
response = completion(
|
||||
model=model_name,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
# 返回生成的内容
|
||||
result = response.choices[0].message.content
|
||||
log_llm_call(model_name, prompt, result) # 记录日志
|
||||
return result
|
||||
|
||||
async def call_llm_async(prompt: str, mode="normal") -> str:
|
||||
"""
|
||||
异步调用LLM
|
||||
|
||||
Args:
|
||||
prompt: 输入的提示词
|
||||
Returns:
|
||||
str: LLM返回的结果
|
||||
"""
|
||||
# 使用asyncio.to_thread包装同步调用
|
||||
result = await asyncio.to_thread(call_llm, prompt, mode)
|
||||
return result
|
||||
|
||||
def _extract_code_blocks(text: str):
|
||||
"""
|
||||
提取所有markdown代码块,返回 (lang, content) 列表。
|
||||
"""
|
||||
pattern = re.compile(r"```([^\n`]*)\n([\s\S]*?)```", re.DOTALL)
|
||||
blocks = []
|
||||
for lang, content in pattern.findall(text):
|
||||
blocks.append((lang.strip().lower(), content.strip()))
|
||||
return blocks
|
||||
|
||||
|
||||
def _find_first_balanced_json_object(text: str):
|
||||
"""
|
||||
在整段文本中扫描并返回首个平衡的花括号 {...} 片段(忽略字符串中的括号)。
|
||||
找到则返回子串,否则返回None。
|
||||
"""
|
||||
depth = 0
|
||||
start_index = None
|
||||
in_string = False
|
||||
string_char = ''
|
||||
escape = False
|
||||
for idx, ch in enumerate(text):
|
||||
if in_string:
|
||||
if escape:
|
||||
escape = False
|
||||
continue
|
||||
if ch == '\\':
|
||||
escape = True
|
||||
continue
|
||||
if ch == string_char:
|
||||
in_string = False
|
||||
continue
|
||||
if ch in ('"', "'"):
|
||||
in_string = True
|
||||
string_char = ch
|
||||
continue
|
||||
if ch == '{':
|
||||
if depth == 0:
|
||||
start_index = idx
|
||||
depth += 1
|
||||
continue
|
||||
if ch == '}':
|
||||
if depth > 0:
|
||||
depth -= 1
|
||||
if depth == 0 and start_index is not None:
|
||||
return text[start_index:idx + 1]
|
||||
return None
|
||||
|
||||
|
||||
def parse_llm_response(res: str) -> dict:
|
||||
"""
|
||||
仅针对 JSON 的稳健解析:
|
||||
1) 优先解析 ```json/json5``` 或未标注语言的代码块
|
||||
2) 自由文本中定位首个平衡的 {...}
|
||||
3) 整体 json5 兜底
|
||||
最终返回字典;否则抛错。
|
||||
"""
|
||||
res = (res or '').strip()
|
||||
if not res:
|
||||
return {}
|
||||
|
||||
# 1) 优先解析代码块(仅 json/json5/未标注语言)
|
||||
for lang, block in _extract_code_blocks(res):
|
||||
if lang and lang not in ("json", "json5"):
|
||||
continue
|
||||
# 先在块内找平衡对象
|
||||
span = _find_first_balanced_json_object(block)
|
||||
candidates = [span] if span else [block]
|
||||
for cand in candidates:
|
||||
if not cand:
|
||||
continue
|
||||
try:
|
||||
obj = json5.loads(cand)
|
||||
if isinstance(obj, dict):
|
||||
return obj
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 2) 扫描全文首个平衡的JSON对象
|
||||
json_span = _find_first_balanced_json_object(res)
|
||||
if json_span:
|
||||
try:
|
||||
obj = json5.loads(json_span)
|
||||
if isinstance(obj, dict):
|
||||
return obj
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 3) 整体 json5 兜底
|
||||
obj = json5.loads(res)
|
||||
return obj
|
||||
|
||||
|
||||
def call_and_parse_llm(prompt: str, mode: str = "normal") -> dict:
|
||||
"""
|
||||
将 LLM 调用与解析合并,并在解析失败时按配置重试。
|
||||
成功返回 dict,超过重试次数仍失败则抛错。
|
||||
"""
|
||||
max_retries = int(getattr(CONFIG.ai, "max_parse_retries", 0))
|
||||
last_err: Exception | None = None
|
||||
for _ in range(1 + max_retries):
|
||||
res = call_llm(prompt, mode)
|
||||
try:
|
||||
return parse_llm_response(res)
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
continue
|
||||
raise ValueError(f"LLM响应解析失败,已重试 {max_retries} 次") from last_err
|
||||
|
||||
|
||||
async def call_and_parse_llm_async(prompt: str, mode: str = "normal") -> dict:
|
||||
"""
|
||||
异步版本:将 LLM 调用与解析合并,并在解析失败时按配置重试。
|
||||
成功返回 dict,超过重试次数仍失败则抛错。
|
||||
"""
|
||||
max_retries = int(getattr(CONFIG.ai, "max_parse_retries", 0))
|
||||
last_err: Exception | None = None
|
||||
for _ in range(1 + max_retries):
|
||||
res = await call_llm_async(prompt, mode)
|
||||
try:
|
||||
return parse_llm_response(res)
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
continue
|
||||
raise ValueError(f"LLM响应解析失败,已重试 {max_retries} 次") from last_err
|
||||
|
||||
|
||||
def get_prompt_and_call_llm(template_path: Path, infos: dict, mode="normal") -> dict:
|
||||
"""
|
||||
根据模板,获取提示词,并调用LLM
|
||||
"""
|
||||
template = read_txt(template_path)
|
||||
prompt = get_prompt(template, infos)
|
||||
return call_and_parse_llm(prompt, mode)
|
||||
|
||||
async def get_prompt_and_call_llm_async(template_path: Path, infos: dict, mode="normal") -> dict:
|
||||
"""
|
||||
异步版本:根据模板,获取提示词,并调用LLM
|
||||
"""
|
||||
template = read_txt(template_path)
|
||||
prompt = get_prompt(template, infos)
|
||||
return await call_and_parse_llm_async(prompt, mode)
|
||||
|
||||
def get_ai_prompt_and_call_llm(infos: dict, mode="normal") -> dict:
|
||||
"""
|
||||
根据模板,获取提示词,并调用LLM
|
||||
"""
|
||||
template_path = CONFIG.paths.templates / "ai.txt"
|
||||
return get_prompt_and_call_llm(template_path, infos, mode)
|
||||
|
||||
async def get_ai_prompt_and_call_llm_async(infos: dict, mode="normal") -> dict:
|
||||
"""
|
||||
异步版本:根据模板,获取提示词,并调用LLM
|
||||
"""
|
||||
template_path = CONFIG.paths.templates / "ai.txt"
|
||||
return await get_prompt_and_call_llm_async(template_path, infos, mode)
|
||||
24
src/utils/llm/__init__.py
Normal file
24
src/utils/llm/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
LLM 调用模块
|
||||
|
||||
提供三个核心 API:
|
||||
- call_llm: 基础调用,返回原始文本
|
||||
- call_llm_json: 调用并解析为 JSON
|
||||
- call_llm_with_template: 使用模板调用(最常用)
|
||||
"""
|
||||
|
||||
from .client import call_llm, call_llm_json, call_llm_with_template, call_ai_action
|
||||
from .config import LLMMode
|
||||
from .exceptions import LLMError, ParseError, ConfigError
|
||||
|
||||
__all__ = [
|
||||
"call_llm",
|
||||
"call_llm_json",
|
||||
"call_llm_with_template",
|
||||
"call_ai_action",
|
||||
"LLMMode",
|
||||
"LLMError",
|
||||
"ParseError",
|
||||
"ConfigError",
|
||||
]
|
||||
|
||||
129
src/utils/llm/client.py
Normal file
129
src/utils/llm/client.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""LLM 客户端核心调用逻辑"""
|
||||
|
||||
from pathlib import Path
|
||||
from litellm import completion
|
||||
|
||||
from .config import LLMMode, LLMConfig
|
||||
from .parser import parse_json
|
||||
from .prompt import build_prompt, load_template
|
||||
from .exceptions import LLMError, ParseError
|
||||
from src.run.log import log_llm_call
|
||||
|
||||
|
||||
async def call_llm(prompt: str, mode: LLMMode = LLMMode.NORMAL) -> str:
|
||||
"""
|
||||
最基础的 LLM 调用,返回原始文本
|
||||
|
||||
Args:
|
||||
prompt: 输入提示词
|
||||
mode: 调用模式
|
||||
|
||||
Returns:
|
||||
str: LLM 返回的原始文本
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# 获取配置
|
||||
config = LLMConfig.from_mode(mode)
|
||||
|
||||
# 调用 litellm(包装为异步)
|
||||
def _call():
|
||||
response = completion(
|
||||
model=config.model_name,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
result = await asyncio.to_thread(_call)
|
||||
|
||||
# 记录日志
|
||||
log_llm_call(config.model_name, prompt, result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def call_llm_json(
|
||||
prompt: str,
|
||||
mode: LLMMode = LLMMode.NORMAL,
|
||||
max_retries: int | None = None
|
||||
) -> dict:
|
||||
"""
|
||||
调用 LLM 并解析为 JSON,内置重试机制
|
||||
|
||||
Args:
|
||||
prompt: 输入提示词
|
||||
mode: 调用模式
|
||||
max_retries: 最大重试次数,None 则从配置读取
|
||||
|
||||
Returns:
|
||||
dict: 解析后的 JSON 对象
|
||||
|
||||
Raises:
|
||||
LLMError: 解析失败且重试次数用尽时抛出
|
||||
"""
|
||||
if max_retries is None:
|
||||
from src.utils.config import CONFIG
|
||||
max_retries = int(getattr(CONFIG.ai, "max_parse_retries", 0))
|
||||
|
||||
last_error = None
|
||||
for attempt in range(max_retries + 1):
|
||||
response = await call_llm(prompt, mode)
|
||||
try:
|
||||
return parse_json(response)
|
||||
except ParseError as e:
|
||||
last_error = e
|
||||
if attempt < max_retries:
|
||||
continue # 继续重试
|
||||
# 最后一次失败,抛出详细错误
|
||||
raise LLMError(
|
||||
f"解析失败(重试 {max_retries} 次后)",
|
||||
cause=last_error
|
||||
) from last_error
|
||||
|
||||
# 不应该到这里,但为了类型检查
|
||||
raise LLMError("未知错误")
|
||||
|
||||
|
||||
async def call_llm_with_template(
|
||||
template_path: Path | str,
|
||||
infos: dict,
|
||||
mode: LLMMode = LLMMode.NORMAL,
|
||||
max_retries: int | None = None
|
||||
) -> dict:
|
||||
"""
|
||||
使用模板调用 LLM(最常用的高级接口)
|
||||
|
||||
Args:
|
||||
template_path: 模板文件路径
|
||||
infos: 要填充的信息字典
|
||||
mode: 调用模式
|
||||
max_retries: 最大重试次数,None 则从配置读取
|
||||
|
||||
Returns:
|
||||
dict: 解析后的 JSON 对象
|
||||
"""
|
||||
template = load_template(template_path)
|
||||
prompt = build_prompt(template, infos)
|
||||
return await call_llm_json(prompt, mode, max_retries)
|
||||
|
||||
|
||||
async def call_ai_action(
|
||||
infos: dict,
|
||||
mode: LLMMode = LLMMode.NORMAL
|
||||
) -> dict:
|
||||
"""
|
||||
AI 行动决策专用接口
|
||||
|
||||
Args:
|
||||
infos: 行动决策所需信息
|
||||
mode: 调用模式
|
||||
|
||||
Returns:
|
||||
dict: AI 行动决策结果
|
||||
"""
|
||||
from src.utils.config import CONFIG
|
||||
template_path = CONFIG.paths.templates / "ai.txt"
|
||||
return await call_llm_with_template(template_path, infos, mode)
|
||||
|
||||
48
src/utils/llm/config.py
Normal file
48
src/utils/llm/config.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""LLM 配置管理"""
|
||||
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
|
||||
|
||||
class LLMMode(str, Enum):
|
||||
"""LLM 调用模式"""
|
||||
NORMAL = "normal"
|
||||
FAST = "fast"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMConfig:
|
||||
"""LLM 配置数据类"""
|
||||
model_name: str
|
||||
api_key: str
|
||||
base_url: str
|
||||
|
||||
@classmethod
|
||||
def from_mode(cls, mode: LLMMode) -> 'LLMConfig':
|
||||
"""
|
||||
根据模式创建配置
|
||||
|
||||
Args:
|
||||
mode: LLM 调用模式
|
||||
|
||||
Returns:
|
||||
LLMConfig: 配置对象
|
||||
"""
|
||||
from src.utils.config import CONFIG
|
||||
|
||||
# 根据模式选择模型
|
||||
model_name = (
|
||||
CONFIG.llm.model_name if mode == LLMMode.NORMAL
|
||||
else CONFIG.llm.fast_model_name
|
||||
)
|
||||
|
||||
# API Key 优先从环境变量读取
|
||||
api_key = os.getenv("QWEN_API_KEY") or CONFIG.llm.key
|
||||
|
||||
return cls(
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
base_url=CONFIG.llm.base_url
|
||||
)
|
||||
|
||||
24
src/utils/llm/exceptions.py
Normal file
24
src/utils/llm/exceptions.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""LLM 相关异常定义"""
|
||||
|
||||
|
||||
class LLMError(Exception):
|
||||
"""LLM 相关错误的基类"""
|
||||
|
||||
def __init__(self, message: str, *, cause: Exception | None = None, **context):
|
||||
super().__init__(message)
|
||||
self.cause = cause
|
||||
self.context = context
|
||||
|
||||
|
||||
class ParseError(LLMError):
|
||||
"""JSON 解析失败"""
|
||||
|
||||
def __init__(self, message: str, *, raw_text: str = ""):
|
||||
super().__init__(message, raw_text=raw_text)
|
||||
self.raw_text = raw_text
|
||||
|
||||
|
||||
class ConfigError(LLMError):
|
||||
"""配置错误"""
|
||||
pass
|
||||
|
||||
185
src/utils/llm/parser.py
Normal file
185
src/utils/llm/parser.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""JSON 解析逻辑"""
|
||||
|
||||
import re
|
||||
import json5
|
||||
from .exceptions import ParseError
|
||||
|
||||
|
||||
def parse_json(text: str) -> dict:
|
||||
"""
|
||||
主解析入口,依次尝试多种策略
|
||||
|
||||
Args:
|
||||
text: 待解析的文本
|
||||
|
||||
Returns:
|
||||
dict: 解析结果
|
||||
|
||||
Raises:
|
||||
ParseError: 所有策略均失败时抛出
|
||||
"""
|
||||
text = (text or '').strip()
|
||||
if not text:
|
||||
return {}
|
||||
|
||||
strategies = [
|
||||
try_parse_code_blocks,
|
||||
try_parse_balanced_json,
|
||||
try_parse_whole_text,
|
||||
]
|
||||
|
||||
errors = []
|
||||
for strategy in strategies:
|
||||
result = strategy(text)
|
||||
if result is not None:
|
||||
return result
|
||||
errors.append(f"{strategy.__name__}")
|
||||
|
||||
# 抛出详细错误
|
||||
raise ParseError(
|
||||
f"所有解析策略均失败: {', '.join(errors)}",
|
||||
raw_text=text[:500] # 只保留前 500 字符避免日志过长
|
||||
)
|
||||
|
||||
|
||||
def try_parse_code_blocks(text: str) -> dict | None:
|
||||
"""
|
||||
尝试从代码块解析 JSON
|
||||
|
||||
Args:
|
||||
text: 待解析的文本
|
||||
|
||||
Returns:
|
||||
dict | None: 解析成功返回字典,失败返回 None
|
||||
"""
|
||||
blocks = _extract_code_blocks(text)
|
||||
|
||||
# 只处理 json/json5 或未标注语言的代码块
|
||||
for lang, block in blocks:
|
||||
if lang and lang not in ("json", "json5"):
|
||||
continue
|
||||
|
||||
# 先在块内找平衡对象
|
||||
span = _find_balanced_json_object(block)
|
||||
candidates = [span] if span else [block]
|
||||
|
||||
for cand in candidates:
|
||||
if not cand:
|
||||
continue
|
||||
try:
|
||||
obj = json5.loads(cand)
|
||||
if isinstance(obj, dict):
|
||||
return obj
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def try_parse_balanced_json(text: str) -> dict | None:
|
||||
"""
|
||||
尝试提取并解析平衡的 JSON 对象
|
||||
|
||||
Args:
|
||||
text: 待解析的文本
|
||||
|
||||
Returns:
|
||||
dict | None: 解析成功返回字典,失败返回 None
|
||||
"""
|
||||
json_span = _find_balanced_json_object(text)
|
||||
if json_span:
|
||||
try:
|
||||
obj = json5.loads(json_span)
|
||||
if isinstance(obj, dict):
|
||||
return obj
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def try_parse_whole_text(text: str) -> dict | None:
|
||||
"""
|
||||
尝试整体解析为 JSON
|
||||
|
||||
Args:
|
||||
text: 待解析的文本
|
||||
|
||||
Returns:
|
||||
dict | None: 解析成功返回字典,失败返回 None
|
||||
"""
|
||||
try:
|
||||
obj = json5.loads(text)
|
||||
if isinstance(obj, dict):
|
||||
return obj
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _extract_code_blocks(text: str) -> list[tuple[str, str]]:
|
||||
"""
|
||||
提取所有 markdown 代码块
|
||||
|
||||
Args:
|
||||
text: 待提取的文本
|
||||
|
||||
Returns:
|
||||
list[tuple[str, str]]: (语言, 内容) 元组列表
|
||||
"""
|
||||
pattern = re.compile(r"```([^\n`]*)\n([\s\S]*?)```", re.DOTALL)
|
||||
blocks = []
|
||||
for lang, content in pattern.findall(text):
|
||||
blocks.append((lang.strip().lower(), content.strip()))
|
||||
return blocks
|
||||
|
||||
|
||||
def _find_balanced_json_object(text: str) -> str | None:
|
||||
"""
|
||||
在文本中扫描并返回首个平衡的花括号 {...} 片段
|
||||
忽略字符串中的括号
|
||||
|
||||
Args:
|
||||
text: 待扫描的文本
|
||||
|
||||
Returns:
|
||||
str | None: 找到则返回子串,否则返回 None
|
||||
"""
|
||||
depth = 0
|
||||
start_index = None
|
||||
in_string = False
|
||||
string_char = ''
|
||||
escape = False
|
||||
|
||||
for idx, ch in enumerate(text):
|
||||
if in_string:
|
||||
if escape:
|
||||
escape = False
|
||||
continue
|
||||
if ch == '\\':
|
||||
escape = True
|
||||
continue
|
||||
if ch == string_char:
|
||||
in_string = False
|
||||
continue
|
||||
|
||||
if ch in ('"', "'"):
|
||||
in_string = True
|
||||
string_char = ch
|
||||
continue
|
||||
|
||||
if ch == '{':
|
||||
if depth == 0:
|
||||
start_index = idx
|
||||
depth += 1
|
||||
continue
|
||||
|
||||
if ch == '}':
|
||||
if depth > 0:
|
||||
depth -= 1
|
||||
if depth == 0 and start_index is not None:
|
||||
return text[start_index:idx + 1]
|
||||
|
||||
return None
|
||||
|
||||
34
src/utils/llm/prompt.py
Normal file
34
src/utils/llm/prompt.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""提示词处理"""
|
||||
|
||||
from pathlib import Path
|
||||
from src.utils.strings import intentify_prompt_infos
|
||||
|
||||
|
||||
def build_prompt(template: str, infos: dict) -> str:
|
||||
"""
|
||||
根据模板构建提示词
|
||||
|
||||
Args:
|
||||
template: 提示词模板
|
||||
infos: 要填充的信息字典
|
||||
|
||||
Returns:
|
||||
str: 构建好的提示词
|
||||
"""
|
||||
processed = intentify_prompt_infos(infos)
|
||||
return template.format(**processed)
|
||||
|
||||
|
||||
def load_template(path: Path | str) -> str:
|
||||
"""
|
||||
加载模板文件
|
||||
|
||||
Args:
|
||||
path: 模板文件路径
|
||||
|
||||
Returns:
|
||||
str: 模板内容
|
||||
"""
|
||||
path = Path(path)
|
||||
return path.read_text(encoding="utf-8")
|
||||
|
||||
Reference in New Issue
Block a user