diff --git a/agent/protocol/message_utils.py b/agent/protocol/message_utils.py index 8c9f4fc7..160ad596 100644 --- a/agent/protocol/message_utils.py +++ b/agent/protocol/message_utils.py @@ -18,6 +18,107 @@ from typing import Dict, List, Set from common.log import logger +_SYNTH_TOOL_ERR = ( + "Error: Missing tool_result adjacent to tool_use (session repair). " + "The conversation history was inconsistent; continue from here." +) + + +def _repair_tool_use_adjacency(messages: List[Dict]) -> int: + """ + Anthropic requires: after assistant content with tool_use, the next message + must be user content listing tool_result for every tool_use id (same user msg). + + Valid histories satisfy this at every such assistant; the loop only mutates + when that condition fails (broken persistence, bad trims, etc.). + """ + + def _synth_block(tid: str) -> Dict: + return { + "type": "tool_result", + "tool_use_id": tid, + "content": _SYNTH_TOOL_ERR, + "is_error": True, + } + + repairs = 0 + i = 0 + while i < len(messages): + msg = messages[i] + if msg.get("role") != "assistant": + i += 1 + continue + + content = msg.get("content", []) + if not isinstance(content, list): + i += 1 + continue + + required = [ + b.get("id") + for b in content + if isinstance(b, dict) and b.get("type") == "tool_use" and b.get("id") + ] + if not required: + i += 1 + continue + + req_set = set(required) + if i + 1 >= len(messages): + messages.append({ + "role": "user", + "content": [_synth_block(tid) for tid in required], + }) + logger.warning( + "⚠️ Appended synthetic tool_result after trailing assistant tool_use" + ) + repairs += 1 + break + + nxt = messages[i + 1] + if nxt.get("role") != "user": + messages.insert( + i + 1, + {"role": "user", "content": [_synth_block(tid) for tid in required]}, + ) + logger.warning( + "⚠️ Inserted synthetic tool_result user after tool_use " + f"(next role={nxt.get('role')!r})" + ) + repairs += 1 + i += 2 + continue + + nc = nxt.get("content", []) + if not isinstance(nc, list): + messages.insert( + i + 1, + {"role": "user", "content": [_synth_block(tid) for tid in required]}, + ) + repairs += 1 + i += 2 + continue + + present = { + b.get("tool_use_id") + for b in nc + if isinstance(b, dict) and b.get("type") == "tool_result" and b.get("tool_use_id") + } + if req_set <= present: + i += 1 + continue + + missing = [tid for tid in required if tid not in present] + nxt["content"] = [_synth_block(tid) for tid in missing] + nc + logger.warning( + "⚠️ Prepended synthetic tool_result for Anthropic adjacency " + f"(missing_ids={missing})" + ) + repairs += len(missing) + i += 1 + + return repairs + # ------------------------------------------------------------------ # # Claude-format sanitizer (used by agent_stream) @@ -28,33 +129,21 @@ def sanitize_claude_messages(messages: List[Dict]) -> int: Validate and fix a Claude-format message list **in-place**. Fixes handled: - - Trailing assistant message with tool_use but no following tool_result + - Anthropic adjacency: assistant tool_use must be immediately followed by + user message(s) containing matching tool_result blocks - Leading orphaned tool_result user messages - Mid-list tool_result blocks whose tool_use_id has no matching tool_use in any preceding assistant message - Returns the number of messages / blocks removed. + Returns: number of removals plus adjacency repair operations (inserts/prepends). """ if not messages: return 0 removed = 0 - # 1. Remove trailing incomplete tool_use assistant messages - while messages: - last = messages[-1] - if last.get("role") != "assistant": - break - content = last.get("content", []) - if isinstance(content, list) and any( - isinstance(b, dict) and b.get("type") == "tool_use" - for b in content - ): - logger.warning("⚠️ Removing trailing incomplete tool_use assistant message") - messages.pop() - removed += 1 - else: - break + # 1. Adjacency repair (Anthropic: tool_result must be in the next user message) + adj_repairs = _repair_tool_use_adjacency(messages) # 2. Remove leading orphaned tool_result user messages while messages: @@ -136,9 +225,15 @@ def sanitize_claude_messages(messages: List[Dict]) -> int: if pass_removed == 0: break + # 4. Removals above can break adjacency; re-run repair only if something was removed. + if removed: + adj_repairs += _repair_tool_use_adjacency(messages) + if removed: logger.info(f"🔧 Message validation: removed {removed} broken message(s)") - return removed + if adj_repairs: + logger.info(f"🔧 Message validation: adjacency repairs={adj_repairs}") + return removed + adj_repairs # ------------------------------------------------------------------ #