diff --git a/electron/main.ts b/electron/main.ts index 34fe45a..f60e80b 100644 --- a/electron/main.ts +++ b/electron/main.ts @@ -4145,6 +4145,111 @@ function registerIpcHandlers() { return { success: false, error: String(e) } } }) + + ipcMain.handle('ai:getEmbeddingModelProfiles', async () => { + try { + const { localEmbeddingModelService } = await import('./services/search/embeddingModelService') + return { + success: true, + result: localEmbeddingModelService.listProfiles(), + currentProfileId: localEmbeddingModelService.getCurrentProfileId() + } + } catch (e) { + console.error('[AI] 获取语义模型列表失败:', e) + return { success: false, error: String(e) } + } + }) + + ipcMain.handle('ai:setEmbeddingModelProfile', async (_, profileId: string) => { + try { + const { localEmbeddingModelService } = await import('./services/search/embeddingModelService') + const result = localEmbeddingModelService.setCurrentProfileId(profileId) + return { success: true, result } + } catch (e) { + console.error('[AI] 设置语义模型失败:', e) + return { success: false, error: String(e) } + } + }) + + ipcMain.handle('ai:getEmbeddingDeviceStatus', async () => { + try { + const { localEmbeddingModelService } = await import('./services/search/embeddingModelService') + return { + success: true, + result: localEmbeddingModelService.getDeviceStatus() + } + } catch (e) { + console.error('[AI] 获取语义向量计算模式失败:', e) + return { success: false, error: String(e) } + } + }) + + ipcMain.handle('ai:setEmbeddingDevice', async (_, device: string) => { + try { + const { localEmbeddingModelService } = await import('./services/search/embeddingModelService') + const result = localEmbeddingModelService.setCurrentDevice(device) + return { + success: true, + result, + status: localEmbeddingModelService.getDeviceStatus() + } + } catch (e) { + console.error('[AI] 设置语义向量计算模式失败:', e) + return { success: false, error: String(e) } + } + }) + + ipcMain.handle('ai:getEmbeddingModelStatus', async (_, profileId?: string) => { + try { + const { localEmbeddingModelService } = await import('./services/search/embeddingModelService') + return { + success: true, + result: await localEmbeddingModelService.getModelStatus(profileId) + } + } catch (e) { + console.error('[AI] 获取语义模型状态失败:', e) + return { success: false, error: String(e) } + } + }) + + ipcMain.handle('ai:downloadEmbeddingModel', async (event, profileId?: string) => { + try { + const { localEmbeddingModelService } = await import('./services/search/embeddingModelService') + const result = await localEmbeddingModelService.downloadModel(profileId, (progress) => { + event.sender.send('ai:embeddingModelDownloadProgress', progress) + }) + return { success: true, result } + } catch (e) { + console.error('[AI] 下载语义模型失败:', e) + return { success: false, error: String(e) } + } + }) + + ipcMain.handle('ai:clearEmbeddingModel', async (_, profileId?: string) => { + try { + const { localEmbeddingModelService } = await import('./services/search/embeddingModelService') + return { + success: true, + result: await localEmbeddingModelService.clearModel(profileId) + } + } catch (e) { + console.error('[AI] 清理语义模型失败:', e) + return { success: false, error: String(e) } + } + }) + + ipcMain.handle('ai:clearSemanticVectorIndex', async (_, vectorModel?: string) => { + try { + const { chatSearchIndexService } = await import('./services/search/chatSearchIndexService') + return { + success: true, + result: chatSearchIndexService.clearSemanticVectorIndex(vectorModel) + } + } catch (e) { + console.error('[AI] 清理语义向量索引失败:', e) + return { success: false, error: String(e) } + } + }) } // 主窗口引用 diff --git a/electron/preload.ts b/electron/preload.ts index 55ae85c..67559e9 100644 --- a/electron/preload.ts +++ b/electron/preload.ts @@ -20,6 +20,17 @@ type SessionVectorIndexProgressEvent = { vectorModel: string } +type EmbeddingModelDownloadProgress = { + profileId: string + displayName: string + remoteHost?: string + file?: string + loaded?: number + total?: number + percent?: number + status?: string +} + function getMcpLaunchConfigSafe(): Promise<{ command: string args: string[] @@ -543,6 +554,14 @@ contextBridge.exposeInMainWorld('electronAPI', { getSessionVectorIndexState: (sessionId: string) => ipcRenderer.invoke('ai:getSessionVectorIndexState', sessionId), prepareSessionVectorIndex: (options: { sessionId: string }) => ipcRenderer.invoke('ai:prepareSessionVectorIndex', options), cancelSessionVectorIndex: (sessionId: string) => ipcRenderer.invoke('ai:cancelSessionVectorIndex', sessionId), + getEmbeddingModelProfiles: () => ipcRenderer.invoke('ai:getEmbeddingModelProfiles'), + setEmbeddingModelProfile: (profileId: string) => ipcRenderer.invoke('ai:setEmbeddingModelProfile', profileId), + getEmbeddingDeviceStatus: () => ipcRenderer.invoke('ai:getEmbeddingDeviceStatus'), + setEmbeddingDevice: (device: 'cpu' | 'dml') => ipcRenderer.invoke('ai:setEmbeddingDevice', device), + getEmbeddingModelStatus: (profileId?: string) => ipcRenderer.invoke('ai:getEmbeddingModelStatus', profileId), + downloadEmbeddingModel: (profileId?: string) => ipcRenderer.invoke('ai:downloadEmbeddingModel', profileId), + clearEmbeddingModel: (profileId?: string) => ipcRenderer.invoke('ai:clearEmbeddingModel', profileId), + clearSemanticVectorIndex: (vectorModel?: string) => ipcRenderer.invoke('ai:clearSemanticVectorIndex', vectorModel), onSummaryChunk: (callback: (chunk: string) => void) => { ipcRenderer.on('ai:summaryChunk', (_, chunk) => callback(chunk)) return () => ipcRenderer.removeAllListeners('ai:summaryChunk') @@ -558,6 +577,10 @@ contextBridge.exposeInMainWorld('electronAPI', { onSessionVectorIndexProgress: (callback: (event: SessionVectorIndexProgressEvent) => void) => { ipcRenderer.on('ai:sessionVectorIndexProgress', (_, event) => callback(event)) return () => ipcRenderer.removeAllListeners('ai:sessionVectorIndexProgress') + }, + onEmbeddingModelDownloadProgress: (callback: (event: EmbeddingModelDownloadProgress) => void) => { + ipcRenderer.on('ai:embeddingModelDownloadProgress', (_, event) => callback(event)) + return () => ipcRenderer.removeAllListeners('ai:embeddingModelDownloadProgress') } } }) diff --git a/electron/services/config.ts b/electron/services/config.ts index ca199f4..8d736ab 100644 --- a/electron/services/config.ts +++ b/electron/services/config.ts @@ -110,6 +110,8 @@ interface ConfigSchema { aiEnableCache: boolean aiEnableThinking: boolean // 是否显示思考过程 aiMessageLimit: number // 摘要提取的消息条数限制 + aiEmbeddingModelProfile: string + aiEmbeddingDevice: 'cpu' | 'dml' mcpEnabled: boolean mcpExposeMediaPaths: boolean mcpProxyPort: number @@ -172,6 +174,8 @@ const defaults: ConfigSchema = { aiEnableCache: true, aiEnableThinking: true, // 默认显示思考过程 aiMessageLimit: 3000, // 默认3000条,用户可调至5000 + aiEmbeddingModelProfile: 'bge-large-zh-v1.5-int8', + aiEmbeddingDevice: 'cpu', mcpEnabled: false, mcpExposeMediaPaths: true, mcpProxyPort: 5032, diff --git a/electron/services/search/chatSearchIndexService.ts b/electron/services/search/chatSearchIndexService.ts index 1ef514a..b369ee5 100644 --- a/electron/services/search/chatSearchIndexService.ts +++ b/electron/services/search/chatSearchIndexService.ts @@ -3,6 +3,11 @@ import { existsSync, mkdirSync } from 'fs' import { join } from 'path' import { chatService, type Message } from '../chatService' import { ConfigService } from '../config' +import { + float32ArrayToBuffer, + hashEmbeddingContent, + localEmbeddingModelService +} from './embeddingModelService' export type ChatSearchIndexProgressStage = | 'preparing_index' @@ -56,6 +61,7 @@ export interface ChatSearchSessionResult { export type ChatVectorIndexProgressStage = | 'preparing' + | 'downloading_model' | 'indexing_messages' | 'vectorizing_messages' | 'completed' @@ -84,6 +90,9 @@ export interface ChatVectorIndexState { isVectorComplete: boolean isVectorRunning: boolean vectorModel: string + vectorModelName?: string + vectorProviderAvailable?: boolean + vectorProviderError?: string } export interface ChatVectorSearchSessionResult { @@ -112,17 +121,7 @@ type MessageIndexRow = { } type MessageVectorRow = MessageIndexRow & { - vector_json: string -} - -type SparseVector = Array<[number, number]> - -interface LocalVectorProvider { - id: string - buildVector(text: string): SparseVector - parseVector(value: string): SparseVector - toWeightMap(vector: SparseVector): Map - dot(queryWeights: Map, vector: SparseVector): number + distance: number } type SessionVectorStateRow = { @@ -141,81 +140,18 @@ type VectorTask = { } const INDEX_DB_NAME = 'chat_search_index.db' -const INDEX_SCHEMA_VERSION = '1' +const INDEX_SCHEMA_VERSION = '3' const INDEX_BATCH_SIZE = 800 const MAX_INDEX_TEXT_CHARS = 8000 const MAX_EXCERPT_RADIUS = 48 const MAX_INDEX_SEARCH_CANDIDATES = 240 -const VECTOR_MODEL_ID = 'local-chargram-hash-v1' -const VECTOR_DIMENSIONS = 2048 -const VECTOR_BATCH_SIZE = 800 -const MAX_VECTOR_TEXT_CHARS = 2400 -const MAX_VECTOR_SCAN_ROWS = 120000 -const VECTOR_MIN_SCORE = 0.055 +const VECTOR_BATCH_SIZE = 32 +const VECTOR_SEARCH_OVERFETCH = 8 +const VECTOR_MIN_SCORE = 0.45 // Vector hits are recall supplements, so keep them below high-confidence keyword hits. const VECTOR_SCORE_BASE = 560 const VECTOR_SCORE_SCALE = 420 const VECTOR_INDEX_CANCELLED_ERROR = 'VECTOR_INDEX_CANCELLED' -const VECTOR_STOP_PHRASES = [ - '有没有', - '是不是', - '是否', - '什么', - '哪个', - '哪些', - '什么时候', - '为什么', - '怎么', - '如何', - '帮我', - '看看', - '请问', - '问一下', - '聊天记录', - '聊天', - '消息', - '记录', - '内容', - '这个', - '那个', - '我们', - '他们', - '对方', - '最近', - '说过', - '提到', - '关于', - '一下' -] -const VECTOR_STOP_WORDS = new Set([ - 'the', - 'and', - 'for', - 'with', - 'that', - 'this', - 'what', - 'when', - 'where', - 'which', - 'why', - 'how', - 'have', - 'has', - '是否', - '什么', - '哪个', - '哪些', - '怎么', - '如何', - '我们', - '他们', - '对方', - '消息', - '聊天', - '记录', - '内容' -]) function cursorKey(message: Pick): string { return `${Number(message.localId || 0)}:${Number(message.createTime || 0)}:${Number(message.sortSeq || 0)}` @@ -239,6 +175,15 @@ function compareIndexRowCursorAsc( || Number(a.local_id || 0) - Number(b.local_id || 0) } +function vectorSessionKey(sessionId: string): number { + let hash = 2166136261 + for (let index = 0; index < sessionId.length; index += 1) { + hash ^= sessionId.charCodeAt(index) + hash = Math.imul(hash, 16777619) + } + return hash >>> 0 +} + function normalizeSearchText(value?: string): string { return String(value || '') .toLowerCase() @@ -338,128 +283,6 @@ function buildSearchTokens(value: string): string { return uniqueStrings(tokens).join(' ') } -function normalizeVectorText(value: string): string { - let normalized = normalizeSearchText(value).slice(0, MAX_VECTOR_TEXT_CHARS) - - for (const phrase of VECTOR_STOP_PHRASES) { - normalized = normalized.replace(new RegExp(phrase, 'gi'), ' ') - } - - return normalized.replace(/\s+/g, ' ').trim() -} - -function hashString(value: string): number { - let hash = 2166136261 - for (let index = 0; index < value.length; index += 1) { - hash ^= value.charCodeAt(index) - hash = Math.imul(hash, 16777619) - } - return hash >>> 0 -} - -function addVectorFeature(weights: Map, feature: string, weight: number): void { - if (!feature || VECTOR_STOP_WORDS.has(feature)) return - - const hash = hashString(feature) - const dimension = hash % VECTOR_DIMENSIONS - const signedWeight = (hash & 0x80000000) ? -weight : weight - weights.set(dimension, (weights.get(dimension) || 0) + signedWeight) -} - -function addChineseVectorFeatures(weights: Map, segment: string): void { - if (!segment || VECTOR_STOP_WORDS.has(segment)) return - - if (segment.length >= 2 && segment.length <= 12) { - addVectorFeature(weights, `zh:${segment}`, 1.35) - } - - for (let size = 2; size <= 4; size += 1) { - if (segment.length < size) continue - const weight = size === 2 ? 0.9 : size === 3 ? 1.1 : 0.85 - for (let index = 0; index <= segment.length - size; index += 1) { - const gram = segment.slice(index, index + size) - if (VECTOR_STOP_WORDS.has(gram)) continue - addVectorFeature(weights, `c${size}:${gram}`, weight) - } - } -} - -function buildLocalSearchVector(value: string): SparseVector { - const normalized = normalizeVectorText(value) - if (!normalized) return [] - - const weights = new Map() - const latinWords: string[] = [] - - for (const match of normalized.matchAll(/[\u3400-\u9fff]+/g)) { - addChineseVectorFeatures(weights, match[0]) - } - - for (const match of normalized.matchAll(/[a-z0-9_@.\-]{2,}/g)) { - const word = match[0] - if (VECTOR_STOP_WORDS.has(word)) continue - latinWords.push(word) - addVectorFeature(weights, `w:${word}`, 1.2) - } - - for (let index = 0; index < latinWords.length - 1; index += 1) { - addVectorFeature(weights, `wb:${latinWords[index]} ${latinWords[index + 1]}`, 1.35) - } - - let norm = 0 - for (const weight of weights.values()) { - norm += weight * weight - } - - if (norm <= 0) return [] - - const scale = Math.sqrt(norm) - return Array.from(weights.entries()) - .map(([dimension, weight]) => [dimension, Number((weight / scale).toFixed(6))] as [number, number]) - .filter(([, weight]) => Math.abs(weight) > 0.000001) - .sort((a, b) => a[0] - b[0]) -} - -function parseSparseVector(value: string): SparseVector { - try { - const parsed = JSON.parse(value) - if (!Array.isArray(parsed)) return [] - - const vector: SparseVector = [] - for (const item of parsed) { - if (!Array.isArray(item) || item.length < 2) continue - const dimension = Number(item[0]) - const weight = Number(item[1]) - if (!Number.isInteger(dimension) || dimension < 0 || !Number.isFinite(weight)) continue - vector.push([dimension, weight]) - } - return vector - } catch { - return [] - } -} - -function dotSparseVector(queryWeights: Map, vector: SparseVector): number { - let score = 0 - for (const [dimension, weight] of vector) { - const queryWeight = queryWeights.get(dimension) - if (queryWeight) score += queryWeight * weight - } - return score -} - -function sparseVectorToMap(vector: SparseVector): Map { - return new Map(vector) -} - -const localVectorProvider: LocalVectorProvider = { - id: VECTOR_MODEL_ID, - buildVector: buildLocalSearchVector, - parseVector: parseSparseVector, - toWeightMap: sparseVectorToMap, - dot: dotSparseVector -} - function createVectorExcerpt(row: Pick, query: string): string { const text = String(row.parsed_content || row.search_text || '') if (!text) return '' @@ -629,6 +452,8 @@ export class ChatSearchIndexService { private db: Database.Database | null = null private dbPath: string | null = null private vectorTasks = new Map() + private sqliteVectorAvailable = false + private sqliteVectorError = '' private getCacheBasePath(): string { const configService = new ConfigService() @@ -662,10 +487,24 @@ export class ChatSearchIndexService { const db = new Database(nextDbPath) this.db = db this.dbPath = nextDbPath + this.loadSqliteVectorExtension(db) this.ensureSchema(db) return db } + private loadSqliteVectorExtension(db: Database.Database): void { + try { + const sqliteVec = require('sqlite-vec') as { load: (db: Database.Database) => void } + sqliteVec.load(db) + this.sqliteVectorAvailable = true + this.sqliteVectorError = '' + } catch (error) { + this.sqliteVectorAvailable = false + this.sqliteVectorError = String(error) + console.warn('[ChatSearchIndex] sqlite-vec 加载失败,语义向量检索将降级为关键词检索:', error) + } + } + private ensureSchema(db: Database.Database): void { db.pragma('journal_mode = WAL') db.pragma('synchronous = NORMAL') @@ -720,12 +559,15 @@ export class ChatSearchIndexService { ); CREATE TABLE IF NOT EXISTS message_vector_index ( - message_id INTEGER PRIMARY KEY, + id INTEGER PRIMARY KEY AUTOINCREMENT, + message_id INTEGER NOT NULL, session_id TEXT NOT NULL, vector_model TEXT NOT NULL, - vector_json TEXT NOT NULL, - feature_count INTEGER NOT NULL DEFAULT 0, - indexed_at INTEGER NOT NULL + embedding_blob BLOB NOT NULL, + dim INTEGER NOT NULL, + content_hash TEXT NOT NULL, + indexed_at INTEGER NOT NULL, + UNIQUE(message_id, vector_model) ); CREATE TABLE IF NOT EXISTS session_vector_state ( @@ -745,16 +587,32 @@ export class ChatSearchIndexService { ON message_index(session_id, sender_username); CREATE INDEX IF NOT EXISTS idx_message_vector_session_model ON message_vector_index(session_id, vector_model); + CREATE INDEX IF NOT EXISTS idx_message_vector_message_model + ON message_vector_index(message_id, vector_model); CREATE INDEX IF NOT EXISTS idx_session_vector_state_session ON session_vector_state(session_id); `) + if (this.sqliteVectorAvailable) { + const dim = localEmbeddingModelService.getProfile().dim + db.exec(` + CREATE VIRTUAL TABLE IF NOT EXISTS message_embedding_vec USING vec0( + vector_id INTEGER PRIMARY KEY, + session_key INTEGER PARTITION KEY, + session_id TEXT, + vector_model TEXT, + embedding FLOAT[${dim}] + ); + `) + } + db.prepare('INSERT OR REPLACE INTO meta(key, value) VALUES (?, ?)').run('schema_version', INDEX_SCHEMA_VERSION) } private resetSchema(db: Database.Database): void { db.exec(` DROP TABLE IF EXISTS message_index_fts; + DROP TABLE IF EXISTS message_embedding_vec; DROP TABLE IF EXISTS message_vector_index; DROP TABLE IF EXISTS session_vector_state; DROP TABLE IF EXISTS message_index; @@ -783,26 +641,36 @@ export class ChatSearchIndexService { return Number(row?.count || 0) } + private getCurrentVectorProfile() { + return localEmbeddingModelService.getProfile() + } + + private getCurrentVectorModelId(): string { + return this.getCurrentVectorProfile().id + } + private getVectorizedCount(db: Database.Database, sessionId: string): number { + const vectorModel = this.getCurrentVectorModelId() const row = db.prepare(` SELECT COUNT(*) AS count FROM message_index m JOIN message_vector_index v ON v.message_id = m.id WHERE m.session_id = ? AND v.vector_model = ? - `).get(sessionId, localVectorProvider.id) as { count?: number } + `).get(sessionId, vectorModel) as { count?: number } return Number(row?.count || 0) } private getVectorTaskKey(sessionId: string): string { - return `${sessionId}:${localVectorProvider.id}` + return `${sessionId}:${this.getCurrentVectorModelId()}` } private getVectorStateRow(db: Database.Database, sessionId: string): SessionVectorStateRow | null { + const vectorModel = this.getCurrentVectorModelId() const row = db.prepare(` SELECT * FROM session_vector_state WHERE session_id = ? AND vector_model = ? - `).get(sessionId, localVectorProvider.id) as SessionVectorStateRow | undefined + `).get(sessionId, vectorModel) as SessionVectorStateRow | undefined return row || null } @@ -836,7 +704,7 @@ export class ChatSearchIndexService { last_error = excluded.last_error `).run( input.sessionId, - localVectorProvider.id, + this.getCurrentVectorModelId(), input.confirmedAt ?? null, input.completedAt ?? null, now, @@ -847,11 +715,13 @@ export class ChatSearchIndexService { getSessionVectorIndexState(sessionId: string): ChatVectorIndexState { const db = this.getDb() + const profile = this.getCurrentVectorProfile() const indexedCount = this.getIndexedCount(db, sessionId) const vectorizedCount = this.getVectorizedCount(db, sessionId) const isRunning = this.vectorTasks.has(this.getVectorTaskKey(sessionId)) const row = this.getVectorStateRow(db, sessionId) - const isComplete = Number(row?.is_complete || 0) === 1 + const isComplete = this.sqliteVectorAvailable + && Number(row?.is_complete || 0) === 1 && vectorizedCount >= indexedCount return { @@ -861,7 +731,10 @@ export class ChatSearchIndexService { pendingCount: Math.max(0, indexedCount - vectorizedCount), isVectorComplete: isComplete, isVectorRunning: isRunning, - vectorModel: localVectorProvider.id + vectorModel: profile.id, + vectorModelName: profile.displayName, + vectorProviderAvailable: this.sqliteVectorAvailable, + vectorProviderError: this.sqliteVectorError } } @@ -873,56 +746,78 @@ export class ChatSearchIndexService { progress: Omit, onProgress?: (progress: ChatVectorIndexProgress) => void | Promise ): Promise { + const profile = this.getCurrentVectorProfile() await onProgress?.({ ...progress, - vectorModel: localVectorProvider.id + vectorModel: profile.id }) } - private upsertVectorRows( + private async upsertVectorRows( db: Database.Database, rows: Array & { indexed_at?: number }> - ): void { + ): Promise { if (rows.length === 0) return + if (!this.sqliteVectorAvailable) { + throw new Error(`本地语义检索不可用:${this.sqliteVectorError || 'sqlite-vec 未加载'}`) + } + + const profile = this.getCurrentVectorProfile() + const embeddings = await localEmbeddingModelService.embedTexts(rows.map((row) => row.search_text), profile.id) const upsertVector = db.prepare(` INSERT INTO message_vector_index ( message_id, session_id, vector_model, - vector_json, - feature_count, + embedding_blob, + dim, + content_hash, indexed_at - ) VALUES (?, ?, ?, ?, ?, ?) - ON CONFLICT(message_id) DO UPDATE SET + ) VALUES (?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(message_id, vector_model) DO UPDATE SET session_id = excluded.session_id, - vector_model = excluded.vector_model, - vector_json = excluded.vector_json, - feature_count = excluded.feature_count, + embedding_blob = excluded.embedding_blob, + dim = excluded.dim, + content_hash = excluded.content_hash, indexed_at = excluded.indexed_at `) + const selectVectorId = db.prepare(` + SELECT id FROM message_vector_index + WHERE message_id = ? AND vector_model = ? + `) + const upsertVec = db.prepare(` + INSERT OR REPLACE INTO message_embedding_vec(vector_id, session_key, session_id, vector_model, embedding) + VALUES (?, ?, ?, ?, ?) + `) const run = db.transaction((items: Array & { indexed_at?: number }>) => { const now = Date.now() - for (const row of items) { - const vector = localVectorProvider.buildVector(row.search_text) + for (let index = 0; index < items.length; index += 1) { + const row = items[index] + const vector = embeddings[index] upsertVector.run( row.id, row.session_id, - localVectorProvider.id, - JSON.stringify(vector), + profile.id, + float32ArrayToBuffer(vector), vector.length, + hashEmbeddingContent(row.search_text), row.indexed_at || now ) + const vectorRow = selectVectorId.get(row.id, profile.id) as { id?: number } | undefined + if (vectorRow?.id) { + upsertVec.run(vectorRow.id, vectorSessionKey(row.session_id), row.session_id, profile.id, float32ArrayToBuffer(vector)) + } } }) run(rows) } - private upsertMessages(db: Database.Database, sessionId: string, messages: Message[], options: { + private async upsertMessages(db: Database.Database, sessionId: string, messages: Message[], options: { vectorize?: boolean - } = {}): void { + } = {}): Promise { if (messages.length === 0) return const upsert = db.prepare(` @@ -978,22 +873,7 @@ export class ChatSearchIndexService { INSERT INTO message_index_fts(rowid, session_id, cursor_key, search_text, token_text) VALUES (?, ?, ?, ?, ?) `) - const upsertVector = db.prepare(` - INSERT INTO message_vector_index ( - message_id, - session_id, - vector_model, - vector_json, - feature_count, - indexed_at - ) VALUES (?, ?, ?, ?, ?, ?) - ON CONFLICT(message_id) DO UPDATE SET - session_id = excluded.session_id, - vector_model = excluded.vector_model, - vector_json = excluded.vector_json, - feature_count = excluded.feature_count, - indexed_at = excluded.indexed_at - `) + const vectorRows: Array & { indexed_at?: number }> = [] const run = db.transaction((items: Message[]) => { const indexedAt = Date.now() @@ -1024,20 +904,29 @@ export class ChatSearchIndexService { deleteFts.run(row.id) insertFts.run(row.id, sessionId, cursorKey(message), searchText, tokenText) if (options.vectorize) { - const vector = localVectorProvider.buildVector(searchText) - upsertVector.run( - row.id, - sessionId, - localVectorProvider.id, - JSON.stringify(vector), - vector.length, - indexedAt - ) + vectorRows.push({ + id: row.id, + session_id: sessionId, + search_text: searchText, + indexed_at: indexedAt + }) } } }) run(messages) + if (vectorRows.length > 0) { + try { + await this.upsertVectorRows(db, vectorRows) + } catch (error) { + this.setSessionVectorState(db, { + sessionId, + completedAt: null, + isComplete: false, + lastError: String(error) + }) + } + } } private updateSessionState(db: Database.Database, sessionId: string, newest: Message | null, isComplete: boolean): ChatSearchIndexState { @@ -1149,7 +1038,7 @@ export class ChatSearchIndexService { const messages = result.messages || [] if (messages.length === 0) break - this.upsertMessages(db, sessionId, messages, { vectorize: vectorizeDuringIndexing }) + await this.upsertMessages(db, sessionId, messages, { vectorize: vectorizeDuringIndexing }) scanned += messages.length newest = messages[messages.length - 1] || newest cursor = { @@ -1181,6 +1070,12 @@ export class ChatSearchIndexService { } db.prepare('DELETE FROM message_index_fts WHERE session_id = ?').run(sessionId) + if (this.sqliteVectorAvailable) { + db.prepare(` + DELETE FROM message_embedding_vec + WHERE vector_id IN (SELECT id FROM message_vector_index WHERE session_id = ?) + `).run(sessionId) + } db.prepare('DELETE FROM message_vector_index WHERE session_id = ?').run(sessionId) db.prepare('DELETE FROM message_index WHERE session_id = ?').run(sessionId) db.prepare('DELETE FROM session_index_state WHERE session_id = ?').run(sessionId) @@ -1193,7 +1088,7 @@ export class ChatSearchIndexService { let messages = firstPage.messages || [] let hasMore = Boolean(firstPage.hasMore) if (messages.length > 0) { - this.upsertMessages(db, sessionId, messages, { vectorize: vectorizeDuringIndexing }) + await this.upsertMessages(db, sessionId, messages, { vectorize: vectorizeDuringIndexing }) scanned += messages.length newest = messages[messages.length - 1] await this.report({ @@ -1220,7 +1115,7 @@ export class ChatSearchIndexService { messages = result.messages || [] if (messages.length === 0) break - this.upsertMessages(db, sessionId, messages, { vectorize: vectorizeDuringIndexing }) + await this.upsertMessages(db, sessionId, messages, { vectorize: vectorizeDuringIndexing }) scanned += messages.length hasMore = Boolean(result.hasMore) @@ -1284,6 +1179,21 @@ export class ChatSearchIndexService { return this.getSessionVectorIndexState(sessionId) } + clearSemanticVectorIndex(vectorModel = this.getCurrentVectorModelId()): { success: boolean; deletedCount: number; vectorModel: string } { + const db = this.getDb() + const row = db.prepare('SELECT COUNT(*) AS count FROM message_vector_index WHERE vector_model = ?').get(vectorModel) as { count?: number } + if (this.sqliteVectorAvailable) { + db.prepare('DELETE FROM message_embedding_vec WHERE vector_model = ?').run(vectorModel) + } + db.prepare('DELETE FROM message_vector_index WHERE vector_model = ?').run(vectorModel) + db.prepare('DELETE FROM session_vector_state WHERE vector_model = ?').run(vectorModel) + return { + success: true, + deletedCount: Number(row?.count || 0), + vectorModel + } + } + private async runPrepareSessionVectorIndex( sessionId: string, task: VectorTask, @@ -1301,6 +1211,35 @@ export class ChatSearchIndexService { }, onProgress) try { + if (!this.sqliteVectorAvailable) { + throw new Error(`本地语义检索不可用:${this.sqliteVectorError || 'sqlite-vec 未加载'}`) + } + + const profile = this.getCurrentVectorProfile() + const modelStatus = await localEmbeddingModelService.getModelStatus(profile.id) + if (!modelStatus.exists) { + await this.reportVectorProgress({ + sessionId, + stage: 'downloading_model', + status: 'running', + processedCount: 0, + totalCount: 0, + message: `正在下载本地语义模型:${profile.displayName}` + }, onProgress) + await localEmbeddingModelService.downloadModel(profile.id, async (progress) => { + await this.reportVectorProgress({ + sessionId, + stage: 'downloading_model', + status: 'running', + processedCount: progress.loaded || 0, + totalCount: progress.total || 0, + message: progress.percent !== undefined + ? `正在下载 ${profile.displayName}:${progress.percent}%` + : `正在下载 ${profile.displayName}` + }, onProgress) + }) + } + const searchState = await this.ensureSessionIndexed(sessionId, async (progress) => { if (task.cancelRequested) { throw new Error(VECTOR_INDEX_CANCELLED_ERROR) @@ -1326,7 +1265,7 @@ export class ChatSearchIndexService { status: 'completed', processedCount: currentState.vectorizedCount, totalCount: currentState.indexedCount, - message: `本地向量索引已就绪,共 ${currentState.vectorizedCount} 条消息` + message: `本地语义向量索引已就绪,共 ${currentState.vectorizedCount} 条消息` }, onProgress) return currentState } @@ -1386,11 +1325,11 @@ export class ChatSearchIndexService { WHERE m.session_id = ? AND v.message_id IS NULL ORDER BY m.id ASC LIMIT ? - `).all(localVectorProvider.id, sessionId, VECTOR_BATCH_SIZE) as Array & { indexed_at?: number }> + `).all(this.getCurrentVectorModelId(), sessionId, VECTOR_BATCH_SIZE) as Array & { indexed_at?: number }> if (rows.length === 0) break - this.upsertVectorRows(db, rows) + await this.upsertVectorRows(db, rows) currentState = this.getSessionVectorIndexState(sessionId) await this.reportVectorProgress({ sessionId, @@ -1417,7 +1356,7 @@ export class ChatSearchIndexService { status: 'completed', processedCount: currentState.vectorizedCount, totalCount: currentState.indexedCount, - message: `本地向量索引已完成,共 ${currentState.vectorizedCount} 条消息` + message: `本地语义向量索引已完成,共 ${currentState.vectorizedCount} 条消息` }, onProgress) return currentState @@ -1576,23 +1515,26 @@ export class ChatSearchIndexService { const db = this.getDb() const state = await this.ensureSessionIndexed(options.sessionId, options.onProgress) const vectorState = this.getSessionVectorIndexState(options.sessionId) - const queryVector = localVectorProvider.buildVector(options.query) + const profile = this.getCurrentVectorProfile() const vectorizedCount = vectorState.vectorizedCount - if (!vectorState.isVectorComplete || queryVector.length === 0) { + if (!this.sqliteVectorAvailable || !vectorState.isVectorComplete || !normalizeSearchText(options.query)) { return { hits: [], indexedCount: state.indexedCount, vectorizedCount, truncated: false, - model: localVectorProvider.id + model: profile.id } } + const queryVector = await localEmbeddingModelService.embedText(options.query, profile.id) + const queryEmbedding = float32ArrayToBuffer(queryVector) + await this.report({ stage: 'searching_index', sessionId: options.sessionId, - message: `正在进行本地向量检索:${options.query}`, + message: `正在进行本地语义检索:${options.query}`, indexedCount: state.indexedCount }, options.onProgress) @@ -1600,14 +1542,19 @@ export class ChatSearchIndexService { const endTime = toTimestampSeconds(options.endTimeMs) const senderUsername = normalizeSearchText(options.senderUsername) const direction = options.direction - const scanLimit = MAX_VECTOR_SCAN_ROWS + const scanLimit = Math.max(options.limit * VECTOR_SEARCH_OVERFETCH, options.limit + 20) const sqlFilters: string[] = [ - 'v.session_id = @sessionId', - 'v.vector_model = @vectorModel' + 'vec.embedding MATCH @queryEmbedding', + 'vec.session_key = @sessionKey', + 'vec.session_id = @sessionId', + 'vec.vector_model = @vectorModel', + 'k = @scanLimit' ] const params: Record = { sessionId: options.sessionId, - vectorModel: localVectorProvider.id, + sessionKey: vectorSessionKey(options.sessionId), + vectorModel: profile.id, + queryEmbedding, scanLimit: scanLimit + 1 } @@ -1628,18 +1575,16 @@ export class ChatSearchIndexService { } const rows = db.prepare(` - SELECT m.*, v.vector_json - FROM message_vector_index v + SELECT m.*, vec.distance + FROM message_embedding_vec vec + JOIN message_vector_index v ON v.id = vec.vector_id JOIN message_index m ON m.id = v.message_id WHERE ${sqlFilters.join(' AND ')} - AND v.feature_count > 0 - ORDER BY m.sort_seq DESC, m.create_time DESC, m.local_id DESC - LIMIT @scanLimit + ORDER BY vec.distance ASC `).all(params) as MessageVectorRow[] - const queryWeights = localVectorProvider.toWeightMap(queryVector) const scored = rows .map((row) => { - const vectorScore = localVectorProvider.dot(queryWeights, localVectorProvider.parseVector(row.vector_json)) + const vectorScore = Math.max(0, Math.min(1, 1 - Number(row.distance || 0))) return { row, vectorScore @@ -1662,7 +1607,7 @@ export class ChatSearchIndexService { indexedCount: state.indexedCount, vectorizedCount, truncated: rows.length > scanLimit, - model: localVectorProvider.id + model: profile.id } } } diff --git a/electron/services/search/embeddingModelService.ts b/electron/services/search/embeddingModelService.ts new file mode 100644 index 0000000..39db932 --- /dev/null +++ b/electron/services/search/embeddingModelService.ts @@ -0,0 +1,596 @@ +import { createHash } from 'crypto' +import { existsSync, mkdirSync, readdirSync, rmSync, statSync } from 'fs' +import { dirname, join } from 'path' +import { ConfigService } from '../config' + +export type EmbeddingModelProfileId = + | 'bge-large-zh-v1.5-int8' + | 'bge-large-zh-v1.5-fp32' + | 'bge-m3' + +export type EmbeddingDevice = 'cpu' | 'dml' + +export type EmbeddingDeviceStatus = { + currentDevice: EmbeddingDevice + effectiveDevice: EmbeddingDevice + gpuAvailable: boolean + provider: 'CPU' | 'DirectML' + info: string +} + +export type EmbeddingModelStatus = { + profileId: string + displayName: string + modelId: string + dim: number + dtype: string + sizeLabel: string + enabled: boolean + exists: boolean + modelDir: string + sizeBytes: number +} + +export type EmbeddingDownloadProgress = { + profileId: string + displayName: string + remoteHost?: string + file?: string + loaded?: number + total?: number + percent?: number + status?: string +} + +export type EmbeddingModelProfile = { + id: EmbeddingModelProfileId + displayName: string + description: string + modelId: string + remoteHosts: string[] + remotePathTemplate: string + revision: string + dim: number + maxTokens: number + maxTextChars: number + dtype: 'q8' | 'fp32' + sizeLabel: string + enabled: boolean +} + +const MODELSCOPE_HOST = 'https://www.modelscope.cn/' +const MODELSCOPE_PATH_TEMPLATE = 'models/{model}/resolve/{revision}/' +const MODELSCOPE_REVISION = 'master' + +export const DEFAULT_EMBEDDING_MODEL_PROFILE: EmbeddingModelProfileId = 'bge-large-zh-v1.5-int8' + +const EMBEDDING_MODEL_PROFILES: EmbeddingModelProfile[] = [ + { + id: 'bge-large-zh-v1.5-int8', + displayName: 'BGE Large 中文 · 推荐', + description: '默认档位,1024 维中文语义向量,优先兼顾召回质量和本地 CPU 性能。', + modelId: 'Xenova/bge-large-zh-v1.5', + remoteHosts: [MODELSCOPE_HOST], + remotePathTemplate: MODELSCOPE_PATH_TEMPLATE, + revision: MODELSCOPE_REVISION, + dim: 1024, + maxTokens: 512, + maxTextChars: 480, + dtype: 'q8', + sizeLabel: '约 330 MB', + enabled: true + }, + { + id: 'bge-large-zh-v1.5-fp32', + displayName: 'BGE Large 中文 · 高质量', + description: '同模型 FP32 推理,精度更完整,下载和内存占用更高。', + modelId: 'Xenova/bge-large-zh-v1.5', + remoteHosts: [MODELSCOPE_HOST], + remotePathTemplate: MODELSCOPE_PATH_TEMPLATE, + revision: MODELSCOPE_REVISION, + dim: 1024, + maxTokens: 512, + maxTextChars: 480, + dtype: 'fp32', + sizeLabel: '约 1.2 GB', + enabled: true + }, + { + id: 'bge-m3', + displayName: 'BGE-M3 · 多语言', + description: '更强的多语言和长文本语义召回,资源占用更高。', + modelId: 'Xenova/bge-m3', + remoteHosts: [MODELSCOPE_HOST], + remotePathTemplate: MODELSCOPE_PATH_TEMPLATE, + revision: MODELSCOPE_REVISION, + dim: 1024, + maxTokens: 8192, + maxTextChars: 2400, + dtype: 'q8', + sizeLabel: '约 600 MB', + enabled: true + } +] + +function safeProfileId(value: unknown): EmbeddingModelProfileId { + const id = String(value || '').trim() as EmbeddingModelProfileId + const profile = EMBEDDING_MODEL_PROFILES.find((item) => item.id === id && item.enabled) + return profile?.id || DEFAULT_EMBEDDING_MODEL_PROFILE +} + +function safeEmbeddingDevice(value: unknown): EmbeddingDevice { + return String(value || '').trim() === 'dml' ? 'dml' : 'cpu' +} + +function directorySize(dir: string): number { + if (!existsSync(dir)) return 0 + + let total = 0 + for (const entry of readdirSync(dir, { withFileTypes: true })) { + const path = join(dir, entry.name) + if (entry.isDirectory()) { + total += directorySize(path) + } else if (entry.isFile()) { + total += statSync(path).size + } + } + return total +} + +function hasModelFiles(dir: string): boolean { + if (!existsSync(dir)) return false + + let hasOnnx = false + let hasTokenizer = false + const visit = (current: string) => { + for (const entry of readdirSync(current, { withFileTypes: true })) { + const path = join(current, entry.name) + if (entry.isDirectory()) { + visit(path) + continue + } + + if (entry.name.endsWith('.onnx')) hasOnnx = true + if (entry.name === 'tokenizer.json' || entry.name === 'tokenizer_config.json') hasTokenizer = true + } + } + + visit(dir) + return hasOnnx && hasTokenizer +} + +function getElectronAppSafe(): any | null { + try { + const electronModule = require('electron') + const electronApp = electronModule && typeof electronModule === 'object' ? electronModule.app : null + return electronApp?.getPath ? electronApp : null + } catch { + return null + } +} + +function getEffectiveCachePathFromConfig(): string { + const configService = new ConfigService() + try { + const configured = String(configService.get('cachePath' as any) || '').trim() + if (configured) return configured + } finally { + configService.close() + } + + const electronApp = getElectronAppSafe() + if (electronApp?.getPath) { + const documentsPath = electronApp.getPath('documents') + if (process.env.VITE_DEV_SERVER_URL) { + return join(documentsPath, 'CipherTalkData') + } + + const installDir = dirname(electronApp.getPath('exe')) + const isOnCDrive = /^[cC]:/i.test(installDir) || installDir.startsWith('\\\\') + return isOnCDrive ? join(documentsPath, 'CipherTalkData') : join(installDir, 'CipherTalkData') + } + + return join(process.cwd(), 'CipherTalkData') +} + +function getDirectMLDllPath(): string | null { + if (process.platform !== 'win32') return null + + try { + const ortEntry = require.resolve('onnxruntime-node') + const arch = process.arch === 'arm64' ? 'arm64' : 'x64' + return join(dirname(ortEntry), '..', 'bin', 'napi-v6', 'win32', arch, 'DirectML.dll') + } catch { + return null + } +} + +function limitEmbeddingText(text: string, maxChars: number): string { + const value = String(text || '') + const limit = Number.isFinite(maxChars) && maxChars > 0 ? Math.floor(maxChars) : 480 + if (value.length <= limit) return value + + const headLength = Math.max(1, Math.floor(limit * 0.75)) + const tailLength = Math.max(1, limit - headLength) + return `${value.slice(0, headLength)}\n${value.slice(-tailLength)}` +} + +function tensorToVectors(output: any, expectedCount: number): Float32Array[] { + const data = output?.data + const dims = Array.isArray(output?.dims) ? output.dims.map((item: unknown) => Number(item)) : [] + if (!data || typeof data.length !== 'number' || dims.length === 0) { + throw new Error('Embedding 模型输出为空') + } + + const dim = Number(dims[dims.length - 1] || 0) + const batch = dims.length >= 2 ? Number(dims[0] || expectedCount) : expectedCount + if (!Number.isInteger(dim) || dim <= 0) { + throw new Error('Embedding 模型输出维度无效') + } + + const vectors: Float32Array[] = [] + for (let index = 0; index < batch; index += 1) { + const start = index * dim + const end = start + dim + if (end > data.length) break + vectors.push(Float32Array.from(data.slice(start, end))) + } + + if (vectors.length !== expectedCount) { + throw new Error(`Embedding 输出数量不匹配:${vectors.length}/${expectedCount}`) + } + + return vectors +} + +export function hashEmbeddingContent(value: string): string { + return createHash('sha256').update(value || '').digest('hex') +} + +export function float32ArrayToBuffer(vector: Float32Array): Buffer { + return Buffer.from(vector.buffer.slice(vector.byteOffset, vector.byteOffset + vector.byteLength)) +} + +function meanPoolNormalize(output: any, attentionMask: any, expectedCount: number): Float32Array[] { + const hidden = output?.last_hidden_state || output?.token_embeddings || output?.logits + const hiddenData = hidden?.data + const hiddenDims = Array.isArray(hidden?.dims) ? hidden.dims.map((item: unknown) => Number(item)) : [] + const maskData = attentionMask?.data + if (!hiddenData || hiddenDims.length !== 3 || !maskData) { + throw new Error('Embedding 模型输出为空') + } + + const [batchSize, seqLength, dim] = hiddenDims + if (batchSize !== expectedCount || !Number.isInteger(seqLength) || !Number.isInteger(dim) || dim <= 0) { + throw new Error(`Embedding 输出维度无效:${hiddenDims.join('x')}`) + } + + const vectors: Float32Array[] = [] + for (let batch = 0; batch < batchSize; batch += 1) { + const vector = new Float32Array(dim) + let tokenCount = 0 + for (let token = 0; token < seqLength; token += 1) { + const mask = Number(maskData[batch * seqLength + token] || 0) + if (mask <= 0) continue + tokenCount += mask + const offset = (batch * seqLength + token) * dim + for (let index = 0; index < dim; index += 1) { + vector[index] += Number(hiddenData[offset + index]) * mask + } + } + + const divisor = tokenCount > 0 ? tokenCount : 1 + let norm = 0 + for (let index = 0; index < dim; index += 1) { + vector[index] /= divisor + norm += vector[index] * vector[index] + } + norm = Math.sqrt(norm) || 1 + for (let index = 0; index < dim; index += 1) { + vector[index] /= norm + } + vectors.push(vector) + } + + return vectors +} + +export class LocalEmbeddingModelService { + private pipelines = new Map>() + private downloadTasks = new Map>() + private dmlFailureReason: string | null = null + + listProfiles(): EmbeddingModelProfile[] { + return EMBEDDING_MODEL_PROFILES.map((profile) => ({ ...profile })) + } + + getProfile(profileId?: string): EmbeddingModelProfile { + const id = safeProfileId(profileId || this.getCurrentProfileId()) + return EMBEDDING_MODEL_PROFILES.find((profile) => profile.id === id)! + } + + getCurrentProfileId(): EmbeddingModelProfileId { + const configService = new ConfigService() + try { + return safeProfileId(configService.get('aiEmbeddingModelProfile' as any)) + } finally { + configService.close() + } + } + + setCurrentProfileId(profileId: string): EmbeddingModelProfileId { + const id = safeProfileId(profileId) + const configService = new ConfigService() + try { + configService.set('aiEmbeddingModelProfile' as any, id) + return id + } finally { + configService.close() + } + } + + getCurrentDevice(): EmbeddingDevice { + const configService = new ConfigService() + try { + return safeEmbeddingDevice(configService.get('aiEmbeddingDevice' as any)) + } finally { + configService.close() + } + } + + setCurrentDevice(device: string): EmbeddingDevice { + const nextDevice = safeEmbeddingDevice(device) + const configService = new ConfigService() + try { + configService.set('aiEmbeddingDevice' as any, nextDevice) + this.dmlFailureReason = null + this.clearPipelines() + return nextDevice + } finally { + configService.close() + } + } + + getDeviceStatus(): EmbeddingDeviceStatus { + const currentDevice = this.getCurrentDevice() + const directMLDll = getDirectMLDllPath() + const directMLAvailable = process.platform === 'win32' && !!directMLDll && existsSync(directMLDll) + + if (currentDevice === 'dml' && this.dmlFailureReason) { + return { + currentDevice, + effectiveDevice: 'cpu', + gpuAvailable: directMLAvailable, + provider: 'CPU', + info: `DirectML 本次运行失败,已自动回退 CPU:${this.dmlFailureReason}` + } + } + + if (currentDevice === 'dml' && directMLAvailable) { + return { + currentDevice, + effectiveDevice: 'dml', + gpuAvailable: true, + provider: 'DirectML', + info: 'DirectML 组件已就绪,将优先使用 GPU;推理失败时自动回退 CPU' + } + } + + if (currentDevice === 'dml') { + return { + currentDevice, + effectiveDevice: 'cpu', + gpuAvailable: false, + provider: 'CPU', + info: process.platform === 'win32' + ? '缺少 DirectML 组件,将使用 CPU' + : '当前系统不支持 DirectML,将使用 CPU' + } + } + + return { + currentDevice, + effectiveDevice: 'cpu', + gpuAvailable: directMLAvailable, + provider: 'CPU', + info: directMLAvailable ? '当前使用 CPU,可切换到 DirectML GPU 实验模式' : '当前使用 CPU' + } + } + + getModelsRoot(): string { + return join(getEffectiveCachePathFromConfig(), 'models', 'embeddings') + } + + getProfileDir(profileId?: string): string { + return join(this.getModelsRoot(), this.getProfile(profileId).id) + } + + async getModelStatus(profileId?: string): Promise { + const profile = this.getProfile(profileId) + const modelDir = this.getProfileDir(profile.id) + const exists = hasModelFiles(modelDir) + return { + profileId: profile.id, + displayName: profile.displayName, + modelId: profile.modelId, + dim: profile.dim, + dtype: profile.dtype, + sizeLabel: profile.sizeLabel, + enabled: profile.enabled, + exists, + modelDir, + sizeBytes: directorySize(modelDir) + } + } + + async downloadModel( + profileId?: string, + onProgress?: (progress: EmbeddingDownloadProgress) => void + ): Promise { + const profile = this.getProfile(profileId) + const existing = this.downloadTasks.get(profile.id) + if (existing) return existing + + const task = (async () => { + mkdirSync(this.getProfileDir(profile.id), { recursive: true }) + await this.downloadPipelineWithFallback(profile, onProgress) + return this.getModelStatus(profile.id) + })() + + this.downloadTasks.set(profile.id, task) + try { + return await task + } finally { + this.downloadTasks.delete(profile.id) + } + } + + async clearModel(profileId?: string): Promise { + const profile = this.getProfile(profileId) + this.clearPipelines(profile.id) + rmSync(this.getProfileDir(profile.id), { recursive: true, force: true }) + return this.getModelStatus(profile.id) + } + + async ensureModelReady(profileId?: string): Promise { + const status = await this.getModelStatus(profileId) + if (!status.exists) { + throw new Error(`本地语义模型未下载:${status.displayName}`) + } + return status + } + + async embedTexts(texts: string[], profileId?: string): Promise { + const profile = this.getProfile(profileId) + const cleaned = texts.map((text) => limitEmbeddingText(String(text || ''), profile.maxTextChars)) + await this.ensureModelReady(profile.id) + const deviceStatus = this.getDeviceStatus() + + if (deviceStatus.effectiveDevice === 'dml') { + try { + return await this.runEmbedding(profile, cleaned, 'dml') + } catch (error) { + console.warn('[Embedding] DirectML 推理失败,回退 CPU:', error) + this.dmlFailureReason = String(error instanceof Error ? error.message : error) + this.clearPipelines(profile.id, 'dml') + } + } + + return this.runEmbedding(profile, cleaned, 'cpu') + } + + async embedText(text: string, profileId?: string): Promise { + const [vector] = await this.embedTexts([text], profileId) + return vector + } + + private async runEmbedding( + profile: EmbeddingModelProfile, + texts: string[], + device: EmbeddingDevice + ): Promise { + const runtime = await this.getPipeline(profile, true, device) + const modelInputs = runtime.tokenizer(texts, { + padding: true, + truncation: true, + max_length: profile.maxTokens + }) + const output = await runtime.model(modelInputs) + return meanPoolNormalize(output, modelInputs.attention_mask, texts.length) + } + + private async getPipeline( + profile: EmbeddingModelProfile, + localOnly: boolean, + device: EmbeddingDevice = 'cpu', + remoteHost?: string, + progressCallback?: (event: any) => void + ): Promise<{ tokenizer: any; model: any }> { + const key = `${profile.id}:${device}:${localOnly ? 'local' : remoteHost || 'remote'}` + const existing = this.pipelines.get(key) + if (existing) return existing + + const promise = (async () => { + const transformers = await import('@huggingface/transformers') + transformers.env.allowLocalModels = true + transformers.env.allowRemoteModels = !localOnly + transformers.env.cacheDir = this.getProfileDir(profile.id) + if (remoteHost) { + transformers.env.remoteHost = remoteHost + transformers.env.remotePathTemplate = profile.remotePathTemplate + } + + const commonOptions = { + cache_dir: this.getProfileDir(profile.id), + local_files_only: localOnly, + revision: profile.revision, + progress_callback: progressCallback + } + const tokenizer = await transformers.AutoTokenizer.from_pretrained(profile.modelId, commonOptions as any) + const model = await transformers.AutoModel.from_pretrained(profile.modelId, { + ...commonOptions, + device, + dtype: profile.dtype + } as any) + return { tokenizer, model } + })() + + this.pipelines.set(key, promise) + try { + return await promise + } catch (error) { + this.pipelines.delete(key) + throw error + } + } + + private clearPipelines(profileId?: string, device?: EmbeddingDevice): void { + for (const key of Array.from(this.pipelines.keys())) { + const matchesProfile = !profileId || key.startsWith(`${profileId}:`) + const matchesDevice = !device || key.includes(`:${device}:`) + if (matchesProfile && matchesDevice) { + this.pipelines.delete(key) + } + } + } + + private async downloadPipelineWithFallback( + profile: EmbeddingModelProfile, + onProgress?: (progress: EmbeddingDownloadProgress) => void + ): Promise { + const errors: string[] = [] + + for (const remoteHost of profile.remoteHosts) { + try { + onProgress?.({ + profileId: profile.id, + displayName: profile.displayName, + remoteHost, + status: 'initiate' + }) + + await this.getPipeline(profile, false, 'cpu', remoteHost, (event) => { + const loaded = Number(event?.loaded || 0) + const total = Number(event?.total || 0) + onProgress?.({ + profileId: profile.id, + displayName: profile.displayName, + remoteHost, + file: String(event?.file || event?.name || ''), + loaded: Number.isFinite(loaded) && loaded > 0 ? loaded : undefined, + total: Number.isFinite(total) && total > 0 ? total : undefined, + percent: total > 0 ? Math.min(100, Math.round((loaded / total) * 100)) : undefined, + status: String(event?.status || '') + }) + }) + return + } catch (error) { + errors.push(`${remoteHost}: ${String(error)}`) + } + } + + throw new Error(`语义模型下载失败。已尝试 ModelScope/魔塔社区:${profile.remoteHosts.join('、')}。请检查网络/代理或稍后重试。${errors.length ? ` 原始错误:${errors.join(' | ')}` : ''}`) + } +} + +export const localEmbeddingModelService = new LocalEmbeddingModelService() diff --git a/package-lock.json b/package-lock.json index ebc8fe2..72f1b00 100644 --- a/package-lock.json +++ b/package-lock.json @@ -12,6 +12,7 @@ "dependencies": { "@emotion/react": "^11.14.0", "@emotion/styled": "^11.14.1", + "@huggingface/transformers": "^4.2.0", "@lobehub/fluent-emoji": "^4.1.0", "@lobehub/icons": "^5.3.0", "@lobehub/ui": "^5.6.5", @@ -50,6 +51,7 @@ "react-window": "^2.2.5", "sherpa-onnx-node": "^1.12.23", "silk-wasm": "^3.7.1", + "sqlite-vec": "^0.1.9", "wechat-emojis": "^1.0.2", "xlsx": "^0.18.5", "zod": "^4.1.12", @@ -1024,7 +1026,6 @@ "version": "1.8.1", "resolved": "https://registry.npmmirror.com/@emnapi/runtime/-/runtime-1.8.1.tgz", "integrity": "sha512-mehfKSMWjjNol8659Z8KxEMrdSJDDot5SXMq00dM8BN4o+CLNXQ0xH2V7EchNHV4RmbZLmmPdEaXZc5H2FXmDg==", - "dev": true, "license": "MIT", "optional": true, "dependencies": { @@ -1035,7 +1036,6 @@ "version": "2.8.1", "resolved": "https://registry.npmmirror.com/tslib/-/tslib-2.8.1.tgz", "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", - "dev": true, "license": "0BSD", "optional": true }, @@ -1755,6 +1755,34 @@ "hono": "^4" } }, + "node_modules/@huggingface/jinja": { + "version": "0.5.8", + "resolved": "https://registry.npmmirror.com/@huggingface/jinja/-/jinja-0.5.8.tgz", + "integrity": "sha512-ZdElB7DPS7QQS8ZnFc5RPPtkg+eN11z8AmIZWAyes6pSbwXqiFB/POVevvm01begdSX1ho9Gxln/F6qlQMsuaA==", + "license": "MIT", + "engines": { + "node": ">=18" + } + }, + "node_modules/@huggingface/tokenizers": { + "version": "0.1.3", + "resolved": "https://registry.npmmirror.com/@huggingface/tokenizers/-/tokenizers-0.1.3.tgz", + "integrity": "sha512-8rF/RRT10u+kn7YuUbUg0OF30K8rjTc78aHpxT+qJ1uWSqxT1MHi8+9ltwYfkFYJzT/oS+qw3JVfHtNMGAdqyA==", + "license": "Apache-2.0" + }, + "node_modules/@huggingface/transformers": { + "version": "4.2.0", + "resolved": "https://registry.npmmirror.com/@huggingface/transformers/-/transformers-4.2.0.tgz", + "integrity": "sha512-8BRCoBMH0XsWaEIamuR0LrJGAfftgHAfb2Vrffy0VKlSAE/MnUJ5/h/zTfEP3fDIft+nk7TqB8xXEyABGitBjQ==", + "license": "Apache-2.0", + "dependencies": { + "@huggingface/jinja": "^0.5.6", + "@huggingface/tokenizers": "^0.1.3", + "onnxruntime-node": "1.24.3", + "onnxruntime-web": "1.26.0-dev.20260416-b7804b056c", + "sharp": "^0.34.5" + } + }, "node_modules/@iconify/types": { "version": "2.0.0", "resolved": "https://registry.npmmirror.com/@iconify/types/-/types-2.0.0.tgz", @@ -1776,7 +1804,6 @@ "version": "1.0.0", "resolved": "https://registry.npmmirror.com/@img/colour/-/colour-1.0.0.tgz", "integrity": "sha512-A5P/LfWGFSl6nsckYtjw9da+19jB8hkJ6ACTGcDfEJ0aE+l2n2El7dsVM7UVHZQ9s2lmYMWlrS21YLy2IR1LUw==", - "dev": true, "license": "MIT", "engines": { "node": ">=18" @@ -1789,7 +1816,6 @@ "cpu": [ "arm64" ], - "dev": true, "license": "Apache-2.0", "optional": true, "os": [ @@ -1812,7 +1838,6 @@ "cpu": [ "x64" ], - "dev": true, "license": "Apache-2.0", "optional": true, "os": [ @@ -1835,7 +1860,6 @@ "cpu": [ "arm64" ], - "dev": true, "license": "LGPL-3.0-or-later", "optional": true, "os": [ @@ -1852,7 +1876,6 @@ "cpu": [ "x64" ], - "dev": true, "license": "LGPL-3.0-or-later", "optional": true, "os": [ @@ -1869,7 +1892,6 @@ "cpu": [ "arm" ], - "dev": true, "license": "LGPL-3.0-or-later", "optional": true, "os": [ @@ -1886,7 +1908,6 @@ "cpu": [ "arm64" ], - "dev": true, "license": "LGPL-3.0-or-later", "optional": true, "os": [ @@ -1903,7 +1924,6 @@ "cpu": [ "ppc64" ], - "dev": true, "license": "LGPL-3.0-or-later", "optional": true, "os": [ @@ -1920,7 +1940,6 @@ "cpu": [ "riscv64" ], - "dev": true, "license": "LGPL-3.0-or-later", "optional": true, "os": [ @@ -1937,7 +1956,6 @@ "cpu": [ "s390x" ], - "dev": true, "license": "LGPL-3.0-or-later", "optional": true, "os": [ @@ -1954,7 +1972,6 @@ "cpu": [ "x64" ], - "dev": true, "license": "LGPL-3.0-or-later", "optional": true, "os": [ @@ -1971,7 +1988,6 @@ "cpu": [ "arm64" ], - "dev": true, "license": "LGPL-3.0-or-later", "optional": true, "os": [ @@ -1988,7 +2004,6 @@ "cpu": [ "x64" ], - "dev": true, "license": "LGPL-3.0-or-later", "optional": true, "os": [ @@ -2005,7 +2020,6 @@ "cpu": [ "arm" ], - "dev": true, "license": "Apache-2.0", "optional": true, "os": [ @@ -2028,7 +2042,6 @@ "cpu": [ "arm64" ], - "dev": true, "license": "Apache-2.0", "optional": true, "os": [ @@ -2051,7 +2064,6 @@ "cpu": [ "ppc64" ], - "dev": true, "license": "Apache-2.0", "optional": true, "os": [ @@ -2074,7 +2086,6 @@ "cpu": [ "riscv64" ], - "dev": true, "license": "Apache-2.0", "optional": true, "os": [ @@ -2097,7 +2108,6 @@ "cpu": [ "s390x" ], - "dev": true, "license": "Apache-2.0", "optional": true, "os": [ @@ -2120,7 +2130,6 @@ "cpu": [ "x64" ], - "dev": true, "license": "Apache-2.0", "optional": true, "os": [ @@ -2143,7 +2152,6 @@ "cpu": [ "arm64" ], - "dev": true, "license": "Apache-2.0", "optional": true, "os": [ @@ -2166,7 +2174,6 @@ "cpu": [ "x64" ], - "dev": true, "license": "Apache-2.0", "optional": true, "os": [ @@ -2189,7 +2196,6 @@ "cpu": [ "wasm32" ], - "dev": true, "license": "Apache-2.0 AND LGPL-3.0-or-later AND MIT", "optional": true, "dependencies": { @@ -2209,7 +2215,6 @@ "cpu": [ "arm64" ], - "dev": true, "license": "Apache-2.0 AND LGPL-3.0-or-later", "optional": true, "os": [ @@ -2229,7 +2234,6 @@ "cpu": [ "ia32" ], - "dev": true, "license": "Apache-2.0 AND LGPL-3.0-or-later", "optional": true, "os": [ @@ -2249,7 +2253,6 @@ "cpu": [ "x64" ], - "dev": true, "license": "Apache-2.0 AND LGPL-3.0-or-later", "optional": true, "os": [ @@ -3431,6 +3434,70 @@ "object-assign": "^4.1.1" } }, + "node_modules/@protobufjs/aspromise": { + "version": "1.1.2", + "resolved": "https://registry.npmmirror.com/@protobufjs/aspromise/-/aspromise-1.1.2.tgz", + "integrity": "sha512-j+gKExEuLmKwvz3OgROXtrJ2UG2x8Ch2YZUxahh+s1F2HZ+wAceUNLkvy6zKCPVRkU++ZWQrdxsUeQXmcg4uoQ==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/base64": { + "version": "1.1.2", + "resolved": "https://registry.npmmirror.com/@protobufjs/base64/-/base64-1.1.2.tgz", + "integrity": "sha512-AZkcAA5vnN/v4PDqKyMR5lx7hZttPDgClv83E//FMNhR2TMcLUhfRUBHCmSl0oi9zMgDDqRUJkSxO3wm85+XLg==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/codegen": { + "version": "2.0.4", + "resolved": "https://registry.npmmirror.com/@protobufjs/codegen/-/codegen-2.0.4.tgz", + "integrity": "sha512-YyFaikqM5sH0ziFZCN3xDC7zeGaB/d0IUb9CATugHWbd1FRFwWwt4ld4OYMPWu5a3Xe01mGAULCdqhMlPl29Jg==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/eventemitter": { + "version": "1.1.0", + "resolved": "https://registry.npmmirror.com/@protobufjs/eventemitter/-/eventemitter-1.1.0.tgz", + "integrity": "sha512-j9ednRT81vYJ9OfVuXG6ERSTdEL1xVsNgqpkxMsbIabzSo3goCjDIveeGv5d03om39ML71RdmrGNjG5SReBP/Q==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/fetch": { + "version": "1.1.0", + "resolved": "https://registry.npmmirror.com/@protobufjs/fetch/-/fetch-1.1.0.tgz", + "integrity": "sha512-lljVXpqXebpsijW71PZaCYeIcE5on1w5DlQy5WH6GLbFryLUrBD4932W/E2BSpfRJWseIL4v/KPgBFxDOIdKpQ==", + "license": "BSD-3-Clause", + "dependencies": { + "@protobufjs/aspromise": "^1.1.1", + "@protobufjs/inquire": "^1.1.0" + } + }, + "node_modules/@protobufjs/float": { + "version": "1.0.2", + "resolved": "https://registry.npmmirror.com/@protobufjs/float/-/float-1.0.2.tgz", + "integrity": "sha512-Ddb+kVXlXst9d+R9PfTIxh1EdNkgoRe5tOX6t01f1lYWOvJnSPDBlG241QLzcyPdoNTsblLUdujGSE4RzrTZGQ==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/inquire": { + "version": "1.1.0", + "resolved": "https://registry.npmmirror.com/@protobufjs/inquire/-/inquire-1.1.0.tgz", + "integrity": "sha512-kdSefcPdruJiFMVSbn801t4vFK7KB/5gd2fYvrxhuJYg8ILrmn9SKSX2tZdV6V+ksulWqS7aXjBcRXl3wHoD9Q==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/path": { + "version": "1.1.2", + "resolved": "https://registry.npmmirror.com/@protobufjs/path/-/path-1.1.2.tgz", + "integrity": "sha512-6JOcJ5Tm08dOHAbdR3GrvP+yUUfkjG5ePsHYczMFLq3ZmMkAD98cDgcT2iA1lJ9NVwFd4tH/iSSoe44YWkltEA==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/pool": { + "version": "1.1.0", + "resolved": "https://registry.npmmirror.com/@protobufjs/pool/-/pool-1.1.0.tgz", + "integrity": "sha512-0kELaGSIDBKvcgS4zkjz1PeddatrjYcmMWOlAuAPwAeccUrPHdUqo/J6LiymHHEiJT5NrF1UVwxY14f+fy4WQw==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/utf8": { + "version": "1.1.0", + "resolved": "https://registry.npmmirror.com/@protobufjs/utf8/-/utf8-1.1.0.tgz", + "integrity": "sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw==", + "license": "BSD-3-Clause" + }, "node_modules/@radix-ui/primitive": { "version": "1.1.3", "resolved": "https://registry.npmmirror.com/@radix-ui/primitive/-/primitive-1.1.3.tgz", @@ -10417,6 +10484,12 @@ "integrity": "sha512-NKfW6bec6GfKc0SGx1e07QZY9PE99u0Bft/0rzSD5k3sO/vwkVUpDUKVm5Gpp5Ue3YfShPFTX2070tDs5kB9Ng==", "license": "MIT" }, + "node_modules/flatbuffers": { + "version": "25.9.23", + "resolved": "https://registry.npmmirror.com/flatbuffers/-/flatbuffers-25.9.23.tgz", + "integrity": "sha512-MI1qs7Lo4Syw0EOzUl0xjs2lsoeqFku44KpngfIduHBYvzm8h2+7K8YMQh1JtVVVrUvhLpNwqVi4DERegUJhPQ==", + "license": "Apache-2.0" + }, "node_modules/for-in": { "version": "1.0.2", "resolved": "https://registry.npmmirror.com/for-in/-/for-in-1.0.2.tgz", @@ -10882,6 +10955,12 @@ "integrity": "sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==", "license": "ISC" }, + "node_modules/guid-typescript": { + "version": "1.0.9", + "resolved": "https://registry.npmmirror.com/guid-typescript/-/guid-typescript-1.0.9.tgz", + "integrity": "sha512-Y8T4vYhEfwJOTbouREvG+3XDsjr8E3kIr7uf+JZ0BYloFsttiHU0WfvANVsR7TxNUJa/WpCnw/Ino/p+DeBhBQ==", + "license": "ISC" + }, "node_modules/hachure-fill": { "version": "0.5.2", "resolved": "https://registry.npmmirror.com/hachure-fill/-/hachure-fill-0.5.2.tgz", @@ -12208,6 +12287,12 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/long": { + "version": "5.3.2", + "resolved": "https://registry.npmmirror.com/long/-/long-5.3.2.tgz", + "integrity": "sha512-mNAgZ1GmyNhD7AuqnTG3/VQ26o760+ZYBPKjPvugO8+nLbYfX6TVpJPseBvopbdY+qpZ/lKUnmEc1LeZYS3QAA==", + "license": "Apache-2.0" + }, "node_modules/longest-streak": { "version": "3.1.0", "resolved": "https://registry.npmmirror.com/longest-streak/-/longest-streak-3.1.0.tgz", @@ -14207,15 +14292,15 @@ } }, "node_modules/onnxruntime-common": { - "version": "1.24.1", - "resolved": "https://registry.npmmirror.com/onnxruntime-common/-/onnxruntime-common-1.24.1.tgz", - "integrity": "sha512-UnV15u4p4XxoIV+jFP4hXPsW93s3QrwLSpi20HUDYHoTfI4z4sjzex3L4XDOxGGZJ/M/catrwAG2go958UQq0w==", + "version": "1.24.3", + "resolved": "https://registry.npmmirror.com/onnxruntime-common/-/onnxruntime-common-1.24.3.tgz", + "integrity": "sha512-GeuPZO6U/LBJXvwdaqHbuUmoXiEdeCjWi/EG7Y1HNnDwJYuk6WUbNXpF6luSUY8yASul3cmUlLGrCCL1ZgVXqA==", "license": "MIT" }, "node_modules/onnxruntime-node": { - "version": "1.24.1", - "resolved": "https://registry.npmmirror.com/onnxruntime-node/-/onnxruntime-node-1.24.1.tgz", - "integrity": "sha512-Ex/oUXKdhDoxvlNxBT3oYtW0MH88yYpPlXQeVQUXpcJQmN24usd/8RCoPLN5kCHwDsiZ+nqsnjciyFRl423dQw==", + "version": "1.24.3", + "resolved": "https://registry.npmmirror.com/onnxruntime-node/-/onnxruntime-node-1.24.3.tgz", + "integrity": "sha512-JH7+czbc8ALA819vlTgcV+Q214/+VjGeBHDjX81+ZCD0PCVCIFGFNtT0V4sXG/1JXypKPgScQcB3ij/hk3YnTg==", "hasInstallScript": true, "license": "MIT", "os": [ @@ -14226,9 +14311,29 @@ "dependencies": { "adm-zip": "^0.5.16", "global-agent": "^3.0.0", - "onnxruntime-common": "1.24.1" + "onnxruntime-common": "1.24.3" } }, + "node_modules/onnxruntime-web": { + "version": "1.26.0-dev.20260416-b7804b056c", + "resolved": "https://registry.npmmirror.com/onnxruntime-web/-/onnxruntime-web-1.26.0-dev.20260416-b7804b056c.tgz", + "integrity": "sha512-MD6Ss4GSpQBo6zqoJzyT9LRbKYs7x/JVN23FT24EcEvlqF4VuzPOeH6X38orZPKHQDbprn7K+SBpu0/mj2CQiw==", + "license": "MIT", + "dependencies": { + "flatbuffers": "^25.1.24", + "guid-typescript": "^1.0.9", + "long": "^5.2.3", + "onnxruntime-common": "1.24.0-dev.20251116-b39e144322", + "platform": "^1.3.6", + "protobufjs": "^7.2.4" + } + }, + "node_modules/onnxruntime-web/node_modules/onnxruntime-common": { + "version": "1.24.0-dev.20251116-b39e144322", + "resolved": "https://registry.npmmirror.com/onnxruntime-common/-/onnxruntime-common-1.24.0-dev.20251116-b39e144322.tgz", + "integrity": "sha512-BOoomdHYmNRL5r4iQ4bMvsl2t0/hzVQ3OM3PHD0gxeXu1PmggqBv3puZicEUVOA3AtHHYmqZtjMj9FOfGrATTw==", + "license": "MIT" + }, "node_modules/openai": { "version": "4.104.0", "resolved": "https://registry.npmmirror.com/openai/-/openai-4.104.0.tgz", @@ -14578,6 +14683,12 @@ "pathe": "^2.0.1" } }, + "node_modules/platform": { + "version": "1.3.6", + "resolved": "https://registry.npmmirror.com/platform/-/platform-1.3.6.tgz", + "integrity": "sha512-fnWVljUchTro6RiCFvCXBbNhJc2NijN7oIQxbwsyL0buWJPG85v81ehlHI9fXrJsMNgTofEoWIQeClKpgxFLrg==", + "license": "MIT" + }, "node_modules/plist": { "version": "3.1.0", "resolved": "https://registry.npmmirror.com/plist/-/plist-3.1.0.tgz", @@ -14771,6 +14882,30 @@ "url": "https://github.com/sponsors/wooorm" } }, + "node_modules/protobufjs": { + "version": "7.5.5", + "resolved": "https://registry.npmmirror.com/protobufjs/-/protobufjs-7.5.5.tgz", + "integrity": "sha512-3wY1AxV+VBNW8Yypfd1yQY9pXnqTAN+KwQxL8iYm3/BjKYMNg4i0owhEe26PWDOMaIrzeeF98Lqd5NGz4omiIg==", + "hasInstallScript": true, + "license": "BSD-3-Clause", + "dependencies": { + "@protobufjs/aspromise": "^1.1.2", + "@protobufjs/base64": "^1.1.2", + "@protobufjs/codegen": "^2.0.4", + "@protobufjs/eventemitter": "^1.1.0", + "@protobufjs/fetch": "^1.1.0", + "@protobufjs/float": "^1.0.2", + "@protobufjs/inquire": "^1.1.0", + "@protobufjs/path": "^1.1.2", + "@protobufjs/pool": "^1.1.0", + "@protobufjs/utf8": "^1.1.0", + "@types/node": ">=13.7.0", + "long": "^5.0.0" + }, + "engines": { + "node": ">=12.0.0" + } + }, "node_modules/proxy-addr": { "version": "2.0.7", "resolved": "https://registry.npmmirror.com/proxy-addr/-/proxy-addr-2.0.7.tgz", @@ -16292,7 +16427,6 @@ "version": "0.34.5", "resolved": "https://registry.npmmirror.com/sharp/-/sharp-0.34.5.tgz", "integrity": "sha512-Ou9I5Ft9WNcCbXrU9cMgPBcCK8LiwLqcbywW3t4oDV37n1pzpuNLsYiAV8eODnjbtQlSDwZ2cUEeQz4E54Hltg==", - "dev": true, "hasInstallScript": true, "license": "Apache-2.0", "dependencies": { @@ -16808,6 +16942,84 @@ "integrity": "sha512-Oo+0REFV59/rz3gfJNKQiBlwfHaSESl1pcGyABQsnnIfWOFt6JNj5gCog2U6MLZ//IGYD+nA8nI+mTShREReaA==", "license": "BSD-3-Clause" }, + "node_modules/sqlite-vec": { + "version": "0.1.9", + "resolved": "https://registry.npmmirror.com/sqlite-vec/-/sqlite-vec-0.1.9.tgz", + "integrity": "sha512-L7XJWRIBNvR9O5+vh1FQ+IGkh/3D2AzVksW5gdtk28m78Hy8skFD0pqReKH1Yp0/BUKRGcffgKvyO/EON5JXpA==", + "license": "MIT OR Apache", + "optionalDependencies": { + "sqlite-vec-darwin-arm64": "0.1.9", + "sqlite-vec-darwin-x64": "0.1.9", + "sqlite-vec-linux-arm64": "0.1.9", + "sqlite-vec-linux-x64": "0.1.9", + "sqlite-vec-windows-x64": "0.1.9" + } + }, + "node_modules/sqlite-vec-darwin-arm64": { + "version": "0.1.9", + "resolved": "https://registry.npmmirror.com/sqlite-vec-darwin-arm64/-/sqlite-vec-darwin-arm64-0.1.9.tgz", + "integrity": "sha512-jSsZpE42OfBkGL/ItyJTVCUwl6o6Ka3U5rc4j+UBDIQzC1ulSSKMEhQLthsOnF/MdAf1MuAkYhkdKmmcjaIZQg==", + "cpu": [ + "arm64" + ], + "license": "MIT OR Apache", + "optional": true, + "os": [ + "darwin" + ] + }, + "node_modules/sqlite-vec-darwin-x64": { + "version": "0.1.9", + "resolved": "https://registry.npmmirror.com/sqlite-vec-darwin-x64/-/sqlite-vec-darwin-x64-0.1.9.tgz", + "integrity": "sha512-KDlVyqQT7pnOhU1ymB9gs7dMbSoVmKHitT+k1/xkjarcX8bBqPxWrGlK/R+C5WmWkfvWwyq5FfXfiBYCBs6PlA==", + "cpu": [ + "x64" + ], + "license": "MIT OR Apache", + "optional": true, + "os": [ + "darwin" + ] + }, + "node_modules/sqlite-vec-linux-arm64": { + "version": "0.1.9", + "resolved": "https://registry.npmmirror.com/sqlite-vec-linux-arm64/-/sqlite-vec-linux-arm64-0.1.9.tgz", + "integrity": "sha512-5wXVJ9c9kR4CHm/wVqXb/R+XUHTdpZ4nWbPHlS+gc9qQFVHs92Km4bPnCKX4rtcPMzvNis+SIzMJR1SCEwpuUw==", + "cpu": [ + "arm64" + ], + "license": "MIT OR Apache", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/sqlite-vec-linux-x64": { + "version": "0.1.9", + "resolved": "https://registry.npmmirror.com/sqlite-vec-linux-x64/-/sqlite-vec-linux-x64-0.1.9.tgz", + "integrity": "sha512-w3tCH8xK2finW8fQJ/m8uqKodXUZ9KAuAar2UIhz4BHILfpE0WM/MTGCRfa7RjYbrYim5Luk3guvMOGI7T7JQA==", + "cpu": [ + "x64" + ], + "license": "MIT OR Apache", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/sqlite-vec-windows-x64": { + "version": "0.1.9", + "resolved": "https://registry.npmmirror.com/sqlite-vec-windows-x64/-/sqlite-vec-windows-x64-0.1.9.tgz", + "integrity": "sha512-y3gEIyy/17bq2QFPQOWLE68TYWcRZkBQVA2XLrTPHNTOp55xJi/BBBmOm40tVMDMjtP+Elpk6UBUXdaq+46b0Q==", + "cpu": [ + "x64" + ], + "license": "MIT OR Apache", + "optional": true, + "os": [ + "win32" + ] + }, "node_modules/ssf": { "version": "0.11.2", "resolved": "https://registry.npmmirror.com/ssf/-/ssf-0.11.2.tgz", diff --git a/package.json b/package.json index 56e191d..421e03c 100644 --- a/package.json +++ b/package.json @@ -38,6 +38,7 @@ "dependencies": { "@emotion/react": "^11.14.0", "@emotion/styled": "^11.14.1", + "@huggingface/transformers": "^4.2.0", "@lobehub/fluent-emoji": "^4.1.0", "@lobehub/icons": "^5.3.0", "@lobehub/ui": "^5.6.5", @@ -76,6 +77,7 @@ "react-window": "^2.2.5", "sherpa-onnx-node": "^1.12.23", "silk-wasm": "^3.7.1", + "sqlite-vec": "^0.1.9", "wechat-emojis": "^1.0.2", "xlsx": "^0.18.5", "zod": "^4.1.12", @@ -211,12 +213,14 @@ "!node_modules/**/*.lib", "!node_modules/**/build/!(Release)/**/*", "!node_modules/**/deps/**/*", + "node_modules/sqlite-vec*/**/*", "node_modules/koffi/build/**/*" ], "asarUnpack": [ "node_modules/ffmpeg-static/**/*", "node_modules/silk-wasm/**/*", "node_modules/sherpa-onnx-node/**/*", + "node_modules/sqlite-vec*/**/*", "node_modules/koffi/**/*", "dist-electron/workers/**/*", "resources/wedecrypt/*.node", diff --git a/scripts/electron-builder.config.cjs b/scripts/electron-builder.config.cjs index 313f6ee..9862f73 100644 --- a/scripts/electron-builder.config.cjs +++ b/scripts/electron-builder.config.cjs @@ -91,6 +91,7 @@ function getFiles(buildTarget) { [ ...commonFiles, '!node_modules/onnxruntime-node/bin/**/win32/arm64/**/*', + 'node_modules/sqlite-vec*/**/*', 'node_modules/koffi/build/koffi/win32_x64/**/*' ] ) @@ -108,6 +109,9 @@ function getFiles(buildTarget) { '!node_modules/onnxruntime-node/bin/**/linux/**/*', '!node_modules/onnxruntime-node/bin/**/win32/**/*', 'node_modules/onnxruntime-node/bin/**/darwin/**/*', + '!node_modules/sqlite-vec-windows-*/**/*', + '!node_modules/sqlite-vec-linux-*/**/*', + 'node_modules/sqlite-vec-darwin-*/**/*', '!node_modules/sherpa-onnx-win-*/**/*', '!node_modules/sherpa-onnx-linux-*/**/*', 'node_modules/sherpa-onnx-darwin-*/**/*', @@ -128,14 +132,14 @@ function getAsarUnpack(buildTarget) { if (buildTarget === 'win') { return appendUnique( withoutItems(baseAsarUnpack, ['node_modules/koffi/**/*']), - ['node_modules/koffi/build/koffi/win32_x64/**/*'] + ['node_modules/sqlite-vec*/**/*', 'node_modules/koffi/build/koffi/win32_x64/**/*'] ) } if (buildTarget === 'mac') { return appendUnique( withoutItems(baseAsarUnpack, ['node_modules/koffi/**/*']), - ['node_modules/koffi/build/koffi/darwin_*/**/*'] + ['node_modules/sqlite-vec*/**/*', 'node_modules/koffi/build/koffi/darwin_*/**/*'] ) } diff --git a/src/components/ai/AISummarySettings.scss b/src/components/ai/AISummarySettings.scss index 8bcb67c..016a9de 100644 --- a/src/components/ai/AISummarySettings.scss +++ b/src/components/ai/AISummarySettings.scss @@ -484,6 +484,75 @@ } } + .semantic-vector-subtitle { + margin-top: 1rem; + margin-bottom: 0.5rem; + font-size: 0.95rem; + font-weight: 500; + } + + .semantic-model-grid { + margin-bottom: 16px; + } + + .semantic-device-grid { + grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); + margin-bottom: 12px; + } + + .semantic-device-status { + padding: 14px 16px; + margin-bottom: 16px; + background: var(--bg-tertiary); + border-radius: 12px; + + .status-indicator { + display: flex; + align-items: center; + gap: 8px; + font-size: 14px; + font-weight: 500; + + &.ready { + color: #22c55e; + } + + &.missing { + color: #f59e0b; + } + } + + p { + margin: 8px 0 0; + color: var(--text-secondary); + font-size: 13px; + line-height: 1.5; + } + } + + .semantic-vector-model-status { + .model-size span { + color: var(--text-tertiary); + } + + .model-path { + margin: 0; + color: var(--text-tertiary); + font-size: 12px; + line-height: 1.5; + word-break: break-all; + } + } + + .semantic-download-progress { + margin-top: 0; + } + + .semantic-vector-actions { + margin-top: 1rem; + margin-bottom: 24px; + } + // 3. 详情卡片 .detail-options { display: grid; @@ -1443,4 +1512,4 @@ } } } -} \ No newline at end of file +} diff --git a/src/components/ai/AISummarySettings.tsx b/src/components/ai/AISummarySettings.tsx index d50fc6f..7a0e4cf 100644 --- a/src/components/ai/AISummarySettings.tsx +++ b/src/components/ai/AISummarySettings.tsx @@ -1,6 +1,6 @@ import { useState, useEffect, useRef } from 'react' -import { Eye, EyeOff, Sparkles, Check, ChevronDown, ChevronUp, Zap, Star, FileText, HelpCircle, X, Plus, Settings2 } from 'lucide-react' -import { getAIProviders, type AIProviderInfo } from '../../types/ai' +import { Eye, EyeOff, Sparkles, Check, ChevronDown, ChevronUp, Zap, Star, FileText, HelpCircle, X, Plus, Settings2, Download, Trash2, Database, CheckCircle, AlertCircle, RefreshCw, Layers, Cpu } from 'lucide-react' +import { getAIProviders, type AIProviderInfo, type EmbeddingDevice, type EmbeddingDeviceStatus, type EmbeddingModelDownloadProgress, type EmbeddingModelProfile, type EmbeddingModelStatus } from '../../types/ai' import { marked } from 'marked' import DOMPurify from 'dompurify' import AIProviderLogo from './AIProviderLogo' @@ -130,6 +130,19 @@ function normalizeProviderModel(providerId: string, modelName: string) { return DEEPSEEK_LEGACY_MODEL_MAP[modelName] || modelName } +function formatBytes(bytes?: number): string { + const value = Number(bytes || 0) + if (value <= 0) return '0 B' + const units = ['B', 'KB', 'MB', 'GB'] + let size = value + let unitIndex = 0 + while (size >= 1024 && unitIndex < units.length - 1) { + size /= 1024 + unitIndex += 1 + } + return `${size.toFixed(unitIndex === 0 ? 0 : 1)} ${units[unitIndex]}` +} + function AISummarySettings({ provider, setProvider, @@ -173,6 +186,14 @@ function AISummarySettings({ const [newPresetBaseURL, setNewPresetBaseURL] = useState('') const [currentPresetName, setCurrentPresetName] = useState('') const [editingPresetId, setEditingPresetId] = useState(null) + const [embeddingProfiles, setEmbeddingProfiles] = useState([]) + const [embeddingProfileId, setEmbeddingProfileId] = useState('bge-large-zh-v1.5-int8') + const [embeddingDevice, setEmbeddingDevice] = useState('cpu') + const [embeddingDeviceStatus, setEmbeddingDeviceStatus] = useState(null) + const [embeddingStatus, setEmbeddingStatus] = useState(null) + const [embeddingProgress, setEmbeddingProgress] = useState(null) + const [isDownloadingEmbedding, setIsDownloadingEmbedding] = useState(false) + const [isClearingEmbedding, setIsClearingEmbedding] = useState(false) useEffect(() => { // 加载提供商列表和统计数据 @@ -180,8 +201,24 @@ function AISummarySettings({ loadUsageStats() loadAllProviderConfigs() loadPresets() + loadEmbeddingModels() + loadEmbeddingDeviceStatus() }, []) + useEffect(() => { + if (!embeddingProfileId) return + loadEmbeddingStatus(embeddingProfileId) + }, [embeddingProfileId]) + + useEffect(() => { + const cleanup = window.electronAPI.ai.onEmbeddingModelDownloadProgress((progress) => { + if (progress.profileId === embeddingProfileId) { + setEmbeddingProgress(progress) + } + }) + return cleanup + }, [embeddingProfileId]) + useEffect(() => { const normalizedModel = normalizeProviderModel(provider, model) if (normalizedModel !== model) { @@ -258,6 +295,123 @@ function AISummarySettings({ } } + const loadEmbeddingModels = async () => { + try { + const result = await window.electronAPI.ai.getEmbeddingModelProfiles() + if (result.success && result.result) { + setEmbeddingProfiles(result.result) + setEmbeddingProfileId(result.currentProfileId || 'bge-large-zh-v1.5-int8') + } + } catch (e) { + console.error('加载语义模型失败:', e) + } + } + + const loadEmbeddingStatus = async (profileId = embeddingProfileId) => { + try { + const result = await window.electronAPI.ai.getEmbeddingModelStatus(profileId) + if (result.success && result.result) { + setEmbeddingStatus(result.result) + } + } catch (e) { + console.error('加载语义模型状态失败:', e) + } + } + + const loadEmbeddingDeviceStatus = async () => { + try { + const result = await window.electronAPI.ai.getEmbeddingDeviceStatus() + if (result.success && result.result) { + setEmbeddingDevice(result.result.currentDevice) + setEmbeddingDeviceStatus(result.result) + } + } catch (e) { + console.error('加载语义向量计算模式失败:', e) + } + } + + const handleEmbeddingDeviceChange = async (device: EmbeddingDevice) => { + setEmbeddingDevice(device) + const result = await window.electronAPI.ai.setEmbeddingDevice(device) + if (!result.success) { + showMessage(result.error || '语义向量计算模式设置失败', false) + await loadEmbeddingDeviceStatus() + return + } + if (result.status) { + setEmbeddingDeviceStatus(result.status) + setEmbeddingDevice(result.status.currentDevice) + } else { + await loadEmbeddingDeviceStatus() + } + showMessage(device === 'dml' ? '已启用 DirectML GPU 实验模式' : '已切换到 CPU 计算模式', true) + } + + const handleEmbeddingProfileChange = async (profileId: string | number) => { + const nextProfileId = String(profileId) + const profile = embeddingProfiles.find((item) => item.id === nextProfileId) + if (profile && !profile.enabled) return + + setEmbeddingProfileId(nextProfileId) + setEmbeddingProgress(null) + const result = await window.electronAPI.ai.setEmbeddingModelProfile(nextProfileId) + if (!result.success) { + showMessage(result.error || '语义模型设置失败', false) + return + } + await loadEmbeddingStatus(nextProfileId) + showMessage('语义模型设置已保存', true) + } + + const handleDownloadEmbeddingModel = async () => { + if (!embeddingProfileId || isDownloadingEmbedding) return + setIsDownloadingEmbedding(true) + setEmbeddingProgress(null) + try { + const result = await window.electronAPI.ai.downloadEmbeddingModel(embeddingProfileId) + if (!result.success || !result.result) { + throw new Error(result.error || '语义模型下载失败') + } + setEmbeddingStatus(result.result) + setEmbeddingProgress(null) + showMessage('语义模型下载完成', true) + } catch (e) { + showMessage(String(e), false) + } finally { + setIsDownloadingEmbedding(false) + } + } + + const handleClearEmbeddingModel = async () => { + if (!embeddingProfileId || isClearingEmbedding) return + setIsClearingEmbedding(true) + try { + const result = await window.electronAPI.ai.clearEmbeddingModel(embeddingProfileId) + if (!result.success || !result.result) { + throw new Error(result.error || '语义模型清理失败') + } + setEmbeddingStatus(result.result) + setEmbeddingProgress(null) + showMessage('语义模型已清理', true) + } catch (e) { + showMessage(String(e), false) + } finally { + setIsClearingEmbedding(false) + } + } + + const handleClearSemanticIndex = async () => { + try { + const result = await window.electronAPI.ai.clearSemanticVectorIndex(embeddingProfileId) + if (!result.success || !result.result) { + throw new Error(result.error || '语义索引清理失败') + } + showMessage(`已清理 ${result.result.deletedCount} 条语义索引`, true) + } catch (e) { + showMessage(String(e), false) + } + } + const handleStartNewPreset = () => { setEditingPresetId(null) setNewPresetStep('provider') @@ -489,6 +643,9 @@ function AISummarySettings({ const currentProvider = providers.find(p => p.id === provider) || providers[0] const modelOptions = currentProvider?.models.map(m => ({ value: m, label: m })) || [] + const embeddingProgressPercent = embeddingProgress?.percent + ?? (embeddingProgress?.total ? Math.round(((embeddingProgress.loaded || 0) / embeddingProgress.total) * 100) : 0) + const embeddingProfile = embeddingProfiles.find(item => item.id === embeddingProfileId) const timeRangeOptions = [ { value: 1, label: '最近 1 天' }, { value: 3, label: '最近 3 天' }, @@ -704,6 +861,173 @@ function AISummarySettings({ +

