fix: long-term memory bug

This commit is contained in:
saboteur7
2026-01-30 11:31:13 +08:00
parent bb850bb6c5
commit 5a466d0ff6
12 changed files with 202 additions and 295 deletions

View File

@@ -50,11 +50,45 @@ class MemoryStorage:
def _init_db(self):
"""Initialize database with schema"""
self.conn = sqlite3.connect(str(self.db_path))
self.conn.row_factory = sqlite3.Row
# Enable JSON support
self.conn.execute("PRAGMA journal_mode=WAL")
try:
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
self.conn.row_factory = sqlite3.Row
# Check database integrity
try:
result = self.conn.execute("PRAGMA integrity_check").fetchone()
if result[0] != 'ok':
print(f"⚠️ Database integrity check failed: {result[0]}")
print(f" Recreating database...")
self.conn.close()
self.conn = None
# Remove corrupted database
self.db_path.unlink(missing_ok=True)
# Remove WAL files
Path(str(self.db_path) + '-wal').unlink(missing_ok=True)
Path(str(self.db_path) + '-shm').unlink(missing_ok=True)
# Reconnect to create new database
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
self.conn.row_factory = sqlite3.Row
except sqlite3.DatabaseError:
# Database is corrupted, recreate it
print(f"⚠️ Database is corrupted, recreating...")
if self.conn:
self.conn.close()
self.conn = None
self.db_path.unlink(missing_ok=True)
Path(str(self.db_path) + '-wal').unlink(missing_ok=True)
Path(str(self.db_path) + '-shm').unlink(missing_ok=True)
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
self.conn.row_factory = sqlite3.Row
# Enable WAL mode for better concurrency
self.conn.execute("PRAGMA journal_mode=WAL")
# Set busy timeout to avoid "database is locked" errors
self.conn.execute("PRAGMA busy_timeout=5000")
except Exception as e:
print(f"⚠️ Unexpected error during database initialization: {e}")
raise
# Create chunks table with embeddings
self.conn.execute("""
@@ -92,6 +126,8 @@ class MemoryStorage:
""")
# Create FTS5 virtual table for keyword search
# Use default unicode61 tokenizer (stable and compatible)
# For CJK support, we'll use LIKE queries as fallback
self.conn.execute("""
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
text,
@@ -261,13 +297,37 @@ class MemoryStorage:
scopes: List[str] = None,
limit: int = 10
) -> List[SearchResult]:
"""Keyword search using FTS5"""
"""
Keyword search using FTS5 + LIKE fallback
Strategy:
1. Try FTS5 search first (good for English and word-based languages)
2. If no results and query contains CJK characters, use LIKE search
"""
if scopes is None:
scopes = ["shared"]
if user_id:
scopes.append("user")
# Build FTS query
# Try FTS5 search first
fts_results = self._search_fts5(query, user_id, scopes, limit)
if fts_results:
return fts_results
# Fallback to LIKE search for CJK characters
if MemoryStorage._contains_cjk(query):
return self._search_like(query, user_id, scopes, limit)
return []
def _search_fts5(
self,
query: str,
user_id: Optional[str],
scopes: List[str],
limit: int
) -> List[SearchResult]:
"""FTS5 full-text search"""
fts_query = self._build_fts_query(query)
if not fts_query:
return []
@@ -299,20 +359,83 @@ class MemoryStorage:
"""
params.append(limit)
rows = self.conn.execute(sql_query, params).fetchall()
try:
rows = self.conn.execute(sql_query, params).fetchall()
return [
SearchResult(
path=row['path'],
start_line=row['start_line'],
end_line=row['end_line'],
score=self._bm25_rank_to_score(row['rank']),
snippet=self._truncate_text(row['text'], 500),
source=row['source'],
user_id=row['user_id']
)
for row in rows
]
except Exception:
return []
def _search_like(
self,
query: str,
user_id: Optional[str],
scopes: List[str],
limit: int
) -> List[SearchResult]:
"""LIKE-based search for CJK characters"""
import re
# Extract CJK words (2+ characters)
cjk_words = re.findall(r'[\u4e00-\u9fff]{2,}', query)
if not cjk_words:
return []
return [
SearchResult(
path=row['path'],
start_line=row['start_line'],
end_line=row['end_line'],
score=self._bm25_rank_to_score(row['rank']),
snippet=self._truncate_text(row['text'], 500),
source=row['source'],
user_id=row['user_id']
)
for row in rows
]
scope_placeholders = ','.join('?' * len(scopes))
# Build LIKE conditions for each word
like_conditions = []
params = []
for word in cjk_words:
like_conditions.append("text LIKE ?")
params.append(f'%{word}%')
where_clause = ' OR '.join(like_conditions)
params.extend(scopes)
if user_id:
sql_query = f"""
SELECT * FROM chunks
WHERE ({where_clause})
AND scope IN ({scope_placeholders})
AND (scope = 'shared' OR user_id = ?)
LIMIT ?
"""
params.extend([user_id, limit])
else:
sql_query = f"""
SELECT * FROM chunks
WHERE ({where_clause})
AND scope IN ({scope_placeholders})
LIMIT ?
"""
params.append(limit)
try:
rows = self.conn.execute(sql_query, params).fetchall()
return [
SearchResult(
path=row['path'],
start_line=row['start_line'],
end_line=row['end_line'],
score=0.5, # Fixed score for LIKE search
snippet=self._truncate_text(row['text'], 500),
source=row['source'],
user_id=row['user_id']
)
for row in rows
]
except Exception:
return []
def delete_by_path(self, path: str):
"""Delete all chunks from a file"""
@@ -354,7 +477,19 @@ class MemoryStorage:
def close(self):
"""Close database connection"""
if self.conn:
self.conn.close()
try:
self.conn.commit() # Ensure all changes are committed
self.conn.close()
self.conn = None # Mark as closed
except Exception as e:
print(f"⚠️ Error closing database connection: {e}")
def __del__(self):
"""Destructor to ensure connection is closed"""
try:
self.close()
except:
pass # Ignore errors during cleanup
# Helper methods
@@ -390,14 +525,29 @@ class MemoryStorage:
return dot_product / (norm1 * norm2)
@staticmethod
def _build_fts_query(raw_query: str) -> Optional[str]:
"""Build FTS5 query from raw text"""
def _contains_cjk(text: str) -> bool:
"""Check if text contains CJK (Chinese/Japanese/Korean) characters"""
import re
tokens = re.findall(r'[A-Za-z0-9_\u4e00-\u9fff]+', raw_query)
return bool(re.search(r'[\u4e00-\u9fff]', text))
@staticmethod
def _build_fts_query(raw_query: str) -> Optional[str]:
"""
Build FTS5 query from raw text
Works best for English and word-based languages.
For CJK characters, LIKE search will be used as fallback.
"""
import re
# Extract words (primarily English words and numbers)
tokens = re.findall(r'[A-Za-z0-9_]+', raw_query)
if not tokens:
return None
# Quote tokens for exact matching
quoted = [f'"{t}"' for t in tokens]
return ' AND '.join(quoted)
# Use OR for more flexible matching
return ' OR '.join(quoted)
@staticmethod
def _bm25_rank_to_score(rank: float) -> float: