diff --git a/electron/main.ts b/electron/main.ts index 9e0a6ce..34fe45a 100644 --- a/electron/main.ts +++ b/electron/main.ts @@ -4103,6 +4103,48 @@ function registerIpcHandlers() { return { success: false, error: String(e) } } }) + + ipcMain.handle('ai:getSessionVectorIndexState', async (_, sessionId: string) => { + try { + const { chatSearchIndexService } = await import('./services/search/chatSearchIndexService') + return { + success: true, + result: chatSearchIndexService.getSessionVectorIndexState(sessionId) + } + } catch (e) { + console.error('[AI] 获取会话向量索引状态失败:', e) + logService?.error('AI', '获取会话向量索引状态失败', { error: String(e) }) + return { success: false, error: String(e) } + } + }) + + ipcMain.handle('ai:prepareSessionVectorIndex', async (event, options: { sessionId: string }) => { + try { + const { chatSearchIndexService } = await import('./services/search/chatSearchIndexService') + const result = await chatSearchIndexService.prepareSessionVectorIndex(options.sessionId, (progress) => { + event.sender.send('ai:sessionVectorIndexProgress', progress) + }) + return { success: true, result } + } catch (e) { + console.error('[AI] 准备会话向量索引失败:', e) + logService?.error('AI', '准备会话向量索引失败', { error: String(e) }) + return { success: false, error: String(e) } + } + }) + + ipcMain.handle('ai:cancelSessionVectorIndex', async (_, sessionId: string) => { + try { + const { chatSearchIndexService } = await import('./services/search/chatSearchIndexService') + return { + success: true, + result: chatSearchIndexService.cancelSessionVectorIndex(sessionId) + } + } catch (e) { + console.error('[AI] 取消会话向量索引失败:', e) + logService?.error('AI', '取消会话向量索引失败', { error: String(e) }) + return { success: false, error: String(e) } + } + }) } // 主窗口引用 diff --git a/electron/preload.ts b/electron/preload.ts index 13932b3..55ae85c 100644 --- a/electron/preload.ts +++ b/electron/preload.ts @@ -1,6 +1,24 @@ import { contextBridge, ipcRenderer } from 'electron' import type { AccountProfile } from '../src/types/account' -import type { SessionQAProgressEvent } from '../src/types/ai' + +type SessionQAProgressEvent = { + stage: string + status: string + message: string + toolName?: string + createdAt?: number + [key: string]: unknown +} + +type SessionVectorIndexProgressEvent = { + sessionId: string + stage: string + status: string + processedCount: number + totalCount: number + message: string + vectorModel: string +} function getMcpLaunchConfigSafe(): Promise<{ command: string @@ -522,6 +540,9 @@ contextBridge.exposeInMainWorld('electronAPI', { model: string enableThinking?: boolean }) => ipcRenderer.invoke('ai:askSessionQuestion', options), + getSessionVectorIndexState: (sessionId: string) => ipcRenderer.invoke('ai:getSessionVectorIndexState', sessionId), + prepareSessionVectorIndex: (options: { sessionId: string }) => ipcRenderer.invoke('ai:prepareSessionVectorIndex', options), + cancelSessionVectorIndex: (sessionId: string) => ipcRenderer.invoke('ai:cancelSessionVectorIndex', sessionId), onSummaryChunk: (callback: (chunk: string) => void) => { ipcRenderer.on('ai:summaryChunk', (_, chunk) => callback(chunk)) return () => ipcRenderer.removeAllListeners('ai:summaryChunk') @@ -533,6 +554,10 @@ contextBridge.exposeInMainWorld('electronAPI', { onSessionQAProgress: (callback: (event: SessionQAProgressEvent) => void) => { ipcRenderer.on('ai:sessionQaProgress', (_, event) => callback(event)) return () => ipcRenderer.removeAllListeners('ai:sessionQaProgress') + }, + onSessionVectorIndexProgress: (callback: (event: SessionVectorIndexProgressEvent) => void) => { + ipcRenderer.on('ai:sessionVectorIndexProgress', (_, event) => callback(event)) + return () => ipcRenderer.removeAllListeners('ai:sessionVectorIndexProgress') } } }) diff --git a/electron/services/ai-agent/qa/sessionQaAgent.ts b/electron/services/ai-agent/qa/sessionQaAgent.ts index ee6d32b..66724a2 100644 --- a/electron/services/ai-agent/qa/sessionQaAgent.ts +++ b/electron/services/ai-agent/qa/sessionQaAgent.ts @@ -33,6 +33,7 @@ export type SessionQAToolName = | 'aggregate_messages' | 'answer' | 'get_session_context' + | 'prepare_vector_index' export type SessionQAProgressStage = 'intent' | 'tool' | 'context' | 'answer' export type SessionQAProgressStatus = 'running' | 'completed' | 'failed' @@ -997,6 +998,7 @@ async function loadLatestContext(sessionId: string, limit = MAX_CONTEXT_MESSAGES } async function searchSessionMessages(sessionId: string, query: string, filters: { + semanticQuery?: string senderUsername?: string startTime?: number endTime?: number @@ -1008,6 +1010,7 @@ async function searchSessionMessages(sessionId: string, query: string, filters: const args = { sessionId, query, + ...(filters.semanticQuery ? { semanticQuery: filters.semanticQuery } : {}), limit: filters.limit || MAX_SEARCH_HITS, matchMode: 'substring', includeRaw: false, @@ -1920,6 +1923,7 @@ export async function answerSessionQuestionWithAgent( try { const search = await searchSessionMessages(options.sessionId, query, { + semanticQuery: `${query} ${options.question}`, startTime: route.timeRange?.startTime, endTime: route.timeRange?.endTime, senderUsername: route.intent === 'participant_focus' diff --git a/electron/services/mcp/readService.ts b/electron/services/mcp/readService.ts index 1180144..4f0d580 100644 --- a/electron/services/mcp/readService.ts +++ b/electron/services/mcp/readService.ts @@ -118,6 +118,7 @@ const listContactsArgsSchema = z.object({ const searchMessagesArgsSchema = z.object({ query: z.string().trim().min(1), + semanticQuery: z.string().trim().min(1).optional(), sessionId: z.string().trim().min(1).optional(), sessionIds: z.array(z.string().trim().min(1)).max(MAX_SEARCH_SESSIONS).optional(), startTime: z.number().int().positive().optional(), @@ -1957,8 +1958,18 @@ export class McpReadService { if (exhaustiveTargetedSearch) { try { const indexedRawHits: SearchRawHit[] = [] + const indexedRawHitMap = new Map() let indexedMessages = 0 let indexedTruncated = false + const semanticQuery = args.data.semanticQuery || args.data.query + const hitKey = (hit: Pick) => `${hit.session.sessionId}:${hit.message.localId}:${hit.message.createTime}:${hit.message.sortSeq}` + const addIndexedRawHit = (hit: SearchRawHit) => { + const key = hitKey(hit) + const existing = indexedRawHitMap.get(key) + if (!existing || hit.score > existing.score) { + indexedRawHitMap.set(key, hit) + } + } await reportProgress(reporter, { stage: 'scanning_messages', @@ -1990,7 +2001,39 @@ export class McpReadService { indexedMessages += indexed.indexedCount indexedTruncated = indexedTruncated || indexed.truncated - for (const hit of indexed.hits) { + const hybridHits = [...indexed.hits] + const vectorState = chatSearchIndexService.getSessionVectorIndexState(session.sessionId) + const shouldRunVectorSearch = matchMode !== 'exact' + && vectorState.isVectorComplete + if (shouldRunVectorSearch) { + try { + const vectorIndexed = await chatSearchIndexService.searchSessionByVector({ + sessionId: session.sessionId, + query: semanticQuery, + limit: Math.max(limit * 4, limit + 20), + matchMode, + startTimeMs, + endTimeMs, + direction: args.data.direction, + senderUsername: args.data.senderUsername, + onProgress: async (progress) => { + await reportProgress(reporter, { + stage: progress.stage === 'searching_index' ? 'streaming_hits' : 'scanning_messages', + message: progress.message, + sessionsScanned: targetSessions.indexOf(session) + 1, + messagesScanned: progress.indexedCount ?? progress.messagesScanned + }) + } + }) + + indexedTruncated = indexedTruncated || vectorIndexed.truncated + hybridHits.push(...vectorIndexed.hits) + } catch (error) { + console.warn('[McpReadService] Local vector search failed, keeping keyword results:', error) + } + } + + for (const hit of hybridHits) { if (!messageMatchesFilters(hit.message, { startTimeMs, endTimeMs, @@ -2001,7 +2044,7 @@ export class McpReadService { continue } - indexedRawHits.push({ + addIndexedRawHit({ session, message: hit.message, matchedField: hit.matchedField, @@ -2011,6 +2054,7 @@ export class McpReadService { } } + indexedRawHits.push(...indexedRawHitMap.values()) indexedRawHits.sort((a, b) => b.score - a.score || compareMessageCursorDesc(a.message, b.message)) const hits = await Promise.all(indexedRawHits.slice(0, limit).map(async (hit): Promise => ({ session: hit.session, diff --git a/electron/services/mcp/tools.ts b/electron/services/mcp/tools.ts index 88e33f5..0c36c10 100644 --- a/electron/services/mcp/tools.ts +++ b/electron/services/mcp/tools.ts @@ -218,6 +218,7 @@ export function registerCipherTalkMcpTools(server: any) { description: 'Search messages across one or more sessions and return agent-friendly hits. Use for broad clue hunting when the target session or keyword is still uncertain. Hit text is in hits[].message.text and hits[].excerpt.', inputSchema: { query: z.string().trim().min(1).describe('Required full-text query.'), + semanticQuery: z.string().trim().min(1).optional().describe('Optional natural-language query for local vector search. Defaults to query.'), sessionId: z.string().trim().min(1).optional().describe('Single session identifier to search. Accepts sessionId, contactId, display name, remark, or nickname when uniquely resolvable.'), sessionIds: z.array(z.string().trim().min(1)).max(20).optional().describe('Multiple session identifiers to search. Each item accepts sessionId, contactId, display name, remark, or nickname when uniquely resolvable.'), startTime: z.number().int().positive().optional().describe('Start timestamp in seconds or milliseconds.'), diff --git a/electron/services/search/chatSearchIndexService.ts b/electron/services/search/chatSearchIndexService.ts index c801da7..1ef514a 100644 --- a/electron/services/search/chatSearchIndexService.ts +++ b/electron/services/search/chatSearchIndexService.ts @@ -54,6 +54,46 @@ export interface ChatSearchSessionResult { truncated: boolean } +export type ChatVectorIndexProgressStage = + | 'preparing' + | 'indexing_messages' + | 'vectorizing_messages' + | 'completed' + +export type ChatVectorIndexProgressStatus = + | 'running' + | 'completed' + | 'cancelled' + | 'failed' + +export interface ChatVectorIndexProgress { + sessionId: string + stage: ChatVectorIndexProgressStage + status: ChatVectorIndexProgressStatus + processedCount: number + totalCount: number + message: string + vectorModel: string +} + +export interface ChatVectorIndexState { + sessionId: string + indexedCount: number + vectorizedCount: number + pendingCount: number + isVectorComplete: boolean + isVectorRunning: boolean + vectorModel: string +} + +export interface ChatVectorSearchSessionResult { + hits: ChatSearchIndexHit[] + indexedCount: number + vectorizedCount: number + truncated: boolean + model: string +} + type MessageIndexRow = { id: number session_id: string @@ -71,12 +111,111 @@ type MessageIndexRow = { message_json: string } +type MessageVectorRow = MessageIndexRow & { + vector_json: string +} + +type SparseVector = Array<[number, number]> + +interface LocalVectorProvider { + id: string + buildVector(text: string): SparseVector + parseVector(value: string): SparseVector + toWeightMap(vector: SparseVector): Map + dot(queryWeights: Map, vector: SparseVector): number +} + +type SessionVectorStateRow = { + session_id: string + vector_model: string + confirmed_at: number | null + completed_at: number | null + updated_at: number + is_complete: number + last_error: string | null +} + +type VectorTask = { + promise: Promise + cancelRequested: boolean +} + const INDEX_DB_NAME = 'chat_search_index.db' const INDEX_SCHEMA_VERSION = '1' const INDEX_BATCH_SIZE = 800 const MAX_INDEX_TEXT_CHARS = 8000 const MAX_EXCERPT_RADIUS = 48 const MAX_INDEX_SEARCH_CANDIDATES = 240 +const VECTOR_MODEL_ID = 'local-chargram-hash-v1' +const VECTOR_DIMENSIONS = 2048 +const VECTOR_BATCH_SIZE = 800 +const MAX_VECTOR_TEXT_CHARS = 2400 +const MAX_VECTOR_SCAN_ROWS = 120000 +const VECTOR_MIN_SCORE = 0.055 +// Vector hits are recall supplements, so keep them below high-confidence keyword hits. +const VECTOR_SCORE_BASE = 560 +const VECTOR_SCORE_SCALE = 420 +const VECTOR_INDEX_CANCELLED_ERROR = 'VECTOR_INDEX_CANCELLED' +const VECTOR_STOP_PHRASES = [ + '有没有', + '是不是', + '是否', + '什么', + '哪个', + '哪些', + '什么时候', + '为什么', + '怎么', + '如何', + '帮我', + '看看', + '请问', + '问一下', + '聊天记录', + '聊天', + '消息', + '记录', + '内容', + '这个', + '那个', + '我们', + '他们', + '对方', + '最近', + '说过', + '提到', + '关于', + '一下' +] +const VECTOR_STOP_WORDS = new Set([ + 'the', + 'and', + 'for', + 'with', + 'that', + 'this', + 'what', + 'when', + 'where', + 'which', + 'why', + 'how', + 'have', + 'has', + '是否', + '什么', + '哪个', + '哪些', + '怎么', + '如何', + '我们', + '他们', + '对方', + '消息', + '聊天', + '记录', + '内容' +]) function cursorKey(message: Pick): string { return `${Number(message.localId || 0)}:${Number(message.createTime || 0)}:${Number(message.sortSeq || 0)}` @@ -91,6 +230,15 @@ function compareCursorAsc( || Number(a.localId || 0) - Number(b.localId || 0) } +function compareIndexRowCursorAsc( + a: Pick, + b: Pick +): number { + return Number(a.sort_seq || 0) - Number(b.sort_seq || 0) + || Number(a.create_time || 0) - Number(b.create_time || 0) + || Number(a.local_id || 0) - Number(b.local_id || 0) +} + function normalizeSearchText(value?: string): string { return String(value || '') .toLowerCase() @@ -190,6 +338,149 @@ function buildSearchTokens(value: string): string { return uniqueStrings(tokens).join(' ') } +function normalizeVectorText(value: string): string { + let normalized = normalizeSearchText(value).slice(0, MAX_VECTOR_TEXT_CHARS) + + for (const phrase of VECTOR_STOP_PHRASES) { + normalized = normalized.replace(new RegExp(phrase, 'gi'), ' ') + } + + return normalized.replace(/\s+/g, ' ').trim() +} + +function hashString(value: string): number { + let hash = 2166136261 + for (let index = 0; index < value.length; index += 1) { + hash ^= value.charCodeAt(index) + hash = Math.imul(hash, 16777619) + } + return hash >>> 0 +} + +function addVectorFeature(weights: Map, feature: string, weight: number): void { + if (!feature || VECTOR_STOP_WORDS.has(feature)) return + + const hash = hashString(feature) + const dimension = hash % VECTOR_DIMENSIONS + const signedWeight = (hash & 0x80000000) ? -weight : weight + weights.set(dimension, (weights.get(dimension) || 0) + signedWeight) +} + +function addChineseVectorFeatures(weights: Map, segment: string): void { + if (!segment || VECTOR_STOP_WORDS.has(segment)) return + + if (segment.length >= 2 && segment.length <= 12) { + addVectorFeature(weights, `zh:${segment}`, 1.35) + } + + for (let size = 2; size <= 4; size += 1) { + if (segment.length < size) continue + const weight = size === 2 ? 0.9 : size === 3 ? 1.1 : 0.85 + for (let index = 0; index <= segment.length - size; index += 1) { + const gram = segment.slice(index, index + size) + if (VECTOR_STOP_WORDS.has(gram)) continue + addVectorFeature(weights, `c${size}:${gram}`, weight) + } + } +} + +function buildLocalSearchVector(value: string): SparseVector { + const normalized = normalizeVectorText(value) + if (!normalized) return [] + + const weights = new Map() + const latinWords: string[] = [] + + for (const match of normalized.matchAll(/[\u3400-\u9fff]+/g)) { + addChineseVectorFeatures(weights, match[0]) + } + + for (const match of normalized.matchAll(/[a-z0-9_@.\-]{2,}/g)) { + const word = match[0] + if (VECTOR_STOP_WORDS.has(word)) continue + latinWords.push(word) + addVectorFeature(weights, `w:${word}`, 1.2) + } + + for (let index = 0; index < latinWords.length - 1; index += 1) { + addVectorFeature(weights, `wb:${latinWords[index]} ${latinWords[index + 1]}`, 1.35) + } + + let norm = 0 + for (const weight of weights.values()) { + norm += weight * weight + } + + if (norm <= 0) return [] + + const scale = Math.sqrt(norm) + return Array.from(weights.entries()) + .map(([dimension, weight]) => [dimension, Number((weight / scale).toFixed(6))] as [number, number]) + .filter(([, weight]) => Math.abs(weight) > 0.000001) + .sort((a, b) => a[0] - b[0]) +} + +function parseSparseVector(value: string): SparseVector { + try { + const parsed = JSON.parse(value) + if (!Array.isArray(parsed)) return [] + + const vector: SparseVector = [] + for (const item of parsed) { + if (!Array.isArray(item) || item.length < 2) continue + const dimension = Number(item[0]) + const weight = Number(item[1]) + if (!Number.isInteger(dimension) || dimension < 0 || !Number.isFinite(weight)) continue + vector.push([dimension, weight]) + } + return vector + } catch { + return [] + } +} + +function dotSparseVector(queryWeights: Map, vector: SparseVector): number { + let score = 0 + for (const [dimension, weight] of vector) { + const queryWeight = queryWeights.get(dimension) + if (queryWeight) score += queryWeight * weight + } + return score +} + +function sparseVectorToMap(vector: SparseVector): Map { + return new Map(vector) +} + +const localVectorProvider: LocalVectorProvider = { + id: VECTOR_MODEL_ID, + buildVector: buildLocalSearchVector, + parseVector: parseSparseVector, + toWeightMap: sparseVectorToMap, + dot: dotSparseVector +} + +function createVectorExcerpt(row: Pick, query: string): string { + const text = String(row.parsed_content || row.search_text || '') + if (!text) return '' + + const normalizedQuery = normalizeSearchText(query) + const exactIndex = normalizedQuery ? normalizeSearchText(text).indexOf(normalizedQuery) : -1 + if (exactIndex >= 0) { + return createExcerpt(text, Math.min(exactIndex, Math.max(0, text.length - 1)), Math.max(normalizedQuery.length, 1)) + } + + const token = buildQueryTokens(query)[0] + if (token) { + const tokenIndex = normalizeSearchText(text).indexOf(token) + if (tokenIndex >= 0) { + return createExcerpt(text, Math.min(tokenIndex, Math.max(0, text.length - 1)), token.length) + } + } + + return createExcerpt(text, 0, Math.min(text.length, 24)) +} + function buildQueryTokens(query: string): string[] { const normalized = normalizeSearchText(query) const tokens: string[] = [] @@ -337,6 +628,7 @@ function toTimestampSeconds(value?: number): number | undefined { export class ChatSearchIndexService { private db: Database.Database | null = null private dbPath: string | null = null + private vectorTasks = new Map() private getCacheBasePath(): string { const configService = new ConfigService() @@ -427,10 +719,34 @@ export class ChatSearchIndexService { is_complete INTEGER NOT NULL DEFAULT 0 ); + CREATE TABLE IF NOT EXISTS message_vector_index ( + message_id INTEGER PRIMARY KEY, + session_id TEXT NOT NULL, + vector_model TEXT NOT NULL, + vector_json TEXT NOT NULL, + feature_count INTEGER NOT NULL DEFAULT 0, + indexed_at INTEGER NOT NULL + ); + + CREATE TABLE IF NOT EXISTS session_vector_state ( + session_id TEXT NOT NULL, + vector_model TEXT NOT NULL, + confirmed_at INTEGER, + completed_at INTEGER, + updated_at INTEGER NOT NULL, + is_complete INTEGER NOT NULL DEFAULT 0, + last_error TEXT, + PRIMARY KEY(session_id, vector_model) + ); + CREATE INDEX IF NOT EXISTS idx_message_index_session_time ON message_index(session_id, sort_seq DESC, create_time DESC, local_id DESC); CREATE INDEX IF NOT EXISTS idx_message_index_session_sender ON message_index(session_id, sender_username); + CREATE INDEX IF NOT EXISTS idx_message_vector_session_model + ON message_vector_index(session_id, vector_model); + CREATE INDEX IF NOT EXISTS idx_session_vector_state_session + ON session_vector_state(session_id); `) db.prepare('INSERT OR REPLACE INTO meta(key, value) VALUES (?, ?)').run('schema_version', INDEX_SCHEMA_VERSION) @@ -439,6 +755,8 @@ export class ChatSearchIndexService { private resetSchema(db: Database.Database): void { db.exec(` DROP TABLE IF EXISTS message_index_fts; + DROP TABLE IF EXISTS message_vector_index; + DROP TABLE IF EXISTS session_vector_state; DROP TABLE IF EXISTS message_index; DROP TABLE IF EXISTS session_index_state; DELETE FROM meta WHERE key = 'schema_version'; @@ -465,7 +783,146 @@ export class ChatSearchIndexService { return Number(row?.count || 0) } - private upsertMessages(db: Database.Database, sessionId: string, messages: Message[]): void { + private getVectorizedCount(db: Database.Database, sessionId: string): number { + const row = db.prepare(` + SELECT COUNT(*) AS count + FROM message_index m + JOIN message_vector_index v ON v.message_id = m.id + WHERE m.session_id = ? AND v.vector_model = ? + `).get(sessionId, localVectorProvider.id) as { count?: number } + return Number(row?.count || 0) + } + + private getVectorTaskKey(sessionId: string): string { + return `${sessionId}:${localVectorProvider.id}` + } + + private getVectorStateRow(db: Database.Database, sessionId: string): SessionVectorStateRow | null { + const row = db.prepare(` + SELECT * + FROM session_vector_state + WHERE session_id = ? AND vector_model = ? + `).get(sessionId, localVectorProvider.id) as SessionVectorStateRow | undefined + return row || null + } + + private isSessionVectorComplete(db: Database.Database, sessionId: string): boolean { + return Number(this.getVectorStateRow(db, sessionId)?.is_complete || 0) === 1 + } + + private setSessionVectorState(db: Database.Database, input: { + sessionId: string + confirmedAt?: number | null + completedAt?: number | null + isComplete: boolean + lastError?: string | null + }): void { + const now = Date.now() + db.prepare(` + INSERT INTO session_vector_state ( + session_id, + vector_model, + confirmed_at, + completed_at, + updated_at, + is_complete, + last_error + ) VALUES (?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(session_id, vector_model) DO UPDATE SET + confirmed_at = COALESCE(excluded.confirmed_at, session_vector_state.confirmed_at), + completed_at = excluded.completed_at, + updated_at = excluded.updated_at, + is_complete = excluded.is_complete, + last_error = excluded.last_error + `).run( + input.sessionId, + localVectorProvider.id, + input.confirmedAt ?? null, + input.completedAt ?? null, + now, + input.isComplete ? 1 : 0, + input.lastError ?? null + ) + } + + getSessionVectorIndexState(sessionId: string): ChatVectorIndexState { + const db = this.getDb() + const indexedCount = this.getIndexedCount(db, sessionId) + const vectorizedCount = this.getVectorizedCount(db, sessionId) + const isRunning = this.vectorTasks.has(this.getVectorTaskKey(sessionId)) + const row = this.getVectorStateRow(db, sessionId) + const isComplete = Number(row?.is_complete || 0) === 1 + && vectorizedCount >= indexedCount + + return { + sessionId, + indexedCount, + vectorizedCount, + pendingCount: Math.max(0, indexedCount - vectorizedCount), + isVectorComplete: isComplete, + isVectorRunning: isRunning, + vectorModel: localVectorProvider.id + } + } + + private async yieldToEventLoop(): Promise { + await new Promise((resolve) => setTimeout(resolve, 0)) + } + + private async reportVectorProgress( + progress: Omit, + onProgress?: (progress: ChatVectorIndexProgress) => void | Promise + ): Promise { + await onProgress?.({ + ...progress, + vectorModel: localVectorProvider.id + }) + } + + private upsertVectorRows( + db: Database.Database, + rows: Array & { indexed_at?: number }> + ): void { + if (rows.length === 0) return + + const upsertVector = db.prepare(` + INSERT INTO message_vector_index ( + message_id, + session_id, + vector_model, + vector_json, + feature_count, + indexed_at + ) VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(message_id) DO UPDATE SET + session_id = excluded.session_id, + vector_model = excluded.vector_model, + vector_json = excluded.vector_json, + feature_count = excluded.feature_count, + indexed_at = excluded.indexed_at + `) + + const run = db.transaction((items: Array & { indexed_at?: number }>) => { + const now = Date.now() + for (const row of items) { + const vector = localVectorProvider.buildVector(row.search_text) + upsertVector.run( + row.id, + row.session_id, + localVectorProvider.id, + JSON.stringify(vector), + vector.length, + row.indexed_at || now + ) + } + }) + + run(rows) + } + + private upsertMessages(db: Database.Database, sessionId: string, messages: Message[], options: { + vectorize?: boolean + } = {}): void { if (messages.length === 0) return const upsert = db.prepare(` @@ -521,6 +978,22 @@ export class ChatSearchIndexService { INSERT INTO message_index_fts(rowid, session_id, cursor_key, search_text, token_text) VALUES (?, ?, ?, ?, ?) `) + const upsertVector = db.prepare(` + INSERT INTO message_vector_index ( + message_id, + session_id, + vector_model, + vector_json, + feature_count, + indexed_at + ) VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(message_id) DO UPDATE SET + session_id = excluded.session_id, + vector_model = excluded.vector_model, + vector_json = excluded.vector_json, + feature_count = excluded.feature_count, + indexed_at = excluded.indexed_at + `) const run = db.transaction((items: Message[]) => { const indexedAt = Date.now() @@ -550,6 +1023,17 @@ export class ChatSearchIndexService { deleteFts.run(row.id) insertFts.run(row.id, sessionId, cursorKey(message), searchText, tokenText) + if (options.vectorize) { + const vector = localVectorProvider.buildVector(searchText) + upsertVector.run( + row.id, + sessionId, + localVectorProvider.id, + JSON.stringify(vector), + vector.length, + indexedAt + ) + } } }) @@ -620,6 +1104,7 @@ export class ChatSearchIndexService { ): Promise { const db = this.getDb() const state = this.getSessionState(db, sessionId) + const vectorizeDuringIndexing = this.isSessionVectorComplete(db, sessionId) let newest: Message | null = state?.isComplete && state.newestSortSeq > 0 ? { localId: state.newestLocalId, @@ -664,7 +1149,7 @@ export class ChatSearchIndexService { const messages = result.messages || [] if (messages.length === 0) break - this.upsertMessages(db, sessionId, messages) + this.upsertMessages(db, sessionId, messages, { vectorize: vectorizeDuringIndexing }) scanned += messages.length newest = messages[messages.length - 1] || newest cursor = { @@ -696,6 +1181,7 @@ export class ChatSearchIndexService { } db.prepare('DELETE FROM message_index_fts WHERE session_id = ?').run(sessionId) + db.prepare('DELETE FROM message_vector_index WHERE session_id = ?').run(sessionId) db.prepare('DELETE FROM message_index WHERE session_id = ?').run(sessionId) db.prepare('DELETE FROM session_index_state WHERE session_id = ?').run(sessionId) @@ -707,7 +1193,7 @@ export class ChatSearchIndexService { let messages = firstPage.messages || [] let hasMore = Boolean(firstPage.hasMore) if (messages.length > 0) { - this.upsertMessages(db, sessionId, messages) + this.upsertMessages(db, sessionId, messages, { vectorize: vectorizeDuringIndexing }) scanned += messages.length newest = messages[messages.length - 1] await this.report({ @@ -734,7 +1220,7 @@ export class ChatSearchIndexService { messages = result.messages || [] if (messages.length === 0) break - this.upsertMessages(db, sessionId, messages) + this.upsertMessages(db, sessionId, messages, { vectorize: vectorizeDuringIndexing }) scanned += messages.length hasMore = Boolean(result.hasMore) @@ -758,6 +1244,222 @@ export class ChatSearchIndexService { return nextState } + async prepareSessionVectorIndex( + sessionId: string, + onProgress?: (progress: ChatVectorIndexProgress) => void | Promise + ): Promise { + const key = this.getVectorTaskKey(sessionId) + const existingTask = this.vectorTasks.get(key) + if (existingTask) { + await this.reportVectorProgress({ + sessionId, + stage: 'preparing', + status: 'running', + processedCount: this.getSessionVectorIndexState(sessionId).vectorizedCount, + totalCount: this.getSessionVectorIndexState(sessionId).indexedCount, + message: '当前会话正在向量化,复用已有任务' + }, onProgress) + return existingTask.promise + } + + const task: VectorTask = { + cancelRequested: false, + promise: Promise.resolve(this.getSessionVectorIndexState(sessionId)) + } + task.promise = this.runPrepareSessionVectorIndex(sessionId, task, onProgress) + this.vectorTasks.set(key, task) + + try { + return await task.promise + } finally { + this.vectorTasks.delete(key) + } + } + + cancelSessionVectorIndex(sessionId: string): ChatVectorIndexState { + const task = this.vectorTasks.get(this.getVectorTaskKey(sessionId)) + if (task) { + task.cancelRequested = true + } + return this.getSessionVectorIndexState(sessionId) + } + + private async runPrepareSessionVectorIndex( + sessionId: string, + task: VectorTask, + onProgress?: (progress: ChatVectorIndexProgress) => void | Promise + ): Promise { + const db = this.getDb() + + await this.reportVectorProgress({ + sessionId, + stage: 'preparing', + status: 'running', + processedCount: 0, + totalCount: 0, + message: '正在准备当前会话搜索索引' + }, onProgress) + + try { + const searchState = await this.ensureSessionIndexed(sessionId, async (progress) => { + if (task.cancelRequested) { + throw new Error(VECTOR_INDEX_CANCELLED_ERROR) + } + await this.reportVectorProgress({ + sessionId, + stage: 'indexing_messages', + status: 'running', + processedCount: progress.indexedCount ?? progress.messagesScanned ?? 0, + totalCount: progress.indexedCount ?? 0, + message: progress.message + }, onProgress) + if (task.cancelRequested) { + throw new Error(VECTOR_INDEX_CANCELLED_ERROR) + } + }) + + let currentState = this.getSessionVectorIndexState(sessionId) + if (currentState.isVectorComplete) { + await this.reportVectorProgress({ + sessionId, + stage: 'completed', + status: 'completed', + processedCount: currentState.vectorizedCount, + totalCount: currentState.indexedCount, + message: `本地向量索引已就绪,共 ${currentState.vectorizedCount} 条消息` + }, onProgress) + return currentState + } + + this.setSessionVectorState(db, { + sessionId, + confirmedAt: Date.now(), + completedAt: null, + isComplete: false, + lastError: null + }) + + if (searchState.indexedCount === 0) { + this.setSessionVectorState(db, { + sessionId, + completedAt: Date.now(), + isComplete: true, + lastError: null + }) + currentState = this.getSessionVectorIndexState(sessionId) + await this.reportVectorProgress({ + sessionId, + stage: 'completed', + status: 'completed', + processedCount: 0, + totalCount: 0, + message: '当前会话暂无可向量化消息' + }, onProgress) + return currentState + } + + while (true) { + if (task.cancelRequested) { + this.setSessionVectorState(db, { + sessionId, + completedAt: null, + isComplete: false, + lastError: 'cancelled' + }) + currentState = this.getSessionVectorIndexState(sessionId) + await this.reportVectorProgress({ + sessionId, + stage: 'vectorizing_messages', + status: 'cancelled', + processedCount: currentState.vectorizedCount, + totalCount: currentState.indexedCount, + message: '已取消当前会话向量化' + }, onProgress) + return currentState + } + + const rows = db.prepare(` + SELECT m.id, m.session_id, m.search_text, m.indexed_at + FROM message_index m + LEFT JOIN message_vector_index v + ON v.message_id = m.id AND v.vector_model = ? + WHERE m.session_id = ? AND v.message_id IS NULL + ORDER BY m.id ASC + LIMIT ? + `).all(localVectorProvider.id, sessionId, VECTOR_BATCH_SIZE) as Array & { indexed_at?: number }> + + if (rows.length === 0) break + + this.upsertVectorRows(db, rows) + currentState = this.getSessionVectorIndexState(sessionId) + await this.reportVectorProgress({ + sessionId, + stage: 'vectorizing_messages', + status: 'running', + processedCount: currentState.vectorizedCount, + totalCount: currentState.indexedCount, + message: `已向量化 ${currentState.vectorizedCount}/${currentState.indexedCount} 条消息` + }, onProgress) + await this.yieldToEventLoop() + } + + this.setSessionVectorState(db, { + sessionId, + completedAt: Date.now(), + isComplete: true, + lastError: null + }) + currentState = this.getSessionVectorIndexState(sessionId) + + await this.reportVectorProgress({ + sessionId, + stage: 'completed', + status: 'completed', + processedCount: currentState.vectorizedCount, + totalCount: currentState.indexedCount, + message: `本地向量索引已完成,共 ${currentState.vectorizedCount} 条消息` + }, onProgress) + + return currentState + } catch (error) { + if (error instanceof Error && error.message === VECTOR_INDEX_CANCELLED_ERROR) { + this.setSessionVectorState(db, { + sessionId, + completedAt: null, + isComplete: false, + lastError: 'cancelled' + }) + const cancelledState = this.getSessionVectorIndexState(sessionId) + await this.reportVectorProgress({ + sessionId, + stage: 'indexing_messages', + status: 'cancelled', + processedCount: cancelledState.vectorizedCount, + totalCount: cancelledState.indexedCount, + message: '已取消当前会话向量化' + }, onProgress) + return cancelledState + } + + this.setSessionVectorState(db, { + sessionId, + completedAt: null, + isComplete: false, + lastError: String(error) + }) + const failedState = this.getSessionVectorIndexState(sessionId) + await this.reportVectorProgress({ + sessionId, + stage: 'vectorizing_messages', + status: 'failed', + processedCount: failedState.vectorizedCount, + totalCount: failedState.indexedCount, + message: `向量化失败:${String(error)}` + }, onProgress) + throw error + } + } + async searchSession(options: ChatSearchSessionOptions): Promise { const db = this.getDb() const state = await this.ensureSessionIndexed(options.sessionId, options.onProgress) @@ -869,6 +1571,100 @@ export class ChatSearchIndexService { truncated: rows.length > options.limit } } + + async searchSessionByVector(options: ChatSearchSessionOptions): Promise { + const db = this.getDb() + const state = await this.ensureSessionIndexed(options.sessionId, options.onProgress) + const vectorState = this.getSessionVectorIndexState(options.sessionId) + const queryVector = localVectorProvider.buildVector(options.query) + const vectorizedCount = vectorState.vectorizedCount + + if (!vectorState.isVectorComplete || queryVector.length === 0) { + return { + hits: [], + indexedCount: state.indexedCount, + vectorizedCount, + truncated: false, + model: localVectorProvider.id + } + } + + await this.report({ + stage: 'searching_index', + sessionId: options.sessionId, + message: `正在进行本地向量检索:${options.query}`, + indexedCount: state.indexedCount + }, options.onProgress) + + const startTime = toTimestampSeconds(options.startTimeMs) + const endTime = toTimestampSeconds(options.endTimeMs) + const senderUsername = normalizeSearchText(options.senderUsername) + const direction = options.direction + const scanLimit = MAX_VECTOR_SCAN_ROWS + const sqlFilters: string[] = [ + 'v.session_id = @sessionId', + 'v.vector_model = @vectorModel' + ] + const params: Record = { + sessionId: options.sessionId, + vectorModel: localVectorProvider.id, + scanLimit: scanLimit + 1 + } + + if (startTime) { + sqlFilters.push('m.create_time >= @startTime') + params.startTime = startTime + } + if (endTime) { + sqlFilters.push('m.create_time <= @endTime') + params.endTime = endTime + } + if (direction) { + sqlFilters.push(direction === 'out' ? 'm.is_send = 1' : '(m.is_send IS NULL OR m.is_send != 1)') + } + if (senderUsername) { + sqlFilters.push('lower(COALESCE(m.sender_username, \'\')) = @senderUsername') + params.senderUsername = senderUsername + } + + const rows = db.prepare(` + SELECT m.*, v.vector_json + FROM message_vector_index v + JOIN message_index m ON m.id = v.message_id + WHERE ${sqlFilters.join(' AND ')} + AND v.feature_count > 0 + ORDER BY m.sort_seq DESC, m.create_time DESC, m.local_id DESC + LIMIT @scanLimit + `).all(params) as MessageVectorRow[] + const queryWeights = localVectorProvider.toWeightMap(queryVector) + const scored = rows + .map((row) => { + const vectorScore = localVectorProvider.dot(queryWeights, localVectorProvider.parseVector(row.vector_json)) + return { + row, + vectorScore + } + }) + .filter((item) => item.vectorScore >= VECTOR_MIN_SCORE) + .sort((a, b) => b.vectorScore - a.vectorScore || compareIndexRowCursorAsc(b.row, a.row)) + .slice(0, options.limit) + + const hits = scored.map(({ row, vectorScore }) => ({ + sessionId: options.sessionId, + message: rowToMessage(row), + excerpt: createVectorExcerpt(row, options.query), + matchedField: 'text' as const, + score: Number((VECTOR_SCORE_BASE + vectorScore * VECTOR_SCORE_SCALE).toFixed(2)) + } satisfies ChatSearchIndexHit)) + + return { + hits, + indexedCount: state.indexedCount, + vectorizedCount, + truncated: rows.length > scanLimit, + model: localVectorProvider.id + } + } } export const chatSearchIndexService = new ChatSearchIndexService() diff --git a/src/pages/AISummaryWindow.scss b/src/pages/AISummaryWindow.scss index 98423d0..2b61a92 100644 --- a/src/pages/AISummaryWindow.scss +++ b/src/pages/AISummaryWindow.scss @@ -2077,6 +2077,50 @@ border-color: var(--primary); } } + + .vector-index-intro { + display: flex; + align-items: flex-start; + gap: 10px; + + svg { + flex: 0 0 auto; + color: var(--primary); + margin-top: 2px; + } + } + + .vector-index-stats { + display: flex; + justify-content: space-between; + gap: 12px; + margin-top: 14px; + font-size: 12px; + color: var(--text-tertiary); + } + + .vector-index-progress { + margin-top: 12px; + + p { + margin-top: 8px; + font-size: 12px; + color: var(--text-tertiary); + } + } + + .vector-index-progress-bar { + height: 6px; + background: var(--bg-secondary); + border-radius: 999px; + overflow: hidden; + + div { + height: 100%; + background: var(--primary); + transition: width 0.2s ease; + } + } } .dialog-actions { @@ -2134,6 +2178,13 @@ transform: translateY(0); } } + + &:disabled { + cursor: not-allowed; + opacity: 0.65; + transform: none; + box-shadow: none; + } } } } diff --git a/src/pages/AISummaryWindow.tsx b/src/pages/AISummaryWindow.tsx index 368a7a6..ed08c53 100644 --- a/src/pages/AISummaryWindow.tsx +++ b/src/pages/AISummaryWindow.tsx @@ -31,6 +31,8 @@ import { type SessionQAHistoryMessage, type SessionQAProgressEvent, type SessionQAResult, + type SessionVectorIndexProgressEvent, + type SessionVectorIndexState, type SummaryEvidenceRef, type SummaryResult, type SummaryStructuredAnalysis @@ -72,6 +74,11 @@ interface EvidenceContextState { error?: string } +interface VectorIndexConfirmState { + question: string + state?: SessionVectorIndexState +} + const RESULT_TABS: Array<{ id: ResultTabId; label: string; icon: LucideIcon }> = [ { id: 'overview', label: '概览', icon: LayoutDashboard }, { id: 'decisions', label: '决策', icon: CheckCircle2 }, @@ -294,6 +301,9 @@ function AISummaryWindow() { const [qaInput, setQaInput] = useState('') const [qaMessages, setQaMessages] = useState([]) const [isAsking, setIsAsking] = useState(false) + const [isVectorIndexing, setIsVectorIndexing] = useState(false) + const [vectorConfirm, setVectorConfirm] = useState(null) + const [vectorProgress, setVectorProgress] = useState(null) const [qaError, setQaError] = useState('') const [copiedEvidenceKey, setCopiedEvidenceKey] = useState('') const [evidenceContext, setEvidenceContext] = useState(null) @@ -1247,11 +1257,11 @@ function AISummaryWindow() { } } - const handleAskQuestion = async () => { - const question = qaInput.trim() + const runAskQuestion = async (question: string) => { if (!sessionId || !question || isAsking) return setQaInput('') + setVectorConfirm(null) setQaError('') setIsAsking(true) @@ -1354,6 +1364,93 @@ function AISummaryWindow() { } } + const handleAskQuestion = async () => { + const question = qaInput.trim() + if (!sessionId || !question || isAsking || isVectorIndexing) return + + setQaError('') + + try { + const stateResult = await window.electronAPI.ai.getSessionVectorIndexState(sessionId) + const state = stateResult.result + if (stateResult.success && state?.isVectorComplete) { + await runAskQuestion(question) + return + } + + setVectorProgress(null) + setVectorConfirm({ + question, + state + }) + } catch { + await runAskQuestion(question) + } + } + + const handleSkipVectorIndex = async () => { + const question = vectorConfirm?.question + setVectorConfirm(null) + setVectorProgress(null) + if (question) { + await runAskQuestion(question) + } + } + + const handlePrepareVectorIndex = async () => { + const question = vectorConfirm?.question + if (!sessionId || !question || isVectorIndexing) return + + let cleanupProgress: (() => void) | undefined + setQaError('') + setIsVectorIndexing(true) + setVectorProgress({ + sessionId, + stage: 'preparing', + status: 'running', + processedCount: vectorConfirm?.state?.vectorizedCount || 0, + totalCount: vectorConfirm?.state?.indexedCount || 0, + message: '正在准备本地向量索引', + vectorModel: vectorConfirm?.state?.vectorModel || 'local-chargram-hash-v1' + }) + + try { + cleanupProgress = window.electronAPI.ai.onSessionVectorIndexProgress((event) => { + if (event.sessionId !== sessionId) return + setVectorProgress(event) + }) + + const result = await window.electronAPI.ai.prepareSessionVectorIndex({ sessionId }) + if (!result.success || !result.result) { + throw new Error(result.error || '向量索引准备失败') + } + + if (result.result.isVectorComplete) { + setVectorConfirm(null) + setVectorProgress(null) + setIsVectorIndexing(false) + await runAskQuestion(question) + } else { + setQaError('本地向量索引未完成,已取消本次准备') + } + } catch (e) { + setQaError(String(e)) + } finally { + cleanupProgress?.() + setIsVectorIndexing(false) + } + } + + const handleCancelVectorIndex = async () => { + if (!sessionId) return + try { + await window.electronAPI.ai.cancelSessionVectorIndex(sessionId) + } catch { + // 取消失败不影响用户继续手动跳过。 + } + setIsVectorIndexing(false) + } + // 删除历史记录 const handleDeleteHistory = (id: number, e: React.MouseEvent) => { e.stopPropagation() @@ -1442,6 +1539,76 @@ function AISummaryWindow() { ) } + const renderVectorIndexDialog = () => { + if (!vectorConfirm) return null + + const totalCount = Math.max( + vectorProgress?.totalCount || 0, + vectorConfirm.state?.indexedCount || 0 + ) + const processedCount = Math.max( + vectorProgress?.processedCount || 0, + vectorConfirm.state?.vectorizedCount || 0 + ) + const progressPercent = totalCount > 0 + ? Math.min(100, Math.round((processedCount / totalCount) * 100)) + : 0 + const pendingCount = Math.max(0, totalCount - processedCount) + + return ( +
{ + if (!isVectorIndexing) { + setVectorConfirm(null) + setVectorProgress(null) + } + }} + > +
e.stopPropagation()}> +
+

