add new personas and weights
This commit is contained in:
@@ -15,7 +15,8 @@ class Persona:
|
||||
id: int
|
||||
name: str
|
||||
prompt: str
|
||||
exclusion_ids: List[int]
|
||||
exclusion_ids: List[int]
|
||||
weight: float
|
||||
|
||||
def _load_personas() -> tuple[dict[int, Persona], dict[str, Persona]]:
|
||||
"""从配表加载persona数据"""
|
||||
@@ -29,12 +30,17 @@ def _load_personas() -> tuple[dict[int, Persona], dict[str, Persona]]:
|
||||
exclusion_ids = []
|
||||
if exclusion_ids_str:
|
||||
exclusion_ids = [int(x.strip()) for x in exclusion_ids_str.split(ids_separator) if x.strip()]
|
||||
# 解析权重(缺失或为 NaN 时默认为 1.0),避免不必要的异常
|
||||
weight_val = row.get("weight", 1)
|
||||
weight_str = str(weight_val).strip()
|
||||
weight = float(weight_str) if weight_str and weight_str.lower() != "nan" else 1.0
|
||||
|
||||
persona = Persona(
|
||||
id=int(row["id"]),
|
||||
name=str(row["name"]),
|
||||
prompt=str(row["prompt"]),
|
||||
exclusion_ids=exclusion_ids
|
||||
exclusion_ids=exclusion_ids,
|
||||
weight=weight,
|
||||
)
|
||||
personas_by_id[persona.id] = persona
|
||||
personas_by_name[persona.name] = persona
|
||||
@@ -59,16 +65,17 @@ def get_random_compatible_personas(num_personas: int = 2) -> List[Persona]:
|
||||
"""
|
||||
all_persona_ids = set(personas_by_id.keys())
|
||||
|
||||
selected_personas = []
|
||||
selected_personas: List[Persona] = []
|
||||
available_ids = all_persona_ids.copy()
|
||||
|
||||
for i in range(num_personas):
|
||||
if not available_ids:
|
||||
raise ValueError(f"只能找到{i}个兼容的persona,无法满足需要的{num_personas}个")
|
||||
|
||||
# 从可用列表中随机选择一个
|
||||
selected_id = random.choice(list(available_ids))
|
||||
selected_persona = personas_by_id[selected_id]
|
||||
# 按权重从可用列表中选择一个
|
||||
candidates: List[Persona] = [personas_by_id[i] for i in available_ids]
|
||||
weights: List[float] = [max(0.0, c.weight) for c in candidates]
|
||||
selected_persona = random.choices(candidates, weights=weights, k=1)[0]
|
||||
selected_id = selected_persona.id
|
||||
selected_personas.append(selected_persona)
|
||||
|
||||
# 更新可用列表:移除已选择的和与其互斥的
|
||||
|
||||
@@ -12,6 +12,7 @@ def load_csv(path: Path) -> pd.DataFrame:
|
||||
"name": str,
|
||||
"description": str,
|
||||
"prompt": str,
|
||||
"weight": float,
|
||||
}
|
||||
for column, dtype in row_types.items():
|
||||
if column in df.columns:
|
||||
|
||||
Reference in New Issue
Block a user