From 4f9c6d4c79d92c0d821ab75fa2bd8a9f4a9e1a60 Mon Sep 17 00:00:00 2001 From: bridge Date: Sat, 30 Aug 2025 21:58:43 +0800 Subject: [PATCH] add llm ai --- src/classes/action.py | 19 +++++++++++++++++-- src/classes/ai.py | 20 +++++++++++++------- src/classes/avatar.py | 12 ++++++++++-- src/classes/cultivation.py | 5 ++++- src/classes/tile.py | 13 ++++++++++++- src/run/run.py | 2 +- src/utils/config.py | 4 ++++ src/utils/io.py | 8 ++++++++ src/utils/llm.py | 24 +++++++++++++++++++++++- {src/static => static}/config.yml | 5 ++++- static/templates/ai.txt | 18 ++++++++++++++++++ 11 files changed, 114 insertions(+), 16 deletions(-) create mode 100644 src/utils/io.py rename {src/static => static}/config.yml (58%) create mode 100644 static/templates/ai.txt diff --git a/src/classes/action.py b/src/classes/action.py index f3728ec..7c8035d 100644 --- a/src/classes/action.py +++ b/src/classes/action.py @@ -2,6 +2,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import TYPE_CHECKING import random +import json from src.classes.essence import Essence, EssenceType from src.classes.root import Root, corres_essence_type @@ -48,6 +49,7 @@ class Move(DefineAction): """ 最基础的移动动作,在tile之间进行切换。 """ + COMMENT = "移动到某个相对位置" def execute(self, delta_x: int, delta_y: int) -> Event|NullEvent: """ 移动到某个tile @@ -71,10 +73,13 @@ class MoveToRegion(DefineAction): """ 移动到某个region """ - def execute(self, region: Region) -> Event|NullEvent: + COMMENT = "移动到某个区域" + def execute(self, region: Region|str) -> Event|NullEvent: """ 移动到某个region """ + if isinstance(region, str): + region = self.world.map.region_names[region] cur_loc = (self.avatar.pos_x, self.avatar.pos_y) region_center_loc = region.center_loc delta_x = region_center_loc[0] - cur_loc[0] @@ -89,6 +94,7 @@ class Cultivate(DefineAction): """ 修炼动作,可以增加修仙进度。 """ + COMMENT = "修炼,增进修为" def execute(self) -> Event|NullEvent: """ 修炼 @@ -116,6 +122,7 @@ class Breakthrough(DefineAction): """ 突破境界 """ + COMMENT = "尝试突破境界" def calc_success_rate(self) -> float: """ 计算突破境界的成功率 @@ -137,4 +144,12 @@ class Breakthrough(DefineAction): return Event(self.world.year, self.world.month, f"{self.avatar.name} 突破境界{res}") -ALL_ACTION_CLASSES = [Move, Cultivate, Breakthrough, MoveToRegion] \ No newline at end of file +ALL_ACTION_CLASSES = [Move, Cultivate, Breakthrough, MoveToRegion] +# 不包括Move +ACTION_SPACE = [ + # {"action": "Move", "params": {"delta_x": int, "delta_y": int}, "comment": Move.COMMENT}, + {"action": "Cultivate", "params": {}, "comment": Cultivate.COMMENT}, + {"action": "Breakthrough", "params": {}, "comment": Breakthrough.COMMENT}, + {"action": "MoveToRegion", "params": {"region": "region_name"}, "comment": MoveToRegion.COMMENT}, +] +ACTION_SPACE_STR = json.dumps(ACTION_SPACE, ensure_ascii=False) \ No newline at end of file diff --git a/src/classes/ai.py b/src/classes/ai.py index 9ff5c47..540b5f3 100644 --- a/src/classes/ai.py +++ b/src/classes/ai.py @@ -8,6 +8,8 @@ from abc import ABC, abstractmethod from src.classes.world import World from src.classes.tile import Region from src.classes.root import corres_essence_type +from src.classes.action import ACTION_SPACE_STR +from src.utils.llm import get_ai_prompt_and_call_llm class AI(ABC): """ @@ -23,12 +25,6 @@ class AI(ABC): """ pass - # def create_event(self, world: World, content: str) -> Event: - # """ - # 创建事件 - # """ - # return Event(world.year, world.month, content) - class RuleAI(AI): """ 规则AI @@ -69,4 +65,14 @@ class LLMAI(AI): """ 决定做什么 """ - pass \ No newline at end of file + action_space_str = ACTION_SPACE_STR + avatar_infos_str = str(self.avatar) + regions_str = "\n".join([str(region) for region in world.map.regions.values()]) + dict_info = { + "action_space": action_space_str, + "avatar_infos": avatar_infos_str, + "regions": regions_str + } + res = get_ai_prompt_and_call_llm(dict_info) + action_name, action_params = res["action_name"], res["action_params"] + return action_name, action_params \ No newline at end of file diff --git a/src/classes/avatar.py b/src/classes/avatar.py index 8f68a67..bdd7de7 100644 --- a/src/classes/avatar.py +++ b/src/classes/avatar.py @@ -12,7 +12,7 @@ from src.classes.cultivation import CultivationProgress, Realm from src.classes.root import Root from src.classes.age import Age from src.utils.strings import to_snake_case -from src.classes.ai import AI, RuleAI +from src.classes.ai import AI, RuleAI, LLMAI class Gender(Enum): MALE = "male" @@ -52,9 +52,17 @@ class Avatar: 在Avatar创建后自动绑定基础动作和AI """ self.tile = self.world.map.get_tile(self.pos_x, self.pos_y) - self.ai = RuleAI(self) + self.ai = LLMAI(self) + # self.ai = RuleAI(self) self._bind_basic_actions() + def __str__(self) -> str: + """ + 获取avatar的详细信息 + 尽量多打一些,因为会用来给LLM进行决策 + """ + return f"Avatar(id={self.id}, 性别={self.gender}, 年龄={self.age}, name={self.name}, 区域={self.tile.region.name}, 灵根={self.root.value}, 境界={self.cultivation_progress})" + def _bind_basic_actions(self): """ 绑定基础动作,如移动等 diff --git a/src/classes/cultivation.py b/src/classes/cultivation.py index c792751..f8ed0a3 100644 --- a/src/classes/cultivation.py +++ b/src/classes/cultivation.py @@ -148,4 +148,7 @@ class CultivationProgress: """ 检查是否可以突破 """ - return self.level in level_to_break_through.keys() \ No newline at end of file + return self.level in level_to_break_through.keys() + + def __str__(self) -> str: + return f"{self.realm.value}{self.stage.value}({self.level}级)。可以突破:{self.can_break_through()}" \ No newline at end of file diff --git a/src/classes/tile.py b/src/classes/tile.py index cb52970..edb6fa7 100644 --- a/src/classes/tile.py +++ b/src/classes/tile.py @@ -45,6 +45,15 @@ class Region(): def __post_init__(self): self.id = next(region_id_counter) + def __str__(self) -> str: + return f"区域。名字:{self.name},描述:{self.description},最浓的灵气:{self.get_most_dense_essence()}, 灵气值:{self.get_most_dense_essence_value()}" + + def get_most_dense_essence(self) -> EssenceType: + return max(self.essence.density.items(), key=lambda x: x[1])[0] + + def get_most_dense_essence_value(self) -> int: + most_dense_essence = self.get_most_dense_essence() + return self.essence.density[most_dense_essence] def __hash__(self) -> int: return hash(self.id) @@ -74,7 +83,8 @@ class Map(): """ def __init__(self, width: int, height: int): self.tiles = {} - self.regions = {} + self.regions = {} # region_id -> region + self.region_names = {} # region_name -> region self.width = width self.height = height @@ -101,6 +111,7 @@ class Map(): region.center_loc = center_loc region.area = len(locs) self.regions[region.id] = region + self.region_names[name] = region return region def get_center_locs(self, locs: list[tuple[int, int]]) -> tuple[int, int]: diff --git a/src/run/run.py b/src/run/run.py index 23acc2a..2af4650 100644 --- a/src/run/run.py +++ b/src/run/run.py @@ -93,7 +93,7 @@ def main(): sim = Simulator(world) # 创建角色,传入当前年份确保年龄与生日匹配 - sim.avatars.update(make_avatars(world, count=14, current_year=world.year)) + sim.avatars.update(make_avatars(world, count=2, current_year=world.year)) front = Front( simulator=sim, diff --git a/src/utils/config.py b/src/utils/config.py index a024872..c8a7641 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -32,6 +32,10 @@ def load_config(): # 合并配置,local_config优先级更高 config = OmegaConf.merge(base_config, local_config) + + # 把paths下的所有值pathlib化 + for key, value in config.paths.items(): + config.paths[key] = Path(value) return config diff --git a/src/utils/io.py b/src/utils/io.py new file mode 100644 index 0000000..e95e2f6 --- /dev/null +++ b/src/utils/io.py @@ -0,0 +1,8 @@ +from pathlib import Path + +def read_txt(path: Path) -> str: + """ + 读入中文txt文件 + """ + with open(path, "r", encoding="utf-8") as f: + return f.read() \ No newline at end of file diff --git a/src/utils/llm.py b/src/utils/llm.py index 0f4f73f..30a061d 100644 --- a/src/utils/llm.py +++ b/src/utils/llm.py @@ -1,7 +1,10 @@ from litellm import completion from langchain.prompts import PromptTemplate +from pathlib import Path +import json from src.utils.config import CONFIG +from src.utils.io import read_txt def get_prompt(template: str, infos: dict) -> str: """ @@ -31,4 +34,23 @@ def call_llm(prompt: str) -> str: ) # 返回生成的内容 - return response.choices[0].message.content \ No newline at end of file + return response.choices[0].message.content + +def get_prompt_and_call_llm(template_path: Path, infos: dict) -> str: + """ + 根据模板,获取提示词,并调用LLM + """ + template = read_txt(template_path) + prompt = get_prompt(template, infos) + res = call_llm(prompt) + json_res = json.loads(res) + print(f"prompt = {prompt}") + print(f"res = {res}") + return json_res + +def get_ai_prompt_and_call_llm(infos: dict) -> dict: + """ + 根据模板,获取提示词,并调用LLM + """ + template_path = CONFIG.paths.templates / "ai.txt" + return get_prompt_and_call_llm(template_path, infos) \ No newline at end of file diff --git a/src/static/config.yml b/static/config.yml similarity index 58% rename from src/static/config.yml rename to static/config.yml index 34cf973..970ecfb 100644 --- a/src/static/config.yml +++ b/static/config.yml @@ -1,4 +1,7 @@ llm: # 填入litellm支持的model name和key model_name: - key: \ No newline at end of file + key: + +paths: + templates: static/templates/ \ No newline at end of file diff --git a/static/templates/ai.txt b/static/templates/ai.txt new file mode 100644 index 0000000..663b53c --- /dev/null +++ b/static/templates/ai.txt @@ -0,0 +1,18 @@ + + + +你是一个决策者,这是一个修仙的仙侠世界,你负责来决定一些NPC的下一步行为。 +每个角色均拥有的动作空间和需要的参数为: +{action_space} +世界地图上存在的区域为: +{regions} +你需要进行决策的NPC的基本信息为: +{avatar_infos} + +注意,只返回json格式的动作 +返回格式: +{{ + "avatar_id": ..., + "action_name": ..., + "action_params": ..., +}} \ No newline at end of file