mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-03-19 21:38:18 +08:00
feat: improve the memory system
This commit is contained in:
@@ -6,5 +6,6 @@ Provides long-term memory capabilities with hybrid search (vector + keyword)
|
||||
|
||||
from agent.memory.manager import MemoryManager
|
||||
from agent.memory.config import MemoryConfig, get_default_memory_config, set_global_memory_config
|
||||
from agent.memory.embedding import create_embedding_provider
|
||||
|
||||
__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config']
|
||||
__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config', 'create_embedding_provider']
|
||||
|
||||
@@ -41,6 +41,10 @@ class MemoryConfig:
|
||||
enable_auto_sync: bool = True
|
||||
sync_on_search: bool = True
|
||||
|
||||
# Memory flush config (独立于模型 context window)
|
||||
flush_token_threshold: int = 50000 # 50K tokens 触发 flush
|
||||
flush_turn_threshold: int = 20 # 20 轮对话触发 flush (用户+AI各一条为一轮)
|
||||
|
||||
def get_workspace(self) -> Path:
|
||||
"""Get workspace root directory"""
|
||||
return Path(self.workspace_root)
|
||||
|
||||
@@ -4,20 +4,19 @@ Embedding providers for memory
|
||||
Supports OpenAI and local embedding models
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
import hashlib
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Base class for embedding providers"""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
@@ -31,7 +30,7 @@ class EmbeddingProvider(ABC):
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""OpenAI embedding provider"""
|
||||
"""OpenAI embedding provider using REST API"""
|
||||
|
||||
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, api_base: Optional[str] = None):
|
||||
"""
|
||||
@@ -45,87 +44,58 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://api.openai.com/v1"
|
||||
|
||||
# Lazy import to avoid dependency issues
|
||||
try:
|
||||
from openai import OpenAI
|
||||
self.client = OpenAI(api_key=api_key, base_url=api_base)
|
||||
except ImportError:
|
||||
raise ImportError("OpenAI package not installed. Install with: pip install openai")
|
||||
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("OpenAI API key is required")
|
||||
|
||||
# Set dimensions based on model
|
||||
self._dimensions = 1536 if "small" in model else 3072
|
||||
|
||||
|
||||
def _call_api(self, input_data):
|
||||
"""Call OpenAI embedding API using requests"""
|
||||
import requests
|
||||
|
||||
url = f"{self.api_base}/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
data = {
|
||||
"input": input_data,
|
||||
"model": self.model
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
response = self.client.embeddings.create(
|
||||
input=text,
|
||||
model=self.model
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
result = self._call_api(text)
|
||||
return result["data"][0]["embedding"]
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
response = self.client.embeddings.create(
|
||||
input=texts,
|
||||
model=self.model
|
||||
)
|
||||
return [item.embedding for item in response.data]
|
||||
|
||||
|
||||
result = self._call_api(texts)
|
||||
return [item["embedding"] for item in result["data"]]
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
|
||||
|
||||
class LocalEmbeddingProvider(EmbeddingProvider):
|
||||
"""Local embedding provider using sentence-transformers"""
|
||||
|
||||
def __init__(self, model: str = "all-MiniLM-L6-v2"):
|
||||
"""
|
||||
Initialize local embedding provider
|
||||
|
||||
Args:
|
||||
model: Model name from sentence-transformers
|
||||
"""
|
||||
self.model_name = model
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self.model = SentenceTransformer(model)
|
||||
self._dimensions = self.model.get_sentence_embedding_dimension()
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"sentence-transformers not installed. "
|
||||
"Install with: pip install sentence-transformers"
|
||||
)
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
embedding = self.model.encode(text, convert_to_numpy=True)
|
||||
return embedding.tolist()
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
embeddings = self.model.encode(texts, convert_to_numpy=True)
|
||||
return embeddings.tolist()
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
# LocalEmbeddingProvider removed - only use OpenAI embedding or keyword search
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""Cache for embeddings to avoid recomputation"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
|
||||
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
|
||||
"""Get cached embedding"""
|
||||
key = self._compute_key(text, provider, model)
|
||||
@@ -156,20 +126,23 @@ def create_embedding_provider(
|
||||
"""
|
||||
Factory function to create embedding provider
|
||||
|
||||
Only supports OpenAI embedding via REST API.
|
||||
If initialization fails, caller should fall back to keyword-only search.
|
||||
|
||||
Args:
|
||||
provider: Provider name ("openai" or "local")
|
||||
model: Model name (provider-specific)
|
||||
api_key: API key for remote providers
|
||||
api_base: API base URL for remote providers
|
||||
provider: Provider name (only "openai" is supported)
|
||||
model: Model name (default: text-embedding-3-small)
|
||||
api_key: OpenAI API key (required)
|
||||
api_base: API base URL (default: https://api.openai.com/v1)
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not "openai" or api_key is missing
|
||||
"""
|
||||
if provider == "openai":
|
||||
model = model or "text-embedding-3-small"
|
||||
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
|
||||
elif provider == "local":
|
||||
model = model or "all-MiniLM-L6-v2"
|
||||
return LocalEmbeddingProvider(model=model)
|
||||
else:
|
||||
raise ValueError(f"Unknown embedding provider: {provider}")
|
||||
if provider != "openai":
|
||||
raise ValueError(f"Only 'openai' provider is supported, got: {provider}")
|
||||
|
||||
model = model or "text-embedding-3-small"
|
||||
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
|
||||
|
||||
@@ -70,8 +70,9 @@ class MemoryManager:
|
||||
except Exception as e:
|
||||
# Embedding provider failed, but that's OK
|
||||
# We can still use keyword search and file operations
|
||||
print(f"⚠️ Warning: Embedding provider initialization failed: {e}")
|
||||
print(f"ℹ️ Memory will work with keyword search only (no semantic search)")
|
||||
from common.log import logger
|
||||
logger.warning(f"[MemoryManager] Embedding provider initialization failed: {e}")
|
||||
logger.info(f"[MemoryManager] Memory will work with keyword search only (no vector search)")
|
||||
|
||||
# Initialize memory flush manager
|
||||
workspace_dir = self.config.get_workspace()
|
||||
@@ -135,13 +136,19 @@ class MemoryManager:
|
||||
# Perform vector search (if embedding provider available)
|
||||
vector_results = []
|
||||
if self.embedding_provider:
|
||||
query_embedding = self.embedding_provider.embed(query)
|
||||
vector_results = self.storage.search_vector(
|
||||
query_embedding=query_embedding,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
limit=max_results * 2 # Get more candidates for merging
|
||||
)
|
||||
try:
|
||||
from common.log import logger
|
||||
query_embedding = self.embedding_provider.embed(query)
|
||||
vector_results = self.storage.search_vector(
|
||||
query_embedding=query_embedding,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
limit=max_results * 2 # Get more candidates for merging
|
||||
)
|
||||
logger.info(f"[MemoryManager] Vector search found {len(vector_results)} results for query: {query}")
|
||||
except Exception as e:
|
||||
from common.log import logger
|
||||
logger.warning(f"[MemoryManager] Vector search failed: {e}")
|
||||
|
||||
# Perform keyword search
|
||||
keyword_results = self.storage.search_keyword(
|
||||
@@ -150,6 +157,8 @@ class MemoryManager:
|
||||
scopes=scopes,
|
||||
limit=max_results * 2
|
||||
)
|
||||
from common.log import logger
|
||||
logger.info(f"[MemoryManager] Keyword search found {len(keyword_results)} results for query: {query}")
|
||||
|
||||
# Merge results
|
||||
merged = self._merge_results(
|
||||
@@ -356,30 +365,30 @@ class MemoryManager:
|
||||
|
||||
def should_flush_memory(
|
||||
self,
|
||||
current_tokens: int,
|
||||
context_window: int = 128000,
|
||||
reserve_tokens: int = 20000,
|
||||
soft_threshold: int = 4000
|
||||
current_tokens: int = 0
|
||||
) -> bool:
|
||||
"""
|
||||
Check if memory flush should be triggered
|
||||
|
||||
独立的 flush 触发机制,不依赖模型 context window。
|
||||
使用配置中的阈值: flush_token_threshold 和 flush_turn_threshold
|
||||
|
||||
Args:
|
||||
current_tokens: Current session token count
|
||||
context_window: Model's context window size (default: 128K)
|
||||
reserve_tokens: Reserve tokens for compaction overhead (default: 20K)
|
||||
soft_threshold: Trigger N tokens before threshold (default: 4K)
|
||||
|
||||
Returns:
|
||||
True if memory flush should run
|
||||
"""
|
||||
return self.flush_manager.should_flush(
|
||||
current_tokens=current_tokens,
|
||||
context_window=context_window,
|
||||
reserve_tokens=reserve_tokens,
|
||||
soft_threshold=soft_threshold
|
||||
token_threshold=self.config.flush_token_threshold,
|
||||
turn_threshold=self.config.flush_turn_threshold
|
||||
)
|
||||
|
||||
def increment_turn(self):
|
||||
"""增加对话轮数计数(每次用户消息+AI回复算一轮)"""
|
||||
self.flush_manager.increment_turn()
|
||||
|
||||
async def execute_memory_flush(
|
||||
self,
|
||||
agent_executor,
|
||||
|
||||
@@ -41,46 +41,42 @@ class MemoryFlushManager:
|
||||
# Tracking
|
||||
self.last_flush_token_count: Optional[int] = None
|
||||
self.last_flush_timestamp: Optional[datetime] = None
|
||||
self.turn_count: int = 0 # 对话轮数计数器
|
||||
|
||||
def should_flush(
|
||||
self,
|
||||
current_tokens: int,
|
||||
context_window: int,
|
||||
reserve_tokens: int = 20000,
|
||||
soft_threshold: int = 4000
|
||||
current_tokens: int = 0,
|
||||
token_threshold: int = 50000,
|
||||
turn_threshold: int = 20
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if memory flush should be triggered
|
||||
|
||||
Similar to clawdbot's shouldRunMemoryFlush logic:
|
||||
threshold = contextWindow - reserveTokens - softThreshold
|
||||
独立的 flush 触发机制,不依赖模型 context window:
|
||||
- Token 阈值: 达到 50K tokens 时触发
|
||||
- 轮次阈值: 达到 20 轮对话时触发
|
||||
|
||||
Args:
|
||||
current_tokens: Current session token count
|
||||
context_window: Model's context window size
|
||||
reserve_tokens: Reserve tokens for compaction overhead
|
||||
soft_threshold: Trigger flush N tokens before threshold
|
||||
token_threshold: Token threshold to trigger flush (default: 50K)
|
||||
turn_threshold: Turn threshold to trigger flush (default: 20)
|
||||
|
||||
Returns:
|
||||
True if flush should run
|
||||
"""
|
||||
if current_tokens <= 0:
|
||||
return False
|
||||
# 检查 token 阈值
|
||||
if current_tokens > 0 and current_tokens >= token_threshold:
|
||||
# 避免重复 flush
|
||||
if self.last_flush_token_count is not None:
|
||||
if current_tokens <= self.last_flush_token_count + 5000:
|
||||
return False
|
||||
return True
|
||||
|
||||
threshold = max(0, context_window - reserve_tokens - soft_threshold)
|
||||
if threshold <= 0:
|
||||
return False
|
||||
# 检查轮次阈值
|
||||
if self.turn_count >= turn_threshold:
|
||||
return True
|
||||
|
||||
# Check if we've crossed the threshold
|
||||
if current_tokens < threshold:
|
||||
return False
|
||||
|
||||
# Avoid duplicate flush in same compaction cycle
|
||||
if self.last_flush_token_count is not None:
|
||||
if current_tokens <= self.last_flush_token_count + soft_threshold:
|
||||
return False
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_today_memory_file(self, user_id: Optional[str] = None) -> Path:
|
||||
"""
|
||||
@@ -130,7 +126,12 @@ class MemoryFlushManager:
|
||||
f"Pre-compaction memory flush. "
|
||||
f"Store durable memories now (use memory/{today}.md for daily notes; "
|
||||
f"create memory/ if needed). "
|
||||
f"If nothing to store, reply with NO_REPLY."
|
||||
f"\n\n"
|
||||
f"重要提示:\n"
|
||||
f"- MEMORY.md: 记录最核心、最常用的信息(例如重要规则、偏好、决策、要求等)\n"
|
||||
f" 如果 MEMORY.md 过长,可以精简或移除不再重要的内容。避免冗长描述,用关键词和要点形式记录\n"
|
||||
f"- memory/{today}.md: 记录当天发生的事件、关键信息、经验教训、对话过程摘要等,突出重点\n"
|
||||
f"- 如果没有重要内容需要记录,回复 NO_REPLY\n"
|
||||
)
|
||||
|
||||
def create_flush_system_prompt(self) -> str:
|
||||
@@ -142,6 +143,20 @@ class MemoryFlushManager:
|
||||
return (
|
||||
"Pre-compaction memory flush turn. "
|
||||
"The session is near auto-compaction; capture durable memories to disk. "
|
||||
"\n\n"
|
||||
"记忆写入原则:\n"
|
||||
"1. MEMORY.md 精简原则: 只记录核心信息(<2000 tokens)\n"
|
||||
" - 记录重要规则、偏好、决策、要求等需要长期记住的关键信息,无需记录过多细节\n"
|
||||
" - 如果 MEMORY.md 过长,可以根据需要精简或删除过时内容\n"
|
||||
"\n"
|
||||
"2. 天级记忆 (memory/YYYY-MM-DD.md):\n"
|
||||
" - 记录当天的重要事件、关键信息、经验教训、对话过程摘要等,确保核心信息点被完整记录\n"
|
||||
"\n"
|
||||
"3. 判断标准:\n"
|
||||
" - 这个信息未来会经常用到吗?→ MEMORY.md\n"
|
||||
" - 这是今天的重要事件或决策吗?→ memory/YYYY-MM-DD.md\n"
|
||||
" - 这是临时性的、不重要的内容吗?→ 不记录\n"
|
||||
"\n"
|
||||
"You may reply, but usually NO_REPLY is correct."
|
||||
)
|
||||
|
||||
@@ -180,6 +195,7 @@ class MemoryFlushManager:
|
||||
# Track flush
|
||||
self.last_flush_token_count = current_tokens
|
||||
self.last_flush_timestamp = datetime.now()
|
||||
self.turn_count = 0 # 重置轮数计数器
|
||||
|
||||
return True
|
||||
|
||||
@@ -187,6 +203,10 @@ class MemoryFlushManager:
|
||||
print(f"Memory flush failed: {e}")
|
||||
return False
|
||||
|
||||
def increment_turn(self):
|
||||
"""增加对话轮数计数"""
|
||||
self.turn_count += 1
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""Get memory flush status"""
|
||||
return {
|
||||
|
||||
@@ -179,8 +179,8 @@ def _build_tooling_section(tools: List[Any], language: str) -> List[str]:
|
||||
tool_map = {}
|
||||
tool_descriptions = {
|
||||
"read": "读取文件内容",
|
||||
"write": "创建或覆盖文件",
|
||||
"edit": "精确编辑文件内容",
|
||||
"write": "创建新文件或完全覆盖现有文件(会删除原内容!追加内容请用 edit)",
|
||||
"edit": "精确编辑文件(追加、修改、删除部分内容)",
|
||||
"ls": "列出目录内容",
|
||||
"grep": "在文件中搜索内容",
|
||||
"find": "按照模式查找文件",
|
||||
@@ -305,17 +305,18 @@ def _build_memory_section(memory_manager: Any, tools: Optional[List[Any]], langu
|
||||
"",
|
||||
"在回答关于以前的工作、决定、日期、人物、偏好或待办事项的任何问题之前:",
|
||||
"",
|
||||
"1. 使用 `memory_search` 在 MEMORY.md 和 memory/*.md 中搜索",
|
||||
"2. 然后使用 `memory_get` 只拉取需要的行",
|
||||
"3. 如果搜索后仍然信心不足,告诉用户你已经检查过了",
|
||||
"1. 不确定信息位置 → 先用 `memory_search` 通过关键词和语义检索相关内容",
|
||||
"2. 已知文件和大致位置 → 直接用 `memory_get` 读取相应的行",
|
||||
"3. search 无结果 → 尝试用 `memory_get` 读取最近两天的记忆文件",
|
||||
"",
|
||||
"**记忆文件结构**:",
|
||||
"- `MEMORY.md`: 长期记忆,包含重要的背景信息",
|
||||
"- `memory/YYYY-MM-DD.md`: 每日记忆,记录当天的对话和事件",
|
||||
"- `MEMORY.md`: 长期记忆(已自动加载,无需主动读取)",
|
||||
"- `memory/YYYY-MM-DD.md`: 每日记忆,记录当天的事件和对话信息",
|
||||
"",
|
||||
"**使用原则**:",
|
||||
"- 自然使用记忆,就像你本来就知道",
|
||||
"- 不要主动提起或列举记忆,除非用户明确询问",
|
||||
"- 自然使用记忆,就像你本来就知道; 不用刻意提起或列举记忆,除非用户提起相关内容",
|
||||
"- 追加内容到现有记忆文件 → 必须用 `edit` 工具(先 read 读取,再 edit 追加)",
|
||||
"- 创建新的记忆文件 → 可以用 `write` 工具(已有记忆文件不可直接write,会覆盖删除)",
|
||||
"",
|
||||
]
|
||||
|
||||
|
||||
@@ -7,9 +7,9 @@ import json
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional, Callable
|
||||
|
||||
from common.log import logger
|
||||
from agent.protocol.models import LLMRequest, LLMModel
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class AgentStreamExecutor:
|
||||
@@ -164,30 +164,24 @@ class AgentStreamExecutor:
|
||||
self._emit_event("turn_start", {"turn": turn})
|
||||
|
||||
# Check if memory flush is needed (before calling LLM)
|
||||
# 使用独立的 flush 阈值(50K tokens 或 20 轮)
|
||||
if self.agent.memory_manager and hasattr(self.agent, 'last_usage'):
|
||||
usage = self.agent.last_usage
|
||||
if usage and 'input_tokens' in usage:
|
||||
current_tokens = usage.get('input_tokens', 0)
|
||||
context_window = self.agent._get_model_context_window()
|
||||
# Use configured reserve_tokens or calculate based on context window
|
||||
reserve_tokens = self.agent._get_context_reserve_tokens()
|
||||
# Use smaller soft_threshold to trigger flush earlier (e.g., at 50K tokens)
|
||||
soft_threshold = 10000 # Trigger 10K tokens before limit
|
||||
|
||||
if self.agent.memory_manager.should_flush_memory(
|
||||
current_tokens=current_tokens,
|
||||
context_window=context_window,
|
||||
reserve_tokens=reserve_tokens,
|
||||
soft_threshold=soft_threshold
|
||||
current_tokens=current_tokens
|
||||
):
|
||||
self._emit_event("memory_flush_start", {
|
||||
"current_tokens": current_tokens,
|
||||
"threshold": context_window - reserve_tokens - soft_threshold
|
||||
"turn_count": self.agent.memory_manager.flush_manager.turn_count
|
||||
})
|
||||
|
||||
# TODO: Execute memory flush in background
|
||||
# This would require async support
|
||||
logger.info(f"Memory flush recommended at {current_tokens} tokens")
|
||||
logger.info(
|
||||
f"Memory flush recommended: tokens={current_tokens}, turns={self.agent.memory_manager.flush_manager.turn_count}")
|
||||
|
||||
# Call LLM
|
||||
assistant_msg, tool_calls = self._call_llm_stream()
|
||||
@@ -321,6 +315,10 @@ class AgentStreamExecutor:
|
||||
logger.info(f"🏁 完成({turn}轮)")
|
||||
self._emit_event("agent_end", {"final_response": final_response})
|
||||
|
||||
# 每轮对话结束后增加计数(用户消息+AI回复=1轮)
|
||||
if self.agent.memory_manager:
|
||||
self.agent.memory_manager.increment_turn()
|
||||
|
||||
return final_response
|
||||
|
||||
def _call_llm_stream(self, retry_on_empty=True, retry_count=0, max_retries=3) -> tuple[str, List[Dict]]:
|
||||
@@ -664,9 +662,11 @@ class AgentStreamExecutor:
|
||||
if not self.messages or not self.agent:
|
||||
return
|
||||
|
||||
# Get context window and reserve tokens from agent
|
||||
# Get context window from agent (based on model)
|
||||
context_window = self.agent._get_model_context_window()
|
||||
reserve_tokens = self.agent._get_context_reserve_tokens()
|
||||
|
||||
# Reserve 10% for response generation
|
||||
reserve_tokens = int(context_window * 0.1)
|
||||
max_tokens = context_window - reserve_tokens
|
||||
|
||||
# Estimate current tokens
|
||||
|
||||
@@ -2,25 +2,17 @@
|
||||
from agent.tools.base_tool import BaseTool
|
||||
from agent.tools.tool_manager import ToolManager
|
||||
|
||||
# Import basic tools (no external dependencies)
|
||||
from agent.tools.calculator.calculator import Calculator
|
||||
|
||||
# Import file operation tools
|
||||
from agent.tools.read.read import Read
|
||||
from agent.tools.write.write import Write
|
||||
from agent.tools.edit.edit import Edit
|
||||
from agent.tools.bash.bash import Bash
|
||||
from agent.tools.grep.grep import Grep
|
||||
from agent.tools.find.find import Find
|
||||
from agent.tools.ls.ls import Ls
|
||||
|
||||
# Import memory tools
|
||||
from agent.tools.memory.memory_search import MemorySearchTool
|
||||
from agent.tools.memory.memory_get import MemoryGetTool
|
||||
|
||||
# Import web tools
|
||||
from agent.tools.web_fetch.web_fetch import WebFetch
|
||||
|
||||
# Import tools with optional dependencies
|
||||
def _import_optional_tools():
|
||||
"""Import tools that have optional dependencies"""
|
||||
@@ -80,17 +72,13 @@ BrowserTool = _import_browser_tool()
|
||||
__all__ = [
|
||||
'BaseTool',
|
||||
'ToolManager',
|
||||
'Calculator',
|
||||
'Read',
|
||||
'Write',
|
||||
'Edit',
|
||||
'Bash',
|
||||
'Grep',
|
||||
'Find',
|
||||
'Ls',
|
||||
'MemorySearchTool',
|
||||
'MemoryGetTool',
|
||||
'WebFetch',
|
||||
# Optional tools (may be None if dependencies not available)
|
||||
'GoogleSearch',
|
||||
'FileSave',
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
import math
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
|
||||
|
||||
class Calculator(BaseTool):
|
||||
name: str = "calculator"
|
||||
description: str = "A tool to perform basic mathematical calculations."
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "The mathematical expression to evaluate (e.g., '2 + 2', '5 * 3', 'sqrt(16)'). "
|
||||
"Ensure your input is a valid Python expression, it will be evaluated directly."
|
||||
}
|
||||
},
|
||||
"required": ["expression"]
|
||||
}
|
||||
config: dict = {}
|
||||
|
||||
def execute(self, args: dict) -> ToolResult:
|
||||
try:
|
||||
# Get the expression
|
||||
expression = args["expression"]
|
||||
|
||||
# Create a safe local environment containing only basic math functions
|
||||
safe_locals = {
|
||||
"abs": abs,
|
||||
"round": round,
|
||||
"max": max,
|
||||
"min": min,
|
||||
"pow": pow,
|
||||
"sqrt": math.sqrt,
|
||||
"sin": math.sin,
|
||||
"cos": math.cos,
|
||||
"tan": math.tan,
|
||||
"pi": math.pi,
|
||||
"e": math.e,
|
||||
"log": math.log,
|
||||
"log10": math.log10,
|
||||
"exp": math.exp,
|
||||
"floor": math.floor,
|
||||
"ceil": math.ceil
|
||||
}
|
||||
|
||||
# Safely evaluate the expression
|
||||
result = eval(expression, {"__builtins__": {}}, safe_locals)
|
||||
|
||||
return ToolResult.success({
|
||||
"result": result,
|
||||
"expression": expression
|
||||
})
|
||||
except Exception as e:
|
||||
return ToolResult.success({
|
||||
"error": str(e),
|
||||
"expression": args.get("expression", "")
|
||||
})
|
||||
@@ -33,7 +33,7 @@ class Edit(BaseTool):
|
||||
},
|
||||
"oldText": {
|
||||
"type": "string",
|
||||
"description": "Exact text to find and replace (must match exactly)"
|
||||
"description": "Exact text to find and replace (must match exactly, cannot be empty). To append to end of file, include the last few lines as oldText."
|
||||
},
|
||||
"newText": {
|
||||
"type": "string",
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from .find import Find
|
||||
|
||||
__all__ = ['Find']
|
||||
@@ -1,177 +0,0 @@
|
||||
"""
|
||||
Find tool - Search for files by glob pattern
|
||||
"""
|
||||
|
||||
import os
|
||||
import glob as glob_module
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES
|
||||
|
||||
|
||||
DEFAULT_LIMIT = 1000
|
||||
|
||||
|
||||
class Find(BaseTool):
|
||||
"""Tool for finding files by pattern"""
|
||||
|
||||
name: str = "find"
|
||||
description: str = f"Search for files by glob pattern. Returns matching file paths relative to the search directory. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} results or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first)."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Glob pattern to match files, e.g. '*.ts', '**/*.json', or 'src/**/*.spec.ts'"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory to search in (default: current directory)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": f"Maximum number of results (default: {DEFAULT_LIMIT})"
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file search
|
||||
|
||||
:param args: Search parameters
|
||||
:return: Search results or error
|
||||
"""
|
||||
pattern = args.get("pattern", "").strip()
|
||||
search_path = args.get("path", ".").strip()
|
||||
limit = args.get("limit", DEFAULT_LIMIT)
|
||||
|
||||
if not pattern:
|
||||
return ToolResult.fail("Error: pattern parameter is required")
|
||||
|
||||
# Resolve search path
|
||||
absolute_path = self._resolve_path(search_path)
|
||||
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: Path not found: {search_path}")
|
||||
|
||||
if not os.path.isdir(absolute_path):
|
||||
return ToolResult.fail(f"Error: Not a directory: {search_path}")
|
||||
|
||||
try:
|
||||
# Load .gitignore patterns
|
||||
ignore_patterns = self._load_gitignore(absolute_path)
|
||||
|
||||
# Search for files
|
||||
results = []
|
||||
search_pattern = os.path.join(absolute_path, pattern)
|
||||
|
||||
# Use glob with recursive support
|
||||
for file_path in glob_module.glob(search_pattern, recursive=True):
|
||||
# Skip if matches ignore patterns
|
||||
if self._should_ignore(file_path, absolute_path, ignore_patterns):
|
||||
continue
|
||||
|
||||
# Get relative path
|
||||
relative_path = os.path.relpath(file_path, absolute_path)
|
||||
|
||||
# Add trailing slash for directories
|
||||
if os.path.isdir(file_path):
|
||||
relative_path += '/'
|
||||
|
||||
results.append(relative_path)
|
||||
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
if not results:
|
||||
return ToolResult.success({"message": "No files found matching pattern", "files": []})
|
||||
|
||||
# Sort results
|
||||
results.sort()
|
||||
|
||||
# Format output
|
||||
raw_output = '\n'.join(results)
|
||||
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
|
||||
|
||||
output = truncation.content
|
||||
details = {}
|
||||
notices = []
|
||||
|
||||
result_limit_reached = len(results) >= limit
|
||||
if result_limit_reached:
|
||||
notices.append(f"{limit} results limit reached. Use limit={limit * 2} for more, or refine pattern")
|
||||
details["result_limit_reached"] = limit
|
||||
|
||||
if truncation.truncated:
|
||||
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
|
||||
details["truncation"] = truncation.to_dict()
|
||||
|
||||
if notices:
|
||||
output += f"\n\n[{'. '.join(notices)}]"
|
||||
|
||||
return ToolResult.success({
|
||||
"output": output,
|
||||
"file_count": len(results),
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error executing find: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
def _load_gitignore(self, directory: str) -> List[str]:
|
||||
"""Load .gitignore patterns from directory"""
|
||||
patterns = []
|
||||
gitignore_path = os.path.join(directory, '.gitignore')
|
||||
|
||||
if os.path.exists(gitignore_path):
|
||||
try:
|
||||
with open(gitignore_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
patterns.append(line)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Add common ignore patterns
|
||||
patterns.extend([
|
||||
'.git',
|
||||
'__pycache__',
|
||||
'*.pyc',
|
||||
'node_modules',
|
||||
'.DS_Store'
|
||||
])
|
||||
|
||||
return patterns
|
||||
|
||||
def _should_ignore(self, file_path: str, base_path: str, patterns: List[str]) -> bool:
|
||||
"""Check if file should be ignored based on patterns"""
|
||||
relative_path = os.path.relpath(file_path, base_path)
|
||||
|
||||
for pattern in patterns:
|
||||
# Simple pattern matching
|
||||
if pattern in relative_path:
|
||||
return True
|
||||
|
||||
# Check if it's a directory pattern
|
||||
if pattern.endswith('/'):
|
||||
if relative_path.startswith(pattern.rstrip('/')):
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -1,3 +0,0 @@
|
||||
from .grep import Grep
|
||||
|
||||
__all__ = ['Grep']
|
||||
@@ -1,248 +0,0 @@
|
||||
"""
|
||||
Grep tool - Search file contents for patterns
|
||||
Uses ripgrep (rg) for fast searching
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import (
|
||||
truncate_head, truncate_line, format_size,
|
||||
DEFAULT_MAX_BYTES, GREP_MAX_LINE_LENGTH
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_LIMIT = 100
|
||||
|
||||
|
||||
class Grep(BaseTool):
|
||||
"""Tool for searching file contents"""
|
||||
|
||||
name: str = "grep"
|
||||
description: str = f"Search file contents for a pattern. Returns matching lines with file paths and line numbers. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} matches or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). Long lines are truncated to {GREP_MAX_LINE_LENGTH} chars."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Search pattern (regex or literal string)"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory or file to search (default: current directory)"
|
||||
},
|
||||
"glob": {
|
||||
"type": "string",
|
||||
"description": "Filter files by glob pattern, e.g. '*.ts' or '**/*.spec.ts'"
|
||||
},
|
||||
"ignoreCase": {
|
||||
"type": "boolean",
|
||||
"description": "Case-insensitive search (default: false)"
|
||||
},
|
||||
"literal": {
|
||||
"type": "boolean",
|
||||
"description": "Treat pattern as literal string instead of regex (default: false)"
|
||||
},
|
||||
"context": {
|
||||
"type": "integer",
|
||||
"description": "Number of lines to show before and after each match (default: 0)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": f"Maximum number of matches to return (default: {DEFAULT_LIMIT})"
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
self.rg_path = self._find_ripgrep()
|
||||
|
||||
def _find_ripgrep(self) -> Optional[str]:
|
||||
"""Find ripgrep executable"""
|
||||
try:
|
||||
result = subprocess.run(['which', 'rg'], capture_output=True, text=True)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip()
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute grep search
|
||||
|
||||
:param args: Search parameters
|
||||
:return: Search results or error
|
||||
"""
|
||||
if not self.rg_path:
|
||||
return ToolResult.fail("Error: ripgrep (rg) is not installed. Please install it first.")
|
||||
|
||||
pattern = args.get("pattern", "").strip()
|
||||
search_path = args.get("path", ".").strip()
|
||||
glob = args.get("glob")
|
||||
ignore_case = args.get("ignoreCase", False)
|
||||
literal = args.get("literal", False)
|
||||
context = args.get("context", 0)
|
||||
limit = args.get("limit", DEFAULT_LIMIT)
|
||||
|
||||
if not pattern:
|
||||
return ToolResult.fail("Error: pattern parameter is required")
|
||||
|
||||
# Resolve search path
|
||||
absolute_path = self._resolve_path(search_path)
|
||||
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: Path not found: {search_path}")
|
||||
|
||||
# Build ripgrep command
|
||||
cmd = [
|
||||
self.rg_path,
|
||||
'--json',
|
||||
'--line-number',
|
||||
'--color=never',
|
||||
'--hidden'
|
||||
]
|
||||
|
||||
if ignore_case:
|
||||
cmd.append('--ignore-case')
|
||||
|
||||
if literal:
|
||||
cmd.append('--fixed-strings')
|
||||
|
||||
if glob:
|
||||
cmd.extend(['--glob', glob])
|
||||
|
||||
cmd.extend([pattern, absolute_path])
|
||||
|
||||
try:
|
||||
# Execute ripgrep
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
cwd=self.cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
# Parse JSON output
|
||||
matches = []
|
||||
match_count = 0
|
||||
|
||||
for line in result.stdout.splitlines():
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
event = json.loads(line)
|
||||
if event.get('type') == 'match':
|
||||
data = event.get('data', {})
|
||||
file_path = data.get('path', {}).get('text')
|
||||
line_number = data.get('line_number')
|
||||
|
||||
if file_path and line_number:
|
||||
matches.append({
|
||||
'file': file_path,
|
||||
'line': line_number
|
||||
})
|
||||
match_count += 1
|
||||
|
||||
if match_count >= limit:
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if match_count == 0:
|
||||
return ToolResult.success({"message": "No matches found", "matches": []})
|
||||
|
||||
# Format output with context
|
||||
output_lines = []
|
||||
lines_truncated = False
|
||||
is_directory = os.path.isdir(absolute_path)
|
||||
|
||||
for match in matches:
|
||||
file_path = match['file']
|
||||
line_number = match['line']
|
||||
|
||||
# Format file path
|
||||
if is_directory:
|
||||
relative_path = os.path.relpath(file_path, absolute_path)
|
||||
else:
|
||||
relative_path = os.path.basename(file_path)
|
||||
|
||||
# Read file and get context
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
file_lines = f.read().split('\n')
|
||||
|
||||
# Calculate context range
|
||||
start = max(0, line_number - 1 - context) if context > 0 else line_number - 1
|
||||
end = min(len(file_lines), line_number + context) if context > 0 else line_number
|
||||
|
||||
# Format lines with context
|
||||
for i in range(start, end):
|
||||
line_text = file_lines[i].replace('\r', '')
|
||||
|
||||
# Truncate long lines
|
||||
truncated_text, was_truncated = truncate_line(line_text)
|
||||
if was_truncated:
|
||||
lines_truncated = True
|
||||
|
||||
# Format output
|
||||
current_line = i + 1
|
||||
if current_line == line_number:
|
||||
output_lines.append(f"{relative_path}:{current_line}: {truncated_text}")
|
||||
else:
|
||||
output_lines.append(f"{relative_path}-{current_line}- {truncated_text}")
|
||||
|
||||
except Exception:
|
||||
output_lines.append(f"{relative_path}:{line_number}: (unable to read file)")
|
||||
|
||||
# Apply byte truncation
|
||||
raw_output = '\n'.join(output_lines)
|
||||
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
|
||||
|
||||
output = truncation.content
|
||||
details = {}
|
||||
notices = []
|
||||
|
||||
if match_count >= limit:
|
||||
notices.append(f"{limit} matches limit reached. Use limit={limit * 2} for more, or refine pattern")
|
||||
details["match_limit_reached"] = limit
|
||||
|
||||
if truncation.truncated:
|
||||
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
|
||||
details["truncation"] = truncation.to_dict()
|
||||
|
||||
if lines_truncated:
|
||||
notices.append(f"Some lines truncated to {GREP_MAX_LINE_LENGTH} chars. Use read tool to see full lines")
|
||||
details["lines_truncated"] = True
|
||||
|
||||
if notices:
|
||||
output += f"\n\n[{'. '.join(notices)}]"
|
||||
|
||||
return ToolResult.success({
|
||||
"output": output,
|
||||
"match_count": match_count,
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return ToolResult.fail("Error: Search timed out after 30 seconds")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error executing grep: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
@@ -4,8 +4,6 @@ Memory get tool
|
||||
Allows agents to read specific sections from memory files
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
from agent.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
@@ -22,7 +20,7 @@ class MemoryGetTool(BaseTool):
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Relative path to the memory file (e.g., 'MEMORY.md', 'memory/2026-01-01.md')"
|
||||
"description": "Relative path to the memory file (e.g. 'memory/2026-01-01.md')"
|
||||
},
|
||||
"start_line": {
|
||||
"type": "integer",
|
||||
@@ -70,7 +68,8 @@ class MemoryGetTool(BaseTool):
|
||||
workspace_dir = self.memory_manager.config.get_workspace()
|
||||
|
||||
# Auto-prepend memory/ if not present and not absolute path
|
||||
if not path.startswith('memory/') and not path.startswith('/'):
|
||||
# Exception: MEMORY.md is in the root directory
|
||||
if not path.startswith('memory/') and not path.startswith('/') and path != 'MEMORY.md':
|
||||
path = f'memory/{path}'
|
||||
|
||||
file_path = workspace_dir / path
|
||||
|
||||
@@ -1,212 +0,0 @@
|
||||
# WebFetch Tool
|
||||
|
||||
免费的网页抓取工具,无需 API Key,可直接抓取网页内容并提取可读文本。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- ✅ **完全免费** - 无需任何 API Key
|
||||
- 🌐 **智能提取** - 自动提取网页主要内容
|
||||
- 📝 **格式转换** - 支持 HTML → Markdown/Text
|
||||
- 🚀 **高性能** - 内置请求重试和超时控制
|
||||
- 🎯 **智能降级** - 优先使用 Readability,可降级到基础提取
|
||||
|
||||
## 安装依赖
|
||||
|
||||
### 基础功能(必需)
|
||||
```bash
|
||||
pip install requests
|
||||
```
|
||||
|
||||
### 增强功能(推荐)
|
||||
```bash
|
||||
# 安装 readability-lxml 以获得更好的内容提取效果
|
||||
pip install readability-lxml
|
||||
|
||||
# 安装 html2text 以获得更好的 Markdown 转换
|
||||
pip install html2text
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 1. 在代码中使用
|
||||
|
||||
```python
|
||||
from agent.tools.web_fetch import WebFetch
|
||||
|
||||
# 创建工具实例
|
||||
tool = WebFetch()
|
||||
|
||||
# 抓取网页(默认返回 Markdown 格式)
|
||||
result = tool.execute({
|
||||
"url": "https://example.com"
|
||||
})
|
||||
|
||||
# 抓取并转换为纯文本
|
||||
result = tool.execute({
|
||||
"url": "https://example.com",
|
||||
"extract_mode": "text",
|
||||
"max_chars": 5000
|
||||
})
|
||||
|
||||
if result.status == "success":
|
||||
data = result.result
|
||||
print(f"标题: {data['title']}")
|
||||
print(f"内容: {data['text']}")
|
||||
```
|
||||
|
||||
### 2. 在 Agent 中使用
|
||||
|
||||
工具会自动加载到 Agent 的工具列表中:
|
||||
|
||||
```python
|
||||
from agent.tools import WebFetch
|
||||
|
||||
tools = [
|
||||
WebFetch(),
|
||||
# ... 其他工具
|
||||
]
|
||||
|
||||
agent = create_agent(tools=tools)
|
||||
```
|
||||
|
||||
### 3. 通过 Skills 使用
|
||||
|
||||
创建一个 skill 文件 `skills/web-fetch/SKILL.md`:
|
||||
|
||||
```markdown
|
||||
---
|
||||
name: web-fetch
|
||||
emoji: 🌐
|
||||
always: true
|
||||
---
|
||||
|
||||
# 网页内容获取
|
||||
|
||||
使用 web_fetch 工具获取网页内容。
|
||||
|
||||
## 使用场景
|
||||
|
||||
- 需要读取某个网页的内容
|
||||
- 需要提取文章正文
|
||||
- 需要获取网页信息
|
||||
|
||||
## 示例
|
||||
|
||||
<example>
|
||||
用户: 帮我看看 https://example.com 这个网页讲了什么
|
||||
助手: <tool_use name="web_fetch">
|
||||
<url>https://example.com</url>
|
||||
<extract_mode>markdown</extract_mode>
|
||||
</tool_use>
|
||||
</example>
|
||||
```
|
||||
|
||||
## 参数说明
|
||||
|
||||
| 参数 | 类型 | 必需 | 默认值 | 说明 |
|
||||
|------|------|------|--------|------|
|
||||
| `url` | string | ✅ | - | 要抓取的 URL(http/https) |
|
||||
| `extract_mode` | string | ❌ | `markdown` | 提取模式:`markdown` 或 `text` |
|
||||
| `max_chars` | integer | ❌ | `50000` | 最大返回字符数(最小 100) |
|
||||
|
||||
## 返回结果
|
||||
|
||||
```python
|
||||
{
|
||||
"url": "https://example.com", # 最终 URL(处理重定向后)
|
||||
"status": 200, # HTTP 状态码
|
||||
"content_type": "text/html", # 内容类型
|
||||
"title": "Example Domain", # 页面标题
|
||||
"extractor": "readability", # 提取器:readability/basic/raw
|
||||
"extract_mode": "markdown", # 提取模式
|
||||
"text": "# Example Domain\n\n...", # 提取的文本内容
|
||||
"length": 1234, # 文本长度
|
||||
"truncated": false, # 是否被截断
|
||||
"warning": "..." # 警告信息(如果有)
|
||||
}
|
||||
```
|
||||
|
||||
## 与其他搜索工具的对比
|
||||
|
||||
| 工具 | 需要 API Key | 功能 | 成本 |
|
||||
|------|-------------|------|------|
|
||||
| `web_fetch` | ❌ 不需要 | 抓取指定 URL 的内容 | 免费 |
|
||||
| `web_search` (Brave) | ✅ 需要 | 搜索引擎查询 | 有免费额度 |
|
||||
| `web_search` (Perplexity) | ✅ 需要 | AI 搜索 + 引用 | 付费 |
|
||||
| `browser` | ❌ 不需要 | 完整浏览器自动化 | 免费但资源占用大 |
|
||||
| `google_search` | ✅ 需要 | Google 搜索 API | 付费 |
|
||||
|
||||
## 技术细节
|
||||
|
||||
### 内容提取策略
|
||||
|
||||
1. **Readability 模式**(推荐)
|
||||
- 使用 Mozilla 的 Readability 算法
|
||||
- 自动识别文章主体内容
|
||||
- 过滤广告、导航栏等噪音
|
||||
|
||||
2. **Basic 模式**(降级)
|
||||
- 简单的 HTML 标签清理
|
||||
- 正则表达式提取文本
|
||||
- 适用于简单页面
|
||||
|
||||
3. **Raw 模式**
|
||||
- 用于非 HTML 内容
|
||||
- 直接返回原始内容
|
||||
|
||||
### 错误处理
|
||||
|
||||
工具会自动处理以下情况:
|
||||
- ✅ HTTP 重定向(最多 3 次)
|
||||
- ✅ 请求超时(默认 30 秒)
|
||||
- ✅ 网络错误自动重试
|
||||
- ✅ 内容提取失败降级
|
||||
|
||||
## 测试
|
||||
|
||||
运行测试脚本:
|
||||
|
||||
```bash
|
||||
cd agent/tools/web_fetch
|
||||
python test_web_fetch.py
|
||||
```
|
||||
|
||||
## 配置选项
|
||||
|
||||
在创建工具时可以传入配置:
|
||||
|
||||
```python
|
||||
tool = WebFetch(config={
|
||||
"timeout": 30, # 请求超时时间(秒)
|
||||
"max_redirects": 3, # 最大重定向次数
|
||||
"user_agent": "..." # 自定义 User-Agent
|
||||
})
|
||||
```
|
||||
|
||||
## 常见问题
|
||||
|
||||
### Q: 为什么推荐安装 readability-lxml?
|
||||
|
||||
A: readability-lxml 提供更好的内容提取质量,能够:
|
||||
- 自动识别文章主体
|
||||
- 过滤广告和导航栏
|
||||
- 保留文章结构
|
||||
|
||||
没有它也能工作,但提取质量会下降。
|
||||
|
||||
### Q: 与 clawdbot 的 web_fetch 有什么区别?
|
||||
|
||||
A: 本实现参考了 clawdbot 的设计,主要区别:
|
||||
- Python 实现(clawdbot 是 TypeScript)
|
||||
- 简化了一些高级特性(如 Firecrawl 集成)
|
||||
- 保留了核心的免费功能
|
||||
- 更容易集成到现有项目
|
||||
|
||||
### Q: 可以抓取需要登录的页面吗?
|
||||
|
||||
A: 当前版本不支持。如需抓取需要登录的页面,请使用 `browser` 工具。
|
||||
|
||||
## 参考
|
||||
|
||||
- [Mozilla Readability](https://github.com/mozilla/readability)
|
||||
- [Clawdbot Web Tools](https://github.com/moltbot/moltbot)
|
||||
@@ -1,3 +0,0 @@
|
||||
from .web_fetch import WebFetch
|
||||
|
||||
__all__ = ['WebFetch']
|
||||
@@ -1,47 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# WebFetch 工具依赖安装脚本
|
||||
|
||||
echo "=================================="
|
||||
echo "WebFetch 工具依赖安装"
|
||||
echo "=================================="
|
||||
echo ""
|
||||
|
||||
# 检查 Python 版本
|
||||
python_version=$(python3 --version 2>&1 | awk '{print $2}')
|
||||
echo "✓ Python 版本: $python_version"
|
||||
echo ""
|
||||
|
||||
# 安装基础依赖
|
||||
echo "📦 安装基础依赖..."
|
||||
python3 -m pip install requests
|
||||
|
||||
# 检查是否成功
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✅ requests 安装成功"
|
||||
else
|
||||
echo "❌ requests 安装失败"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
|
||||
# 安装推荐依赖
|
||||
echo "📦 安装推荐依赖(提升内容提取质量)..."
|
||||
python3 -m pip install readability-lxml html2text
|
||||
|
||||
# 检查是否成功
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✅ readability-lxml 和 html2text 安装成功"
|
||||
else
|
||||
echo "⚠️ 推荐依赖安装失败,但不影响基础功能"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=================================="
|
||||
echo "安装完成!"
|
||||
echo "=================================="
|
||||
echo ""
|
||||
echo "运行测试:"
|
||||
echo " python3 agent/tools/web_fetch/test_web_fetch.py"
|
||||
echo ""
|
||||
@@ -1,365 +0,0 @@
|
||||
"""
|
||||
Web Fetch tool - Fetch and extract readable content from URLs
|
||||
Supports HTML to Markdown/Text conversion using Mozilla's Readability
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, Any, Optional
|
||||
from urllib.parse import urlparse
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class WebFetch(BaseTool):
|
||||
"""Tool for fetching and extracting readable content from web pages"""
|
||||
|
||||
name: str = "web_fetch"
|
||||
description: str = "Fetch and extract readable content from a URL (HTML → markdown/text). Use for lightweight page access without browser automation. Returns title, content, and metadata."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "HTTP or HTTPS URL to fetch"
|
||||
},
|
||||
"extract_mode": {
|
||||
"type": "string",
|
||||
"description": "Extraction mode: 'markdown' (default) or 'text'",
|
||||
"enum": ["markdown", "text"],
|
||||
"default": "markdown"
|
||||
},
|
||||
"max_chars": {
|
||||
"type": "integer",
|
||||
"description": "Maximum characters to return (default: 50000)",
|
||||
"minimum": 100,
|
||||
"default": 50000
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.timeout = self.config.get("timeout", 20)
|
||||
self.max_redirects = self.config.get("max_redirects", 3)
|
||||
self.user_agent = self.config.get(
|
||||
"user_agent",
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36"
|
||||
)
|
||||
|
||||
# Setup session with retry strategy
|
||||
self.session = self._create_session()
|
||||
|
||||
# Check if readability-lxml is available
|
||||
self.readability_available = self._check_readability()
|
||||
|
||||
def _create_session(self) -> requests.Session:
|
||||
"""Create a requests session with retry strategy"""
|
||||
session = requests.Session()
|
||||
|
||||
# Retry strategy - handles failed requests, not redirects
|
||||
retry_strategy = Retry(
|
||||
total=3,
|
||||
backoff_factor=1,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
allowed_methods=["GET", "HEAD"]
|
||||
)
|
||||
|
||||
# HTTPAdapter handles retries; requests handles redirects via allow_redirects
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
session.mount("http://", adapter)
|
||||
session.mount("https://", adapter)
|
||||
|
||||
# Set max redirects on session
|
||||
session.max_redirects = self.max_redirects
|
||||
|
||||
return session
|
||||
|
||||
def _check_readability(self) -> bool:
|
||||
"""Check if readability-lxml is available"""
|
||||
try:
|
||||
from readability import Document
|
||||
return True
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"readability-lxml not installed. Install with: pip install readability-lxml\n"
|
||||
"Falling back to basic HTML extraction."
|
||||
)
|
||||
return False
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute web fetch operation
|
||||
|
||||
:param args: Contains url, extract_mode, and max_chars parameters
|
||||
:return: Extracted content or error message
|
||||
"""
|
||||
url = args.get("url", "").strip()
|
||||
extract_mode = args.get("extract_mode", "markdown").lower()
|
||||
max_chars = args.get("max_chars", 50000)
|
||||
|
||||
if not url:
|
||||
return ToolResult.fail("Error: url parameter is required")
|
||||
|
||||
# Validate URL
|
||||
if not self._is_valid_url(url):
|
||||
return ToolResult.fail(f"Error: Invalid URL (must be http or https): {url}")
|
||||
|
||||
# Validate extract_mode
|
||||
if extract_mode not in ["markdown", "text"]:
|
||||
extract_mode = "markdown"
|
||||
|
||||
# Validate max_chars
|
||||
if not isinstance(max_chars, int) or max_chars < 100:
|
||||
max_chars = 50000
|
||||
|
||||
try:
|
||||
# Fetch the URL
|
||||
response = self._fetch_url(url)
|
||||
|
||||
# Extract content
|
||||
result = self._extract_content(
|
||||
html=response.text,
|
||||
url=response.url,
|
||||
status_code=response.status_code,
|
||||
content_type=response.headers.get("content-type", ""),
|
||||
extract_mode=extract_mode,
|
||||
max_chars=max_chars
|
||||
)
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
return ToolResult.fail(f"Error: Request timeout after {self.timeout} seconds")
|
||||
except requests.exceptions.TooManyRedirects:
|
||||
return ToolResult.fail(f"Error: Too many redirects (limit: {self.max_redirects})")
|
||||
except requests.exceptions.RequestException as e:
|
||||
return ToolResult.fail(f"Error fetching URL: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Web fetch error: {e}", exc_info=True)
|
||||
return ToolResult.fail(f"Error: {str(e)}")
|
||||
|
||||
def _is_valid_url(self, url: str) -> bool:
|
||||
"""Validate URL format"""
|
||||
try:
|
||||
result = urlparse(url)
|
||||
return result.scheme in ["http", "https"] and bool(result.netloc)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _fetch_url(self, url: str) -> requests.Response:
|
||||
"""
|
||||
Fetch URL with proper headers and error handling
|
||||
|
||||
:param url: URL to fetch
|
||||
:return: Response object
|
||||
"""
|
||||
headers = {
|
||||
"User-Agent": self.user_agent,
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
|
||||
"Accept-Language": "en-US,en;q=0.9,zh-CN,zh;q=0.8",
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
|
||||
# Note: requests library handles redirects automatically
|
||||
# The max_redirects is set in the session's adapter (HTTPAdapter)
|
||||
response = self.session.get(
|
||||
url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
allow_redirects=True
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
def _extract_content(
|
||||
self,
|
||||
html: str,
|
||||
url: str,
|
||||
status_code: int,
|
||||
content_type: str,
|
||||
extract_mode: str,
|
||||
max_chars: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract readable content from HTML
|
||||
|
||||
:param html: HTML content
|
||||
:param url: Original URL
|
||||
:param status_code: HTTP status code
|
||||
:param content_type: Content type header
|
||||
:param extract_mode: 'markdown' or 'text'
|
||||
:param max_chars: Maximum characters to return
|
||||
:return: Extracted content and metadata
|
||||
"""
|
||||
# Check content type
|
||||
if "text/html" not in content_type.lower():
|
||||
# Non-HTML content
|
||||
text = html[:max_chars]
|
||||
truncated = len(html) > max_chars
|
||||
|
||||
return {
|
||||
"url": url,
|
||||
"status": status_code,
|
||||
"content_type": content_type,
|
||||
"extractor": "raw",
|
||||
"text": text,
|
||||
"length": len(text),
|
||||
"truncated": truncated,
|
||||
"message": f"Non-HTML content (type: {content_type})"
|
||||
}
|
||||
|
||||
# Extract readable content from HTML
|
||||
if self.readability_available:
|
||||
return self._extract_with_readability(
|
||||
html, url, status_code, content_type, extract_mode, max_chars
|
||||
)
|
||||
else:
|
||||
return self._extract_basic(
|
||||
html, url, status_code, content_type, extract_mode, max_chars
|
||||
)
|
||||
|
||||
def _extract_with_readability(
|
||||
self,
|
||||
html: str,
|
||||
url: str,
|
||||
status_code: int,
|
||||
content_type: str,
|
||||
extract_mode: str,
|
||||
max_chars: int
|
||||
) -> Dict[str, Any]:
|
||||
"""Extract content using Mozilla's Readability"""
|
||||
try:
|
||||
from readability import Document
|
||||
|
||||
# Parse with Readability
|
||||
doc = Document(html)
|
||||
title = doc.title()
|
||||
content_html = doc.summary()
|
||||
|
||||
# Convert to markdown or text
|
||||
if extract_mode == "markdown":
|
||||
text = self._html_to_markdown(content_html)
|
||||
else:
|
||||
text = self._html_to_text(content_html)
|
||||
|
||||
# Truncate if needed
|
||||
truncated = len(text) > max_chars
|
||||
if truncated:
|
||||
text = text[:max_chars]
|
||||
|
||||
return {
|
||||
"url": url,
|
||||
"status": status_code,
|
||||
"content_type": content_type,
|
||||
"title": title,
|
||||
"extractor": "readability",
|
||||
"extract_mode": extract_mode,
|
||||
"text": text,
|
||||
"length": len(text),
|
||||
"truncated": truncated
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Readability extraction failed: {e}")
|
||||
# Fallback to basic extraction
|
||||
return self._extract_basic(
|
||||
html, url, status_code, content_type, extract_mode, max_chars
|
||||
)
|
||||
|
||||
def _extract_basic(
|
||||
self,
|
||||
html: str,
|
||||
url: str,
|
||||
status_code: int,
|
||||
content_type: str,
|
||||
extract_mode: str,
|
||||
max_chars: int
|
||||
) -> Dict[str, Any]:
|
||||
"""Basic HTML extraction without Readability"""
|
||||
# Extract title
|
||||
title_match = re.search(r'<title[^>]*>(.*?)</title>', html, re.IGNORECASE | re.DOTALL)
|
||||
title = title_match.group(1).strip() if title_match else "Untitled"
|
||||
|
||||
# Remove script and style tags
|
||||
text = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
text = re.sub(r'<style[^>]*>.*?</style>', '', text, flags=re.DOTALL | re.IGNORECASE)
|
||||
|
||||
# Remove HTML tags
|
||||
text = re.sub(r'<[^>]+>', ' ', text)
|
||||
|
||||
# Clean up whitespace
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
|
||||
# Truncate if needed
|
||||
truncated = len(text) > max_chars
|
||||
if truncated:
|
||||
text = text[:max_chars]
|
||||
|
||||
return {
|
||||
"url": url,
|
||||
"status": status_code,
|
||||
"content_type": content_type,
|
||||
"title": title,
|
||||
"extractor": "basic",
|
||||
"extract_mode": extract_mode,
|
||||
"text": text,
|
||||
"length": len(text),
|
||||
"truncated": truncated,
|
||||
"warning": "Using basic extraction. Install readability-lxml for better results."
|
||||
}
|
||||
|
||||
def _html_to_markdown(self, html: str) -> str:
|
||||
"""Convert HTML to Markdown (basic implementation)"""
|
||||
try:
|
||||
# Try to use html2text if available
|
||||
import html2text
|
||||
h = html2text.HTML2Text()
|
||||
h.ignore_links = False
|
||||
h.ignore_images = False
|
||||
h.body_width = 0 # Don't wrap lines
|
||||
return h.handle(html)
|
||||
except ImportError:
|
||||
# Fallback to basic conversion
|
||||
return self._html_to_text(html)
|
||||
|
||||
def _html_to_text(self, html: str) -> str:
|
||||
"""Convert HTML to plain text"""
|
||||
# Remove script and style tags
|
||||
text = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
text = re.sub(r'<style[^>]*>.*?</style>', '', text, flags=re.DOTALL | re.IGNORECASE)
|
||||
|
||||
# Convert common tags to text equivalents
|
||||
text = re.sub(r'<br\s*/?>', '\n', text, flags=re.IGNORECASE)
|
||||
text = re.sub(r'<p[^>]*>', '\n\n', text, flags=re.IGNORECASE)
|
||||
text = re.sub(r'</p>', '', text, flags=re.IGNORECASE)
|
||||
text = re.sub(r'<h[1-6][^>]*>', '\n\n', text, flags=re.IGNORECASE)
|
||||
text = re.sub(r'</h[1-6]>', '\n', text, flags=re.IGNORECASE)
|
||||
|
||||
# Remove all other HTML tags
|
||||
text = re.sub(r'<[^>]+>', '', text)
|
||||
|
||||
# Decode HTML entities
|
||||
import html
|
||||
text = html.unescape(text)
|
||||
|
||||
# Clean up whitespace
|
||||
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
|
||||
text = re.sub(r' +', ' ', text)
|
||||
text = text.strip()
|
||||
|
||||
return text
|
||||
|
||||
def close(self):
|
||||
"""Close the session"""
|
||||
if hasattr(self, 'session'):
|
||||
self.session.close()
|
||||
Reference in New Issue
Block a user