本地语义检索

+

+ 使用 ModelScope/魔塔社区的 BGE 模型在本地生成真实语义向量,用于问 AI 检索聊天记录。模型下载到缓存目录,仅需下载一次。 +

+ +

语义模型版本

+
+ {embeddingProfiles.map(profile => ( + + ))} +
+ +

向量计算模式

+
+ + + +
+ + {embeddingDeviceStatus && ( +
+
+ {embeddingDeviceStatus.effectiveDevice === 'dml' ? : } + 当前计算: {embeddingDeviceStatus.provider} +
+

{embeddingDeviceStatus.info}

+
+ )} + +
+ {embeddingStatus ? ( +
+
+ {embeddingStatus.exists ? ( + <> + + 语义模型已就绪 + + ) : ( + <> + + 语义模型未下载 + + )} +
+

+ 模型大小: {formatBytes(embeddingStatus.sizeBytes)} + · {embeddingStatus.dim || embeddingProfile?.dim || 1024} 维 +

+

模型目录: {embeddingStatus.modelDir}

+
+ ) : ( +

正在检查模型状态...

+ )} +
+ + {isDownloadingEmbedding && ( +
+
+
+
+ + {embeddingProgress?.percent !== undefined + ? `${embeddingProgress.percent.toFixed(1)}%` + : (embeddingProgress?.remoteHost ? `连接 ${embeddingProgress.remoteHost}` : '准备中')} + +
+ )} + +
+ {!embeddingStatus?.exists && ( + + )} + {embeddingStatus?.exists && ( + + )} + + +
+ {/* 3. 摘要偏好 */}