准备本地向量索引

+
+
+
+ +

当前会话尚未完成本地向量化。建立后,本次和后续问 AI 会优先使用本地相似度检索,新消息会自动增量处理。

+
+ +
+ 已向量化 {processedCount} 条 + 待处理 {pendingCount} 条 +
+ + {isVectorIndexing && ( +
+
+
+
+

{vectorProgress?.message || `正在处理 ${processedCount}/${totalCount}`}

+
+ )} +
+
+ + +
+
+
+ ) + } + const renderAskPanel = () => (
@@ -1503,7 +1670,7 @@ function AISummaryWindow() { value={qaInput} placeholder="追问当前会话..." rows={2} - disabled={isAsking} + disabled={isAsking || isVectorIndexing} onChange={(event) => setQaInput(event.target.value)} onKeyDown={(event) => { if (event.key === 'Enter' && !event.shiftKey) { @@ -1516,10 +1683,10 @@ function AISummaryWindow() { className="qa-send-btn" type="button" onClick={handleAskQuestion} - disabled={!qaInput.trim() || isAsking} + disabled={!qaInput.trim() || isAsking || isVectorIndexing} data-tooltip="发送问题" > - {isAsking ? : } + {isAsking || isVectorIndexing ? : }
@@ -1577,6 +1744,11 @@ function AISummaryWindow() {
)} + {isVectorIndexing && ( +
+ +
+ )} {workspaceMode === 'summary' && result && !isGenerating && ( <> @@ -1908,6 +2080,8 @@ function AISummaryWindow() {
)} + + {renderVectorIndexDialog()} ) } diff --git a/src/types/ai.ts b/src/types/ai.ts index 1403b3e..9909a68 100644 --- a/src/types/ai.ts +++ b/src/types/ai.ts @@ -154,6 +154,7 @@ export interface SessionQAToolCall { | 'aggregate_messages' | 'answer' | 'get_session_context' + | 'prepare_vector_index' args: Record summary: string } @@ -173,6 +174,38 @@ export interface SessionQAProgressEvent { createdAt: number } +export interface SessionVectorIndexState { + sessionId: string + indexedCount: number + vectorizedCount: number + pendingCount: number + isVectorComplete: boolean + isVectorRunning: boolean + vectorModel: string +} + +export type SessionVectorIndexProgressStage = + | 'preparing' + | 'indexing_messages' + | 'vectorizing_messages' + | 'completed' + +export type SessionVectorIndexProgressStatus = + | 'running' + | 'completed' + | 'cancelled' + | 'failed' + +export interface SessionVectorIndexProgressEvent { + sessionId: string + stage: SessionVectorIndexProgressStage + status: SessionVectorIndexProgressStatus + processedCount: number + totalCount: number + message: string + vectorModel: string +} + export interface SessionQAResult { sessionId: string question: string diff --git a/src/types/electron.d.ts b/src/types/electron.d.ts index 4573482..2d3171b 100644 --- a/src/types/electron.d.ts +++ b/src/types/electron.d.ts @@ -3,6 +3,8 @@ import type { SessionQAHistoryMessage, SessionQAProgressEvent, SessionQAResult, + SessionVectorIndexProgressEvent, + SessionVectorIndexState, SummaryResult, SummaryStructuredAnalysis } from './ai' @@ -1072,9 +1074,25 @@ export interface ElectronAPI { result?: SessionQAResult error?: string }> + getSessionVectorIndexState: (sessionId: string) => Promise<{ + success: boolean + result?: SessionVectorIndexState + error?: string + }> + prepareSessionVectorIndex: (options: { sessionId: string }) => Promise<{ + success: boolean + result?: SessionVectorIndexState + error?: string + }> + cancelSessionVectorIndex: (sessionId: string) => Promise<{ + success: boolean + result?: SessionVectorIndexState + error?: string + }> onSummaryChunk: (callback: (chunk: string) => void) => () => void onSessionQAChunk: (callback: (chunk: string) => void) => () => void onSessionQAProgress: (callback: (event: SessionQAProgressEvent) => void) => () => void + onSessionVectorIndexProgress: (callback: (event: SessionVectorIndexProgressEvent) => void) => () => void } }