Files
chatgpt-on-wechat/agent/protocol/message_utils.py
2026-03-10 14:52:07 +08:00

180 lines
6.4 KiB
Python

"""
Message sanitizer — fix broken tool_use / tool_result pairs.
Provides two public helpers that can be reused across agent_stream.py
and any bot that converts messages to OpenAI format:
1. sanitize_claude_messages(messages)
Operates on the internal Claude-format message list (in-place).
2. drop_orphaned_tool_results_openai(messages)
Operates on an already-converted OpenAI-format message list,
returning a cleaned copy.
"""
from __future__ import annotations
from typing import Dict, List, Set
from common.log import logger
# ------------------------------------------------------------------ #
# Claude-format sanitizer (used by agent_stream)
# ------------------------------------------------------------------ #
def sanitize_claude_messages(messages: List[Dict]) -> int:
"""
Validate and fix a Claude-format message list **in-place**.
Fixes handled:
- Trailing assistant message with tool_use but no following tool_result
- Leading orphaned tool_result user messages
- Mid-list tool_result blocks whose tool_use_id has no matching
tool_use in any preceding assistant message
Returns the number of messages / blocks removed.
"""
if not messages:
return 0
removed = 0
# 1. Remove trailing incomplete tool_use assistant messages
while messages:
last = messages[-1]
if last.get("role") != "assistant":
break
content = last.get("content", [])
if isinstance(content, list) and any(
isinstance(b, dict) and b.get("type") == "tool_use"
for b in content
):
logger.warning("⚠️ Removing trailing incomplete tool_use assistant message")
messages.pop()
removed += 1
else:
break
# 2. Remove leading orphaned tool_result user messages
while messages:
first = messages[0]
if first.get("role") != "user":
break
content = first.get("content", [])
if isinstance(content, list) and _has_block_type(content, "tool_result") \
and not _has_block_type(content, "text"):
logger.warning("⚠️ Removing leading orphaned tool_result user message")
messages.pop(0)
removed += 1
else:
break
# 3. Full scan: ensure every tool_result references a known tool_use id
known_ids: Set[str] = set()
i = 0
while i < len(messages):
msg = messages[i]
role = msg.get("role")
content = msg.get("content", [])
if role == "assistant" and isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "tool_use":
tid = block.get("id", "")
if tid:
known_ids.add(tid)
elif role == "user" and isinstance(content, list):
if not _has_block_type(content, "tool_result"):
i += 1
continue
orphaned = [
b.get("tool_use_id", "")
for b in content
if isinstance(b, dict)
and b.get("type") == "tool_result"
and b.get("tool_use_id", "")
and b.get("tool_use_id", "") not in known_ids
]
if orphaned:
orphaned_set = set(orphaned)
if not _has_block_type(content, "text"):
logger.warning(
f"⚠️ Removing orphaned tool_result message (tool_ids: {orphaned})"
)
messages.pop(i)
removed += 1
# Also remove a preceding broken assistant tool_use message
if i > 0 and messages[i - 1].get("role") == "assistant":
prev = messages[i - 1].get("content", [])
if isinstance(prev, list) and _has_block_type(prev, "tool_use"):
messages.pop(i - 1)
removed += 1
i -= 1
continue
else:
new_content = [
b for b in content
if not (
isinstance(b, dict)
and b.get("type") == "tool_result"
and b.get("tool_use_id", "") in orphaned_set
)
]
delta = len(content) - len(new_content)
if delta:
logger.warning(
f"⚠️ Stripped {delta} orphaned tool_result block(s) from mixed message"
)
msg["content"] = new_content
removed += delta
i += 1
if removed:
logger.info(f"🔧 Message validation: removed {removed} broken message(s)")
return removed
# ------------------------------------------------------------------ #
# OpenAI-format sanitizer (used by minimax_bot, openai_compatible_bot)
# ------------------------------------------------------------------ #
def drop_orphaned_tool_results_openai(messages: List[Dict]) -> List[Dict]:
"""
Return a copy of *messages* (OpenAI format) with any ``role=tool``
messages removed if their ``tool_call_id`` does not match a
``tool_calls[].id`` in a preceding assistant message.
"""
known_ids: Set[str] = set()
cleaned: List[Dict] = []
for msg in messages:
if msg.get("role") == "assistant" and msg.get("tool_calls"):
for tc in msg["tool_calls"]:
tc_id = tc.get("id", "")
if tc_id:
known_ids.add(tc_id)
if msg.get("role") == "tool":
ref_id = msg.get("tool_call_id", "")
if ref_id and ref_id not in known_ids:
logger.warning(
f"[MessageSanitizer] Dropping orphaned tool result "
f"(tool_call_id={ref_id} not in known ids)"
)
continue
cleaned.append(msg)
return cleaned
# ------------------------------------------------------------------ #
# Internal helpers
# ------------------------------------------------------------------ #
def _has_block_type(content: list, block_type: str) -> bool:
return any(
isinstance(b, dict) and b.get("type") == block_type
for b in content
)