mirror of
https://github.com/ILoveBingLu/CipherTalk.git
synced 2026-05-13 15:45:04 +08:00
feat: 实现记忆证据服务与检索引擎
新增 EvidenceService,基于消息引用处理记忆证据扩展 新增 MemoryKeywordSearchOptions、MemoryKeywordSearchHit 类型,支持关键词检索功能 增强 MemoryDatabase,为记忆项添加 FTS 全文索引并实现关键词搜索方法 将记忆库结构版本升级至 2,适配新功能 创建 RetrievalEngine 检索引擎,统一管理关键词、近似最近邻(ANN)搜索 实现倒数排序融合(RRF)算法,合并多来源搜索结果 在 ChatPage 中新增记忆构建流程的 UI 组件与状态管理 扩展 Electron API,支持记忆构建状态获取与预处理
This commit is contained in:
@@ -44,6 +44,18 @@ type SessionVectorIndexProgressEvent = {
|
||||
vectorModel: string
|
||||
}
|
||||
|
||||
type SessionMemoryBuildProgressEvent = {
|
||||
sessionId: string
|
||||
stage: string
|
||||
status: string
|
||||
processedCount: number
|
||||
totalCount: number
|
||||
message: string
|
||||
messageCount: number
|
||||
blockCount: number
|
||||
factCount: number
|
||||
}
|
||||
|
||||
type EmbeddingModelDownloadProgress = {
|
||||
profileId: string
|
||||
displayName: string
|
||||
@@ -603,6 +615,8 @@ contextBridge.exposeInMainWorld('electronAPI', {
|
||||
getSessionVectorIndexState: (sessionId: string) => ipcRenderer.invoke('ai:getSessionVectorIndexState', sessionId),
|
||||
prepareSessionVectorIndex: (options: { sessionId: string }) => ipcRenderer.invoke('ai:prepareSessionVectorIndex', options),
|
||||
cancelSessionVectorIndex: (sessionId: string) => ipcRenderer.invoke('ai:cancelSessionVectorIndex', sessionId),
|
||||
getSessionMemoryBuildState: (sessionId: string) => ipcRenderer.invoke('ai:getSessionMemoryBuildState', sessionId),
|
||||
prepareSessionMemory: (options: { sessionId: string }) => ipcRenderer.invoke('ai:prepareSessionMemory', options),
|
||||
getEmbeddingModelProfiles: () => ipcRenderer.invoke('ai:getEmbeddingModelProfiles'),
|
||||
setEmbeddingModelProfile: (profileId: string) => ipcRenderer.invoke('ai:setEmbeddingModelProfile', profileId),
|
||||
getEmbeddingDeviceStatus: () => ipcRenderer.invoke('ai:getEmbeddingDeviceStatus'),
|
||||
@@ -637,6 +651,10 @@ contextBridge.exposeInMainWorld('electronAPI', {
|
||||
ipcRenderer.on('ai:sessionVectorIndexProgress', (_, event) => callback(event))
|
||||
return () => ipcRenderer.removeAllListeners('ai:sessionVectorIndexProgress')
|
||||
},
|
||||
onSessionMemoryBuildProgress: (callback: (event: SessionMemoryBuildProgressEvent) => void) => {
|
||||
ipcRenderer.on('ai:sessionMemoryBuildProgress', (_, event) => callback(event))
|
||||
return () => ipcRenderer.removeAllListeners('ai:sessionMemoryBuildProgress')
|
||||
},
|
||||
onEmbeddingModelDownloadProgress: (callback: (event: EmbeddingModelDownloadProgress) => void) => {
|
||||
ipcRenderer.on('ai:embeddingModelDownloadProgress', (_, event) => callback(event))
|
||||
return () => ipcRenderer.removeAllListeners('ai:embeddingModelDownloadProgress')
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
import { chatService } from '../chatService'
|
||||
import type { Message } from '../chatService'
|
||||
import type { MemoryEvidenceRef, MemoryItem } from './memorySchema'
|
||||
import type { RetrievalExpandedEvidence } from '../retrieval/retrievalTypes'
|
||||
|
||||
function compareRefAsc(a: MemoryEvidenceRef, b: MemoryEvidenceRef): number {
|
||||
return Number(a.sortSeq || 0) - Number(b.sortSeq || 0)
|
||||
|| Number(a.createTime || 0) - Number(b.createTime || 0)
|
||||
|| Number(a.localId || 0) - Number(b.localId || 0)
|
||||
}
|
||||
|
||||
function messageKey(message: Pick<Message, 'localId' | 'createTime' | 'sortSeq'>): string {
|
||||
return `${Number(message.localId || 0)}:${Number(message.createTime || 0)}:${Number(message.sortSeq || 0)}`
|
||||
}
|
||||
|
||||
function uniqueMessages(messages: Message[]): Message[] {
|
||||
const seen = new Set<string>()
|
||||
const result: Message[] = []
|
||||
for (const message of messages) {
|
||||
const key = messageKey(message)
|
||||
if (seen.has(key)) continue
|
||||
seen.add(key)
|
||||
result.push(message)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
async function loadAnchor(ref: MemoryEvidenceRef): Promise<Message | null> {
|
||||
try {
|
||||
const result = await chatService.getMessageByLocalId(ref.sessionId, ref.localId)
|
||||
if (!result.success || !result.message) return null
|
||||
if (Number(ref.createTime || 0) > 0 && Number(result.message.createTime || 0) !== Number(ref.createTime)) {
|
||||
return {
|
||||
...result.message,
|
||||
createTime: ref.createTime,
|
||||
sortSeq: ref.sortSeq
|
||||
}
|
||||
}
|
||||
return result.message
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
async function expandAroundRef(ref: MemoryEvidenceRef, radius: number): Promise<RetrievalExpandedEvidence> {
|
||||
const [beforeResult, anchor, afterResult] = await Promise.all([
|
||||
chatService.getMessagesBefore(ref.sessionId, ref.sortSeq, radius, ref.createTime, ref.localId),
|
||||
loadAnchor(ref),
|
||||
chatService.getMessagesAfter(ref.sessionId, ref.sortSeq, radius, ref.createTime, ref.localId)
|
||||
])
|
||||
|
||||
return {
|
||||
ref,
|
||||
before: uniqueMessages(beforeResult.success ? beforeResult.messages || [] : []),
|
||||
anchor,
|
||||
after: uniqueMessages(afterResult.success ? afterResult.messages || [] : [])
|
||||
}
|
||||
}
|
||||
|
||||
export class EvidenceService {
|
||||
async expandMemoryEvidence(memory: MemoryItem): Promise<RetrievalExpandedEvidence[]> {
|
||||
const refs = [...memory.sourceRefs].sort(compareRefAsc)
|
||||
if (refs.length === 0) return []
|
||||
|
||||
if (memory.sourceType === 'conversation_block') {
|
||||
const first = refs[0]
|
||||
const last = refs[refs.length - 1]
|
||||
if (messageKey(first) === messageKey(last)) {
|
||||
return [await expandAroundRef(first, 3)]
|
||||
}
|
||||
|
||||
const [before, after] = await Promise.all([
|
||||
chatService.getMessagesBefore(first.sessionId, first.sortSeq, 3, first.createTime, first.localId),
|
||||
chatService.getMessagesAfter(last.sessionId, last.sortSeq, 3, last.createTime, last.localId)
|
||||
])
|
||||
return [{
|
||||
ref: first,
|
||||
before: uniqueMessages(before.success ? before.messages || [] : []),
|
||||
anchor: null,
|
||||
after: uniqueMessages(after.success ? after.messages || [] : [])
|
||||
}]
|
||||
}
|
||||
|
||||
const radius = memory.sourceType === 'message' ? 6 : 3
|
||||
const limitedRefs = refs.slice(0, memory.sourceType === 'fact' ? 3 : 1)
|
||||
return Promise.all(limitedRefs.map((ref) => expandAroundRef(ref, radius)))
|
||||
}
|
||||
}
|
||||
|
||||
export const evidenceService = new EvidenceService()
|
||||
@@ -20,6 +20,22 @@ import {
|
||||
type MemoryVectorStoreName
|
||||
} from './memorySchema'
|
||||
|
||||
export type MemoryKeywordSearchOptions = {
|
||||
query: string
|
||||
sessionId?: string
|
||||
sourceTypes?: MemorySourceType[]
|
||||
startTimeMs?: number
|
||||
endTimeMs?: number
|
||||
limit?: number
|
||||
}
|
||||
|
||||
export type MemoryKeywordSearchHit = {
|
||||
item: MemoryItem
|
||||
rank: number
|
||||
score: number
|
||||
retrievalSource: 'memory_fts' | 'memory_like'
|
||||
}
|
||||
|
||||
function nowMs(): number {
|
||||
return Date.now()
|
||||
}
|
||||
@@ -101,6 +117,67 @@ function parseEvidenceRefsJson(value: string): MemoryEvidenceRef[] {
|
||||
}
|
||||
}
|
||||
|
||||
function toTimestampSeconds(value?: number): number | null {
|
||||
if (!Number.isFinite(Number(value)) || Number(value) <= 0) return null
|
||||
const numberValue = Number(value)
|
||||
return numberValue > 10_000_000_000 ? Math.floor(numberValue / 1000) : Math.floor(numberValue)
|
||||
}
|
||||
|
||||
function escapeFtsPhrase(value: string): string {
|
||||
return `"${String(value || '').replace(/"/g, '""')}"`
|
||||
}
|
||||
|
||||
function buildMemoryFtsQuery(query: string): string {
|
||||
const normalized = String(query || '')
|
||||
.replace(/[\u200b-\u200f\ufeff]/g, '')
|
||||
.replace(/[,。!?;:、“”‘’()()[\]{}<>《》|\\/+=*_~`#$%^&-]+/g, ' ')
|
||||
.replace(/\s+/g, ' ')
|
||||
.trim()
|
||||
if (!normalized) return ''
|
||||
|
||||
const terms = normalized
|
||||
.split(/\s+/)
|
||||
.map((term) => term.trim())
|
||||
.filter(Boolean)
|
||||
return terms.length > 1
|
||||
? terms.map(escapeFtsPhrase).join(' AND ')
|
||||
: escapeFtsPhrase(normalized)
|
||||
}
|
||||
|
||||
function buildMemoryFilterSql(
|
||||
options: Pick<MemoryKeywordSearchOptions, 'sessionId' | 'sourceTypes' | 'startTimeMs' | 'endTimeMs'>,
|
||||
params: Record<string, unknown>
|
||||
): string {
|
||||
const clauses: string[] = []
|
||||
if (options.sessionId) {
|
||||
clauses.push('m.session_id = @sessionId')
|
||||
params.sessionId = options.sessionId
|
||||
}
|
||||
|
||||
const sourceTypes = Array.from(new Set((options.sourceTypes || []).filter((type) => MEMORY_SOURCE_TYPES.includes(type))))
|
||||
if (sourceTypes.length > 0) {
|
||||
const placeholders = sourceTypes.map((_, index) => `@sourceType${index}`)
|
||||
sourceTypes.forEach((sourceType, index) => {
|
||||
params[`sourceType${index}`] = sourceType
|
||||
})
|
||||
clauses.push(`m.source_type IN (${placeholders.join(', ')})`)
|
||||
}
|
||||
|
||||
const startTime = toTimestampSeconds(options.startTimeMs)
|
||||
if (startTime) {
|
||||
clauses.push('COALESCE(m.time_end, m.time_start, 0) >= @startTime')
|
||||
params.startTime = startTime
|
||||
}
|
||||
|
||||
const endTime = toTimestampSeconds(options.endTimeMs)
|
||||
if (endTime) {
|
||||
clauses.push('COALESCE(m.time_start, m.time_end, 0) <= @endTime')
|
||||
params.endTime = endTime
|
||||
}
|
||||
|
||||
return clauses.length ? `AND ${clauses.join(' AND ')}` : ''
|
||||
}
|
||||
|
||||
function safeSourceType(value: string): MemorySourceType {
|
||||
return MEMORY_SOURCE_TYPES.includes(value as MemorySourceType)
|
||||
? value as MemorySourceType
|
||||
@@ -265,12 +342,55 @@ export class MemoryDatabase {
|
||||
ON memory_embeddings(vector_store, vector_ref);
|
||||
`)
|
||||
|
||||
db.exec(`
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS memory_items_fts USING fts5(
|
||||
title,
|
||||
content,
|
||||
entities,
|
||||
tags,
|
||||
tokenize = 'unicode61 remove_diacritics 2'
|
||||
);
|
||||
`)
|
||||
this.syncMemoryFtsIndex(db)
|
||||
|
||||
db.prepare(`
|
||||
INSERT OR REPLACE INTO memory_meta(key, value, updated_at)
|
||||
VALUES ('schema_version', ?, ?)
|
||||
`).run(MEMORY_SCHEMA_VERSION, nowMs())
|
||||
}
|
||||
|
||||
private syncMemoryFtsIndex(db: Database.Database): void {
|
||||
const row = db.prepare(`
|
||||
SELECT COUNT(*) AS count
|
||||
FROM memory_items m
|
||||
LEFT JOIN memory_items_fts f ON f.rowid = m.id
|
||||
WHERE f.rowid IS NULL
|
||||
`).get() as { count: number } | undefined
|
||||
if (!Number(row?.count || 0)) return
|
||||
|
||||
db.prepare(`
|
||||
INSERT INTO memory_items_fts(rowid, title, content, entities, tags)
|
||||
SELECT id, title, content, entities_json, tags_json
|
||||
FROM memory_items
|
||||
WHERE id NOT IN (SELECT rowid FROM memory_items_fts)
|
||||
`).run()
|
||||
}
|
||||
|
||||
private upsertMemoryFtsRow(item: MemoryItem): void {
|
||||
const db = this.getDb()
|
||||
db.prepare('DELETE FROM memory_items_fts WHERE rowid = ?').run(item.id)
|
||||
db.prepare(`
|
||||
INSERT INTO memory_items_fts(rowid, title, content, entities, tags)
|
||||
VALUES (@id, @title, @content, @entities, @tags)
|
||||
`).run({
|
||||
id: item.id,
|
||||
title: item.title,
|
||||
content: item.content,
|
||||
entities: safeJsonStringify(item.entities || [], []),
|
||||
tags: safeJsonStringify(item.tags || [], [])
|
||||
})
|
||||
}
|
||||
|
||||
upsertMemoryItem(input: MemoryItemInput): MemoryItem {
|
||||
const db = this.getDb()
|
||||
const timestamp = nowMs()
|
||||
@@ -338,6 +458,7 @@ export class MemoryDatabase {
|
||||
|
||||
const item = this.getMemoryItemByUid(memoryUid)
|
||||
if (!item) throw new Error('Failed to load upserted memory item')
|
||||
this.upsertMemoryFtsRow(item)
|
||||
return item
|
||||
}
|
||||
|
||||
@@ -408,8 +529,80 @@ export class MemoryDatabase {
|
||||
return Number(row?.count || 0)
|
||||
}
|
||||
|
||||
searchMemoryItemsByKeyword(options: MemoryKeywordSearchOptions): MemoryKeywordSearchHit[] {
|
||||
const query = String(options.query || '').trim()
|
||||
if (!query) return []
|
||||
|
||||
const db = this.getDb()
|
||||
const limit = Math.max(1, Math.min(Math.floor(options.limit || 80), 500))
|
||||
const rowsById = new Map<number, MemoryKeywordSearchHit>()
|
||||
const params: Record<string, unknown> = { limit }
|
||||
const filterSql = buildMemoryFilterSql(options, params)
|
||||
const ftsQuery = buildMemoryFtsQuery(query)
|
||||
|
||||
if (ftsQuery) {
|
||||
const ftsRows = db.prepare(`
|
||||
SELECT m.*, bm25(memory_items_fts) AS fts_rank
|
||||
FROM memory_items_fts
|
||||
JOIN memory_items m ON m.id = memory_items_fts.rowid
|
||||
WHERE memory_items_fts MATCH @ftsQuery
|
||||
${filterSql}
|
||||
ORDER BY fts_rank ASC, COALESCE(m.time_end, m.time_start, m.updated_at) DESC, m.id DESC
|
||||
LIMIT @limit
|
||||
`).all({
|
||||
...params,
|
||||
ftsQuery
|
||||
}) as Array<MemoryItemRow & { fts_rank: number }>
|
||||
|
||||
ftsRows.forEach((row, index) => {
|
||||
rowsById.set(Number(row.id), {
|
||||
item: toMemoryItem(row),
|
||||
rank: index + 1,
|
||||
score: Number((1000 + Math.max(0, 100 - Number(row.fts_rank || 0))).toFixed(4)),
|
||||
retrievalSource: 'memory_fts'
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
const likeParams: Record<string, unknown> = { ...params, likeQuery: `%${query}%` }
|
||||
const likeFilterSql = buildMemoryFilterSql(options, likeParams)
|
||||
const likeRows = db.prepare(`
|
||||
SELECT m.*
|
||||
FROM memory_items m
|
||||
WHERE (
|
||||
m.title LIKE @likeQuery
|
||||
OR m.content LIKE @likeQuery
|
||||
OR m.entities_json LIKE @likeQuery
|
||||
OR m.tags_json LIKE @likeQuery
|
||||
)
|
||||
${likeFilterSql}
|
||||
ORDER BY COALESCE(m.time_end, m.time_start, m.updated_at) DESC, m.id DESC
|
||||
LIMIT @limit
|
||||
`).all(likeParams) as MemoryItemRow[]
|
||||
|
||||
let likeRank = 1
|
||||
for (const row of likeRows) {
|
||||
const id = Number(row.id)
|
||||
if (rowsById.has(id)) continue
|
||||
rowsById.set(id, {
|
||||
item: toMemoryItem(row),
|
||||
rank: likeRank,
|
||||
score: 500,
|
||||
retrievalSource: 'memory_like'
|
||||
})
|
||||
likeRank += 1
|
||||
}
|
||||
|
||||
return Array.from(rowsById.values())
|
||||
.sort((a, b) => b.score - a.score || b.item.importance - a.item.importance || b.item.updatedAt - a.item.updatedAt)
|
||||
.slice(0, limit)
|
||||
.map((hit, index) => ({ ...hit, rank: index + 1 }))
|
||||
}
|
||||
|
||||
deleteMemoryItem(id: number): boolean {
|
||||
const result = this.getDb().prepare('DELETE FROM memory_items WHERE id = ?').run(id)
|
||||
const db = this.getDb()
|
||||
db.prepare('DELETE FROM memory_items_fts WHERE rowid = ?').run(id)
|
||||
const result = db.prepare('DELETE FROM memory_items WHERE id = ?').run(id)
|
||||
return result.changes > 0
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
export const MEMORY_DB_NAME = 'agent_memory.db'
|
||||
export const MEMORY_SCHEMA_VERSION = '1'
|
||||
export const MEMORY_SCHEMA_VERSION = '2'
|
||||
|
||||
export const MEMORY_SOURCE_TYPES = [
|
||||
'message',
|
||||
|
||||
@@ -0,0 +1,380 @@
|
||||
import { memoryDatabase } from '../memory/memoryDatabase'
|
||||
import type { MemoryKeywordSearchHit } from '../memory/memoryDatabase'
|
||||
import { evidenceService } from '../memory/evidenceService'
|
||||
import type { MemoryItem } from '../memory/memorySchema'
|
||||
import { chatSearchIndexService } from '../search/chatSearchIndexService'
|
||||
import { localRerankerService, type RerankDocument } from './rerankerService'
|
||||
import { reciprocalRankFusion } from './rrf'
|
||||
import type {
|
||||
RetrievalCandidate,
|
||||
RetrievalEngineOptions,
|
||||
RetrievalEngineResult,
|
||||
RetrievalHit,
|
||||
RetrievalRerankStats,
|
||||
RetrievalSourceName,
|
||||
RetrievalSourceStats
|
||||
} from './retrievalTypes'
|
||||
|
||||
type SourceHit = {
|
||||
source: RetrievalSourceName
|
||||
memory: MemoryItem
|
||||
rank: number
|
||||
score: number
|
||||
}
|
||||
|
||||
const DEFAULT_LIMIT = 20
|
||||
const DEFAULT_KEYWORD_LIMIT = 80
|
||||
const DEFAULT_ANN_LIMIT = 80
|
||||
const DEFAULT_RERANK_LIMIT = 120
|
||||
const DEFAULT_RRF_K = 60
|
||||
|
||||
function compactText(value: string, limit: number): string {
|
||||
const normalized = String(value || '').replace(/\s+/g, ' ').trim()
|
||||
if (!normalized) return ''
|
||||
return normalized.length > limit ? `${normalized.slice(0, limit - 1)}...` : normalized
|
||||
}
|
||||
|
||||
function uniqueQueries(values: string[]): string[] {
|
||||
const seen = new Set<string>()
|
||||
const result: string[] = []
|
||||
for (const value of values) {
|
||||
const query = String(value || '').replace(/\s+/g, ' ').trim()
|
||||
const key = query.toLowerCase()
|
||||
if (!query || seen.has(key)) continue
|
||||
seen.add(key)
|
||||
result.push(query)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
function memoryKey(memory: MemoryItem): string {
|
||||
return String(memory.id)
|
||||
}
|
||||
|
||||
function messageMemoryUid(sessionId: string, message: { localId: number; createTime: number; sortSeq: number }): string {
|
||||
return `message:${sessionId}:${Number(message.localId || 0)}:${Number(message.createTime || 0)}:${Number(message.sortSeq || 0)}`
|
||||
}
|
||||
|
||||
function buildRerankDocument(candidate: RetrievalCandidate): RerankDocument {
|
||||
const memory = candidate.memory
|
||||
const sourceRefs = memory.sourceRefs
|
||||
.slice(0, 3)
|
||||
.map((ref) => `${ref.senderUsername || 'unknown'} ${ref.createTime}: ${ref.excerpt || ''}`)
|
||||
.filter(Boolean)
|
||||
.join('\n')
|
||||
const text = [
|
||||
`type: ${memory.sourceType}`,
|
||||
memory.title ? `title: ${memory.title}` : '',
|
||||
memory.timeStart || memory.timeEnd ? `time: ${memory.timeStart || ''}-${memory.timeEnd || ''}` : '',
|
||||
`content: ${memory.content}`,
|
||||
sourceRefs ? `evidence:\n${sourceRefs}` : ''
|
||||
].filter(Boolean).join('\n')
|
||||
|
||||
return {
|
||||
id: candidate.key,
|
||||
text: compactText(text, 4000),
|
||||
originalScore: candidate.rrfScore,
|
||||
metadata: {
|
||||
memoryId: memory.id,
|
||||
sourceType: memory.sourceType,
|
||||
sources: candidate.sources
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function toCandidate(hit: SourceHit): RetrievalCandidate {
|
||||
return {
|
||||
key: memoryKey(hit.memory),
|
||||
memory: hit.memory,
|
||||
sources: [hit.source],
|
||||
sourceRanks: { [hit.source]: hit.rank },
|
||||
sourceScores: { [hit.source]: hit.score },
|
||||
rrfScore: 0
|
||||
}
|
||||
}
|
||||
|
||||
function mergeSourceDetails(candidate: RetrievalCandidate, hits: SourceHit[]): RetrievalCandidate {
|
||||
const sources: RetrievalSourceName[] = []
|
||||
const sourceRanks: RetrievalCandidate['sourceRanks'] = {}
|
||||
const sourceScores: RetrievalCandidate['sourceScores'] = {}
|
||||
|
||||
for (const hit of hits) {
|
||||
if (!sources.includes(hit.source)) sources.push(hit.source)
|
||||
sourceRanks[hit.source] = Math.min(sourceRanks[hit.source] || Number.MAX_SAFE_INTEGER, hit.rank)
|
||||
sourceScores[hit.source] = Math.max(sourceScores[hit.source] || 0, hit.score)
|
||||
}
|
||||
|
||||
return {
|
||||
...candidate,
|
||||
sources,
|
||||
sourceRanks,
|
||||
sourceScores
|
||||
}
|
||||
}
|
||||
|
||||
function normalizeLimit(value: unknown, fallback: number, max: number): number {
|
||||
const numberValue = Math.floor(Number(value || fallback))
|
||||
return Math.max(1, Math.min(Number.isFinite(numberValue) ? numberValue : fallback, max))
|
||||
}
|
||||
|
||||
export class RetrievalEngine {
|
||||
async search(options: RetrievalEngineOptions): Promise<RetrievalEngineResult> {
|
||||
const startedAt = Date.now()
|
||||
const query = String(options.query || '').trim()
|
||||
if (!query) {
|
||||
return {
|
||||
query,
|
||||
semanticQuery: '',
|
||||
hits: [],
|
||||
sourceStats: [],
|
||||
rerank: { attempted: false, applied: false, skippedReason: 'empty_query' },
|
||||
latencyMs: Date.now() - startedAt
|
||||
}
|
||||
}
|
||||
|
||||
const limit = normalizeLimit(options.limit, DEFAULT_LIMIT, 100)
|
||||
const keywordLimit = normalizeLimit(options.keywordLimit, DEFAULT_KEYWORD_LIMIT, 500)
|
||||
const annLimit = normalizeLimit(options.annLimit, DEFAULT_ANN_LIMIT, 500)
|
||||
const rerankLimit = normalizeLimit(options.rerankLimit, DEFAULT_RERANK_LIMIT, 500)
|
||||
const keywordQueries = uniqueQueries([query, ...(options.keywordQueries || [])])
|
||||
const semanticQueries = uniqueQueries([
|
||||
options.semanticQuery || query,
|
||||
...(options.semanticQueries || [])
|
||||
])
|
||||
const semanticQuery = semanticQueries[0] || query
|
||||
const sourceStats: RetrievalSourceStats[] = []
|
||||
|
||||
const keywordHits = this.collectKeywordHits(options, keywordQueries, keywordLimit, sourceStats)
|
||||
const annHits = await this.collectAnnHits(options, semanticQueries, annLimit, sourceStats)
|
||||
const candidates = this.fuseCandidates([...keywordHits, ...annHits], options.rrfK)
|
||||
const rerankStats: RetrievalRerankStats = { attempted: false, applied: false }
|
||||
const ranked = await this.applyRerank(query, semanticQuery, candidates, rerankLimit, rerankStats, options.rerank !== false)
|
||||
const selected = ranked.slice(0, limit)
|
||||
const hits = await this.expandHits(selected, options.expandEvidence !== false)
|
||||
|
||||
return {
|
||||
query,
|
||||
semanticQuery,
|
||||
hits,
|
||||
sourceStats,
|
||||
rerank: rerankStats,
|
||||
latencyMs: Date.now() - startedAt
|
||||
}
|
||||
}
|
||||
|
||||
private collectKeywordHits(
|
||||
options: RetrievalEngineOptions,
|
||||
queries: string[],
|
||||
limit: number,
|
||||
sourceStats: RetrievalSourceStats[]
|
||||
): SourceHit[] {
|
||||
const hits: SourceHit[] = []
|
||||
let error: string | undefined
|
||||
|
||||
for (const query of queries) {
|
||||
try {
|
||||
const rows = memoryDatabase.searchMemoryItemsByKeyword({
|
||||
query,
|
||||
sessionId: options.sessionId,
|
||||
sourceTypes: options.sourceTypes,
|
||||
startTimeMs: options.startTimeMs,
|
||||
endTimeMs: options.endTimeMs,
|
||||
limit
|
||||
})
|
||||
hits.push(...rows.map((row) => this.keywordRowToSourceHit(row)))
|
||||
} catch (searchError) {
|
||||
error = String(searchError)
|
||||
}
|
||||
}
|
||||
|
||||
const ftsCount = hits.filter((hit) => hit.source === 'memory_fts').length
|
||||
const likeCount = hits.filter((hit) => hit.source === 'memory_like').length
|
||||
sourceStats.push({ name: 'memory_fts', attempted: true, hitCount: ftsCount, ...(error ? { error } : {}) })
|
||||
sourceStats.push({ name: 'memory_like', attempted: true, hitCount: likeCount, ...(error ? { error } : {}) })
|
||||
return this.dedupeSourceHits(hits)
|
||||
}
|
||||
|
||||
private keywordRowToSourceHit(row: MemoryKeywordSearchHit): SourceHit {
|
||||
return {
|
||||
source: row.retrievalSource,
|
||||
memory: row.item,
|
||||
rank: row.rank,
|
||||
score: row.score
|
||||
}
|
||||
}
|
||||
|
||||
private async collectAnnHits(
|
||||
options: RetrievalEngineOptions,
|
||||
queries: string[],
|
||||
limit: number,
|
||||
sourceStats: RetrievalSourceStats[]
|
||||
): Promise<SourceHit[]> {
|
||||
if (!options.sessionId) {
|
||||
sourceStats.push({ name: 'message_ann', attempted: false, hitCount: 0, skippedReason: 'session_required' })
|
||||
return []
|
||||
}
|
||||
|
||||
const vectorState = chatSearchIndexService.getSessionVectorIndexState(options.sessionId)
|
||||
if (!vectorState.vectorProviderAvailable) {
|
||||
sourceStats.push({ name: 'message_ann', attempted: false, hitCount: 0, skippedReason: 'vector_provider_unavailable' })
|
||||
return []
|
||||
}
|
||||
if (!vectorState.isVectorComplete) {
|
||||
sourceStats.push({ name: 'message_ann', attempted: false, hitCount: 0, skippedReason: 'vector_index_incomplete' })
|
||||
return []
|
||||
}
|
||||
|
||||
const hits: SourceHit[] = []
|
||||
let error: string | undefined
|
||||
for (const query of queries) {
|
||||
try {
|
||||
const result = await chatSearchIndexService.searchSessionByVector({
|
||||
sessionId: options.sessionId,
|
||||
query,
|
||||
limit,
|
||||
startTimeMs: options.startTimeMs,
|
||||
endTimeMs: options.endTimeMs,
|
||||
direction: options.direction,
|
||||
senderUsername: options.senderUsername
|
||||
})
|
||||
result.hits.forEach((hit, index) => {
|
||||
const uid = messageMemoryUid(hit.sessionId, hit.message)
|
||||
const memory = memoryDatabase.getMemoryItemByUid(uid)
|
||||
if (!memory) return
|
||||
hits.push({
|
||||
source: 'message_ann',
|
||||
memory,
|
||||
rank: index + 1,
|
||||
score: hit.score
|
||||
})
|
||||
})
|
||||
} catch (searchError) {
|
||||
error = String(searchError)
|
||||
}
|
||||
}
|
||||
|
||||
sourceStats.push({
|
||||
name: 'message_ann',
|
||||
attempted: true,
|
||||
hitCount: hits.length,
|
||||
...(error ? { error } : {})
|
||||
})
|
||||
return this.dedupeSourceHits(hits)
|
||||
}
|
||||
|
||||
private dedupeSourceHits(hits: SourceHit[]): SourceHit[] {
|
||||
const byKey = new Map<string, SourceHit>()
|
||||
for (const hit of hits) {
|
||||
const key = `${hit.source}:${memoryKey(hit.memory)}`
|
||||
const existing = byKey.get(key)
|
||||
if (!existing || hit.rank < existing.rank || hit.score > existing.score) {
|
||||
byKey.set(key, hit)
|
||||
}
|
||||
}
|
||||
return Array.from(byKey.values()).sort((a, b) => a.rank - b.rank || b.score - a.score)
|
||||
}
|
||||
|
||||
private fuseCandidates(sourceHits: SourceHit[], rrfK?: number): RetrievalCandidate[] {
|
||||
const hitsByMemory = new Map<string, SourceHit[]>()
|
||||
const listsBySource = new Map<RetrievalSourceName, SourceHit[]>()
|
||||
|
||||
for (const hit of sourceHits) {
|
||||
const key = memoryKey(hit.memory)
|
||||
const grouped = hitsByMemory.get(key) || []
|
||||
grouped.push(hit)
|
||||
hitsByMemory.set(key, grouped)
|
||||
|
||||
const list = listsBySource.get(hit.source) || []
|
||||
list.push(hit)
|
||||
listsBySource.set(hit.source, list)
|
||||
}
|
||||
|
||||
const fused = reciprocalRankFusion(
|
||||
Array.from(listsBySource.values()).map((list) => list
|
||||
.sort((a, b) => a.rank - b.rank || b.score - a.score)
|
||||
.map((hit, index) => ({ item: hit, rank: hit.rank || index + 1, score: hit.score }))),
|
||||
(hit) => memoryKey(hit.memory),
|
||||
rrfK || DEFAULT_RRF_K
|
||||
)
|
||||
|
||||
return fused.map((item) => {
|
||||
const candidate = toCandidate(item.item)
|
||||
candidate.rrfScore = Number(item.rrfScore.toFixed(8))
|
||||
return mergeSourceDetails(candidate, hitsByMemory.get(item.key) || [item.item])
|
||||
})
|
||||
}
|
||||
|
||||
private async applyRerank(
|
||||
query: string,
|
||||
semanticQuery: string,
|
||||
candidates: RetrievalCandidate[],
|
||||
limit: number,
|
||||
stats: RetrievalRerankStats,
|
||||
enabled: boolean
|
||||
): Promise<RetrievalCandidate[]> {
|
||||
if (!enabled) {
|
||||
stats.skippedReason = 'disabled'
|
||||
return candidates
|
||||
}
|
||||
if (candidates.length === 0) {
|
||||
stats.skippedReason = 'no_candidates'
|
||||
return candidates
|
||||
}
|
||||
if (!localRerankerService.isEnabled()) {
|
||||
stats.skippedReason = 'config_disabled'
|
||||
return candidates
|
||||
}
|
||||
|
||||
stats.attempted = true
|
||||
const rerankInput = candidates.slice(0, limit)
|
||||
try {
|
||||
const reranked = await localRerankerService.rerank(
|
||||
[query, semanticQuery].filter(Boolean).join('\n'),
|
||||
rerankInput.map(buildRerankDocument),
|
||||
{ limit }
|
||||
)
|
||||
const byKey = new Map(candidates.map((candidate) => [candidate.key, candidate]))
|
||||
const rerankedKeys = new Set<string>()
|
||||
const rankedCandidates = reranked
|
||||
.map((result) => {
|
||||
const candidate = byKey.get(result.id)
|
||||
if (!candidate) return null
|
||||
rerankedKeys.add(result.id)
|
||||
return {
|
||||
...candidate,
|
||||
rrfScore: candidate.rrfScore,
|
||||
rerankScore: result.rerankScore,
|
||||
finalScore: result.combinedScore
|
||||
}
|
||||
})
|
||||
.filter((item): item is RetrievalCandidate & { rerankScore: number; finalScore: number } => Boolean(item))
|
||||
.sort((a, b) => b.finalScore - a.finalScore)
|
||||
|
||||
stats.applied = rankedCandidates.length > 0
|
||||
const rest = candidates.filter((candidate) => !rerankedKeys.has(candidate.key))
|
||||
return [...rankedCandidates, ...rest]
|
||||
} catch (error) {
|
||||
stats.error = String(error)
|
||||
stats.skippedReason = 'rerank_failed'
|
||||
return candidates
|
||||
}
|
||||
}
|
||||
|
||||
private async expandHits(candidates: RetrievalCandidate[], expandEvidence: boolean): Promise<RetrievalHit[]> {
|
||||
const hits: RetrievalHit[] = []
|
||||
for (let index = 0; index < candidates.length; index += 1) {
|
||||
const candidate = candidates[index] as RetrievalCandidate & { rerankScore?: number; finalScore?: number }
|
||||
const evidence = expandEvidence ? await evidenceService.expandMemoryEvidence(candidate.memory) : []
|
||||
hits.push({
|
||||
...candidate,
|
||||
rank: index + 1,
|
||||
score: Number((candidate.finalScore ?? candidate.rerankScore ?? candidate.rrfScore).toFixed(8)),
|
||||
...(candidate.rerankScore != null ? { rerankScore: candidate.rerankScore } : {}),
|
||||
evidence
|
||||
})
|
||||
}
|
||||
return hits
|
||||
}
|
||||
}
|
||||
|
||||
export const retrievalEngine = new RetrievalEngine()
|
||||
@@ -1,3 +1,75 @@
|
||||
import type { Message } from '../chatService'
|
||||
import type { MemoryEvidenceRef, MemoryItem, MemorySourceType } from '../memory/memorySchema'
|
||||
|
||||
export type RetrievalSourceName = 'memory_fts' | 'memory_like' | 'message_ann'
|
||||
|
||||
export type RetrievalEngineOptions = {
|
||||
query: string
|
||||
semanticQuery?: string
|
||||
keywordQueries?: string[]
|
||||
semanticQueries?: string[]
|
||||
sessionId?: string
|
||||
sourceTypes?: MemorySourceType[]
|
||||
startTimeMs?: number
|
||||
endTimeMs?: number
|
||||
direction?: 'in' | 'out'
|
||||
senderUsername?: string
|
||||
limit?: number
|
||||
keywordLimit?: number
|
||||
annLimit?: number
|
||||
rrfK?: number
|
||||
rerank?: boolean
|
||||
rerankLimit?: number
|
||||
expandEvidence?: boolean
|
||||
}
|
||||
|
||||
export type RetrievalCandidate = {
|
||||
key: string
|
||||
memory: MemoryItem
|
||||
sources: RetrievalSourceName[]
|
||||
sourceRanks: Partial<Record<RetrievalSourceName, number>>
|
||||
sourceScores: Partial<Record<RetrievalSourceName, number>>
|
||||
rrfScore: number
|
||||
}
|
||||
|
||||
export type RetrievalExpandedEvidence = {
|
||||
ref: MemoryEvidenceRef
|
||||
before: Message[]
|
||||
anchor: Message | null
|
||||
after: Message[]
|
||||
}
|
||||
|
||||
export type RetrievalHit = RetrievalCandidate & {
|
||||
rank: number
|
||||
score: number
|
||||
rerankScore?: number
|
||||
evidence: RetrievalExpandedEvidence[]
|
||||
}
|
||||
|
||||
export type RetrievalSourceStats = {
|
||||
name: RetrievalSourceName
|
||||
attempted: boolean
|
||||
hitCount: number
|
||||
skippedReason?: string
|
||||
error?: string
|
||||
}
|
||||
|
||||
export type RetrievalRerankStats = {
|
||||
attempted: boolean
|
||||
applied: boolean
|
||||
skippedReason?: string
|
||||
error?: string
|
||||
}
|
||||
|
||||
export type RetrievalEngineResult = {
|
||||
query: string
|
||||
semanticQuery: string
|
||||
hits: RetrievalHit[]
|
||||
sourceStats: RetrievalSourceStats[]
|
||||
rerank: RetrievalRerankStats
|
||||
latencyMs: number
|
||||
}
|
||||
|
||||
export type RetrievalEvalEvidenceRef = {
|
||||
localId: number
|
||||
createTime: number
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
export type RrfRankedItem<T> = {
|
||||
item: T
|
||||
rank: number
|
||||
score?: number
|
||||
}
|
||||
|
||||
export type RrfMergedItem<T> = {
|
||||
key: string
|
||||
item: T
|
||||
rrfScore: number
|
||||
ranks: number[]
|
||||
scores: number[]
|
||||
}
|
||||
|
||||
export function reciprocalRankFusion<T>(
|
||||
rankedLists: Array<Array<RrfRankedItem<T>>>,
|
||||
getKey: (item: T) => string,
|
||||
k: number = 60
|
||||
): Array<RrfMergedItem<T>> {
|
||||
const safeK = Math.max(1, Math.floor(k || 60))
|
||||
const byKey = new Map<string, RrfMergedItem<T>>()
|
||||
|
||||
for (const list of rankedLists) {
|
||||
for (let index = 0; index < list.length; index += 1) {
|
||||
const entry = list[index]
|
||||
const rank = Math.max(1, Math.floor(entry.rank || index + 1))
|
||||
const key = getKey(entry.item)
|
||||
const existing = byKey.get(key)
|
||||
const contribution = 1 / (safeK + rank)
|
||||
|
||||
if (existing) {
|
||||
existing.rrfScore += contribution
|
||||
existing.ranks.push(rank)
|
||||
if (Number.isFinite(entry.score)) existing.scores.push(Number(entry.score))
|
||||
} else {
|
||||
byKey.set(key, {
|
||||
key,
|
||||
item: entry.item,
|
||||
rrfScore: contribution,
|
||||
ranks: [rank],
|
||||
scores: Number.isFinite(entry.score) ? [Number(entry.score)] : []
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Array.from(byKey.values())
|
||||
.sort((a, b) => b.rrfScore - a.rrfScore || Math.min(...a.ranks) - Math.min(...b.ranks))
|
||||
}
|
||||
+13
-2
@@ -341,7 +341,8 @@
|
||||
animation: spin 1s linear infinite;
|
||||
}
|
||||
|
||||
&.vector-index-btn {
|
||||
&.vector-index-btn,
|
||||
&.memory-build-btn {
|
||||
position: relative;
|
||||
|
||||
&.complete {
|
||||
@@ -386,6 +387,11 @@
|
||||
background: #d97706;
|
||||
}
|
||||
}
|
||||
|
||||
&.memory-build-btn.complete {
|
||||
color: #0891b2;
|
||||
background: rgba(8, 145, 178, 0.12);
|
||||
}
|
||||
}
|
||||
|
||||
// 日期选择器包装器
|
||||
@@ -1194,7 +1200,8 @@
|
||||
}
|
||||
}
|
||||
|
||||
.vector-index-btn {
|
||||
.vector-index-btn,
|
||||
.memory-build-btn {
|
||||
overflow: visible;
|
||||
|
||||
&.complete {
|
||||
@@ -1236,6 +1243,10 @@
|
||||
background: #d97706;
|
||||
}
|
||||
}
|
||||
|
||||
.memory-build-btn.complete {
|
||||
color: #0891b2;
|
||||
}
|
||||
}
|
||||
|
||||
.message-list {
|
||||
|
||||
+151
-1
@@ -10,7 +10,12 @@ import { getImageXorKey, getImageAesKey, getQuoteStyle } from '../services/confi
|
||||
import { LRUCache } from '../utils/lruCache'
|
||||
import { LivePhotoIcon } from '../components/LivePhotoIcon'
|
||||
import type { ChatSession, Message } from '../types/models'
|
||||
import type { SessionVectorIndexProgressEvent, SessionVectorIndexState } from '../types/ai'
|
||||
import type {
|
||||
SessionMemoryBuildProgressEvent,
|
||||
SessionMemoryBuildState,
|
||||
SessionVectorIndexProgressEvent,
|
||||
SessionVectorIndexState
|
||||
} from '../types/ai'
|
||||
import { List, RowComponentProps } from 'react-window'
|
||||
import './ChatPage.scss'
|
||||
|
||||
@@ -358,6 +363,9 @@ function ChatPage(_props: ChatPageProps) {
|
||||
const [vectorIndexState, setVectorIndexState] = useState<SessionVectorIndexState | null>(null)
|
||||
const [vectorIndexProgress, setVectorIndexProgress] = useState<SessionVectorIndexProgressEvent | null>(null)
|
||||
const [isPreparingVectorIndex, setIsPreparingVectorIndex] = useState(false)
|
||||
const [memoryBuildState, setMemoryBuildState] = useState<SessionMemoryBuildState | null>(null)
|
||||
const [memoryBuildProgress, setMemoryBuildProgress] = useState<SessionMemoryBuildProgressEvent | null>(null)
|
||||
const [isPreparingMemoryBuild, setIsPreparingMemoryBuild] = useState(false)
|
||||
|
||||
const showTopToast = useCallback((text: string, success = true) => {
|
||||
setTopToast({ text, success })
|
||||
@@ -377,6 +385,19 @@ function ChatPage(_props: ChatPageProps) {
|
||||
}
|
||||
}, [])
|
||||
|
||||
const refreshMemoryBuildState = useCallback(async (sessionId: string) => {
|
||||
try {
|
||||
const result = await window.electronAPI.ai.getSessionMemoryBuildState(sessionId)
|
||||
if (currentSessionIdRef.current !== sessionId) return
|
||||
if (result.success && result.result) {
|
||||
setMemoryBuildState(result.result)
|
||||
setIsPreparingMemoryBuild(result.result.isRunning)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('获取会话记忆状态失败:', error)
|
||||
}
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false
|
||||
setVectorIndexState(null)
|
||||
@@ -401,6 +422,29 @@ function ChatPage(_props: ChatPageProps) {
|
||||
}
|
||||
}, [currentSessionId])
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false
|
||||
setMemoryBuildState(null)
|
||||
setMemoryBuildProgress(null)
|
||||
setIsPreparingMemoryBuild(false)
|
||||
|
||||
if (!currentSessionId) return
|
||||
|
||||
window.electronAPI.ai.getSessionMemoryBuildState(currentSessionId).then((result) => {
|
||||
if (cancelled) return
|
||||
if (result.success && result.result) {
|
||||
setMemoryBuildState(result.result)
|
||||
setIsPreparingMemoryBuild(result.result.isRunning)
|
||||
}
|
||||
}).catch((error) => {
|
||||
if (!cancelled) console.error('加载会话记忆状态失败:', error)
|
||||
})
|
||||
|
||||
return () => {
|
||||
cancelled = true
|
||||
}
|
||||
}, [currentSessionId])
|
||||
|
||||
useEffect(() => {
|
||||
return window.electronAPI.ai.onSessionVectorIndexProgress((event) => {
|
||||
const activeSessionId = currentSessionIdRef.current
|
||||
@@ -447,6 +491,35 @@ function ChatPage(_props: ChatPageProps) {
|
||||
})
|
||||
}, [refreshVectorIndexState, showTopToast])
|
||||
|
||||
useEffect(() => {
|
||||
return window.electronAPI.ai.onSessionMemoryBuildProgress((event) => {
|
||||
const activeSessionId = currentSessionIdRef.current
|
||||
if (!event || !activeSessionId || event.sessionId !== activeSessionId) return
|
||||
|
||||
setMemoryBuildProgress(event)
|
||||
setIsPreparingMemoryBuild(event.status === 'running')
|
||||
setMemoryBuildState((prev) => ({
|
||||
sessionId: event.sessionId,
|
||||
messageCount: event.messageCount,
|
||||
blockCount: event.blockCount,
|
||||
factCount: event.factCount,
|
||||
totalCount: event.totalCount || prev?.totalCount || 0,
|
||||
processedCount: event.processedCount,
|
||||
isRunning: event.status === 'running',
|
||||
updatedAt: Date.now(),
|
||||
completedAt: event.status === 'completed' ? Date.now() : prev?.completedAt,
|
||||
lastError: event.status === 'failed' ? event.message : prev?.lastError
|
||||
}))
|
||||
|
||||
if (event.status === 'completed') {
|
||||
void refreshMemoryBuildState(event.sessionId)
|
||||
} else if (event.status === 'failed') {
|
||||
showTopToast(event.message || '会话记忆构建失败', false)
|
||||
void refreshMemoryBuildState(event.sessionId)
|
||||
}
|
||||
})
|
||||
}, [refreshMemoryBuildState, showTopToast])
|
||||
|
||||
const vectorIndexTotal = vectorIndexProgress?.totalCount || vectorIndexState?.indexedCount || 0
|
||||
const vectorIndexDone = vectorIndexProgress?.processedCount ?? vectorIndexState?.vectorizedCount ?? 0
|
||||
const vectorIndexPercent = vectorIndexTotal > 0
|
||||
@@ -469,6 +542,23 @@ function ChatPage(_props: ChatPageProps) {
|
||||
? `增量向量化:待处理 ${vectorIndexState?.pendingCount || 0} 条`
|
||||
: '增量向量化当前聊天'
|
||||
|
||||
const memoryBuildTotal = memoryBuildProgress?.totalCount || memoryBuildState?.totalCount || 0
|
||||
const memoryBuildDone = memoryBuildProgress?.processedCount ?? memoryBuildState?.processedCount ?? 0
|
||||
const memoryBuildPercent = memoryBuildTotal > 0
|
||||
? Math.min(100, Math.max(0, Math.round((memoryBuildDone / memoryBuildTotal) * 100)))
|
||||
: memoryBuildState && memoryBuildState.totalCount > 0 ? 100 : 0
|
||||
const memoryBuildCount = (memoryBuildState?.messageCount || 0) + (memoryBuildState?.blockCount || 0) + (memoryBuildState?.factCount || 0)
|
||||
const memoryBuildBadgeLabel = isPreparingMemoryBuild
|
||||
? `${memoryBuildPercent}%`
|
||||
: memoryBuildCount > 0
|
||||
? memoryBuildCount > 99 ? '99+' : String(memoryBuildCount)
|
||||
: ''
|
||||
const memoryButtonTitle = isPreparingMemoryBuild
|
||||
? `正在构建会话记忆:${memoryBuildProgress?.message || `${memoryBuildDone}/${memoryBuildTotal}`}`
|
||||
: memoryBuildCount > 0
|
||||
? `重建会话记忆:消息 ${memoryBuildState?.messageCount || 0},片段 ${memoryBuildState?.blockCount || 0},事实 ${memoryBuildState?.factCount || 0}`
|
||||
: '构建当前聊天三层记忆'
|
||||
|
||||
const handleVectorIndexClick = useCallback(async () => {
|
||||
if (!currentSessionId) return
|
||||
|
||||
@@ -553,6 +643,50 @@ function ChatPage(_props: ChatPageProps) {
|
||||
vectorIndexState?.vectorModel
|
||||
])
|
||||
|
||||
const handleMemoryBuildClick = useCallback(async () => {
|
||||
if (!currentSessionId || isPreparingMemoryBuild) return
|
||||
|
||||
setIsPreparingMemoryBuild(true)
|
||||
setMemoryBuildProgress({
|
||||
sessionId: currentSessionId,
|
||||
stage: 'preparing',
|
||||
status: 'running',
|
||||
processedCount: memoryBuildState?.processedCount || 0,
|
||||
totalCount: memoryBuildState?.totalCount || 0,
|
||||
message: '正在准备会话记忆构建',
|
||||
messageCount: memoryBuildState?.messageCount || 0,
|
||||
blockCount: memoryBuildState?.blockCount || 0,
|
||||
factCount: memoryBuildState?.factCount || 0
|
||||
})
|
||||
|
||||
try {
|
||||
const result = await window.electronAPI.ai.prepareSessionMemory({ sessionId: currentSessionId })
|
||||
if (currentSessionIdRef.current !== currentSessionId) return
|
||||
if (result.success && result.result) {
|
||||
setMemoryBuildState(result.result)
|
||||
setIsPreparingMemoryBuild(result.result.isRunning)
|
||||
setMemoryBuildProgress(null)
|
||||
showTopToast(`会话记忆已构建:${result.result.totalCount} 条`, true)
|
||||
} else {
|
||||
setIsPreparingMemoryBuild(false)
|
||||
showTopToast(result.error || '会话记忆构建失败', false)
|
||||
}
|
||||
} catch (error) {
|
||||
if (currentSessionIdRef.current !== currentSessionId) return
|
||||
setIsPreparingMemoryBuild(false)
|
||||
showTopToast(`会话记忆构建失败: ${String(error)}`, false)
|
||||
}
|
||||
}, [
|
||||
currentSessionId,
|
||||
isPreparingMemoryBuild,
|
||||
memoryBuildState?.blockCount,
|
||||
memoryBuildState?.factCount,
|
||||
memoryBuildState?.messageCount,
|
||||
memoryBuildState?.processedCount,
|
||||
memoryBuildState?.totalCount,
|
||||
showTopToast
|
||||
])
|
||||
|
||||
useEffect(() => {
|
||||
isLoadingMoreRef.current = isLoadingMore
|
||||
}, [isLoadingMore])
|
||||
@@ -2001,6 +2135,22 @@ function ChatPage(_props: ChatPageProps) {
|
||||
<span className="vector-index-badge">{vectorIndexBadgeLabel}</span>
|
||||
)}
|
||||
</button>
|
||||
<button
|
||||
className={`icon-btn memory-build-btn ${isPreparingMemoryBuild ? 'running active' : ''} ${memoryBuildCount > 0 ? 'complete' : ''}`}
|
||||
onClick={handleMemoryBuildClick}
|
||||
disabled={!currentSessionId || isPreparingMemoryBuild}
|
||||
title={memoryButtonTitle}
|
||||
aria-label="构建当前聊天三层记忆"
|
||||
>
|
||||
{isPreparingMemoryBuild ? (
|
||||
<Radar size={18} className="vector-index-radar" />
|
||||
) : (
|
||||
<Database size={18} />
|
||||
)}
|
||||
{memoryBuildBadgeLabel && (
|
||||
<span className="vector-index-badge">{memoryBuildBadgeLabel}</span>
|
||||
)}
|
||||
</button>
|
||||
{!isGroupChat(currentSession.username) && (
|
||||
<button
|
||||
className="icon-btn moments-btn"
|
||||
|
||||
@@ -288,6 +288,44 @@ export interface SessionVectorIndexProgressEvent {
|
||||
vectorModel: string
|
||||
}
|
||||
|
||||
export type SessionMemoryBuildProgressStage =
|
||||
| 'preparing'
|
||||
| 'indexing_messages'
|
||||
| 'building_messages'
|
||||
| 'building_blocks'
|
||||
| 'building_facts'
|
||||
| 'completed'
|
||||
|
||||
export type SessionMemoryBuildProgressStatus =
|
||||
| 'running'
|
||||
| 'completed'
|
||||
| 'failed'
|
||||
|
||||
export interface SessionMemoryBuildState {
|
||||
sessionId: string
|
||||
messageCount: number
|
||||
blockCount: number
|
||||
factCount: number
|
||||
totalCount: number
|
||||
processedCount: number
|
||||
isRunning: boolean
|
||||
updatedAt: number
|
||||
completedAt?: number
|
||||
lastError?: string
|
||||
}
|
||||
|
||||
export interface SessionMemoryBuildProgressEvent {
|
||||
sessionId: string
|
||||
stage: SessionMemoryBuildProgressStage
|
||||
status: SessionMemoryBuildProgressStatus
|
||||
processedCount: number
|
||||
totalCount: number
|
||||
message: string
|
||||
messageCount: number
|
||||
blockCount: number
|
||||
factCount: number
|
||||
}
|
||||
|
||||
export interface EmbeddingModelProfile {
|
||||
id: string
|
||||
displayName: string
|
||||
|
||||
Vendored
+13
@@ -13,6 +13,8 @@ import type {
|
||||
SessionQACancelResult,
|
||||
SessionQAStartResult,
|
||||
SessionQAResult,
|
||||
SessionMemoryBuildProgressEvent,
|
||||
SessionMemoryBuildState,
|
||||
SessionVectorIndexProgressEvent,
|
||||
SessionVectorIndexState,
|
||||
SummaryResult,
|
||||
@@ -1137,6 +1139,16 @@ export interface ElectronAPI {
|
||||
result?: SessionVectorIndexState
|
||||
error?: string
|
||||
}>
|
||||
getSessionMemoryBuildState: (sessionId: string) => Promise<{
|
||||
success: boolean
|
||||
result?: SessionMemoryBuildState
|
||||
error?: string
|
||||
}>
|
||||
prepareSessionMemory: (options: { sessionId: string }) => Promise<{
|
||||
success: boolean
|
||||
result?: SessionMemoryBuildState
|
||||
error?: string
|
||||
}>
|
||||
getEmbeddingModelProfiles: () => Promise<{
|
||||
success: boolean
|
||||
result?: EmbeddingModelProfile[]
|
||||
@@ -1185,6 +1197,7 @@ export interface ElectronAPI {
|
||||
onSessionQAEvent: (callback: (event: SessionQAJobEvent) => void) => () => void
|
||||
onSessionQAConversationUpdated: (callback: (event: SessionQAConversationDetail) => void) => () => void
|
||||
onSessionVectorIndexProgress: (callback: (event: SessionVectorIndexProgressEvent) => void) => () => void
|
||||
onSessionMemoryBuildProgress: (callback: (event: SessionMemoryBuildProgressEvent) => void) => () => void
|
||||
onEmbeddingModelDownloadProgress: (callback: (event: EmbeddingModelDownloadProgress) => void) => () => void
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user