update llm
This commit is contained in:
@@ -6,7 +6,6 @@ PyYAML>=6.0
|
||||
# LLM integration
|
||||
litellm>=1.0.0
|
||||
omegaconf>=2.3.0
|
||||
langchain>=0.1.0
|
||||
json5>=0.9.0
|
||||
|
||||
# Development and testing (optional)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from litellm import completion
|
||||
from langchain.prompts import PromptTemplate
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
import re
|
||||
@@ -14,14 +13,13 @@ def get_prompt(template: str, infos: dict) -> str:
|
||||
"""
|
||||
根据模板,获取提示词
|
||||
"""
|
||||
prompt_template = PromptTemplate(template=template)
|
||||
# 将 dict/list 等结构化对象转为 JSON 字符串
|
||||
# 策略:
|
||||
# - avatar_infos: 不包装 intent(模板里已经说明是 dict[Name, info])
|
||||
# - general_action_infos: 强制包装 intent 以凸显语义
|
||||
# - 其他容器类型:默认包装 intent
|
||||
processed_infos = intentify_prompt_infos(infos)
|
||||
return prompt_template.format(**processed_infos)
|
||||
return template.format(**processed_infos)
|
||||
|
||||
|
||||
def call_llm(prompt: str, mode="normal") -> str:
|
||||
@@ -161,25 +159,55 @@ def parse_llm_response(res: str) -> dict:
|
||||
return obj
|
||||
|
||||
|
||||
def get_prompt_and_call_llm(template_path: Path, infos: dict, mode="normal") -> str:
|
||||
def call_and_parse_llm(prompt: str, mode: str = "normal") -> dict:
|
||||
"""
|
||||
将 LLM 调用与解析合并,并在解析失败时按配置重试。
|
||||
成功返回 dict,超过重试次数仍失败则抛错。
|
||||
"""
|
||||
max_retries = int(getattr(CONFIG.llm, "max_parse_retries", 0))
|
||||
last_err: Exception | None = None
|
||||
for _ in range(1 + max_retries):
|
||||
res = call_llm(prompt, mode)
|
||||
try:
|
||||
return parse_llm_response(res)
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
continue
|
||||
raise ValueError(f"LLM响应解析失败,已重试 {max_retries} 次") from last_err
|
||||
|
||||
|
||||
async def call_and_parse_llm_async(prompt: str, mode: str = "normal") -> dict:
|
||||
"""
|
||||
异步版本:将 LLM 调用与解析合并,并在解析失败时按配置重试。
|
||||
成功返回 dict,超过重试次数仍失败则抛错。
|
||||
"""
|
||||
max_retries = int(getattr(CONFIG.llm, "max_parse_retries", 0))
|
||||
last_err: Exception | None = None
|
||||
for _ in range(1 + max_retries):
|
||||
res = await call_llm_async(prompt, mode)
|
||||
try:
|
||||
return parse_llm_response(res)
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
continue
|
||||
raise ValueError(f"LLM响应解析失败,已重试 {max_retries} 次") from last_err
|
||||
|
||||
|
||||
def get_prompt_and_call_llm(template_path: Path, infos: dict, mode="normal") -> dict:
|
||||
"""
|
||||
根据模板,获取提示词,并调用LLM
|
||||
"""
|
||||
template = read_txt(template_path)
|
||||
prompt = get_prompt(template, infos)
|
||||
res = call_llm(prompt, mode)
|
||||
json_res = parse_llm_response(res)
|
||||
return json_res
|
||||
return call_and_parse_llm(prompt, mode)
|
||||
|
||||
async def get_prompt_and_call_llm_async(template_path: Path, infos: dict, mode="normal") -> str:
|
||||
async def get_prompt_and_call_llm_async(template_path: Path, infos: dict, mode="normal") -> dict:
|
||||
"""
|
||||
异步版本:根据模板,获取提示词,并调用LLM
|
||||
"""
|
||||
template = read_txt(template_path)
|
||||
prompt = get_prompt(template, infos)
|
||||
res = await call_llm_async(prompt, mode)
|
||||
json_res = parse_llm_response(res)
|
||||
return json_res
|
||||
return await call_and_parse_llm_async(prompt, mode)
|
||||
|
||||
def get_ai_prompt_and_call_llm(infos: dict, mode="normal") -> dict:
|
||||
"""
|
||||
|
||||
@@ -4,6 +4,7 @@ llm:
|
||||
fast_model_name: "your-fast-model-name"
|
||||
key: "your-api-key"
|
||||
base_url: "your-base-url-of-llm"
|
||||
max_parse_retries: 3
|
||||
|
||||
paths:
|
||||
templates: static/templates/
|
||||
@@ -11,7 +12,7 @@ paths:
|
||||
|
||||
ai:
|
||||
mode: "llm" # "rule" or "llm"
|
||||
max_decide_num: 3
|
||||
max_decide_num: 4
|
||||
|
||||
game:
|
||||
init_npc_num: 9
|
||||
|
||||
Reference in New Issue
Block a user