update llm
This commit is contained in:
@@ -9,9 +9,10 @@ import asyncio
|
||||
|
||||
from src.classes.world import World
|
||||
from src.classes.event import Event, NULL_EVENT
|
||||
from src.utils.llm import call_ai_action
|
||||
from src.utils.llm import call_llm_with_task_name
|
||||
from src.classes.typings import ACTION_NAME_PARAMS_PAIRS
|
||||
from src.classes.actions import ACTION_INFOS_STR
|
||||
from src.utils.config import CONFIG
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.classes.avatar import Avatar
|
||||
@@ -66,7 +67,8 @@ class LLMAI(AI):
|
||||
"world_info": world_info,
|
||||
"general_action_infos": general_action_infos,
|
||||
}
|
||||
res = await call_ai_action(info)
|
||||
template_path = CONFIG.paths.templates / "ai.txt"
|
||||
res = await call_llm_with_task_name("action_decision", template_path, info)
|
||||
return avatar, res
|
||||
|
||||
# 直接并发所有任务
|
||||
|
||||
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from src.classes.event import Event
|
||||
from src.utils.config import CONFIG
|
||||
from src.utils.llm import call_llm_with_template, LLMMode
|
||||
from src.utils.llm import call_llm_with_task_name
|
||||
from src.run.log import get_logger
|
||||
from src.classes.actions import ACTION_INFOS_STR
|
||||
|
||||
@@ -92,8 +92,8 @@ async def generate_long_term_objective(avatar: "Avatar") -> Optional[LongTermObj
|
||||
"general_action_infos": ACTION_INFOS_STR,
|
||||
}
|
||||
|
||||
# 调用LLM并自动解析JSON(使用fast模型)
|
||||
response_data = await call_llm_with_template(template_path, infos, LLMMode.NORMAL)
|
||||
# 调用LLM并自动解析JSON(使用配置的模型模式)
|
||||
response_data = await call_llm_with_task_name("long_term_objective", template_path, infos)
|
||||
|
||||
content = response_data.get("long_term_objective", "").strip()
|
||||
|
||||
|
||||
@@ -5,9 +5,8 @@ from typing import TYPE_CHECKING
|
||||
import asyncio
|
||||
|
||||
from src.classes.action.action import DefineAction, ActualActionMixin, LLMAction
|
||||
from src.classes.tile import get_avatar_distance
|
||||
from src.classes.event import Event
|
||||
from src.utils.llm import call_llm_with_template, LLMMode
|
||||
from src.utils.llm import call_llm_with_task_name
|
||||
from src.utils.config import CONFIG
|
||||
from src.classes.relation import relation_display_names, Relation
|
||||
from src.classes.relations import get_possible_new_relations
|
||||
@@ -86,7 +85,7 @@ class MutualAction(DefineAction, LLMAction, ActualActionMixin, TargetingMixin):
|
||||
async def _call_llm_feedback(self, infos: dict) -> dict:
|
||||
"""异步调用 LLM 获取反馈"""
|
||||
template_path = self._get_template_path()
|
||||
return await call_llm_with_template(template_path, infos, LLMMode.FAST)
|
||||
return await call_llm_with_task_name("interaction_feedback", template_path, infos)
|
||||
|
||||
def _set_target_immediate_action(self, target_avatar: "Avatar", action_name: str, action_params: dict) -> None:
|
||||
"""
|
||||
|
||||
@@ -10,7 +10,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from src.classes.event import Event
|
||||
from src.utils.config import CONFIG
|
||||
from src.utils.llm import call_llm_with_template, LLMMode
|
||||
from src.utils.llm import call_llm_with_task_name
|
||||
from src.run.log import get_logger
|
||||
|
||||
from src.classes.nickname_data import Nickname
|
||||
@@ -75,7 +75,7 @@ async def generate_nickname(avatar: "Avatar") -> Optional[dict]:
|
||||
}
|
||||
|
||||
# 调用LLM并自动解析JSON
|
||||
response_data = await call_llm_with_template(template_path, infos, LLMMode.NORMAL)
|
||||
response_data = await call_llm_with_task_name("nickname", template_path, infos)
|
||||
|
||||
nickname = response_data.get("nickname", "").strip()
|
||||
thinking = response_data.get("thinking", "")
|
||||
|
||||
@@ -13,7 +13,7 @@ from src.classes.relations import (
|
||||
)
|
||||
from src.classes.calendar import get_date_str
|
||||
from src.classes.event import Event
|
||||
from src.utils.llm import call_llm_with_template, LLMMode
|
||||
from src.utils.llm import call_llm_with_task_name
|
||||
from src.utils.config import CONFIG
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -64,7 +64,7 @@ class RelationResolver:
|
||||
"""
|
||||
infos = RelationResolver._build_prompt_data(avatar_a, avatar_b)
|
||||
|
||||
result = await call_llm_with_template(RelationResolver.TEMPLATE_PATH, infos, mode=LLMMode.FAST)
|
||||
result = await call_llm_with_task_name("relation_resolver", RelationResolver.TEMPLATE_PATH, infos)
|
||||
|
||||
changed = result.get("changed", False)
|
||||
if not changed:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Any, Dict, List, TYPE_CHECKING
|
||||
from src.utils.llm import call_llm_with_template
|
||||
from src.utils.llm import call_llm_with_task_name
|
||||
from src.utils.config import CONFIG
|
||||
import json
|
||||
|
||||
@@ -33,7 +33,8 @@ async def make_decision(
|
||||
|
||||
# 3. 调用 AI
|
||||
template_path = CONFIG.paths.templates / "single_choice.txt"
|
||||
result = await call_llm_with_template(
|
||||
result = await call_llm_with_task_name(
|
||||
"single_choice",
|
||||
template_path,
|
||||
infos={
|
||||
"avatar_infos": avatar_infos,
|
||||
|
||||
@@ -7,7 +7,7 @@ if TYPE_CHECKING:
|
||||
from src.classes.avatar import Avatar
|
||||
|
||||
from src.utils.config import CONFIG
|
||||
from src.utils.llm import call_llm_with_template, LLMMode
|
||||
from src.utils.llm import call_llm_with_task_name
|
||||
from src.classes.relations import (
|
||||
process_relation_changes,
|
||||
get_relation_change_context
|
||||
@@ -118,7 +118,7 @@ class StoryTeller:
|
||||
infos = StoryTeller._build_template_data(event, res, avatar_infos, prompt, *actors)
|
||||
|
||||
# 移除了 try-except 块,允许异常向上冒泡,以便 Fail Fast
|
||||
data = await call_llm_with_template(template_path, infos, LLMMode.FAST)
|
||||
data = await call_llm_with_task_name("story_teller", template_path, infos)
|
||||
story = data.get("story", "").strip()
|
||||
|
||||
if story:
|
||||
|
||||
@@ -7,16 +7,17 @@ LLM 调用模块
|
||||
- call_llm_with_template: 使用模板调用(最常用)
|
||||
"""
|
||||
|
||||
from .client import call_llm, call_llm_json, call_llm_with_template, call_ai_action
|
||||
from .config import LLMMode
|
||||
from .client import call_llm, call_llm_json, call_llm_with_template, call_llm_with_task_name
|
||||
from .config import LLMMode, get_task_mode
|
||||
from .exceptions import LLMError, ParseError, ConfigError
|
||||
|
||||
__all__ = [
|
||||
"call_llm",
|
||||
"call_llm_json",
|
||||
"call_llm_with_template",
|
||||
"call_ai_action",
|
||||
"call_llm_with_task_name",
|
||||
"LLMMode",
|
||||
"get_task_mode",
|
||||
"LLMError",
|
||||
"ParseError",
|
||||
"ConfigError",
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Optional
|
||||
|
||||
from src.run.log import log_llm_call
|
||||
from src.utils.config import CONFIG
|
||||
from .config import LLMMode, LLMConfig
|
||||
from .config import LLMMode, LLMConfig, get_task_mode
|
||||
from .parser import parse_json
|
||||
from .prompt import build_prompt, load_template
|
||||
from .exceptions import LLMError, ParseError
|
||||
@@ -139,7 +139,24 @@ async def call_llm_with_template(
|
||||
return await call_llm_json(prompt, mode, max_retries)
|
||||
|
||||
|
||||
async def call_ai_action(infos: dict, mode: LLMMode = LLMMode.NORMAL) -> dict:
|
||||
"""AI 行动决策专用接口"""
|
||||
template_path = CONFIG.paths.templates / "ai.txt"
|
||||
return await call_llm_with_template(template_path, infos, mode)
|
||||
async def call_llm_with_task_name(
|
||||
task_name: str,
|
||||
template_path: Path | str,
|
||||
infos: dict,
|
||||
max_retries: int | None = None
|
||||
) -> dict:
|
||||
"""
|
||||
根据任务名称自动选择 LLM 模式并调用
|
||||
|
||||
Args:
|
||||
task_name: 任务名称,用于在 config.yml 中查找对应的模式
|
||||
template_path: 模板路径
|
||||
infos: 模板参数
|
||||
max_retries: 最大重试次数
|
||||
|
||||
Returns:
|
||||
dict: LLM 返回的 JSON 数据
|
||||
"""
|
||||
mode = get_task_mode(task_name)
|
||||
return await call_llm_with_template(template_path, infos, mode, max_retries)
|
||||
|
||||
|
||||
@@ -46,3 +46,19 @@ class LLMConfig:
|
||||
base_url=CONFIG.llm.base_url
|
||||
)
|
||||
|
||||
|
||||
def get_task_mode(task_name: str) -> LLMMode:
|
||||
"""
|
||||
获取指定任务的 LLM 调用模式
|
||||
|
||||
Args:
|
||||
task_name: 任务名称 (配置在 llm.default_modes 下的 key)
|
||||
|
||||
Returns:
|
||||
LLMMode: 对应的模式,如果未配置则默认返回 NORMAL
|
||||
"""
|
||||
from src.utils.config import CONFIG
|
||||
|
||||
# 获取配置的模式字符串,默认 normal
|
||||
mode_str = getattr(CONFIG.llm.default_modes, task_name, "normal")
|
||||
return LLMMode(mode_str)
|
||||
|
||||
@@ -7,6 +7,14 @@ llm:
|
||||
model_name: "openai/qwen-plus"
|
||||
fast_model_name: "openai/qwen-flash"
|
||||
base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
default_modes:
|
||||
action_decision: "normal"
|
||||
long_term_objective: "normal"
|
||||
nickname: "normal"
|
||||
single_choice: "normal"
|
||||
relation_resolver: "fast"
|
||||
story_teller: "fast"
|
||||
interaction_feedback: "fast"
|
||||
|
||||
paths:
|
||||
templates: static/templates/
|
||||
|
||||
Reference in New Issue
Block a user