diff --git a/agent/protocol/agent_stream.py b/agent/protocol/agent_stream.py index 49050a6b..26f6fd49 100644 --- a/agent/protocol/agent_stream.py +++ b/agent/protocol/agent_stream.py @@ -196,6 +196,11 @@ class AgentStreamExecutor: # are never stripped mid-execution (which would cause LLM loops). self._trim_messages() + # Validate after trimming: trimming may leave orphaned tool_use at the + # boundary (e.g. the last kept turn ends with an assistant tool_use whose + # tool_result was in a discarded turn). + self._validate_and_fix_messages() + self._emit_event("agent_start") final_response = "" diff --git a/agent/protocol/message_utils.py b/agent/protocol/message_utils.py index 3215ed48..8c9f4fc7 100644 --- a/agent/protocol/message_utils.py +++ b/agent/protocol/message_utils.py @@ -70,67 +70,71 @@ def sanitize_claude_messages(messages: List[Dict]) -> int: else: break - # 3. Full scan: ensure every tool_result references a known tool_use id - known_ids: Set[str] = set() - i = 0 - while i < len(messages): - msg = messages[i] - role = msg.get("role") - content = msg.get("content", []) + # 3. Iteratively remove unmatched tool_use / tool_result until stable. + # Removing one broken message can orphan others (e.g. an assistant msg + # with both matched and unmatched tool_use — deleting it orphans the + # previously-matched tool_result). Loop until clean. + for _ in range(5): + use_ids: Set[str] = set() + result_ids: Set[str] = set() + for msg in messages: + for block in (msg.get("content") or []): + if not isinstance(block, dict): + continue + if block.get("type") == "tool_use" and block.get("id"): + use_ids.add(block["id"]) + elif block.get("type") == "tool_result" and block.get("tool_use_id"): + result_ids.add(block["tool_use_id"]) - if role == "assistant" and isinstance(content, list): - for block in content: - if isinstance(block, dict) and block.get("type") == "tool_use": - tid = block.get("id", "") - if tid: - known_ids.add(tid) + bad_use = use_ids - result_ids + bad_result = result_ids - use_ids + if not bad_use and not bad_result: + break - elif role == "user" and isinstance(content, list): - if not _has_block_type(content, "tool_result"): + pass_removed = 0 + i = 0 + while i < len(messages): + msg = messages[i] + role = msg.get("role") + content = msg.get("content", []) + if not isinstance(content, list): i += 1 continue - orphaned = [ - b.get("tool_use_id", "") - for b in content - if isinstance(b, dict) - and b.get("type") == "tool_result" - and b.get("tool_use_id", "") - and b.get("tool_use_id", "") not in known_ids - ] - if orphaned: - orphaned_set = set(orphaned) - if not _has_block_type(content, "text"): - logger.warning( - f"⚠️ Removing orphaned tool_result message (tool_ids: {orphaned})" - ) - messages.pop(i) - removed += 1 - # Also remove a preceding broken assistant tool_use message - if i > 0 and messages[i - 1].get("role") == "assistant": - prev = messages[i - 1].get("content", []) - if isinstance(prev, list) and _has_block_type(prev, "tool_use"): - messages.pop(i - 1) - removed += 1 - i -= 1 - continue - else: - new_content = [ - b for b in content - if not ( - isinstance(b, dict) - and b.get("type") == "tool_result" - and b.get("tool_use_id", "") in orphaned_set - ) - ] - delta = len(content) - len(new_content) - if delta: - logger.warning( - f"⚠️ Stripped {delta} orphaned tool_result block(s) from mixed message" - ) - msg["content"] = new_content - removed += delta - i += 1 + if role == "assistant" and bad_use and any( + isinstance(b, dict) and b.get("type") == "tool_use" + and b.get("id") in bad_use for b in content + ): + logger.warning(f"⚠️ Removing assistant msg with unmatched tool_use") + messages.pop(i) + pass_removed += 1 + continue + + if role == "user" and bad_result and _has_block_type(content, "tool_result"): + has_bad = any( + isinstance(b, dict) and b.get("type") == "tool_result" + and b.get("tool_use_id") in bad_result for b in content + ) + if has_bad: + if not _has_block_type(content, "text"): + logger.warning(f"⚠️ Removing user msg with unmatched tool_result") + messages.pop(i) + pass_removed += 1 + continue + else: + before = len(content) + msg["content"] = [ + b for b in content + if not (isinstance(b, dict) and b.get("type") == "tool_result" + and b.get("tool_use_id") in bad_result) + ] + pass_removed += before - len(msg["content"]) + + i += 1 + + removed += pass_removed + if pass_removed == 0: + break if removed: logger.info(f"🔧 Message validation: removed {removed} broken message(s)")