mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-04-17 17:43:01 +08:00
fix: tool call match
This commit is contained in:
@@ -196,6 +196,11 @@ class AgentStreamExecutor:
|
||||
# are never stripped mid-execution (which would cause LLM loops).
|
||||
self._trim_messages()
|
||||
|
||||
# Validate after trimming: trimming may leave orphaned tool_use at the
|
||||
# boundary (e.g. the last kept turn ends with an assistant tool_use whose
|
||||
# tool_result was in a discarded turn).
|
||||
self._validate_and_fix_messages()
|
||||
|
||||
self._emit_event("agent_start")
|
||||
|
||||
final_response = ""
|
||||
|
||||
@@ -70,67 +70,71 @@ def sanitize_claude_messages(messages: List[Dict]) -> int:
|
||||
else:
|
||||
break
|
||||
|
||||
# 3. Full scan: ensure every tool_result references a known tool_use id
|
||||
known_ids: Set[str] = set()
|
||||
i = 0
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", [])
|
||||
# 3. Iteratively remove unmatched tool_use / tool_result until stable.
|
||||
# Removing one broken message can orphan others (e.g. an assistant msg
|
||||
# with both matched and unmatched tool_use — deleting it orphans the
|
||||
# previously-matched tool_result). Loop until clean.
|
||||
for _ in range(5):
|
||||
use_ids: Set[str] = set()
|
||||
result_ids: Set[str] = set()
|
||||
for msg in messages:
|
||||
for block in (msg.get("content") or []):
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") == "tool_use" and block.get("id"):
|
||||
use_ids.add(block["id"])
|
||||
elif block.get("type") == "tool_result" and block.get("tool_use_id"):
|
||||
result_ids.add(block["tool_use_id"])
|
||||
|
||||
if role == "assistant" and isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||
tid = block.get("id", "")
|
||||
if tid:
|
||||
known_ids.add(tid)
|
||||
bad_use = use_ids - result_ids
|
||||
bad_result = result_ids - use_ids
|
||||
if not bad_use and not bad_result:
|
||||
break
|
||||
|
||||
elif role == "user" and isinstance(content, list):
|
||||
if not _has_block_type(content, "tool_result"):
|
||||
pass_removed = 0
|
||||
i = 0
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
i += 1
|
||||
continue
|
||||
|
||||
orphaned = [
|
||||
b.get("tool_use_id", "")
|
||||
for b in content
|
||||
if isinstance(b, dict)
|
||||
and b.get("type") == "tool_result"
|
||||
and b.get("tool_use_id", "")
|
||||
and b.get("tool_use_id", "") not in known_ids
|
||||
]
|
||||
if orphaned:
|
||||
orphaned_set = set(orphaned)
|
||||
if not _has_block_type(content, "text"):
|
||||
logger.warning(
|
||||
f"⚠️ Removing orphaned tool_result message (tool_ids: {orphaned})"
|
||||
)
|
||||
messages.pop(i)
|
||||
removed += 1
|
||||
# Also remove a preceding broken assistant tool_use message
|
||||
if i > 0 and messages[i - 1].get("role") == "assistant":
|
||||
prev = messages[i - 1].get("content", [])
|
||||
if isinstance(prev, list) and _has_block_type(prev, "tool_use"):
|
||||
messages.pop(i - 1)
|
||||
removed += 1
|
||||
i -= 1
|
||||
continue
|
||||
else:
|
||||
new_content = [
|
||||
b for b in content
|
||||
if not (
|
||||
isinstance(b, dict)
|
||||
and b.get("type") == "tool_result"
|
||||
and b.get("tool_use_id", "") in orphaned_set
|
||||
)
|
||||
]
|
||||
delta = len(content) - len(new_content)
|
||||
if delta:
|
||||
logger.warning(
|
||||
f"⚠️ Stripped {delta} orphaned tool_result block(s) from mixed message"
|
||||
)
|
||||
msg["content"] = new_content
|
||||
removed += delta
|
||||
i += 1
|
||||
if role == "assistant" and bad_use and any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_use"
|
||||
and b.get("id") in bad_use for b in content
|
||||
):
|
||||
logger.warning(f"⚠️ Removing assistant msg with unmatched tool_use")
|
||||
messages.pop(i)
|
||||
pass_removed += 1
|
||||
continue
|
||||
|
||||
if role == "user" and bad_result and _has_block_type(content, "tool_result"):
|
||||
has_bad = any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
and b.get("tool_use_id") in bad_result for b in content
|
||||
)
|
||||
if has_bad:
|
||||
if not _has_block_type(content, "text"):
|
||||
logger.warning(f"⚠️ Removing user msg with unmatched tool_result")
|
||||
messages.pop(i)
|
||||
pass_removed += 1
|
||||
continue
|
||||
else:
|
||||
before = len(content)
|
||||
msg["content"] = [
|
||||
b for b in content
|
||||
if not (isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
and b.get("tool_use_id") in bad_result)
|
||||
]
|
||||
pass_removed += before - len(msg["content"])
|
||||
|
||||
i += 1
|
||||
|
||||
removed += pass_removed
|
||||
if pass_removed == 0:
|
||||
break
|
||||
|
||||
if removed:
|
||||
logger.info(f"🔧 Message validation: removed {removed} broken message(s)")
|
||||
|
||||
Reference in New Issue
Block a user