add persona conditions
This commit is contained in:
@@ -15,6 +15,21 @@ class Alignment(Enum):
|
||||
def get_info(self) -> str:
|
||||
return alignment_strs[self] + ": " + alignment_infos[self]
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.value)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
"""
|
||||
允许与同类或字符串比较:
|
||||
- Alignment: 恒等比较
|
||||
- str: 同时支持英文值(value)与中文显示(__str__)
|
||||
"""
|
||||
if isinstance(other, Alignment):
|
||||
return self is other
|
||||
if isinstance(other, str):
|
||||
return other == self.value or other == str(self)
|
||||
return False
|
||||
|
||||
|
||||
alignment_strs = {
|
||||
Alignment.RIGHTEOUS: "正",
|
||||
|
||||
@@ -93,9 +93,9 @@ class Avatar:
|
||||
|
||||
# 最大寿元已在 Age 构造时基于境界初始化
|
||||
|
||||
# 如果personas列表为空,则随机分配两个不互斥的persona
|
||||
# 如果personas列表为空,则随机分配两个符合条件且不互斥的persona
|
||||
if not self.personas:
|
||||
self.personas = get_random_compatible_personas(persona_num)
|
||||
self.personas = get_random_compatible_personas(persona_num, avatar=self)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.id)
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
|
||||
from src.utils.df import game_configs
|
||||
from src.utils.config import CONFIG
|
||||
|
||||
ids_separator = CONFIG.df.ids_separator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# 仅用于类型检查,避免运行时循环导入
|
||||
from src.classes.avatar import Avatar
|
||||
|
||||
@dataclass
|
||||
class Persona:
|
||||
"""
|
||||
@@ -17,6 +21,7 @@ class Persona:
|
||||
prompt: str
|
||||
exclusion_ids: List[int]
|
||||
weight: float
|
||||
condition: str
|
||||
|
||||
def _load_personas() -> tuple[dict[int, Persona], dict[str, Persona]]:
|
||||
"""从配表加载persona数据"""
|
||||
@@ -34,6 +39,9 @@ def _load_personas() -> tuple[dict[int, Persona], dict[str, Persona]]:
|
||||
weight_val = row.get("weight", 1)
|
||||
weight_str = str(weight_val).strip()
|
||||
weight = float(weight_str) if weight_str and weight_str.lower() != "nan" else 1.0
|
||||
# 条件:可为空
|
||||
condition_val = row.get("condition", "")
|
||||
condition = "" if str(condition_val) == "nan" else str(condition_val).strip()
|
||||
|
||||
persona = Persona(
|
||||
id=int(row["id"]),
|
||||
@@ -41,6 +49,7 @@ def _load_personas() -> tuple[dict[int, Persona], dict[str, Persona]]:
|
||||
prompt=str(row["prompt"]),
|
||||
exclusion_ids=exclusion_ids,
|
||||
weight=weight,
|
||||
condition=condition,
|
||||
)
|
||||
personas_by_id[persona.id] = persona
|
||||
personas_by_name[persona.name] = persona
|
||||
@@ -50,12 +59,32 @@ def _load_personas() -> tuple[dict[int, Persona], dict[str, Persona]]:
|
||||
# 从配表加载persona数据
|
||||
personas_by_id, personas_by_name = _load_personas()
|
||||
|
||||
def get_random_compatible_personas(num_personas: int = 2) -> List[Persona]:
|
||||
def _is_persona_allowed(persona_id: int, already_selected_ids: set[int], avatar: Optional["Avatar"]) -> bool:
|
||||
"""
|
||||
统一判断:persona 是否允许被选择(条件 + 互斥)。
|
||||
- 条件:当存在 avatar 且配置了 condition 时,通过安全 eval 判断。
|
||||
- 互斥:与已选 persona 双向互斥。
|
||||
"""
|
||||
persona = personas_by_id[persona_id]
|
||||
# 条件判定
|
||||
if avatar is not None and persona.condition:
|
||||
allowed = bool(eval(persona.condition, {"__builtins__": {}}, {"avatar": avatar}))
|
||||
if not allowed:
|
||||
return False
|
||||
# 与已选互斥检查(双向)
|
||||
for sid in already_selected_ids:
|
||||
other = personas_by_id[sid]
|
||||
if (persona_id in other.exclusion_ids) or (sid in persona.exclusion_ids):
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_random_compatible_personas(num_personas: int = 2, avatar: Optional["Avatar"] = None) -> List[Persona]:
|
||||
"""
|
||||
随机选择指定数量的互相不冲突的persona
|
||||
|
||||
Args:
|
||||
num_personas: 需要选择的persona数量,默认为2
|
||||
avatar: 可选,若提供则按 persona.condition 过滤
|
||||
|
||||
Returns:
|
||||
List[Persona]: 互相不冲突的persona列表
|
||||
@@ -63,32 +92,25 @@ def get_random_compatible_personas(num_personas: int = 2) -> List[Persona]:
|
||||
Raises:
|
||||
ValueError: 如果无法找到足够数量的兼容persona
|
||||
"""
|
||||
all_persona_ids = set(personas_by_id.keys())
|
||||
# 初始候选:若提供 avatar,则先按条件过滤;否则全量
|
||||
initial_ids = set(personas_by_id.keys())
|
||||
if avatar is not None:
|
||||
initial_ids = {pid for pid in initial_ids if _is_persona_allowed(pid, set(), avatar)}
|
||||
|
||||
selected_personas: List[Persona] = []
|
||||
available_ids = all_persona_ids.copy()
|
||||
|
||||
selected_ids: set[int] = set()
|
||||
|
||||
for i in range(num_personas):
|
||||
# 按当前已选进行二次过滤(互斥 + 条件)
|
||||
available_ids = [pid for pid in initial_ids if pid not in selected_ids and _is_persona_allowed(pid, selected_ids, avatar)]
|
||||
if not available_ids:
|
||||
raise ValueError(f"只能找到{i}个兼容的persona,无法满足需要的{num_personas}个")
|
||||
# 按权重从可用列表中选择一个
|
||||
candidates: List[Persona] = [personas_by_id[i] for i in available_ids]
|
||||
|
||||
candidates: List[Persona] = [personas_by_id[pid] for pid in available_ids]
|
||||
weights: List[float] = [max(0.0, c.weight) for c in candidates]
|
||||
selected_persona = random.choices(candidates, weights=weights, k=1)[0]
|
||||
selected_id = selected_persona.id
|
||||
selected_personas.append(selected_persona)
|
||||
|
||||
# 更新可用列表:移除已选择的和与其互斥的
|
||||
available_ids.discard(selected_id) # 移除自己
|
||||
|
||||
# 移除所有与当前选择互斥的persona
|
||||
for exclusion_id in selected_persona.exclusion_ids:
|
||||
available_ids.discard(exclusion_id)
|
||||
|
||||
# 移除所有将当前选择作为互斥对象的persona
|
||||
for persona_id in list(available_ids):
|
||||
if selected_id in personas_by_id[persona_id].exclusion_ids:
|
||||
available_ids.discard(persona_id)
|
||||
|
||||
selected_ids.add(selected_persona.id)
|
||||
|
||||
return selected_personas
|
||||
|
||||
|
||||
Reference in New Issue
Block a user