feat: data reload system
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Callable, TYPE_CHECKING, Coroutine
|
||||
|
||||
@@ -14,6 +14,11 @@ from src.run.log import get_logger
|
||||
if TYPE_CHECKING:
|
||||
from src.classes.world import World
|
||||
|
||||
@dataclass
|
||||
class History:
|
||||
text: str = ""
|
||||
modifications: dict = field(default_factory=dict)
|
||||
|
||||
class HistoryManager:
|
||||
"""
|
||||
历史管理器
|
||||
@@ -142,7 +147,7 @@ class HistoryManager:
|
||||
sect = sects_by_id.get(sid)
|
||||
if sect:
|
||||
old_name = sect.name
|
||||
self._update_obj_attrs(sect, data)
|
||||
self._update_obj_attrs(sect, data, "sects", sid_str)
|
||||
|
||||
# 同步 sects_by_name 索引
|
||||
if sect.name != old_name:
|
||||
@@ -161,8 +166,8 @@ class HistoryManager:
|
||||
def _apply_item_changes(self, result: Dict[str, Any]):
|
||||
"""处理物品/功法变更"""
|
||||
self._update_techniques(result.get("techniques_change", {}))
|
||||
self._update_items(result.get("weapons_change", {}), weapons_by_name)
|
||||
self._update_items(result.get("auxiliarys_change", {}), None)
|
||||
self._update_items(result.get("weapons_change", {}), weapons_by_name, "weapons")
|
||||
self._update_items(result.get("auxiliarys_change", {}), None, "auxiliaries")
|
||||
|
||||
# --- Update Logic ---
|
||||
|
||||
@@ -177,7 +182,7 @@ class HistoryManager:
|
||||
if self.world and self.world.map:
|
||||
region = self.world.map.regions.get(rid)
|
||||
if region:
|
||||
self._update_obj_attrs(region, data)
|
||||
self._update_obj_attrs(region, data, "regions", rid_str)
|
||||
self.logger.info(f"[History] 区域变更 - ID: {rid}, Name: {region.name}, Desc: {region.desc}")
|
||||
count += 1
|
||||
except Exception as e:
|
||||
@@ -197,7 +202,7 @@ class HistoryManager:
|
||||
tech = techniques_by_id.get(tid)
|
||||
if tech:
|
||||
old_name = tech.name
|
||||
self._update_obj_attrs(tech, data)
|
||||
self._update_obj_attrs(tech, data, "techniques", tid_str)
|
||||
|
||||
if tech.name != old_name:
|
||||
if old_name in techniques_by_name:
|
||||
@@ -212,7 +217,7 @@ class HistoryManager:
|
||||
if count > 0:
|
||||
self.logger.info(f"[History] 更新了 {count} 本功法")
|
||||
|
||||
def _update_items(self, changes: Dict[str, Any], by_name_index: Optional[Dict[str, Any]]):
|
||||
def _update_items(self, changes: Dict[str, Any], by_name_index: Optional[Dict[str, Any]], category: str):
|
||||
"""更新物品 (ItemRegistry)"""
|
||||
if not changes: return
|
||||
|
||||
@@ -223,7 +228,7 @@ class HistoryManager:
|
||||
item = ItemRegistry.get(iid)
|
||||
if item:
|
||||
old_name = item.name
|
||||
self._update_obj_attrs(item, data)
|
||||
self._update_obj_attrs(item, data, category, iid_str)
|
||||
|
||||
if by_name_index is not None and item.name != old_name:
|
||||
if old_name in by_name_index:
|
||||
@@ -238,12 +243,25 @@ class HistoryManager:
|
||||
if count > 0:
|
||||
self.logger.info(f"[History] 更新了 {count} 件装备")
|
||||
|
||||
def _update_obj_attrs(self, obj: Any, data: Dict[str, Any]):
|
||||
"""通用属性更新 helper"""
|
||||
def _update_obj_attrs(self, obj: Any, data: Dict[str, Any], category: str = None, id_str: str = None):
|
||||
"""通用属性更新 helper,并记录差分"""
|
||||
|
||||
recorded_changes = {}
|
||||
|
||||
if "name" in data and data["name"]:
|
||||
obj.name = str(data["name"])
|
||||
val = str(data["name"])
|
||||
obj.name = val
|
||||
recorded_changes["name"] = val
|
||||
|
||||
if "desc" in data and data["desc"]:
|
||||
obj.desc = str(data["desc"])
|
||||
val = str(data["desc"])
|
||||
obj.desc = val
|
||||
recorded_changes["desc"] = val
|
||||
|
||||
# 记录差分到 World
|
||||
if category and id_str and recorded_changes and self.world:
|
||||
self.world.record_modification(category, id_str, recorded_changes)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 模拟运行
|
||||
|
||||
@@ -8,6 +8,7 @@ from src.classes.avatar_manager import AvatarManager
|
||||
from src.classes.event_manager import EventManager
|
||||
from src.classes.circulation import CirculationManager
|
||||
from src.classes.gathering.gathering import GatheringManager
|
||||
from src.classes.history import History
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.classes.avatar import Avatar
|
||||
@@ -29,8 +30,8 @@ class World():
|
||||
circulation: CirculationManager = field(default_factory=CirculationManager)
|
||||
# Gathering 管理器
|
||||
gathering_manager: GatheringManager = field(default_factory=GatheringManager)
|
||||
# 世界历史文本
|
||||
history: str = ""
|
||||
# 世界历史
|
||||
history: "History" = field(default_factory=lambda: History())
|
||||
|
||||
def get_info(self, detailed: bool = False, avatar: Optional["Avatar"] = None) -> dict:
|
||||
"""
|
||||
@@ -54,7 +55,25 @@ class World():
|
||||
|
||||
def set_history(self, history_text: str):
|
||||
"""设置世界历史文本"""
|
||||
self.history = history_text
|
||||
self.history.text = history_text
|
||||
|
||||
def record_modification(self, category: str, id_str: str, changes: dict):
|
||||
"""
|
||||
记录历史修改差分
|
||||
|
||||
Args:
|
||||
category: 修改类别 (sects, regions, techniques, weapons, auxiliaries)
|
||||
id_str: 对象 ID 字符串
|
||||
changes: 修改的属性字典
|
||||
"""
|
||||
if category not in self.history.modifications:
|
||||
self.history.modifications[category] = {}
|
||||
|
||||
if id_str not in self.history.modifications[category]:
|
||||
self.history.modifications[category][id_str] = {}
|
||||
|
||||
# 累加修改(后来的覆盖前面的)
|
||||
self.history.modifications[category][id_str].update(changes)
|
||||
|
||||
@property
|
||||
def static_info(self) -> dict:
|
||||
@@ -75,8 +94,8 @@ class World():
|
||||
"购物": "在城市区域可以购买练气级别丹药、兵器。购买丹药后会立刻服用强化自身。购买兵器可以帮自己切换兵器类型为顺手的类型。",
|
||||
"拍卖会": "每隔一段不确定的时间会有神秘人组织的拍卖会,或许有好货出售。"
|
||||
}
|
||||
if self.history:
|
||||
desc["历史"] = self.history
|
||||
if self.history.text:
|
||||
desc["历史"] = self.history.text
|
||||
return desc
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -6,6 +6,8 @@ import subprocess
|
||||
import time
|
||||
import threading
|
||||
import signal
|
||||
import random
|
||||
from omegaconf import OmegaConf
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from typing import List, Optional
|
||||
@@ -39,8 +41,6 @@ from src.classes.long_term_objective import set_user_long_term_objective, clear_
|
||||
from src.sim.save.save_game import save_game, list_saves
|
||||
from src.sim.load.load_game import load_game
|
||||
from src.utils import protagonist as prot_utils
|
||||
import random
|
||||
from omegaconf import OmegaConf
|
||||
from src.utils.llm.client import test_connectivity
|
||||
from src.utils.llm.config import LLMConfig, LLMMode
|
||||
from src.run.data_loader import reload_all_static_data
|
||||
|
||||
@@ -41,6 +41,103 @@ from src.classes.relation import Relation
|
||||
from src.utils.config import CONFIG
|
||||
|
||||
|
||||
def apply_history_modifications(world, modifications):
|
||||
"""
|
||||
回放历史修改记录,恢复世界状态
|
||||
"""
|
||||
if not modifications:
|
||||
return
|
||||
|
||||
print(f"正在回放历史差分 ({len(modifications)} 个分类)...")
|
||||
|
||||
# 导入需要修改的对象容器
|
||||
from src.classes.sect import sects_by_id, sects_by_name
|
||||
from src.classes.technique import techniques_by_id, techniques_by_name
|
||||
from src.classes.item_registry import ItemRegistry
|
||||
|
||||
# 1. 宗门修改
|
||||
sects_mod = modifications.get("sects", {})
|
||||
for sid_str, changes in sects_mod.items():
|
||||
try:
|
||||
sid = int(sid_str)
|
||||
sect = sects_by_id.get(sid)
|
||||
if sect:
|
||||
old_name = sect.name
|
||||
# 应用修改
|
||||
if "name" in changes: sect.name = changes["name"]
|
||||
if "desc" in changes: sect.desc = changes["desc"]
|
||||
# 同步索引
|
||||
if sect.name != old_name:
|
||||
if old_name in sects_by_name: del sects_by_name[old_name]
|
||||
sects_by_name[sect.name] = sect
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2. 区域修改
|
||||
regions_mod = modifications.get("regions", {})
|
||||
for rid_str, changes in regions_mod.items():
|
||||
try:
|
||||
rid = int(rid_str)
|
||||
region = world.map.regions.get(rid)
|
||||
if region:
|
||||
if "name" in changes: region.name = changes["name"]
|
||||
if "desc" in changes: region.desc = changes["desc"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 3. 功法修改
|
||||
techniques_mod = modifications.get("techniques", {})
|
||||
for tid_str, changes in techniques_mod.items():
|
||||
try:
|
||||
tid = int(tid_str)
|
||||
tech = techniques_by_id.get(tid)
|
||||
if tech:
|
||||
old_name = tech.name
|
||||
if "name" in changes: tech.name = changes["name"]
|
||||
if "desc" in changes: tech.desc = changes["desc"]
|
||||
if tech.name != old_name:
|
||||
if old_name in techniques_by_name: del techniques_by_name[old_name]
|
||||
techniques_by_name[tech.name] = tech
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 4. 武器修改 (通过 ItemRegistry)
|
||||
weapons_mod = modifications.get("weapons", {})
|
||||
from src.classes.weapon import weapons_by_name
|
||||
for iid_str, changes in weapons_mod.items():
|
||||
try:
|
||||
iid = int(iid_str)
|
||||
item = ItemRegistry.get(iid)
|
||||
if item:
|
||||
old_name = item.name
|
||||
if "name" in changes: item.name = changes["name"]
|
||||
if "desc" in changes: item.desc = changes["desc"]
|
||||
if item.name != old_name:
|
||||
if old_name in weapons_by_name: del weapons_by_name[old_name]
|
||||
weapons_by_name[item.name] = item
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 5. 辅助装备修改 (通过 ItemRegistry)
|
||||
aux_mod = modifications.get("auxiliaries", {})
|
||||
from src.classes.auxiliary import auxiliaries_by_name
|
||||
for iid_str, changes in aux_mod.items():
|
||||
try:
|
||||
iid = int(iid_str)
|
||||
item = ItemRegistry.get(iid)
|
||||
if item:
|
||||
old_name = item.name
|
||||
if "name" in changes: item.name = changes["name"]
|
||||
if "desc" in changes: item.desc = changes["desc"]
|
||||
if item.name != old_name:
|
||||
if old_name in auxiliaries_by_name: del auxiliaries_by_name[old_name]
|
||||
auxiliaries_by_name[item.name] = item
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print("历史差分回放完成。")
|
||||
|
||||
|
||||
def get_events_db_path(save_path: Path) -> Path:
|
||||
"""
|
||||
根据存档路径计算事件数据库路径。
|
||||
@@ -109,9 +206,13 @@ def load_game(save_path: Optional[Path] = None) -> Tuple["World", "Simulator", L
|
||||
)
|
||||
|
||||
# 恢复世界历史
|
||||
history = world_data.get("history", "")
|
||||
if history:
|
||||
world.set_history(history)
|
||||
history_data = world_data.get("history", {})
|
||||
world.history.text = history_data.get("text", "")
|
||||
world.history.modifications = history_data.get("modifications", {})
|
||||
|
||||
# 恢复并回放历史修改记录(关键修复:在加载角色前还原规则)
|
||||
if world.history.modifications:
|
||||
apply_history_modifications(world, world.history.modifications)
|
||||
|
||||
# 重建天地灵机
|
||||
from src.classes.celestial_phenomenon import celestial_phenomena_by_id
|
||||
|
||||
@@ -123,7 +123,10 @@ def save_game(
|
||||
# 出世物品流转
|
||||
"circulation": world.circulation.to_save_dict(),
|
||||
# 世界历史
|
||||
"history": world.history,
|
||||
"history": {
|
||||
"text": world.history.text,
|
||||
"modifications": world.history.modifications
|
||||
},
|
||||
}
|
||||
|
||||
# 保存所有Avatar(第一阶段:不含relations)
|
||||
|
||||
Reference in New Issue
Block a user