add llm ai
This commit is contained in:
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
import random
|
||||
import json
|
||||
|
||||
from src.classes.essence import Essence, EssenceType
|
||||
from src.classes.root import Root, corres_essence_type
|
||||
@@ -48,6 +49,7 @@ class Move(DefineAction):
|
||||
"""
|
||||
最基础的移动动作,在tile之间进行切换。
|
||||
"""
|
||||
COMMENT = "移动到某个相对位置"
|
||||
def execute(self, delta_x: int, delta_y: int) -> Event|NullEvent:
|
||||
"""
|
||||
移动到某个tile
|
||||
@@ -71,10 +73,13 @@ class MoveToRegion(DefineAction):
|
||||
"""
|
||||
移动到某个region
|
||||
"""
|
||||
def execute(self, region: Region) -> Event|NullEvent:
|
||||
COMMENT = "移动到某个区域"
|
||||
def execute(self, region: Region|str) -> Event|NullEvent:
|
||||
"""
|
||||
移动到某个region
|
||||
"""
|
||||
if isinstance(region, str):
|
||||
region = self.world.map.region_names[region]
|
||||
cur_loc = (self.avatar.pos_x, self.avatar.pos_y)
|
||||
region_center_loc = region.center_loc
|
||||
delta_x = region_center_loc[0] - cur_loc[0]
|
||||
@@ -89,6 +94,7 @@ class Cultivate(DefineAction):
|
||||
"""
|
||||
修炼动作,可以增加修仙进度。
|
||||
"""
|
||||
COMMENT = "修炼,增进修为"
|
||||
def execute(self) -> Event|NullEvent:
|
||||
"""
|
||||
修炼
|
||||
@@ -116,6 +122,7 @@ class Breakthrough(DefineAction):
|
||||
"""
|
||||
突破境界
|
||||
"""
|
||||
COMMENT = "尝试突破境界"
|
||||
def calc_success_rate(self) -> float:
|
||||
"""
|
||||
计算突破境界的成功率
|
||||
@@ -137,4 +144,12 @@ class Breakthrough(DefineAction):
|
||||
return Event(self.world.year, self.world.month, f"{self.avatar.name} 突破境界{res}")
|
||||
|
||||
|
||||
ALL_ACTION_CLASSES = [Move, Cultivate, Breakthrough, MoveToRegion]
|
||||
ALL_ACTION_CLASSES = [Move, Cultivate, Breakthrough, MoveToRegion]
|
||||
# 不包括Move
|
||||
ACTION_SPACE = [
|
||||
# {"action": "Move", "params": {"delta_x": int, "delta_y": int}, "comment": Move.COMMENT},
|
||||
{"action": "Cultivate", "params": {}, "comment": Cultivate.COMMENT},
|
||||
{"action": "Breakthrough", "params": {}, "comment": Breakthrough.COMMENT},
|
||||
{"action": "MoveToRegion", "params": {"region": "region_name"}, "comment": MoveToRegion.COMMENT},
|
||||
]
|
||||
ACTION_SPACE_STR = json.dumps(ACTION_SPACE, ensure_ascii=False)
|
||||
@@ -8,6 +8,8 @@ from abc import ABC, abstractmethod
|
||||
from src.classes.world import World
|
||||
from src.classes.tile import Region
|
||||
from src.classes.root import corres_essence_type
|
||||
from src.classes.action import ACTION_SPACE_STR
|
||||
from src.utils.llm import get_ai_prompt_and_call_llm
|
||||
|
||||
class AI(ABC):
|
||||
"""
|
||||
@@ -23,12 +25,6 @@ class AI(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
# def create_event(self, world: World, content: str) -> Event:
|
||||
# """
|
||||
# 创建事件
|
||||
# """
|
||||
# return Event(world.year, world.month, content)
|
||||
|
||||
class RuleAI(AI):
|
||||
"""
|
||||
规则AI
|
||||
@@ -69,4 +65,14 @@ class LLMAI(AI):
|
||||
"""
|
||||
决定做什么
|
||||
"""
|
||||
pass
|
||||
action_space_str = ACTION_SPACE_STR
|
||||
avatar_infos_str = str(self.avatar)
|
||||
regions_str = "\n".join([str(region) for region in world.map.regions.values()])
|
||||
dict_info = {
|
||||
"action_space": action_space_str,
|
||||
"avatar_infos": avatar_infos_str,
|
||||
"regions": regions_str
|
||||
}
|
||||
res = get_ai_prompt_and_call_llm(dict_info)
|
||||
action_name, action_params = res["action_name"], res["action_params"]
|
||||
return action_name, action_params
|
||||
@@ -12,7 +12,7 @@ from src.classes.cultivation import CultivationProgress, Realm
|
||||
from src.classes.root import Root
|
||||
from src.classes.age import Age
|
||||
from src.utils.strings import to_snake_case
|
||||
from src.classes.ai import AI, RuleAI
|
||||
from src.classes.ai import AI, RuleAI, LLMAI
|
||||
|
||||
class Gender(Enum):
|
||||
MALE = "male"
|
||||
@@ -52,9 +52,17 @@ class Avatar:
|
||||
在Avatar创建后自动绑定基础动作和AI
|
||||
"""
|
||||
self.tile = self.world.map.get_tile(self.pos_x, self.pos_y)
|
||||
self.ai = RuleAI(self)
|
||||
self.ai = LLMAI(self)
|
||||
# self.ai = RuleAI(self)
|
||||
self._bind_basic_actions()
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
获取avatar的详细信息
|
||||
尽量多打一些,因为会用来给LLM进行决策
|
||||
"""
|
||||
return f"Avatar(id={self.id}, 性别={self.gender}, 年龄={self.age}, name={self.name}, 区域={self.tile.region.name}, 灵根={self.root.value}, 境界={self.cultivation_progress})"
|
||||
|
||||
def _bind_basic_actions(self):
|
||||
"""
|
||||
绑定基础动作,如移动等
|
||||
|
||||
@@ -148,4 +148,7 @@ class CultivationProgress:
|
||||
"""
|
||||
检查是否可以突破
|
||||
"""
|
||||
return self.level in level_to_break_through.keys()
|
||||
return self.level in level_to_break_through.keys()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.realm.value}{self.stage.value}({self.level}级)。可以突破:{self.can_break_through()}"
|
||||
@@ -45,6 +45,15 @@ class Region():
|
||||
def __post_init__(self):
|
||||
self.id = next(region_id_counter)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"区域。名字:{self.name},描述:{self.description},最浓的灵气:{self.get_most_dense_essence()}, 灵气值:{self.get_most_dense_essence_value()}"
|
||||
|
||||
def get_most_dense_essence(self) -> EssenceType:
|
||||
return max(self.essence.density.items(), key=lambda x: x[1])[0]
|
||||
|
||||
def get_most_dense_essence_value(self) -> int:
|
||||
most_dense_essence = self.get_most_dense_essence()
|
||||
return self.essence.density[most_dense_essence]
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.id)
|
||||
@@ -74,7 +83,8 @@ class Map():
|
||||
"""
|
||||
def __init__(self, width: int, height: int):
|
||||
self.tiles = {}
|
||||
self.regions = {}
|
||||
self.regions = {} # region_id -> region
|
||||
self.region_names = {} # region_name -> region
|
||||
self.width = width
|
||||
self.height = height
|
||||
|
||||
@@ -101,6 +111,7 @@ class Map():
|
||||
region.center_loc = center_loc
|
||||
region.area = len(locs)
|
||||
self.regions[region.id] = region
|
||||
self.region_names[name] = region
|
||||
return region
|
||||
|
||||
def get_center_locs(self, locs: list[tuple[int, int]]) -> tuple[int, int]:
|
||||
|
||||
@@ -93,7 +93,7 @@ def main():
|
||||
sim = Simulator(world)
|
||||
|
||||
# 创建角色,传入当前年份确保年龄与生日匹配
|
||||
sim.avatars.update(make_avatars(world, count=14, current_year=world.year))
|
||||
sim.avatars.update(make_avatars(world, count=2, current_year=world.year))
|
||||
|
||||
front = Front(
|
||||
simulator=sim,
|
||||
|
||||
@@ -32,6 +32,10 @@ def load_config():
|
||||
|
||||
# 合并配置,local_config优先级更高
|
||||
config = OmegaConf.merge(base_config, local_config)
|
||||
|
||||
# 把paths下的所有值pathlib化
|
||||
for key, value in config.paths.items():
|
||||
config.paths[key] = Path(value)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
8
src/utils/io.py
Normal file
8
src/utils/io.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from pathlib import Path
|
||||
|
||||
def read_txt(path: Path) -> str:
|
||||
"""
|
||||
读入中文txt文件
|
||||
"""
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
@@ -1,7 +1,10 @@
|
||||
from litellm import completion
|
||||
from langchain.prompts import PromptTemplate
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
from src.utils.config import CONFIG
|
||||
from src.utils.io import read_txt
|
||||
|
||||
def get_prompt(template: str, infos: dict) -> str:
|
||||
"""
|
||||
@@ -31,4 +34,23 @@ def call_llm(prompt: str) -> str:
|
||||
)
|
||||
|
||||
# 返回生成的内容
|
||||
return response.choices[0].message.content
|
||||
return response.choices[0].message.content
|
||||
|
||||
def get_prompt_and_call_llm(template_path: Path, infos: dict) -> str:
|
||||
"""
|
||||
根据模板,获取提示词,并调用LLM
|
||||
"""
|
||||
template = read_txt(template_path)
|
||||
prompt = get_prompt(template, infos)
|
||||
res = call_llm(prompt)
|
||||
json_res = json.loads(res)
|
||||
print(f"prompt = {prompt}")
|
||||
print(f"res = {res}")
|
||||
return json_res
|
||||
|
||||
def get_ai_prompt_and_call_llm(infos: dict) -> dict:
|
||||
"""
|
||||
根据模板,获取提示词,并调用LLM
|
||||
"""
|
||||
template_path = CONFIG.paths.templates / "ai.txt"
|
||||
return get_prompt_and_call_llm(template_path, infos)
|
||||
@@ -1,4 +1,7 @@
|
||||
llm:
|
||||
# 填入litellm支持的model name和key
|
||||
model_name:
|
||||
key:
|
||||
key:
|
||||
|
||||
paths:
|
||||
templates: static/templates/
|
||||
18
static/templates/ai.txt
Normal file
18
static/templates/ai.txt
Normal file
@@ -0,0 +1,18 @@
|
||||
|
||||
|
||||
|
||||
你是一个决策者,这是一个修仙的仙侠世界,你负责来决定一些NPC的下一步行为。
|
||||
每个角色均拥有的动作空间和需要的参数为:
|
||||
{action_space}
|
||||
世界地图上存在的区域为:
|
||||
{regions}
|
||||
你需要进行决策的NPC的基本信息为:
|
||||
{avatar_infos}
|
||||
|
||||
注意,只返回json格式的动作
|
||||
返回格式:
|
||||
{{
|
||||
"avatar_id": ...,
|
||||
"action_name": ...,
|
||||
"action_params": ...,
|
||||
}}
|
||||
Reference in New Issue
Block a user