fix: tool call match

This commit is contained in:
zhayujie
2026-03-12 17:05:27 +08:00
parent 153c9e3565
commit d78105d57c
2 changed files with 65 additions and 56 deletions

View File

@@ -196,6 +196,11 @@ class AgentStreamExecutor:
# are never stripped mid-execution (which would cause LLM loops).
self._trim_messages()
# Validate after trimming: trimming may leave orphaned tool_use at the
# boundary (e.g. the last kept turn ends with an assistant tool_use whose
# tool_result was in a discarded turn).
self._validate_and_fix_messages()
self._emit_event("agent_start")
final_response = ""

View File

@@ -70,67 +70,71 @@ def sanitize_claude_messages(messages: List[Dict]) -> int:
else:
break
# 3. Full scan: ensure every tool_result references a known tool_use id
known_ids: Set[str] = set()
i = 0
while i < len(messages):
msg = messages[i]
role = msg.get("role")
content = msg.get("content", [])
# 3. Iteratively remove unmatched tool_use / tool_result until stable.
# Removing one broken message can orphan others (e.g. an assistant msg
# with both matched and unmatched tool_use — deleting it orphans the
# previously-matched tool_result). Loop until clean.
for _ in range(5):
use_ids: Set[str] = set()
result_ids: Set[str] = set()
for msg in messages:
for block in (msg.get("content") or []):
if not isinstance(block, dict):
continue
if block.get("type") == "tool_use" and block.get("id"):
use_ids.add(block["id"])
elif block.get("type") == "tool_result" and block.get("tool_use_id"):
result_ids.add(block["tool_use_id"])
if role == "assistant" and isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "tool_use":
tid = block.get("id", "")
if tid:
known_ids.add(tid)
bad_use = use_ids - result_ids
bad_result = result_ids - use_ids
if not bad_use and not bad_result:
break
elif role == "user" and isinstance(content, list):
if not _has_block_type(content, "tool_result"):
pass_removed = 0
i = 0
while i < len(messages):
msg = messages[i]
role = msg.get("role")
content = msg.get("content", [])
if not isinstance(content, list):
i += 1
continue
orphaned = [
b.get("tool_use_id", "")
for b in content
if isinstance(b, dict)
and b.get("type") == "tool_result"
and b.get("tool_use_id", "")
and b.get("tool_use_id", "") not in known_ids
]
if orphaned:
orphaned_set = set(orphaned)
if not _has_block_type(content, "text"):
logger.warning(
f"⚠️ Removing orphaned tool_result message (tool_ids: {orphaned})"
)
messages.pop(i)
removed += 1
# Also remove a preceding broken assistant tool_use message
if i > 0 and messages[i - 1].get("role") == "assistant":
prev = messages[i - 1].get("content", [])
if isinstance(prev, list) and _has_block_type(prev, "tool_use"):
messages.pop(i - 1)
removed += 1
i -= 1
continue
else:
new_content = [
b for b in content
if not (
isinstance(b, dict)
and b.get("type") == "tool_result"
and b.get("tool_use_id", "") in orphaned_set
)
]
delta = len(content) - len(new_content)
if delta:
logger.warning(
f"⚠️ Stripped {delta} orphaned tool_result block(s) from mixed message"
)
msg["content"] = new_content
removed += delta
i += 1
if role == "assistant" and bad_use and any(
isinstance(b, dict) and b.get("type") == "tool_use"
and b.get("id") in bad_use for b in content
):
logger.warning(f"⚠️ Removing assistant msg with unmatched tool_use")
messages.pop(i)
pass_removed += 1
continue
if role == "user" and bad_result and _has_block_type(content, "tool_result"):
has_bad = any(
isinstance(b, dict) and b.get("type") == "tool_result"
and b.get("tool_use_id") in bad_result for b in content
)
if has_bad:
if not _has_block_type(content, "text"):
logger.warning(f"⚠️ Removing user msg with unmatched tool_result")
messages.pop(i)
pass_removed += 1
continue
else:
before = len(content)
msg["content"] = [
b for b in content
if not (isinstance(b, dict) and b.get("type") == "tool_result"
and b.get("tool_use_id") in bad_result)
]
pass_removed += before - len(msg["content"])
i += 1
removed += pass_removed
if pass_removed == 0:
break
if removed:
logger.info(f"🔧 Message validation: removed {removed} broken message(s)")