fix action bugs
This commit is contained in:
@@ -17,7 +17,7 @@ def get_prompt(template: str, infos: dict) -> str:
|
||||
return prompt_template.format(**infos)
|
||||
|
||||
|
||||
def call_llm(prompt: str) -> str:
|
||||
def call_llm(prompt: str, mode="normal") -> str:
|
||||
"""
|
||||
调用LLM
|
||||
|
||||
@@ -27,7 +27,12 @@ def call_llm(prompt: str) -> str:
|
||||
str: LLM返回的结果
|
||||
"""
|
||||
# 从配置中获取模型信息
|
||||
model_name = CONFIG.llm.model_name
|
||||
if mode == "normal":
|
||||
model_name = CONFIG.llm.model_name
|
||||
elif mode == "fast":
|
||||
model_name = CONFIG.llm.fast_model_name
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
api_key = CONFIG.llm.key
|
||||
base_url = CONFIG.llm.base_url
|
||||
# 调用litellm的completion函数
|
||||
@@ -43,7 +48,7 @@ def call_llm(prompt: str) -> str:
|
||||
log_llm_call(model_name, prompt, result) # 记录日志
|
||||
return result
|
||||
|
||||
async def call_llm_async(prompt: str) -> str:
|
||||
async def call_llm_async(prompt: str, mode="normal") -> str:
|
||||
"""
|
||||
异步调用LLM
|
||||
|
||||
@@ -53,7 +58,7 @@ async def call_llm_async(prompt: str) -> str:
|
||||
str: LLM返回的结果
|
||||
"""
|
||||
# 使用asyncio.to_thread包装同步调用
|
||||
result = await asyncio.to_thread(call_llm, prompt)
|
||||
result = await asyncio.to_thread(call_llm, prompt, mode)
|
||||
return result
|
||||
|
||||
def parse_llm_response(res: str) -> dict:
|
||||
@@ -69,38 +74,36 @@ def parse_llm_response(res: str) -> dict:
|
||||
|
||||
return json5.loads(res)
|
||||
|
||||
def get_prompt_and_call_llm(template_path: Path, infos: dict) -> str:
|
||||
def get_prompt_and_call_llm(template_path: Path, infos: dict, mode="normal") -> str:
|
||||
"""
|
||||
根据模板,获取提示词,并调用LLM
|
||||
"""
|
||||
template = read_txt(template_path)
|
||||
prompt = get_prompt(template, infos)
|
||||
res = call_llm(prompt)
|
||||
res = call_llm(prompt, mode)
|
||||
json_res = parse_llm_response(res)
|
||||
return json_res
|
||||
|
||||
async def get_prompt_and_call_llm_async(template_path: Path, infos: dict) -> str:
|
||||
async def get_prompt_and_call_llm_async(template_path: Path, infos: dict, mode="normal") -> str:
|
||||
"""
|
||||
异步版本:根据模板,获取提示词,并调用LLM
|
||||
"""
|
||||
template = read_txt(template_path)
|
||||
prompt = get_prompt(template, infos)
|
||||
res = await call_llm_async(prompt)
|
||||
res = await call_llm_async(prompt, mode)
|
||||
json_res = parse_llm_response(res)
|
||||
# print(f"prompt = {prompt}")
|
||||
# print(f"json_res = {json_res}")
|
||||
return json_res
|
||||
|
||||
def get_ai_prompt_and_call_llm(infos: dict) -> dict:
|
||||
def get_ai_prompt_and_call_llm(infos: dict, mode="normal") -> dict:
|
||||
"""
|
||||
根据模板,获取提示词,并调用LLM
|
||||
"""
|
||||
template_path = CONFIG.paths.templates / "ai.txt"
|
||||
return get_prompt_and_call_llm(template_path, infos)
|
||||
return get_prompt_and_call_llm(template_path, infos, mode)
|
||||
|
||||
async def get_ai_prompt_and_call_llm_async(infos: dict) -> dict:
|
||||
async def get_ai_prompt_and_call_llm_async(infos: dict, mode="normal") -> dict:
|
||||
"""
|
||||
异步版本:根据模板,获取提示词,并调用LLM
|
||||
"""
|
||||
template_path = CONFIG.paths.templates / "ai.txt"
|
||||
return await get_prompt_and_call_llm_async(template_path, infos)
|
||||
return await get_prompt_and_call_llm_async(template_path, infos, mode)
|
||||
Reference in New Issue
Block a user