摘要详细程度

diff --git a/src/pages/AISummaryWindow.tsx b/src/pages/AISummaryWindow.tsx index ed08c53..109e077 100644 --- a/src/pages/AISummaryWindow.tsx +++ b/src/pages/AISummaryWindow.tsx @@ -1410,8 +1410,8 @@ function AISummaryWindow() { status: 'running', processedCount: vectorConfirm?.state?.vectorizedCount || 0, totalCount: vectorConfirm?.state?.indexedCount || 0, - message: '正在准备本地向量索引', - vectorModel: vectorConfirm?.state?.vectorModel || 'local-chargram-hash-v1' + message: '正在准备本地语义向量索引', + vectorModel: vectorConfirm?.state?.vectorModel || 'bge-large-zh-v1.5-int8' }) try { @@ -1422,7 +1422,7 @@ function AISummaryWindow() { const result = await window.electronAPI.ai.prepareSessionVectorIndex({ sessionId }) if (!result.success || !result.result) { - throw new Error(result.error || '向量索引准备失败') + throw new Error(result.error || '语义向量索引准备失败') } if (result.result.isVectorComplete) { @@ -1431,7 +1431,7 @@ function AISummaryWindow() { setIsVectorIndexing(false) await runAskQuestion(question) } else { - setQaError('本地向量索引未完成,已取消本次准备') + setQaError('本地语义向量索引未完成,已取消本次准备') } } catch (e) { setQaError(String(e)) @@ -1554,6 +1554,7 @@ function AISummaryWindow() { ? Math.min(100, Math.round((processedCount / totalCount) * 100)) : 0 const pendingCount = Math.max(0, totalCount - processedCount) + const isDownloadingModel = vectorProgress?.stage === 'downloading_model' return (
e.stopPropagation()}>
-

准备本地向量索引

+

准备本地语义向量索引

-

当前会话尚未完成本地向量化。建立后,本次和后续问 AI 会优先使用本地相似度检索,新消息会自动增量处理。

+

当前会话尚未完成本地语义向量化。建立后,本次和后续问 AI 会优先使用真实语义检索,新消息会自动增量处理。

- 已向量化 {processedCount} 条 - 待处理 {pendingCount} 条 + {vectorConfirm.state?.vectorModelName || vectorProgress?.vectorModel || 'BGE Small 中文'} + {isDownloadingModel ? ( + 正在下载语义模型 + ) : ( + <> + 已语义向量化 {processedCount} 条 + 待处理 {pendingCount} 条 + + )}
{isVectorIndexing && ( @@ -1601,7 +1609,7 @@ function AISummaryWindow() { onClick={handlePrepareVectorIndex} disabled={isVectorIndexing} > - {isVectorIndexing ? '处理中' : '开始向量化'} + {isVectorIndexing ? '处理中' : '开始语义向量化'}
@@ -1745,7 +1753,7 @@ function AISummaryWindow() {
)} {isVectorIndexing && ( -
+
)} diff --git a/src/services/config.ts b/src/services/config.ts index c633fdd..a300ad3 100644 --- a/src/services/config.ts +++ b/src/services/config.ts @@ -44,7 +44,9 @@ export const CONFIG_KEYS = { AUTH_CREDENTIAL_ID: 'authCredentialId', AUTH_PASSWORD_HASH: 'authPasswordHash', AUTH_PASSWORD_SALT: 'authPasswordSalt', - CLOSE_TO_TRAY: 'closeToTray' + CLOSE_TO_TRAY: 'closeToTray', + AI_EMBEDDING_MODEL_PROFILE: 'aiEmbeddingModelProfile', + AI_EMBEDDING_DEVICE: 'aiEmbeddingDevice' } as const export type { AccountProfile, AccountProfileInput, AccountProfilePatch } @@ -628,6 +630,24 @@ export async function setAiMessageLimit(limit: number): Promise { await config.set('aiMessageLimit', limit) } +export async function getAiEmbeddingModelProfile(): Promise { + const value = await config.get(CONFIG_KEYS.AI_EMBEDDING_MODEL_PROFILE) + return (value as string) || 'bge-large-zh-v1.5-int8' +} + +export async function setAiEmbeddingModelProfile(profileId: string): Promise { + await config.set(CONFIG_KEYS.AI_EMBEDDING_MODEL_PROFILE, profileId || 'bge-large-zh-v1.5-int8') +} + +export async function getAiEmbeddingDevice(): Promise<'cpu' | 'dml'> { + const value = await config.get(CONFIG_KEYS.AI_EMBEDDING_DEVICE) + return value === 'dml' ? 'dml' : 'cpu' +} + +export async function setAiEmbeddingDevice(device: 'cpu' | 'dml'): Promise { + await config.set(CONFIG_KEYS.AI_EMBEDDING_DEVICE, device === 'dml' ? 'dml' : 'cpu') +} + // --- MCP 配置 --- export async function getMcpEnabled(): Promise { diff --git a/src/types/ai.ts b/src/types/ai.ts index 9909a68..192a304 100644 --- a/src/types/ai.ts +++ b/src/types/ai.ts @@ -182,10 +182,14 @@ export interface SessionVectorIndexState { isVectorComplete: boolean isVectorRunning: boolean vectorModel: string + vectorModelName?: string + vectorProviderAvailable?: boolean + vectorProviderError?: string } export type SessionVectorIndexProgressStage = | 'preparing' + | 'downloading_model' | 'indexing_messages' | 'vectorizing_messages' | 'completed' @@ -206,6 +210,56 @@ export interface SessionVectorIndexProgressEvent { vectorModel: string } +export interface EmbeddingModelProfile { + id: string + displayName: string + description: string + modelId: string + remoteHosts: string[] + remotePathTemplate: string + revision: string + dim: number + maxTokens: number + maxTextChars: number + dtype: string + sizeLabel: string + enabled: boolean +} + +export interface EmbeddingModelStatus { + profileId: string + displayName: string + modelId: string + dim: number + dtype: string + sizeLabel: string + enabled: boolean + exists: boolean + modelDir: string + sizeBytes: number +} + +export type EmbeddingDevice = 'cpu' | 'dml' + +export interface EmbeddingDeviceStatus { + currentDevice: EmbeddingDevice + effectiveDevice: EmbeddingDevice + gpuAvailable: boolean + provider: 'CPU' | 'DirectML' + info: string +} + +export interface EmbeddingModelDownloadProgress { + profileId: string + displayName: string + remoteHost?: string + file?: string + loaded?: number + total?: number + percent?: number + status?: string +} + export interface SessionQAResult { sessionId: string question: string diff --git a/src/types/electron.d.ts b/src/types/electron.d.ts index 2d3171b..bb45e01 100644 --- a/src/types/electron.d.ts +++ b/src/types/electron.d.ts @@ -1,5 +1,10 @@ import type { ChatSession, Message, Contact, ContactInfo } from './models' import type { + EmbeddingDevice, + EmbeddingDeviceStatus, + EmbeddingModelDownloadProgress, + EmbeddingModelProfile, + EmbeddingModelStatus, SessionQAHistoryMessage, SessionQAProgressEvent, SessionQAResult, @@ -1089,10 +1094,53 @@ export interface ElectronAPI { result?: SessionVectorIndexState error?: string }> + getEmbeddingModelProfiles: () => Promise<{ + success: boolean + result?: EmbeddingModelProfile[] + currentProfileId?: string + error?: string + }> + setEmbeddingModelProfile: (profileId: string) => Promise<{ + success: boolean + result?: string + error?: string + }> + getEmbeddingDeviceStatus: () => Promise<{ + success: boolean + result?: EmbeddingDeviceStatus + error?: string + }> + setEmbeddingDevice: (device: EmbeddingDevice) => Promise<{ + success: boolean + result?: EmbeddingDevice + status?: EmbeddingDeviceStatus + error?: string + }> + getEmbeddingModelStatus: (profileId?: string) => Promise<{ + success: boolean + result?: EmbeddingModelStatus + error?: string + }> + downloadEmbeddingModel: (profileId?: string) => Promise<{ + success: boolean + result?: EmbeddingModelStatus + error?: string + }> + clearEmbeddingModel: (profileId?: string) => Promise<{ + success: boolean + result?: EmbeddingModelStatus + error?: string + }> + clearSemanticVectorIndex: (vectorModel?: string) => Promise<{ + success: boolean + result?: { success: boolean; deletedCount: number; vectorModel: string } + error?: string + }> onSummaryChunk: (callback: (chunk: string) => void) => () => void onSessionQAChunk: (callback: (chunk: string) => void) => () => void onSessionQAProgress: (callback: (event: SessionQAProgressEvent) => void) => () => void onSessionVectorIndexProgress: (callback: (event: SessionVectorIndexProgressEvent) => void) => () => void + onEmbeddingModelDownloadProgress: (callback: (event: EmbeddingModelDownloadProgress) => void) => () => void } }