mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-03-18 04:25:14 +08:00
fix: validate tool_call_id pairing #2690
This commit is contained in:
@@ -8,6 +8,7 @@ import time
|
||||
from typing import List, Dict, Any, Optional, Callable, Tuple
|
||||
|
||||
from agent.protocol.models import LLMRequest, LLMModel
|
||||
from agent.protocol.message_utils import sanitize_claude_messages
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
|
||||
@@ -475,6 +476,10 @@ class AgentStreamExecutor:
|
||||
# Trim messages if needed (using agent's context management)
|
||||
self._trim_messages()
|
||||
|
||||
# Re-validate after trimming: trimming may produce new orphaned
|
||||
# tool_result messages when it removes turns at the boundary.
|
||||
self._validate_and_fix_messages()
|
||||
|
||||
# Prepare messages
|
||||
messages = self._prepare_messages()
|
||||
turns = self._identify_complete_turns()
|
||||
@@ -900,56 +905,8 @@ class AgentStreamExecutor:
|
||||
return error_result
|
||||
|
||||
def _validate_and_fix_messages(self):
|
||||
"""
|
||||
Validate message history and fix broken tool_use/tool_result pairs.
|
||||
|
||||
Historical messages restored from DB are text-only (no tool calls),
|
||||
so this method only needs to handle edge cases in the current session:
|
||||
- Trailing assistant message with tool_use but no following tool_result
|
||||
(e.g. process was interrupted mid-execution)
|
||||
- Orphaned tool_result at the start of messages (e.g. after context
|
||||
trimming removed the preceding assistant tool_use)
|
||||
"""
|
||||
if not self.messages:
|
||||
return
|
||||
|
||||
removed = 0
|
||||
|
||||
# Remove trailing incomplete tool_use assistant messages
|
||||
while self.messages:
|
||||
last_msg = self.messages[-1]
|
||||
if last_msg.get("role") == "assistant":
|
||||
content = last_msg.get("content", [])
|
||||
if isinstance(content, list) and any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_use"
|
||||
for b in content
|
||||
):
|
||||
logger.warning("⚠️ Removing trailing incomplete tool_use assistant message")
|
||||
self.messages.pop()
|
||||
removed += 1
|
||||
continue
|
||||
break
|
||||
|
||||
# Remove leading orphaned tool_result user messages
|
||||
while self.messages:
|
||||
first_msg = self.messages[0]
|
||||
if first_msg.get("role") == "user":
|
||||
content = first_msg.get("content", [])
|
||||
if isinstance(content, list) and any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
for b in content
|
||||
) and not any(
|
||||
isinstance(b, dict) and b.get("type") == "text"
|
||||
for b in content
|
||||
):
|
||||
logger.warning("⚠️ Removing leading orphaned tool_result user message")
|
||||
self.messages.pop(0)
|
||||
removed += 1
|
||||
continue
|
||||
break
|
||||
|
||||
if removed > 0:
|
||||
logger.info(f"🔧 Message validation: removed {removed} broken message(s)")
|
||||
"""Delegate to the shared sanitizer (see message_sanitizer.py)."""
|
||||
sanitize_claude_messages(self.messages)
|
||||
|
||||
def _identify_complete_turns(self) -> List[Dict]:
|
||||
"""
|
||||
|
||||
179
agent/protocol/message_utils.py
Normal file
179
agent/protocol/message_utils.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
Message sanitizer — fix broken tool_use / tool_result pairs.
|
||||
|
||||
Provides two public helpers that can be reused across agent_stream.py
|
||||
and any bot that converts messages to OpenAI format:
|
||||
|
||||
1. sanitize_claude_messages(messages)
|
||||
Operates on the internal Claude-format message list (in-place).
|
||||
|
||||
2. drop_orphaned_tool_results_openai(messages)
|
||||
Operates on an already-converted OpenAI-format message list,
|
||||
returning a cleaned copy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Set
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Claude-format sanitizer (used by agent_stream)
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def sanitize_claude_messages(messages: List[Dict]) -> int:
|
||||
"""
|
||||
Validate and fix a Claude-format message list **in-place**.
|
||||
|
||||
Fixes handled:
|
||||
- Trailing assistant message with tool_use but no following tool_result
|
||||
- Leading orphaned tool_result user messages
|
||||
- Mid-list tool_result blocks whose tool_use_id has no matching
|
||||
tool_use in any preceding assistant message
|
||||
|
||||
Returns the number of messages / blocks removed.
|
||||
"""
|
||||
if not messages:
|
||||
return 0
|
||||
|
||||
removed = 0
|
||||
|
||||
# 1. Remove trailing incomplete tool_use assistant messages
|
||||
while messages:
|
||||
last = messages[-1]
|
||||
if last.get("role") != "assistant":
|
||||
break
|
||||
content = last.get("content", [])
|
||||
if isinstance(content, list) and any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_use"
|
||||
for b in content
|
||||
):
|
||||
logger.warning("⚠️ Removing trailing incomplete tool_use assistant message")
|
||||
messages.pop()
|
||||
removed += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# 2. Remove leading orphaned tool_result user messages
|
||||
while messages:
|
||||
first = messages[0]
|
||||
if first.get("role") != "user":
|
||||
break
|
||||
content = first.get("content", [])
|
||||
if isinstance(content, list) and _has_block_type(content, "tool_result") \
|
||||
and not _has_block_type(content, "text"):
|
||||
logger.warning("⚠️ Removing leading orphaned tool_result user message")
|
||||
messages.pop(0)
|
||||
removed += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# 3. Full scan: ensure every tool_result references a known tool_use id
|
||||
known_ids: Set[str] = set()
|
||||
i = 0
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", [])
|
||||
|
||||
if role == "assistant" and isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||
tid = block.get("id", "")
|
||||
if tid:
|
||||
known_ids.add(tid)
|
||||
|
||||
elif role == "user" and isinstance(content, list):
|
||||
if not _has_block_type(content, "tool_result"):
|
||||
i += 1
|
||||
continue
|
||||
|
||||
orphaned = [
|
||||
b.get("tool_use_id", "")
|
||||
for b in content
|
||||
if isinstance(b, dict)
|
||||
and b.get("type") == "tool_result"
|
||||
and b.get("tool_use_id", "")
|
||||
and b.get("tool_use_id", "") not in known_ids
|
||||
]
|
||||
if orphaned:
|
||||
orphaned_set = set(orphaned)
|
||||
if not _has_block_type(content, "text"):
|
||||
logger.warning(
|
||||
f"⚠️ Removing orphaned tool_result message (tool_ids: {orphaned})"
|
||||
)
|
||||
messages.pop(i)
|
||||
removed += 1
|
||||
# Also remove a preceding broken assistant tool_use message
|
||||
if i > 0 and messages[i - 1].get("role") == "assistant":
|
||||
prev = messages[i - 1].get("content", [])
|
||||
if isinstance(prev, list) and _has_block_type(prev, "tool_use"):
|
||||
messages.pop(i - 1)
|
||||
removed += 1
|
||||
i -= 1
|
||||
continue
|
||||
else:
|
||||
new_content = [
|
||||
b for b in content
|
||||
if not (
|
||||
isinstance(b, dict)
|
||||
and b.get("type") == "tool_result"
|
||||
and b.get("tool_use_id", "") in orphaned_set
|
||||
)
|
||||
]
|
||||
delta = len(content) - len(new_content)
|
||||
if delta:
|
||||
logger.warning(
|
||||
f"⚠️ Stripped {delta} orphaned tool_result block(s) from mixed message"
|
||||
)
|
||||
msg["content"] = new_content
|
||||
removed += delta
|
||||
i += 1
|
||||
|
||||
if removed:
|
||||
logger.info(f"🔧 Message validation: removed {removed} broken message(s)")
|
||||
return removed
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# OpenAI-format sanitizer (used by minimax_bot, openai_compatible_bot)
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def drop_orphaned_tool_results_openai(messages: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
Return a copy of *messages* (OpenAI format) with any ``role=tool``
|
||||
messages removed if their ``tool_call_id`` does not match a
|
||||
``tool_calls[].id`` in a preceding assistant message.
|
||||
"""
|
||||
known_ids: Set[str] = set()
|
||||
cleaned: List[Dict] = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
tc_id = tc.get("id", "")
|
||||
if tc_id:
|
||||
known_ids.add(tc_id)
|
||||
|
||||
if msg.get("role") == "tool":
|
||||
ref_id = msg.get("tool_call_id", "")
|
||||
if ref_id and ref_id not in known_ids:
|
||||
logger.warning(
|
||||
f"[MessageSanitizer] Dropping orphaned tool result "
|
||||
f"(tool_call_id={ref_id} not in known ids)"
|
||||
)
|
||||
continue
|
||||
cleaned.append(msg)
|
||||
return cleaned
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _has_block_type(content: list, block_type: str) -> bool:
|
||||
return any(
|
||||
isinstance(b, dict) and b.get("type") == block_type
|
||||
for b in content
|
||||
)
|
||||
@@ -13,6 +13,7 @@ from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf, load_config
|
||||
from common import const
|
||||
from agent.protocol.message_utils import drop_orphaned_tool_results_openai
|
||||
|
||||
|
||||
# MiniMax对话模型API
|
||||
@@ -356,7 +357,7 @@ class MinimaxBot(Bot):
|
||||
|
||||
converted.append(openai_msg)
|
||||
|
||||
return converted
|
||||
return drop_orphaned_tool_results_openai(converted)
|
||||
|
||||
def _convert_tools_to_openai_format(self, tools):
|
||||
"""
|
||||
|
||||
@@ -10,6 +10,7 @@ This includes: OpenAI, LinkAI, Azure OpenAI, and many third-party providers.
|
||||
import json
|
||||
import openai
|
||||
from common.log import logger
|
||||
from agent.protocol.message_utils import drop_orphaned_tool_results_openai
|
||||
|
||||
|
||||
class OpenAICompatibleBot:
|
||||
@@ -300,5 +301,5 @@ class OpenAICompatibleBot:
|
||||
else:
|
||||
# Other formats, keep as is
|
||||
openai_messages.append(msg)
|
||||
|
||||
return openai_messages
|
||||
|
||||
return drop_orphaned_tool_results_openai(openai_messages)
|
||||
|
||||
Reference in New Issue
Block a user