feat: 实现记忆证据服务与检索引擎

新增 EvidenceService,基于消息引用处理记忆证据扩展
新增 MemoryKeywordSearchOptions、MemoryKeywordSearchHit 类型,支持关键词检索功能
增强 MemoryDatabase,为记忆项添加 FTS 全文索引并实现关键词搜索方法
将记忆库结构版本升级至 2,适配新功能
创建 RetrievalEngine 检索引擎,统一管理关键词、近似最近邻(ANN)搜索
实现倒数排序融合(RRF)算法,合并多来源搜索结果
在 ChatPage 中新增记忆构建流程的 UI 组件与状态管理
扩展 Electron API,支持记忆构建状态获取与预处理
This commit is contained in:
ILoveBingLu
2026-04-27 22:25:04 +08:00
parent f24014e6d6
commit 4bd3d6b9a0
11 changed files with 1019 additions and 5 deletions
+18
View File
@@ -44,6 +44,18 @@ type SessionVectorIndexProgressEvent = {
vectorModel: string
}
type SessionMemoryBuildProgressEvent = {
sessionId: string
stage: string
status: string
processedCount: number
totalCount: number
message: string
messageCount: number
blockCount: number
factCount: number
}
type EmbeddingModelDownloadProgress = {
profileId: string
displayName: string
@@ -603,6 +615,8 @@ contextBridge.exposeInMainWorld('electronAPI', {
getSessionVectorIndexState: (sessionId: string) => ipcRenderer.invoke('ai:getSessionVectorIndexState', sessionId),
prepareSessionVectorIndex: (options: { sessionId: string }) => ipcRenderer.invoke('ai:prepareSessionVectorIndex', options),
cancelSessionVectorIndex: (sessionId: string) => ipcRenderer.invoke('ai:cancelSessionVectorIndex', sessionId),
getSessionMemoryBuildState: (sessionId: string) => ipcRenderer.invoke('ai:getSessionMemoryBuildState', sessionId),
prepareSessionMemory: (options: { sessionId: string }) => ipcRenderer.invoke('ai:prepareSessionMemory', options),
getEmbeddingModelProfiles: () => ipcRenderer.invoke('ai:getEmbeddingModelProfiles'),
setEmbeddingModelProfile: (profileId: string) => ipcRenderer.invoke('ai:setEmbeddingModelProfile', profileId),
getEmbeddingDeviceStatus: () => ipcRenderer.invoke('ai:getEmbeddingDeviceStatus'),
@@ -637,6 +651,10 @@ contextBridge.exposeInMainWorld('electronAPI', {
ipcRenderer.on('ai:sessionVectorIndexProgress', (_, event) => callback(event))
return () => ipcRenderer.removeAllListeners('ai:sessionVectorIndexProgress')
},
onSessionMemoryBuildProgress: (callback: (event: SessionMemoryBuildProgressEvent) => void) => {
ipcRenderer.on('ai:sessionMemoryBuildProgress', (_, event) => callback(event))
return () => ipcRenderer.removeAllListeners('ai:sessionMemoryBuildProgress')
},
onEmbeddingModelDownloadProgress: (callback: (event: EmbeddingModelDownloadProgress) => void) => {
ipcRenderer.on('ai:embeddingModelDownloadProgress', (_, event) => callback(event))
return () => ipcRenderer.removeAllListeners('ai:embeddingModelDownloadProgress')
@@ -0,0 +1,90 @@
import { chatService } from '../chatService'
import type { Message } from '../chatService'
import type { MemoryEvidenceRef, MemoryItem } from './memorySchema'
import type { RetrievalExpandedEvidence } from '../retrieval/retrievalTypes'
function compareRefAsc(a: MemoryEvidenceRef, b: MemoryEvidenceRef): number {
return Number(a.sortSeq || 0) - Number(b.sortSeq || 0)
|| Number(a.createTime || 0) - Number(b.createTime || 0)
|| Number(a.localId || 0) - Number(b.localId || 0)
}
function messageKey(message: Pick<Message, 'localId' | 'createTime' | 'sortSeq'>): string {
return `${Number(message.localId || 0)}:${Number(message.createTime || 0)}:${Number(message.sortSeq || 0)}`
}
function uniqueMessages(messages: Message[]): Message[] {
const seen = new Set<string>()
const result: Message[] = []
for (const message of messages) {
const key = messageKey(message)
if (seen.has(key)) continue
seen.add(key)
result.push(message)
}
return result
}
async function loadAnchor(ref: MemoryEvidenceRef): Promise<Message | null> {
try {
const result = await chatService.getMessageByLocalId(ref.sessionId, ref.localId)
if (!result.success || !result.message) return null
if (Number(ref.createTime || 0) > 0 && Number(result.message.createTime || 0) !== Number(ref.createTime)) {
return {
...result.message,
createTime: ref.createTime,
sortSeq: ref.sortSeq
}
}
return result.message
} catch {
return null
}
}
async function expandAroundRef(ref: MemoryEvidenceRef, radius: number): Promise<RetrievalExpandedEvidence> {
const [beforeResult, anchor, afterResult] = await Promise.all([
chatService.getMessagesBefore(ref.sessionId, ref.sortSeq, radius, ref.createTime, ref.localId),
loadAnchor(ref),
chatService.getMessagesAfter(ref.sessionId, ref.sortSeq, radius, ref.createTime, ref.localId)
])
return {
ref,
before: uniqueMessages(beforeResult.success ? beforeResult.messages || [] : []),
anchor,
after: uniqueMessages(afterResult.success ? afterResult.messages || [] : [])
}
}
export class EvidenceService {
async expandMemoryEvidence(memory: MemoryItem): Promise<RetrievalExpandedEvidence[]> {
const refs = [...memory.sourceRefs].sort(compareRefAsc)
if (refs.length === 0) return []
if (memory.sourceType === 'conversation_block') {
const first = refs[0]
const last = refs[refs.length - 1]
if (messageKey(first) === messageKey(last)) {
return [await expandAroundRef(first, 3)]
}
const [before, after] = await Promise.all([
chatService.getMessagesBefore(first.sessionId, first.sortSeq, 3, first.createTime, first.localId),
chatService.getMessagesAfter(last.sessionId, last.sortSeq, 3, last.createTime, last.localId)
])
return [{
ref: first,
before: uniqueMessages(before.success ? before.messages || [] : []),
anchor: null,
after: uniqueMessages(after.success ? after.messages || [] : [])
}]
}
const radius = memory.sourceType === 'message' ? 6 : 3
const limitedRefs = refs.slice(0, memory.sourceType === 'fact' ? 3 : 1)
return Promise.all(limitedRefs.map((ref) => expandAroundRef(ref, radius)))
}
}
export const evidenceService = new EvidenceService()
+194 -1
View File
@@ -20,6 +20,22 @@ import {
type MemoryVectorStoreName
} from './memorySchema'
export type MemoryKeywordSearchOptions = {
query: string
sessionId?: string
sourceTypes?: MemorySourceType[]
startTimeMs?: number
endTimeMs?: number
limit?: number
}
export type MemoryKeywordSearchHit = {
item: MemoryItem
rank: number
score: number
retrievalSource: 'memory_fts' | 'memory_like'
}
function nowMs(): number {
return Date.now()
}
@@ -101,6 +117,67 @@ function parseEvidenceRefsJson(value: string): MemoryEvidenceRef[] {
}
}
function toTimestampSeconds(value?: number): number | null {
if (!Number.isFinite(Number(value)) || Number(value) <= 0) return null
const numberValue = Number(value)
return numberValue > 10_000_000_000 ? Math.floor(numberValue / 1000) : Math.floor(numberValue)
}
function escapeFtsPhrase(value: string): string {
return `"${String(value || '').replace(/"/g, '""')}"`
}
function buildMemoryFtsQuery(query: string): string {
const normalized = String(query || '')
.replace(/[\u200b-\u200f\ufeff]/g, '')
.replace(/[,。!?;:、“”‘’()()[\]{}<>《》|\\/+=*_~`#$%^&-]+/g, ' ')
.replace(/\s+/g, ' ')
.trim()
if (!normalized) return ''
const terms = normalized
.split(/\s+/)
.map((term) => term.trim())
.filter(Boolean)
return terms.length > 1
? terms.map(escapeFtsPhrase).join(' AND ')
: escapeFtsPhrase(normalized)
}
function buildMemoryFilterSql(
options: Pick<MemoryKeywordSearchOptions, 'sessionId' | 'sourceTypes' | 'startTimeMs' | 'endTimeMs'>,
params: Record<string, unknown>
): string {
const clauses: string[] = []
if (options.sessionId) {
clauses.push('m.session_id = @sessionId')
params.sessionId = options.sessionId
}
const sourceTypes = Array.from(new Set((options.sourceTypes || []).filter((type) => MEMORY_SOURCE_TYPES.includes(type))))
if (sourceTypes.length > 0) {
const placeholders = sourceTypes.map((_, index) => `@sourceType${index}`)
sourceTypes.forEach((sourceType, index) => {
params[`sourceType${index}`] = sourceType
})
clauses.push(`m.source_type IN (${placeholders.join(', ')})`)
}
const startTime = toTimestampSeconds(options.startTimeMs)
if (startTime) {
clauses.push('COALESCE(m.time_end, m.time_start, 0) >= @startTime')
params.startTime = startTime
}
const endTime = toTimestampSeconds(options.endTimeMs)
if (endTime) {
clauses.push('COALESCE(m.time_start, m.time_end, 0) <= @endTime')
params.endTime = endTime
}
return clauses.length ? `AND ${clauses.join(' AND ')}` : ''
}
function safeSourceType(value: string): MemorySourceType {
return MEMORY_SOURCE_TYPES.includes(value as MemorySourceType)
? value as MemorySourceType
@@ -265,12 +342,55 @@ export class MemoryDatabase {
ON memory_embeddings(vector_store, vector_ref);
`)
db.exec(`
CREATE VIRTUAL TABLE IF NOT EXISTS memory_items_fts USING fts5(
title,
content,
entities,
tags,
tokenize = 'unicode61 remove_diacritics 2'
);
`)
this.syncMemoryFtsIndex(db)
db.prepare(`
INSERT OR REPLACE INTO memory_meta(key, value, updated_at)
VALUES ('schema_version', ?, ?)
`).run(MEMORY_SCHEMA_VERSION, nowMs())
}
private syncMemoryFtsIndex(db: Database.Database): void {
const row = db.prepare(`
SELECT COUNT(*) AS count
FROM memory_items m
LEFT JOIN memory_items_fts f ON f.rowid = m.id
WHERE f.rowid IS NULL
`).get() as { count: number } | undefined
if (!Number(row?.count || 0)) return
db.prepare(`
INSERT INTO memory_items_fts(rowid, title, content, entities, tags)
SELECT id, title, content, entities_json, tags_json
FROM memory_items
WHERE id NOT IN (SELECT rowid FROM memory_items_fts)
`).run()
}
private upsertMemoryFtsRow(item: MemoryItem): void {
const db = this.getDb()
db.prepare('DELETE FROM memory_items_fts WHERE rowid = ?').run(item.id)
db.prepare(`
INSERT INTO memory_items_fts(rowid, title, content, entities, tags)
VALUES (@id, @title, @content, @entities, @tags)
`).run({
id: item.id,
title: item.title,
content: item.content,
entities: safeJsonStringify(item.entities || [], []),
tags: safeJsonStringify(item.tags || [], [])
})
}
upsertMemoryItem(input: MemoryItemInput): MemoryItem {
const db = this.getDb()
const timestamp = nowMs()
@@ -338,6 +458,7 @@ export class MemoryDatabase {
const item = this.getMemoryItemByUid(memoryUid)
if (!item) throw new Error('Failed to load upserted memory item')
this.upsertMemoryFtsRow(item)
return item
}
@@ -408,8 +529,80 @@ export class MemoryDatabase {
return Number(row?.count || 0)
}
searchMemoryItemsByKeyword(options: MemoryKeywordSearchOptions): MemoryKeywordSearchHit[] {
const query = String(options.query || '').trim()
if (!query) return []
const db = this.getDb()
const limit = Math.max(1, Math.min(Math.floor(options.limit || 80), 500))
const rowsById = new Map<number, MemoryKeywordSearchHit>()
const params: Record<string, unknown> = { limit }
const filterSql = buildMemoryFilterSql(options, params)
const ftsQuery = buildMemoryFtsQuery(query)
if (ftsQuery) {
const ftsRows = db.prepare(`
SELECT m.*, bm25(memory_items_fts) AS fts_rank
FROM memory_items_fts
JOIN memory_items m ON m.id = memory_items_fts.rowid
WHERE memory_items_fts MATCH @ftsQuery
${filterSql}
ORDER BY fts_rank ASC, COALESCE(m.time_end, m.time_start, m.updated_at) DESC, m.id DESC
LIMIT @limit
`).all({
...params,
ftsQuery
}) as Array<MemoryItemRow & { fts_rank: number }>
ftsRows.forEach((row, index) => {
rowsById.set(Number(row.id), {
item: toMemoryItem(row),
rank: index + 1,
score: Number((1000 + Math.max(0, 100 - Number(row.fts_rank || 0))).toFixed(4)),
retrievalSource: 'memory_fts'
})
})
}
const likeParams: Record<string, unknown> = { ...params, likeQuery: `%${query}%` }
const likeFilterSql = buildMemoryFilterSql(options, likeParams)
const likeRows = db.prepare(`
SELECT m.*
FROM memory_items m
WHERE (
m.title LIKE @likeQuery
OR m.content LIKE @likeQuery
OR m.entities_json LIKE @likeQuery
OR m.tags_json LIKE @likeQuery
)
${likeFilterSql}
ORDER BY COALESCE(m.time_end, m.time_start, m.updated_at) DESC, m.id DESC
LIMIT @limit
`).all(likeParams) as MemoryItemRow[]
let likeRank = 1
for (const row of likeRows) {
const id = Number(row.id)
if (rowsById.has(id)) continue
rowsById.set(id, {
item: toMemoryItem(row),
rank: likeRank,
score: 500,
retrievalSource: 'memory_like'
})
likeRank += 1
}
return Array.from(rowsById.values())
.sort((a, b) => b.score - a.score || b.item.importance - a.item.importance || b.item.updatedAt - a.item.updatedAt)
.slice(0, limit)
.map((hit, index) => ({ ...hit, rank: index + 1 }))
}
deleteMemoryItem(id: number): boolean {
const result = this.getDb().prepare('DELETE FROM memory_items WHERE id = ?').run(id)
const db = this.getDb()
db.prepare('DELETE FROM memory_items_fts WHERE rowid = ?').run(id)
const result = db.prepare('DELETE FROM memory_items WHERE id = ?').run(id)
return result.changes > 0
}
+1 -1
View File
@@ -1,5 +1,5 @@
export const MEMORY_DB_NAME = 'agent_memory.db'
export const MEMORY_SCHEMA_VERSION = '1'
export const MEMORY_SCHEMA_VERSION = '2'
export const MEMORY_SOURCE_TYPES = [
'message',
@@ -0,0 +1,380 @@
import { memoryDatabase } from '../memory/memoryDatabase'
import type { MemoryKeywordSearchHit } from '../memory/memoryDatabase'
import { evidenceService } from '../memory/evidenceService'
import type { MemoryItem } from '../memory/memorySchema'
import { chatSearchIndexService } from '../search/chatSearchIndexService'
import { localRerankerService, type RerankDocument } from './rerankerService'
import { reciprocalRankFusion } from './rrf'
import type {
RetrievalCandidate,
RetrievalEngineOptions,
RetrievalEngineResult,
RetrievalHit,
RetrievalRerankStats,
RetrievalSourceName,
RetrievalSourceStats
} from './retrievalTypes'
type SourceHit = {
source: RetrievalSourceName
memory: MemoryItem
rank: number
score: number
}
const DEFAULT_LIMIT = 20
const DEFAULT_KEYWORD_LIMIT = 80
const DEFAULT_ANN_LIMIT = 80
const DEFAULT_RERANK_LIMIT = 120
const DEFAULT_RRF_K = 60
function compactText(value: string, limit: number): string {
const normalized = String(value || '').replace(/\s+/g, ' ').trim()
if (!normalized) return ''
return normalized.length > limit ? `${normalized.slice(0, limit - 1)}...` : normalized
}
function uniqueQueries(values: string[]): string[] {
const seen = new Set<string>()
const result: string[] = []
for (const value of values) {
const query = String(value || '').replace(/\s+/g, ' ').trim()
const key = query.toLowerCase()
if (!query || seen.has(key)) continue
seen.add(key)
result.push(query)
}
return result
}
function memoryKey(memory: MemoryItem): string {
return String(memory.id)
}
function messageMemoryUid(sessionId: string, message: { localId: number; createTime: number; sortSeq: number }): string {
return `message:${sessionId}:${Number(message.localId || 0)}:${Number(message.createTime || 0)}:${Number(message.sortSeq || 0)}`
}
function buildRerankDocument(candidate: RetrievalCandidate): RerankDocument {
const memory = candidate.memory
const sourceRefs = memory.sourceRefs
.slice(0, 3)
.map((ref) => `${ref.senderUsername || 'unknown'} ${ref.createTime}: ${ref.excerpt || ''}`)
.filter(Boolean)
.join('\n')
const text = [
`type: ${memory.sourceType}`,
memory.title ? `title: ${memory.title}` : '',
memory.timeStart || memory.timeEnd ? `time: ${memory.timeStart || ''}-${memory.timeEnd || ''}` : '',
`content: ${memory.content}`,
sourceRefs ? `evidence:\n${sourceRefs}` : ''
].filter(Boolean).join('\n')
return {
id: candidate.key,
text: compactText(text, 4000),
originalScore: candidate.rrfScore,
metadata: {
memoryId: memory.id,
sourceType: memory.sourceType,
sources: candidate.sources
}
}
}
function toCandidate(hit: SourceHit): RetrievalCandidate {
return {
key: memoryKey(hit.memory),
memory: hit.memory,
sources: [hit.source],
sourceRanks: { [hit.source]: hit.rank },
sourceScores: { [hit.source]: hit.score },
rrfScore: 0
}
}
function mergeSourceDetails(candidate: RetrievalCandidate, hits: SourceHit[]): RetrievalCandidate {
const sources: RetrievalSourceName[] = []
const sourceRanks: RetrievalCandidate['sourceRanks'] = {}
const sourceScores: RetrievalCandidate['sourceScores'] = {}
for (const hit of hits) {
if (!sources.includes(hit.source)) sources.push(hit.source)
sourceRanks[hit.source] = Math.min(sourceRanks[hit.source] || Number.MAX_SAFE_INTEGER, hit.rank)
sourceScores[hit.source] = Math.max(sourceScores[hit.source] || 0, hit.score)
}
return {
...candidate,
sources,
sourceRanks,
sourceScores
}
}
function normalizeLimit(value: unknown, fallback: number, max: number): number {
const numberValue = Math.floor(Number(value || fallback))
return Math.max(1, Math.min(Number.isFinite(numberValue) ? numberValue : fallback, max))
}
export class RetrievalEngine {
async search(options: RetrievalEngineOptions): Promise<RetrievalEngineResult> {
const startedAt = Date.now()
const query = String(options.query || '').trim()
if (!query) {
return {
query,
semanticQuery: '',
hits: [],
sourceStats: [],
rerank: { attempted: false, applied: false, skippedReason: 'empty_query' },
latencyMs: Date.now() - startedAt
}
}
const limit = normalizeLimit(options.limit, DEFAULT_LIMIT, 100)
const keywordLimit = normalizeLimit(options.keywordLimit, DEFAULT_KEYWORD_LIMIT, 500)
const annLimit = normalizeLimit(options.annLimit, DEFAULT_ANN_LIMIT, 500)
const rerankLimit = normalizeLimit(options.rerankLimit, DEFAULT_RERANK_LIMIT, 500)
const keywordQueries = uniqueQueries([query, ...(options.keywordQueries || [])])
const semanticQueries = uniqueQueries([
options.semanticQuery || query,
...(options.semanticQueries || [])
])
const semanticQuery = semanticQueries[0] || query
const sourceStats: RetrievalSourceStats[] = []
const keywordHits = this.collectKeywordHits(options, keywordQueries, keywordLimit, sourceStats)
const annHits = await this.collectAnnHits(options, semanticQueries, annLimit, sourceStats)
const candidates = this.fuseCandidates([...keywordHits, ...annHits], options.rrfK)
const rerankStats: RetrievalRerankStats = { attempted: false, applied: false }
const ranked = await this.applyRerank(query, semanticQuery, candidates, rerankLimit, rerankStats, options.rerank !== false)
const selected = ranked.slice(0, limit)
const hits = await this.expandHits(selected, options.expandEvidence !== false)
return {
query,
semanticQuery,
hits,
sourceStats,
rerank: rerankStats,
latencyMs: Date.now() - startedAt
}
}
private collectKeywordHits(
options: RetrievalEngineOptions,
queries: string[],
limit: number,
sourceStats: RetrievalSourceStats[]
): SourceHit[] {
const hits: SourceHit[] = []
let error: string | undefined
for (const query of queries) {
try {
const rows = memoryDatabase.searchMemoryItemsByKeyword({
query,
sessionId: options.sessionId,
sourceTypes: options.sourceTypes,
startTimeMs: options.startTimeMs,
endTimeMs: options.endTimeMs,
limit
})
hits.push(...rows.map((row) => this.keywordRowToSourceHit(row)))
} catch (searchError) {
error = String(searchError)
}
}
const ftsCount = hits.filter((hit) => hit.source === 'memory_fts').length
const likeCount = hits.filter((hit) => hit.source === 'memory_like').length
sourceStats.push({ name: 'memory_fts', attempted: true, hitCount: ftsCount, ...(error ? { error } : {}) })
sourceStats.push({ name: 'memory_like', attempted: true, hitCount: likeCount, ...(error ? { error } : {}) })
return this.dedupeSourceHits(hits)
}
private keywordRowToSourceHit(row: MemoryKeywordSearchHit): SourceHit {
return {
source: row.retrievalSource,
memory: row.item,
rank: row.rank,
score: row.score
}
}
private async collectAnnHits(
options: RetrievalEngineOptions,
queries: string[],
limit: number,
sourceStats: RetrievalSourceStats[]
): Promise<SourceHit[]> {
if (!options.sessionId) {
sourceStats.push({ name: 'message_ann', attempted: false, hitCount: 0, skippedReason: 'session_required' })
return []
}
const vectorState = chatSearchIndexService.getSessionVectorIndexState(options.sessionId)
if (!vectorState.vectorProviderAvailable) {
sourceStats.push({ name: 'message_ann', attempted: false, hitCount: 0, skippedReason: 'vector_provider_unavailable' })
return []
}
if (!vectorState.isVectorComplete) {
sourceStats.push({ name: 'message_ann', attempted: false, hitCount: 0, skippedReason: 'vector_index_incomplete' })
return []
}
const hits: SourceHit[] = []
let error: string | undefined
for (const query of queries) {
try {
const result = await chatSearchIndexService.searchSessionByVector({
sessionId: options.sessionId,
query,
limit,
startTimeMs: options.startTimeMs,
endTimeMs: options.endTimeMs,
direction: options.direction,
senderUsername: options.senderUsername
})
result.hits.forEach((hit, index) => {
const uid = messageMemoryUid(hit.sessionId, hit.message)
const memory = memoryDatabase.getMemoryItemByUid(uid)
if (!memory) return
hits.push({
source: 'message_ann',
memory,
rank: index + 1,
score: hit.score
})
})
} catch (searchError) {
error = String(searchError)
}
}
sourceStats.push({
name: 'message_ann',
attempted: true,
hitCount: hits.length,
...(error ? { error } : {})
})
return this.dedupeSourceHits(hits)
}
private dedupeSourceHits(hits: SourceHit[]): SourceHit[] {
const byKey = new Map<string, SourceHit>()
for (const hit of hits) {
const key = `${hit.source}:${memoryKey(hit.memory)}`
const existing = byKey.get(key)
if (!existing || hit.rank < existing.rank || hit.score > existing.score) {
byKey.set(key, hit)
}
}
return Array.from(byKey.values()).sort((a, b) => a.rank - b.rank || b.score - a.score)
}
private fuseCandidates(sourceHits: SourceHit[], rrfK?: number): RetrievalCandidate[] {
const hitsByMemory = new Map<string, SourceHit[]>()
const listsBySource = new Map<RetrievalSourceName, SourceHit[]>()
for (const hit of sourceHits) {
const key = memoryKey(hit.memory)
const grouped = hitsByMemory.get(key) || []
grouped.push(hit)
hitsByMemory.set(key, grouped)
const list = listsBySource.get(hit.source) || []
list.push(hit)
listsBySource.set(hit.source, list)
}
const fused = reciprocalRankFusion(
Array.from(listsBySource.values()).map((list) => list
.sort((a, b) => a.rank - b.rank || b.score - a.score)
.map((hit, index) => ({ item: hit, rank: hit.rank || index + 1, score: hit.score }))),
(hit) => memoryKey(hit.memory),
rrfK || DEFAULT_RRF_K
)
return fused.map((item) => {
const candidate = toCandidate(item.item)
candidate.rrfScore = Number(item.rrfScore.toFixed(8))
return mergeSourceDetails(candidate, hitsByMemory.get(item.key) || [item.item])
})
}
private async applyRerank(
query: string,
semanticQuery: string,
candidates: RetrievalCandidate[],
limit: number,
stats: RetrievalRerankStats,
enabled: boolean
): Promise<RetrievalCandidate[]> {
if (!enabled) {
stats.skippedReason = 'disabled'
return candidates
}
if (candidates.length === 0) {
stats.skippedReason = 'no_candidates'
return candidates
}
if (!localRerankerService.isEnabled()) {
stats.skippedReason = 'config_disabled'
return candidates
}
stats.attempted = true
const rerankInput = candidates.slice(0, limit)
try {
const reranked = await localRerankerService.rerank(
[query, semanticQuery].filter(Boolean).join('\n'),
rerankInput.map(buildRerankDocument),
{ limit }
)
const byKey = new Map(candidates.map((candidate) => [candidate.key, candidate]))
const rerankedKeys = new Set<string>()
const rankedCandidates = reranked
.map((result) => {
const candidate = byKey.get(result.id)
if (!candidate) return null
rerankedKeys.add(result.id)
return {
...candidate,
rrfScore: candidate.rrfScore,
rerankScore: result.rerankScore,
finalScore: result.combinedScore
}
})
.filter((item): item is RetrievalCandidate & { rerankScore: number; finalScore: number } => Boolean(item))
.sort((a, b) => b.finalScore - a.finalScore)
stats.applied = rankedCandidates.length > 0
const rest = candidates.filter((candidate) => !rerankedKeys.has(candidate.key))
return [...rankedCandidates, ...rest]
} catch (error) {
stats.error = String(error)
stats.skippedReason = 'rerank_failed'
return candidates
}
}
private async expandHits(candidates: RetrievalCandidate[], expandEvidence: boolean): Promise<RetrievalHit[]> {
const hits: RetrievalHit[] = []
for (let index = 0; index < candidates.length; index += 1) {
const candidate = candidates[index] as RetrievalCandidate & { rerankScore?: number; finalScore?: number }
const evidence = expandEvidence ? await evidenceService.expandMemoryEvidence(candidate.memory) : []
hits.push({
...candidate,
rank: index + 1,
score: Number((candidate.finalScore ?? candidate.rerankScore ?? candidate.rrfScore).toFixed(8)),
...(candidate.rerankScore != null ? { rerankScore: candidate.rerankScore } : {}),
evidence
})
}
return hits
}
}
export const retrievalEngine = new RetrievalEngine()
@@ -1,3 +1,75 @@
import type { Message } from '../chatService'
import type { MemoryEvidenceRef, MemoryItem, MemorySourceType } from '../memory/memorySchema'
export type RetrievalSourceName = 'memory_fts' | 'memory_like' | 'message_ann'
export type RetrievalEngineOptions = {
query: string
semanticQuery?: string
keywordQueries?: string[]
semanticQueries?: string[]
sessionId?: string
sourceTypes?: MemorySourceType[]
startTimeMs?: number
endTimeMs?: number
direction?: 'in' | 'out'
senderUsername?: string
limit?: number
keywordLimit?: number
annLimit?: number
rrfK?: number
rerank?: boolean
rerankLimit?: number
expandEvidence?: boolean
}
export type RetrievalCandidate = {
key: string
memory: MemoryItem
sources: RetrievalSourceName[]
sourceRanks: Partial<Record<RetrievalSourceName, number>>
sourceScores: Partial<Record<RetrievalSourceName, number>>
rrfScore: number
}
export type RetrievalExpandedEvidence = {
ref: MemoryEvidenceRef
before: Message[]
anchor: Message | null
after: Message[]
}
export type RetrievalHit = RetrievalCandidate & {
rank: number
score: number
rerankScore?: number
evidence: RetrievalExpandedEvidence[]
}
export type RetrievalSourceStats = {
name: RetrievalSourceName
attempted: boolean
hitCount: number
skippedReason?: string
error?: string
}
export type RetrievalRerankStats = {
attempted: boolean
applied: boolean
skippedReason?: string
error?: string
}
export type RetrievalEngineResult = {
query: string
semanticQuery: string
hits: RetrievalHit[]
sourceStats: RetrievalSourceStats[]
rerank: RetrievalRerankStats
latencyMs: number
}
export type RetrievalEvalEvidenceRef = {
localId: number
createTime: number
+49
View File
@@ -0,0 +1,49 @@
export type RrfRankedItem<T> = {
item: T
rank: number
score?: number
}
export type RrfMergedItem<T> = {
key: string
item: T
rrfScore: number
ranks: number[]
scores: number[]
}
export function reciprocalRankFusion<T>(
rankedLists: Array<Array<RrfRankedItem<T>>>,
getKey: (item: T) => string,
k: number = 60
): Array<RrfMergedItem<T>> {
const safeK = Math.max(1, Math.floor(k || 60))
const byKey = new Map<string, RrfMergedItem<T>>()
for (const list of rankedLists) {
for (let index = 0; index < list.length; index += 1) {
const entry = list[index]
const rank = Math.max(1, Math.floor(entry.rank || index + 1))
const key = getKey(entry.item)
const existing = byKey.get(key)
const contribution = 1 / (safeK + rank)
if (existing) {
existing.rrfScore += contribution
existing.ranks.push(rank)
if (Number.isFinite(entry.score)) existing.scores.push(Number(entry.score))
} else {
byKey.set(key, {
key,
item: entry.item,
rrfScore: contribution,
ranks: [rank],
scores: Number.isFinite(entry.score) ? [Number(entry.score)] : []
})
}
}
}
return Array.from(byKey.values())
.sort((a, b) => b.rrfScore - a.rrfScore || Math.min(...a.ranks) - Math.min(...b.ranks))
}
+13 -2
View File
@@ -341,7 +341,8 @@
animation: spin 1s linear infinite;
}
&.vector-index-btn {
&.vector-index-btn,
&.memory-build-btn {
position: relative;
&.complete {
@@ -386,6 +387,11 @@
background: #d97706;
}
}
&.memory-build-btn.complete {
color: #0891b2;
background: rgba(8, 145, 178, 0.12);
}
}
// 日期选择器包装器
@@ -1194,7 +1200,8 @@
}
}
.vector-index-btn {
.vector-index-btn,
.memory-build-btn {
overflow: visible;
&.complete {
@@ -1236,6 +1243,10 @@
background: #d97706;
}
}
.memory-build-btn.complete {
color: #0891b2;
}
}
.message-list {
+151 -1
View File
@@ -10,7 +10,12 @@ import { getImageXorKey, getImageAesKey, getQuoteStyle } from '../services/confi
import { LRUCache } from '../utils/lruCache'
import { LivePhotoIcon } from '../components/LivePhotoIcon'
import type { ChatSession, Message } from '../types/models'
import type { SessionVectorIndexProgressEvent, SessionVectorIndexState } from '../types/ai'
import type {
SessionMemoryBuildProgressEvent,
SessionMemoryBuildState,
SessionVectorIndexProgressEvent,
SessionVectorIndexState
} from '../types/ai'
import { List, RowComponentProps } from 'react-window'
import './ChatPage.scss'
@@ -358,6 +363,9 @@ function ChatPage(_props: ChatPageProps) {
const [vectorIndexState, setVectorIndexState] = useState<SessionVectorIndexState | null>(null)
const [vectorIndexProgress, setVectorIndexProgress] = useState<SessionVectorIndexProgressEvent | null>(null)
const [isPreparingVectorIndex, setIsPreparingVectorIndex] = useState(false)
const [memoryBuildState, setMemoryBuildState] = useState<SessionMemoryBuildState | null>(null)
const [memoryBuildProgress, setMemoryBuildProgress] = useState<SessionMemoryBuildProgressEvent | null>(null)
const [isPreparingMemoryBuild, setIsPreparingMemoryBuild] = useState(false)
const showTopToast = useCallback((text: string, success = true) => {
setTopToast({ text, success })
@@ -377,6 +385,19 @@ function ChatPage(_props: ChatPageProps) {
}
}, [])
const refreshMemoryBuildState = useCallback(async (sessionId: string) => {
try {
const result = await window.electronAPI.ai.getSessionMemoryBuildState(sessionId)
if (currentSessionIdRef.current !== sessionId) return
if (result.success && result.result) {
setMemoryBuildState(result.result)
setIsPreparingMemoryBuild(result.result.isRunning)
}
} catch (error) {
console.error('获取会话记忆状态失败:', error)
}
}, [])
useEffect(() => {
let cancelled = false
setVectorIndexState(null)
@@ -401,6 +422,29 @@ function ChatPage(_props: ChatPageProps) {
}
}, [currentSessionId])
useEffect(() => {
let cancelled = false
setMemoryBuildState(null)
setMemoryBuildProgress(null)
setIsPreparingMemoryBuild(false)
if (!currentSessionId) return
window.electronAPI.ai.getSessionMemoryBuildState(currentSessionId).then((result) => {
if (cancelled) return
if (result.success && result.result) {
setMemoryBuildState(result.result)
setIsPreparingMemoryBuild(result.result.isRunning)
}
}).catch((error) => {
if (!cancelled) console.error('加载会话记忆状态失败:', error)
})
return () => {
cancelled = true
}
}, [currentSessionId])
useEffect(() => {
return window.electronAPI.ai.onSessionVectorIndexProgress((event) => {
const activeSessionId = currentSessionIdRef.current
@@ -447,6 +491,35 @@ function ChatPage(_props: ChatPageProps) {
})
}, [refreshVectorIndexState, showTopToast])
useEffect(() => {
return window.electronAPI.ai.onSessionMemoryBuildProgress((event) => {
const activeSessionId = currentSessionIdRef.current
if (!event || !activeSessionId || event.sessionId !== activeSessionId) return
setMemoryBuildProgress(event)
setIsPreparingMemoryBuild(event.status === 'running')
setMemoryBuildState((prev) => ({
sessionId: event.sessionId,
messageCount: event.messageCount,
blockCount: event.blockCount,
factCount: event.factCount,
totalCount: event.totalCount || prev?.totalCount || 0,
processedCount: event.processedCount,
isRunning: event.status === 'running',
updatedAt: Date.now(),
completedAt: event.status === 'completed' ? Date.now() : prev?.completedAt,
lastError: event.status === 'failed' ? event.message : prev?.lastError
}))
if (event.status === 'completed') {
void refreshMemoryBuildState(event.sessionId)
} else if (event.status === 'failed') {
showTopToast(event.message || '会话记忆构建失败', false)
void refreshMemoryBuildState(event.sessionId)
}
})
}, [refreshMemoryBuildState, showTopToast])
const vectorIndexTotal = vectorIndexProgress?.totalCount || vectorIndexState?.indexedCount || 0
const vectorIndexDone = vectorIndexProgress?.processedCount ?? vectorIndexState?.vectorizedCount ?? 0
const vectorIndexPercent = vectorIndexTotal > 0
@@ -469,6 +542,23 @@ function ChatPage(_props: ChatPageProps) {
? `增量向量化:待处理 ${vectorIndexState?.pendingCount || 0}`
: '增量向量化当前聊天'
const memoryBuildTotal = memoryBuildProgress?.totalCount || memoryBuildState?.totalCount || 0
const memoryBuildDone = memoryBuildProgress?.processedCount ?? memoryBuildState?.processedCount ?? 0
const memoryBuildPercent = memoryBuildTotal > 0
? Math.min(100, Math.max(0, Math.round((memoryBuildDone / memoryBuildTotal) * 100)))
: memoryBuildState && memoryBuildState.totalCount > 0 ? 100 : 0
const memoryBuildCount = (memoryBuildState?.messageCount || 0) + (memoryBuildState?.blockCount || 0) + (memoryBuildState?.factCount || 0)
const memoryBuildBadgeLabel = isPreparingMemoryBuild
? `${memoryBuildPercent}%`
: memoryBuildCount > 0
? memoryBuildCount > 99 ? '99+' : String(memoryBuildCount)
: ''
const memoryButtonTitle = isPreparingMemoryBuild
? `正在构建会话记忆:${memoryBuildProgress?.message || `${memoryBuildDone}/${memoryBuildTotal}`}`
: memoryBuildCount > 0
? `重建会话记忆:消息 ${memoryBuildState?.messageCount || 0},片段 ${memoryBuildState?.blockCount || 0},事实 ${memoryBuildState?.factCount || 0}`
: '构建当前聊天三层记忆'
const handleVectorIndexClick = useCallback(async () => {
if (!currentSessionId) return
@@ -553,6 +643,50 @@ function ChatPage(_props: ChatPageProps) {
vectorIndexState?.vectorModel
])
const handleMemoryBuildClick = useCallback(async () => {
if (!currentSessionId || isPreparingMemoryBuild) return
setIsPreparingMemoryBuild(true)
setMemoryBuildProgress({
sessionId: currentSessionId,
stage: 'preparing',
status: 'running',
processedCount: memoryBuildState?.processedCount || 0,
totalCount: memoryBuildState?.totalCount || 0,
message: '正在准备会话记忆构建',
messageCount: memoryBuildState?.messageCount || 0,
blockCount: memoryBuildState?.blockCount || 0,
factCount: memoryBuildState?.factCount || 0
})
try {
const result = await window.electronAPI.ai.prepareSessionMemory({ sessionId: currentSessionId })
if (currentSessionIdRef.current !== currentSessionId) return
if (result.success && result.result) {
setMemoryBuildState(result.result)
setIsPreparingMemoryBuild(result.result.isRunning)
setMemoryBuildProgress(null)
showTopToast(`会话记忆已构建:${result.result.totalCount}`, true)
} else {
setIsPreparingMemoryBuild(false)
showTopToast(result.error || '会话记忆构建失败', false)
}
} catch (error) {
if (currentSessionIdRef.current !== currentSessionId) return
setIsPreparingMemoryBuild(false)
showTopToast(`会话记忆构建失败: ${String(error)}`, false)
}
}, [
currentSessionId,
isPreparingMemoryBuild,
memoryBuildState?.blockCount,
memoryBuildState?.factCount,
memoryBuildState?.messageCount,
memoryBuildState?.processedCount,
memoryBuildState?.totalCount,
showTopToast
])
useEffect(() => {
isLoadingMoreRef.current = isLoadingMore
}, [isLoadingMore])
@@ -2001,6 +2135,22 @@ function ChatPage(_props: ChatPageProps) {
<span className="vector-index-badge">{vectorIndexBadgeLabel}</span>
)}
</button>
<button
className={`icon-btn memory-build-btn ${isPreparingMemoryBuild ? 'running active' : ''} ${memoryBuildCount > 0 ? 'complete' : ''}`}
onClick={handleMemoryBuildClick}
disabled={!currentSessionId || isPreparingMemoryBuild}
title={memoryButtonTitle}
aria-label="构建当前聊天三层记忆"
>
{isPreparingMemoryBuild ? (
<Radar size={18} className="vector-index-radar" />
) : (
<Database size={18} />
)}
{memoryBuildBadgeLabel && (
<span className="vector-index-badge">{memoryBuildBadgeLabel}</span>
)}
</button>
{!isGroupChat(currentSession.username) && (
<button
className="icon-btn moments-btn"
+38
View File
@@ -288,6 +288,44 @@ export interface SessionVectorIndexProgressEvent {
vectorModel: string
}
export type SessionMemoryBuildProgressStage =
| 'preparing'
| 'indexing_messages'
| 'building_messages'
| 'building_blocks'
| 'building_facts'
| 'completed'
export type SessionMemoryBuildProgressStatus =
| 'running'
| 'completed'
| 'failed'
export interface SessionMemoryBuildState {
sessionId: string
messageCount: number
blockCount: number
factCount: number
totalCount: number
processedCount: number
isRunning: boolean
updatedAt: number
completedAt?: number
lastError?: string
}
export interface SessionMemoryBuildProgressEvent {
sessionId: string
stage: SessionMemoryBuildProgressStage
status: SessionMemoryBuildProgressStatus
processedCount: number
totalCount: number
message: string
messageCount: number
blockCount: number
factCount: number
}
export interface EmbeddingModelProfile {
id: string
displayName: string
+13
View File
@@ -13,6 +13,8 @@ import type {
SessionQACancelResult,
SessionQAStartResult,
SessionQAResult,
SessionMemoryBuildProgressEvent,
SessionMemoryBuildState,
SessionVectorIndexProgressEvent,
SessionVectorIndexState,
SummaryResult,
@@ -1137,6 +1139,16 @@ export interface ElectronAPI {
result?: SessionVectorIndexState
error?: string
}>
getSessionMemoryBuildState: (sessionId: string) => Promise<{
success: boolean
result?: SessionMemoryBuildState
error?: string
}>
prepareSessionMemory: (options: { sessionId: string }) => Promise<{
success: boolean
result?: SessionMemoryBuildState
error?: string
}>
getEmbeddingModelProfiles: () => Promise<{
success: boolean
result?: EmbeddingModelProfile[]
@@ -1185,6 +1197,7 @@ export interface ElectronAPI {
onSessionQAEvent: (callback: (event: SessionQAJobEvent) => void) => () => void
onSessionQAConversationUpdated: (callback: (event: SessionQAConversationDetail) => void) => () => void
onSessionVectorIndexProgress: (callback: (event: SessionVectorIndexProgressEvent) => void) => () => void
onSessionMemoryBuildProgress: (callback: (event: SessionMemoryBuildProgressEvent) => void) => () => void
onEmbeddingModelDownloadProgress: (callback: (event: EmbeddingModelDownloadProgress) => void) => () => void
}