add new personas and weights

This commit is contained in:
bridge
2025-10-04 15:22:42 +08:00
parent cf4e7a89c3
commit 757bcdebc0
3 changed files with 41 additions and 31 deletions

View File

@@ -15,7 +15,8 @@ class Persona:
id: int
name: str
prompt: str
exclusion_ids: List[int]
exclusion_ids: List[int]
weight: float
def _load_personas() -> tuple[dict[int, Persona], dict[str, Persona]]:
"""从配表加载persona数据"""
@@ -29,12 +30,17 @@ def _load_personas() -> tuple[dict[int, Persona], dict[str, Persona]]:
exclusion_ids = []
if exclusion_ids_str:
exclusion_ids = [int(x.strip()) for x in exclusion_ids_str.split(ids_separator) if x.strip()]
# 解析权重(缺失或为 NaN 时默认为 1.0),避免不必要的异常
weight_val = row.get("weight", 1)
weight_str = str(weight_val).strip()
weight = float(weight_str) if weight_str and weight_str.lower() != "nan" else 1.0
persona = Persona(
id=int(row["id"]),
name=str(row["name"]),
prompt=str(row["prompt"]),
exclusion_ids=exclusion_ids
exclusion_ids=exclusion_ids,
weight=weight,
)
personas_by_id[persona.id] = persona
personas_by_name[persona.name] = persona
@@ -59,16 +65,17 @@ def get_random_compatible_personas(num_personas: int = 2) -> List[Persona]:
"""
all_persona_ids = set(personas_by_id.keys())
selected_personas = []
selected_personas: List[Persona] = []
available_ids = all_persona_ids.copy()
for i in range(num_personas):
if not available_ids:
raise ValueError(f"只能找到{i}个兼容的persona无法满足需要的{num_personas}")
# 从可用列表中随机选择一个
selected_id = random.choice(list(available_ids))
selected_persona = personas_by_id[selected_id]
# 按权重从可用列表中选择一个
candidates: List[Persona] = [personas_by_id[i] for i in available_ids]
weights: List[float] = [max(0.0, c.weight) for c in candidates]
selected_persona = random.choices(candidates, weights=weights, k=1)[0]
selected_id = selected_persona.id
selected_personas.append(selected_persona)
# 更新可用列表:移除已选择的和与其互斥的

View File

@@ -12,6 +12,7 @@ def load_csv(path: Path) -> pd.DataFrame:
"name": str,
"description": str,
"prompt": str,
"weight": float,
}
for column, dtype in row_types.items():
if column in df.columns: