mirror of
https://github.com/ILoveBingLu/CipherTalk.git
synced 2026-05-20 19:40:32 +08:00
feat: integrate embedding model functionality and UI enhancements
- Added @huggingface/transformers and sqlite-vec dependencies to package.json. - Updated electron-builder configuration to include sqlite-vec files. - Enhanced AISummarySettings component with new styles and layout for embedding model selection and device configuration. - Implemented embedding model loading, status checking, and download functionality in AISummarySettings. - Added new types for embedding models and device status in ai.ts and electron.d.ts. - Updated config service to manage embedding model profile and device settings. - Modified AISummaryWindow to reflect changes in vector indexing messages and statuses.
This commit is contained in:
@@ -4145,6 +4145,111 @@ function registerIpcHandlers() {
|
||||
return { success: false, error: String(e) }
|
||||
}
|
||||
})
|
||||
|
||||
ipcMain.handle('ai:getEmbeddingModelProfiles', async () => {
|
||||
try {
|
||||
const { localEmbeddingModelService } = await import('./services/search/embeddingModelService')
|
||||
return {
|
||||
success: true,
|
||||
result: localEmbeddingModelService.listProfiles(),
|
||||
currentProfileId: localEmbeddingModelService.getCurrentProfileId()
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('[AI] 获取语义模型列表失败:', e)
|
||||
return { success: false, error: String(e) }
|
||||
}
|
||||
})
|
||||
|
||||
ipcMain.handle('ai:setEmbeddingModelProfile', async (_, profileId: string) => {
|
||||
try {
|
||||
const { localEmbeddingModelService } = await import('./services/search/embeddingModelService')
|
||||
const result = localEmbeddingModelService.setCurrentProfileId(profileId)
|
||||
return { success: true, result }
|
||||
} catch (e) {
|
||||
console.error('[AI] 设置语义模型失败:', e)
|
||||
return { success: false, error: String(e) }
|
||||
}
|
||||
})
|
||||
|
||||
ipcMain.handle('ai:getEmbeddingDeviceStatus', async () => {
|
||||
try {
|
||||
const { localEmbeddingModelService } = await import('./services/search/embeddingModelService')
|
||||
return {
|
||||
success: true,
|
||||
result: localEmbeddingModelService.getDeviceStatus()
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('[AI] 获取语义向量计算模式失败:', e)
|
||||
return { success: false, error: String(e) }
|
||||
}
|
||||
})
|
||||
|
||||
ipcMain.handle('ai:setEmbeddingDevice', async (_, device: string) => {
|
||||
try {
|
||||
const { localEmbeddingModelService } = await import('./services/search/embeddingModelService')
|
||||
const result = localEmbeddingModelService.setCurrentDevice(device)
|
||||
return {
|
||||
success: true,
|
||||
result,
|
||||
status: localEmbeddingModelService.getDeviceStatus()
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('[AI] 设置语义向量计算模式失败:', e)
|
||||
return { success: false, error: String(e) }
|
||||
}
|
||||
})
|
||||
|
||||
ipcMain.handle('ai:getEmbeddingModelStatus', async (_, profileId?: string) => {
|
||||
try {
|
||||
const { localEmbeddingModelService } = await import('./services/search/embeddingModelService')
|
||||
return {
|
||||
success: true,
|
||||
result: await localEmbeddingModelService.getModelStatus(profileId)
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('[AI] 获取语义模型状态失败:', e)
|
||||
return { success: false, error: String(e) }
|
||||
}
|
||||
})
|
||||
|
||||
ipcMain.handle('ai:downloadEmbeddingModel', async (event, profileId?: string) => {
|
||||
try {
|
||||
const { localEmbeddingModelService } = await import('./services/search/embeddingModelService')
|
||||
const result = await localEmbeddingModelService.downloadModel(profileId, (progress) => {
|
||||
event.sender.send('ai:embeddingModelDownloadProgress', progress)
|
||||
})
|
||||
return { success: true, result }
|
||||
} catch (e) {
|
||||
console.error('[AI] 下载语义模型失败:', e)
|
||||
return { success: false, error: String(e) }
|
||||
}
|
||||
})
|
||||
|
||||
ipcMain.handle('ai:clearEmbeddingModel', async (_, profileId?: string) => {
|
||||
try {
|
||||
const { localEmbeddingModelService } = await import('./services/search/embeddingModelService')
|
||||
return {
|
||||
success: true,
|
||||
result: await localEmbeddingModelService.clearModel(profileId)
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('[AI] 清理语义模型失败:', e)
|
||||
return { success: false, error: String(e) }
|
||||
}
|
||||
})
|
||||
|
||||
ipcMain.handle('ai:clearSemanticVectorIndex', async (_, vectorModel?: string) => {
|
||||
try {
|
||||
const { chatSearchIndexService } = await import('./services/search/chatSearchIndexService')
|
||||
return {
|
||||
success: true,
|
||||
result: chatSearchIndexService.clearSemanticVectorIndex(vectorModel)
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('[AI] 清理语义向量索引失败:', e)
|
||||
return { success: false, error: String(e) }
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 主窗口引用
|
||||
|
||||
@@ -20,6 +20,17 @@ type SessionVectorIndexProgressEvent = {
|
||||
vectorModel: string
|
||||
}
|
||||
|
||||
type EmbeddingModelDownloadProgress = {
|
||||
profileId: string
|
||||
displayName: string
|
||||
remoteHost?: string
|
||||
file?: string
|
||||
loaded?: number
|
||||
total?: number
|
||||
percent?: number
|
||||
status?: string
|
||||
}
|
||||
|
||||
function getMcpLaunchConfigSafe(): Promise<{
|
||||
command: string
|
||||
args: string[]
|
||||
@@ -543,6 +554,14 @@ contextBridge.exposeInMainWorld('electronAPI', {
|
||||
getSessionVectorIndexState: (sessionId: string) => ipcRenderer.invoke('ai:getSessionVectorIndexState', sessionId),
|
||||
prepareSessionVectorIndex: (options: { sessionId: string }) => ipcRenderer.invoke('ai:prepareSessionVectorIndex', options),
|
||||
cancelSessionVectorIndex: (sessionId: string) => ipcRenderer.invoke('ai:cancelSessionVectorIndex', sessionId),
|
||||
getEmbeddingModelProfiles: () => ipcRenderer.invoke('ai:getEmbeddingModelProfiles'),
|
||||
setEmbeddingModelProfile: (profileId: string) => ipcRenderer.invoke('ai:setEmbeddingModelProfile', profileId),
|
||||
getEmbeddingDeviceStatus: () => ipcRenderer.invoke('ai:getEmbeddingDeviceStatus'),
|
||||
setEmbeddingDevice: (device: 'cpu' | 'dml') => ipcRenderer.invoke('ai:setEmbeddingDevice', device),
|
||||
getEmbeddingModelStatus: (profileId?: string) => ipcRenderer.invoke('ai:getEmbeddingModelStatus', profileId),
|
||||
downloadEmbeddingModel: (profileId?: string) => ipcRenderer.invoke('ai:downloadEmbeddingModel', profileId),
|
||||
clearEmbeddingModel: (profileId?: string) => ipcRenderer.invoke('ai:clearEmbeddingModel', profileId),
|
||||
clearSemanticVectorIndex: (vectorModel?: string) => ipcRenderer.invoke('ai:clearSemanticVectorIndex', vectorModel),
|
||||
onSummaryChunk: (callback: (chunk: string) => void) => {
|
||||
ipcRenderer.on('ai:summaryChunk', (_, chunk) => callback(chunk))
|
||||
return () => ipcRenderer.removeAllListeners('ai:summaryChunk')
|
||||
@@ -558,6 +577,10 @@ contextBridge.exposeInMainWorld('electronAPI', {
|
||||
onSessionVectorIndexProgress: (callback: (event: SessionVectorIndexProgressEvent) => void) => {
|
||||
ipcRenderer.on('ai:sessionVectorIndexProgress', (_, event) => callback(event))
|
||||
return () => ipcRenderer.removeAllListeners('ai:sessionVectorIndexProgress')
|
||||
},
|
||||
onEmbeddingModelDownloadProgress: (callback: (event: EmbeddingModelDownloadProgress) => void) => {
|
||||
ipcRenderer.on('ai:embeddingModelDownloadProgress', (_, event) => callback(event))
|
||||
return () => ipcRenderer.removeAllListeners('ai:embeddingModelDownloadProgress')
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -110,6 +110,8 @@ interface ConfigSchema {
|
||||
aiEnableCache: boolean
|
||||
aiEnableThinking: boolean // 是否显示思考过程
|
||||
aiMessageLimit: number // 摘要提取的消息条数限制
|
||||
aiEmbeddingModelProfile: string
|
||||
aiEmbeddingDevice: 'cpu' | 'dml'
|
||||
mcpEnabled: boolean
|
||||
mcpExposeMediaPaths: boolean
|
||||
mcpProxyPort: number
|
||||
@@ -172,6 +174,8 @@ const defaults: ConfigSchema = {
|
||||
aiEnableCache: true,
|
||||
aiEnableThinking: true, // 默认显示思考过程
|
||||
aiMessageLimit: 3000, // 默认3000条,用户可调至5000
|
||||
aiEmbeddingModelProfile: 'bge-large-zh-v1.5-int8',
|
||||
aiEmbeddingDevice: 'cpu',
|
||||
mcpEnabled: false,
|
||||
mcpExposeMediaPaths: true,
|
||||
mcpProxyPort: 5032,
|
||||
|
||||
@@ -3,6 +3,11 @@ import { existsSync, mkdirSync } from 'fs'
|
||||
import { join } from 'path'
|
||||
import { chatService, type Message } from '../chatService'
|
||||
import { ConfigService } from '../config'
|
||||
import {
|
||||
float32ArrayToBuffer,
|
||||
hashEmbeddingContent,
|
||||
localEmbeddingModelService
|
||||
} from './embeddingModelService'
|
||||
|
||||
export type ChatSearchIndexProgressStage =
|
||||
| 'preparing_index'
|
||||
@@ -56,6 +61,7 @@ export interface ChatSearchSessionResult {
|
||||
|
||||
export type ChatVectorIndexProgressStage =
|
||||
| 'preparing'
|
||||
| 'downloading_model'
|
||||
| 'indexing_messages'
|
||||
| 'vectorizing_messages'
|
||||
| 'completed'
|
||||
@@ -84,6 +90,9 @@ export interface ChatVectorIndexState {
|
||||
isVectorComplete: boolean
|
||||
isVectorRunning: boolean
|
||||
vectorModel: string
|
||||
vectorModelName?: string
|
||||
vectorProviderAvailable?: boolean
|
||||
vectorProviderError?: string
|
||||
}
|
||||
|
||||
export interface ChatVectorSearchSessionResult {
|
||||
@@ -112,17 +121,7 @@ type MessageIndexRow = {
|
||||
}
|
||||
|
||||
type MessageVectorRow = MessageIndexRow & {
|
||||
vector_json: string
|
||||
}
|
||||
|
||||
type SparseVector = Array<[number, number]>
|
||||
|
||||
interface LocalVectorProvider {
|
||||
id: string
|
||||
buildVector(text: string): SparseVector
|
||||
parseVector(value: string): SparseVector
|
||||
toWeightMap(vector: SparseVector): Map<number, number>
|
||||
dot(queryWeights: Map<number, number>, vector: SparseVector): number
|
||||
distance: number
|
||||
}
|
||||
|
||||
type SessionVectorStateRow = {
|
||||
@@ -141,81 +140,18 @@ type VectorTask = {
|
||||
}
|
||||
|
||||
const INDEX_DB_NAME = 'chat_search_index.db'
|
||||
const INDEX_SCHEMA_VERSION = '1'
|
||||
const INDEX_SCHEMA_VERSION = '3'
|
||||
const INDEX_BATCH_SIZE = 800
|
||||
const MAX_INDEX_TEXT_CHARS = 8000
|
||||
const MAX_EXCERPT_RADIUS = 48
|
||||
const MAX_INDEX_SEARCH_CANDIDATES = 240
|
||||
const VECTOR_MODEL_ID = 'local-chargram-hash-v1'
|
||||
const VECTOR_DIMENSIONS = 2048
|
||||
const VECTOR_BATCH_SIZE = 800
|
||||
const MAX_VECTOR_TEXT_CHARS = 2400
|
||||
const MAX_VECTOR_SCAN_ROWS = 120000
|
||||
const VECTOR_MIN_SCORE = 0.055
|
||||
const VECTOR_BATCH_SIZE = 32
|
||||
const VECTOR_SEARCH_OVERFETCH = 8
|
||||
const VECTOR_MIN_SCORE = 0.45
|
||||
// Vector hits are recall supplements, so keep them below high-confidence keyword hits.
|
||||
const VECTOR_SCORE_BASE = 560
|
||||
const VECTOR_SCORE_SCALE = 420
|
||||
const VECTOR_INDEX_CANCELLED_ERROR = 'VECTOR_INDEX_CANCELLED'
|
||||
const VECTOR_STOP_PHRASES = [
|
||||
'有没有',
|
||||
'是不是',
|
||||
'是否',
|
||||
'什么',
|
||||
'哪个',
|
||||
'哪些',
|
||||
'什么时候',
|
||||
'为什么',
|
||||
'怎么',
|
||||
'如何',
|
||||
'帮我',
|
||||
'看看',
|
||||
'请问',
|
||||
'问一下',
|
||||
'聊天记录',
|
||||
'聊天',
|
||||
'消息',
|
||||
'记录',
|
||||
'内容',
|
||||
'这个',
|
||||
'那个',
|
||||
'我们',
|
||||
'他们',
|
||||
'对方',
|
||||
'最近',
|
||||
'说过',
|
||||
'提到',
|
||||
'关于',
|
||||
'一下'
|
||||
]
|
||||
const VECTOR_STOP_WORDS = new Set([
|
||||
'the',
|
||||
'and',
|
||||
'for',
|
||||
'with',
|
||||
'that',
|
||||
'this',
|
||||
'what',
|
||||
'when',
|
||||
'where',
|
||||
'which',
|
||||
'why',
|
||||
'how',
|
||||
'have',
|
||||
'has',
|
||||
'是否',
|
||||
'什么',
|
||||
'哪个',
|
||||
'哪些',
|
||||
'怎么',
|
||||
'如何',
|
||||
'我们',
|
||||
'他们',
|
||||
'对方',
|
||||
'消息',
|
||||
'聊天',
|
||||
'记录',
|
||||
'内容'
|
||||
])
|
||||
|
||||
function cursorKey(message: Pick<Message, 'localId' | 'createTime' | 'sortSeq'>): string {
|
||||
return `${Number(message.localId || 0)}:${Number(message.createTime || 0)}:${Number(message.sortSeq || 0)}`
|
||||
@@ -239,6 +175,15 @@ function compareIndexRowCursorAsc(
|
||||
|| Number(a.local_id || 0) - Number(b.local_id || 0)
|
||||
}
|
||||
|
||||
function vectorSessionKey(sessionId: string): number {
|
||||
let hash = 2166136261
|
||||
for (let index = 0; index < sessionId.length; index += 1) {
|
||||
hash ^= sessionId.charCodeAt(index)
|
||||
hash = Math.imul(hash, 16777619)
|
||||
}
|
||||
return hash >>> 0
|
||||
}
|
||||
|
||||
function normalizeSearchText(value?: string): string {
|
||||
return String(value || '')
|
||||
.toLowerCase()
|
||||
@@ -338,128 +283,6 @@ function buildSearchTokens(value: string): string {
|
||||
return uniqueStrings(tokens).join(' ')
|
||||
}
|
||||
|
||||
function normalizeVectorText(value: string): string {
|
||||
let normalized = normalizeSearchText(value).slice(0, MAX_VECTOR_TEXT_CHARS)
|
||||
|
||||
for (const phrase of VECTOR_STOP_PHRASES) {
|
||||
normalized = normalized.replace(new RegExp(phrase, 'gi'), ' ')
|
||||
}
|
||||
|
||||
return normalized.replace(/\s+/g, ' ').trim()
|
||||
}
|
||||
|
||||
function hashString(value: string): number {
|
||||
let hash = 2166136261
|
||||
for (let index = 0; index < value.length; index += 1) {
|
||||
hash ^= value.charCodeAt(index)
|
||||
hash = Math.imul(hash, 16777619)
|
||||
}
|
||||
return hash >>> 0
|
||||
}
|
||||
|
||||
function addVectorFeature(weights: Map<number, number>, feature: string, weight: number): void {
|
||||
if (!feature || VECTOR_STOP_WORDS.has(feature)) return
|
||||
|
||||
const hash = hashString(feature)
|
||||
const dimension = hash % VECTOR_DIMENSIONS
|
||||
const signedWeight = (hash & 0x80000000) ? -weight : weight
|
||||
weights.set(dimension, (weights.get(dimension) || 0) + signedWeight)
|
||||
}
|
||||
|
||||
function addChineseVectorFeatures(weights: Map<number, number>, segment: string): void {
|
||||
if (!segment || VECTOR_STOP_WORDS.has(segment)) return
|
||||
|
||||
if (segment.length >= 2 && segment.length <= 12) {
|
||||
addVectorFeature(weights, `zh:${segment}`, 1.35)
|
||||
}
|
||||
|
||||
for (let size = 2; size <= 4; size += 1) {
|
||||
if (segment.length < size) continue
|
||||
const weight = size === 2 ? 0.9 : size === 3 ? 1.1 : 0.85
|
||||
for (let index = 0; index <= segment.length - size; index += 1) {
|
||||
const gram = segment.slice(index, index + size)
|
||||
if (VECTOR_STOP_WORDS.has(gram)) continue
|
||||
addVectorFeature(weights, `c${size}:${gram}`, weight)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function buildLocalSearchVector(value: string): SparseVector {
|
||||
const normalized = normalizeVectorText(value)
|
||||
if (!normalized) return []
|
||||
|
||||
const weights = new Map<number, number>()
|
||||
const latinWords: string[] = []
|
||||
|
||||
for (const match of normalized.matchAll(/[\u3400-\u9fff]+/g)) {
|
||||
addChineseVectorFeatures(weights, match[0])
|
||||
}
|
||||
|
||||
for (const match of normalized.matchAll(/[a-z0-9_@.\-]{2,}/g)) {
|
||||
const word = match[0]
|
||||
if (VECTOR_STOP_WORDS.has(word)) continue
|
||||
latinWords.push(word)
|
||||
addVectorFeature(weights, `w:${word}`, 1.2)
|
||||
}
|
||||
|
||||
for (let index = 0; index < latinWords.length - 1; index += 1) {
|
||||
addVectorFeature(weights, `wb:${latinWords[index]} ${latinWords[index + 1]}`, 1.35)
|
||||
}
|
||||
|
||||
let norm = 0
|
||||
for (const weight of weights.values()) {
|
||||
norm += weight * weight
|
||||
}
|
||||
|
||||
if (norm <= 0) return []
|
||||
|
||||
const scale = Math.sqrt(norm)
|
||||
return Array.from(weights.entries())
|
||||
.map(([dimension, weight]) => [dimension, Number((weight / scale).toFixed(6))] as [number, number])
|
||||
.filter(([, weight]) => Math.abs(weight) > 0.000001)
|
||||
.sort((a, b) => a[0] - b[0])
|
||||
}
|
||||
|
||||
function parseSparseVector(value: string): SparseVector {
|
||||
try {
|
||||
const parsed = JSON.parse(value)
|
||||
if (!Array.isArray(parsed)) return []
|
||||
|
||||
const vector: SparseVector = []
|
||||
for (const item of parsed) {
|
||||
if (!Array.isArray(item) || item.length < 2) continue
|
||||
const dimension = Number(item[0])
|
||||
const weight = Number(item[1])
|
||||
if (!Number.isInteger(dimension) || dimension < 0 || !Number.isFinite(weight)) continue
|
||||
vector.push([dimension, weight])
|
||||
}
|
||||
return vector
|
||||
} catch {
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
function dotSparseVector(queryWeights: Map<number, number>, vector: SparseVector): number {
|
||||
let score = 0
|
||||
for (const [dimension, weight] of vector) {
|
||||
const queryWeight = queryWeights.get(dimension)
|
||||
if (queryWeight) score += queryWeight * weight
|
||||
}
|
||||
return score
|
||||
}
|
||||
|
||||
function sparseVectorToMap(vector: SparseVector): Map<number, number> {
|
||||
return new Map(vector)
|
||||
}
|
||||
|
||||
const localVectorProvider: LocalVectorProvider = {
|
||||
id: VECTOR_MODEL_ID,
|
||||
buildVector: buildLocalSearchVector,
|
||||
parseVector: parseSparseVector,
|
||||
toWeightMap: sparseVectorToMap,
|
||||
dot: dotSparseVector
|
||||
}
|
||||
|
||||
function createVectorExcerpt(row: Pick<MessageIndexRow, 'parsed_content' | 'search_text'>, query: string): string {
|
||||
const text = String(row.parsed_content || row.search_text || '')
|
||||
if (!text) return ''
|
||||
@@ -629,6 +452,8 @@ export class ChatSearchIndexService {
|
||||
private db: Database.Database | null = null
|
||||
private dbPath: string | null = null
|
||||
private vectorTasks = new Map<string, VectorTask>()
|
||||
private sqliteVectorAvailable = false
|
||||
private sqliteVectorError = ''
|
||||
|
||||
private getCacheBasePath(): string {
|
||||
const configService = new ConfigService()
|
||||
@@ -662,10 +487,24 @@ export class ChatSearchIndexService {
|
||||
const db = new Database(nextDbPath)
|
||||
this.db = db
|
||||
this.dbPath = nextDbPath
|
||||
this.loadSqliteVectorExtension(db)
|
||||
this.ensureSchema(db)
|
||||
return db
|
||||
}
|
||||
|
||||
private loadSqliteVectorExtension(db: Database.Database): void {
|
||||
try {
|
||||
const sqliteVec = require('sqlite-vec') as { load: (db: Database.Database) => void }
|
||||
sqliteVec.load(db)
|
||||
this.sqliteVectorAvailable = true
|
||||
this.sqliteVectorError = ''
|
||||
} catch (error) {
|
||||
this.sqliteVectorAvailable = false
|
||||
this.sqliteVectorError = String(error)
|
||||
console.warn('[ChatSearchIndex] sqlite-vec 加载失败,语义向量检索将降级为关键词检索:', error)
|
||||
}
|
||||
}
|
||||
|
||||
private ensureSchema(db: Database.Database): void {
|
||||
db.pragma('journal_mode = WAL')
|
||||
db.pragma('synchronous = NORMAL')
|
||||
@@ -720,12 +559,15 @@ export class ChatSearchIndexService {
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS message_vector_index (
|
||||
message_id INTEGER PRIMARY KEY,
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
message_id INTEGER NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
vector_model TEXT NOT NULL,
|
||||
vector_json TEXT NOT NULL,
|
||||
feature_count INTEGER NOT NULL DEFAULT 0,
|
||||
indexed_at INTEGER NOT NULL
|
||||
embedding_blob BLOB NOT NULL,
|
||||
dim INTEGER NOT NULL,
|
||||
content_hash TEXT NOT NULL,
|
||||
indexed_at INTEGER NOT NULL,
|
||||
UNIQUE(message_id, vector_model)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS session_vector_state (
|
||||
@@ -745,16 +587,32 @@ export class ChatSearchIndexService {
|
||||
ON message_index(session_id, sender_username);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_vector_session_model
|
||||
ON message_vector_index(session_id, vector_model);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_vector_message_model
|
||||
ON message_vector_index(message_id, vector_model);
|
||||
CREATE INDEX IF NOT EXISTS idx_session_vector_state_session
|
||||
ON session_vector_state(session_id);
|
||||
`)
|
||||
|
||||
if (this.sqliteVectorAvailable) {
|
||||
const dim = localEmbeddingModelService.getProfile().dim
|
||||
db.exec(`
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS message_embedding_vec USING vec0(
|
||||
vector_id INTEGER PRIMARY KEY,
|
||||
session_key INTEGER PARTITION KEY,
|
||||
session_id TEXT,
|
||||
vector_model TEXT,
|
||||
embedding FLOAT[${dim}]
|
||||
);
|
||||
`)
|
||||
}
|
||||
|
||||
db.prepare('INSERT OR REPLACE INTO meta(key, value) VALUES (?, ?)').run('schema_version', INDEX_SCHEMA_VERSION)
|
||||
}
|
||||
|
||||
private resetSchema(db: Database.Database): void {
|
||||
db.exec(`
|
||||
DROP TABLE IF EXISTS message_index_fts;
|
||||
DROP TABLE IF EXISTS message_embedding_vec;
|
||||
DROP TABLE IF EXISTS message_vector_index;
|
||||
DROP TABLE IF EXISTS session_vector_state;
|
||||
DROP TABLE IF EXISTS message_index;
|
||||
@@ -783,26 +641,36 @@ export class ChatSearchIndexService {
|
||||
return Number(row?.count || 0)
|
||||
}
|
||||
|
||||
private getCurrentVectorProfile() {
|
||||
return localEmbeddingModelService.getProfile()
|
||||
}
|
||||
|
||||
private getCurrentVectorModelId(): string {
|
||||
return this.getCurrentVectorProfile().id
|
||||
}
|
||||
|
||||
private getVectorizedCount(db: Database.Database, sessionId: string): number {
|
||||
const vectorModel = this.getCurrentVectorModelId()
|
||||
const row = db.prepare(`
|
||||
SELECT COUNT(*) AS count
|
||||
FROM message_index m
|
||||
JOIN message_vector_index v ON v.message_id = m.id
|
||||
WHERE m.session_id = ? AND v.vector_model = ?
|
||||
`).get(sessionId, localVectorProvider.id) as { count?: number }
|
||||
`).get(sessionId, vectorModel) as { count?: number }
|
||||
return Number(row?.count || 0)
|
||||
}
|
||||
|
||||
private getVectorTaskKey(sessionId: string): string {
|
||||
return `${sessionId}:${localVectorProvider.id}`
|
||||
return `${sessionId}:${this.getCurrentVectorModelId()}`
|
||||
}
|
||||
|
||||
private getVectorStateRow(db: Database.Database, sessionId: string): SessionVectorStateRow | null {
|
||||
const vectorModel = this.getCurrentVectorModelId()
|
||||
const row = db.prepare(`
|
||||
SELECT *
|
||||
FROM session_vector_state
|
||||
WHERE session_id = ? AND vector_model = ?
|
||||
`).get(sessionId, localVectorProvider.id) as SessionVectorStateRow | undefined
|
||||
`).get(sessionId, vectorModel) as SessionVectorStateRow | undefined
|
||||
return row || null
|
||||
}
|
||||
|
||||
@@ -836,7 +704,7 @@ export class ChatSearchIndexService {
|
||||
last_error = excluded.last_error
|
||||
`).run(
|
||||
input.sessionId,
|
||||
localVectorProvider.id,
|
||||
this.getCurrentVectorModelId(),
|
||||
input.confirmedAt ?? null,
|
||||
input.completedAt ?? null,
|
||||
now,
|
||||
@@ -847,11 +715,13 @@ export class ChatSearchIndexService {
|
||||
|
||||
getSessionVectorIndexState(sessionId: string): ChatVectorIndexState {
|
||||
const db = this.getDb()
|
||||
const profile = this.getCurrentVectorProfile()
|
||||
const indexedCount = this.getIndexedCount(db, sessionId)
|
||||
const vectorizedCount = this.getVectorizedCount(db, sessionId)
|
||||
const isRunning = this.vectorTasks.has(this.getVectorTaskKey(sessionId))
|
||||
const row = this.getVectorStateRow(db, sessionId)
|
||||
const isComplete = Number(row?.is_complete || 0) === 1
|
||||
const isComplete = this.sqliteVectorAvailable
|
||||
&& Number(row?.is_complete || 0) === 1
|
||||
&& vectorizedCount >= indexedCount
|
||||
|
||||
return {
|
||||
@@ -861,7 +731,10 @@ export class ChatSearchIndexService {
|
||||
pendingCount: Math.max(0, indexedCount - vectorizedCount),
|
||||
isVectorComplete: isComplete,
|
||||
isVectorRunning: isRunning,
|
||||
vectorModel: localVectorProvider.id
|
||||
vectorModel: profile.id,
|
||||
vectorModelName: profile.displayName,
|
||||
vectorProviderAvailable: this.sqliteVectorAvailable,
|
||||
vectorProviderError: this.sqliteVectorError
|
||||
}
|
||||
}
|
||||
|
||||
@@ -873,56 +746,78 @@ export class ChatSearchIndexService {
|
||||
progress: Omit<ChatVectorIndexProgress, 'vectorModel'>,
|
||||
onProgress?: (progress: ChatVectorIndexProgress) => void | Promise<void>
|
||||
): Promise<void> {
|
||||
const profile = this.getCurrentVectorProfile()
|
||||
await onProgress?.({
|
||||
...progress,
|
||||
vectorModel: localVectorProvider.id
|
||||
vectorModel: profile.id
|
||||
})
|
||||
}
|
||||
|
||||
private upsertVectorRows(
|
||||
private async upsertVectorRows(
|
||||
db: Database.Database,
|
||||
rows: Array<Pick<MessageIndexRow, 'id' | 'session_id' | 'search_text'> & { indexed_at?: number }>
|
||||
): void {
|
||||
): Promise<void> {
|
||||
if (rows.length === 0) return
|
||||
if (!this.sqliteVectorAvailable) {
|
||||
throw new Error(`本地语义检索不可用:${this.sqliteVectorError || 'sqlite-vec 未加载'}`)
|
||||
}
|
||||
|
||||
const profile = this.getCurrentVectorProfile()
|
||||
const embeddings = await localEmbeddingModelService.embedTexts(rows.map((row) => row.search_text), profile.id)
|
||||
|
||||
const upsertVector = db.prepare(`
|
||||
INSERT INTO message_vector_index (
|
||||
message_id,
|
||||
session_id,
|
||||
vector_model,
|
||||
vector_json,
|
||||
feature_count,
|
||||
embedding_blob,
|
||||
dim,
|
||||
content_hash,
|
||||
indexed_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(message_id) DO UPDATE SET
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(message_id, vector_model) DO UPDATE SET
|
||||
session_id = excluded.session_id,
|
||||
vector_model = excluded.vector_model,
|
||||
vector_json = excluded.vector_json,
|
||||
feature_count = excluded.feature_count,
|
||||
embedding_blob = excluded.embedding_blob,
|
||||
dim = excluded.dim,
|
||||
content_hash = excluded.content_hash,
|
||||
indexed_at = excluded.indexed_at
|
||||
`)
|
||||
const selectVectorId = db.prepare(`
|
||||
SELECT id FROM message_vector_index
|
||||
WHERE message_id = ? AND vector_model = ?
|
||||
`)
|
||||
const upsertVec = db.prepare(`
|
||||
INSERT OR REPLACE INTO message_embedding_vec(vector_id, session_key, session_id, vector_model, embedding)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
`)
|
||||
|
||||
const run = db.transaction((items: Array<Pick<MessageIndexRow, 'id' | 'session_id' | 'search_text'> & { indexed_at?: number }>) => {
|
||||
const now = Date.now()
|
||||
for (const row of items) {
|
||||
const vector = localVectorProvider.buildVector(row.search_text)
|
||||
for (let index = 0; index < items.length; index += 1) {
|
||||
const row = items[index]
|
||||
const vector = embeddings[index]
|
||||
upsertVector.run(
|
||||
row.id,
|
||||
row.session_id,
|
||||
localVectorProvider.id,
|
||||
JSON.stringify(vector),
|
||||
profile.id,
|
||||
float32ArrayToBuffer(vector),
|
||||
vector.length,
|
||||
hashEmbeddingContent(row.search_text),
|
||||
row.indexed_at || now
|
||||
)
|
||||
const vectorRow = selectVectorId.get(row.id, profile.id) as { id?: number } | undefined
|
||||
if (vectorRow?.id) {
|
||||
upsertVec.run(vectorRow.id, vectorSessionKey(row.session_id), row.session_id, profile.id, float32ArrayToBuffer(vector))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
run(rows)
|
||||
}
|
||||
|
||||
private upsertMessages(db: Database.Database, sessionId: string, messages: Message[], options: {
|
||||
private async upsertMessages(db: Database.Database, sessionId: string, messages: Message[], options: {
|
||||
vectorize?: boolean
|
||||
} = {}): void {
|
||||
} = {}): Promise<void> {
|
||||
if (messages.length === 0) return
|
||||
|
||||
const upsert = db.prepare(`
|
||||
@@ -978,22 +873,7 @@ export class ChatSearchIndexService {
|
||||
INSERT INTO message_index_fts(rowid, session_id, cursor_key, search_text, token_text)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
`)
|
||||
const upsertVector = db.prepare(`
|
||||
INSERT INTO message_vector_index (
|
||||
message_id,
|
||||
session_id,
|
||||
vector_model,
|
||||
vector_json,
|
||||
feature_count,
|
||||
indexed_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(message_id) DO UPDATE SET
|
||||
session_id = excluded.session_id,
|
||||
vector_model = excluded.vector_model,
|
||||
vector_json = excluded.vector_json,
|
||||
feature_count = excluded.feature_count,
|
||||
indexed_at = excluded.indexed_at
|
||||
`)
|
||||
const vectorRows: Array<Pick<MessageIndexRow, 'id' | 'session_id' | 'search_text'> & { indexed_at?: number }> = []
|
||||
|
||||
const run = db.transaction((items: Message[]) => {
|
||||
const indexedAt = Date.now()
|
||||
@@ -1024,20 +904,29 @@ export class ChatSearchIndexService {
|
||||
deleteFts.run(row.id)
|
||||
insertFts.run(row.id, sessionId, cursorKey(message), searchText, tokenText)
|
||||
if (options.vectorize) {
|
||||
const vector = localVectorProvider.buildVector(searchText)
|
||||
upsertVector.run(
|
||||
row.id,
|
||||
sessionId,
|
||||
localVectorProvider.id,
|
||||
JSON.stringify(vector),
|
||||
vector.length,
|
||||
indexedAt
|
||||
)
|
||||
vectorRows.push({
|
||||
id: row.id,
|
||||
session_id: sessionId,
|
||||
search_text: searchText,
|
||||
indexed_at: indexedAt
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
run(messages)
|
||||
if (vectorRows.length > 0) {
|
||||
try {
|
||||
await this.upsertVectorRows(db, vectorRows)
|
||||
} catch (error) {
|
||||
this.setSessionVectorState(db, {
|
||||
sessionId,
|
||||
completedAt: null,
|
||||
isComplete: false,
|
||||
lastError: String(error)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private updateSessionState(db: Database.Database, sessionId: string, newest: Message | null, isComplete: boolean): ChatSearchIndexState {
|
||||
@@ -1149,7 +1038,7 @@ export class ChatSearchIndexService {
|
||||
|
||||
const messages = result.messages || []
|
||||
if (messages.length === 0) break
|
||||
this.upsertMessages(db, sessionId, messages, { vectorize: vectorizeDuringIndexing })
|
||||
await this.upsertMessages(db, sessionId, messages, { vectorize: vectorizeDuringIndexing })
|
||||
scanned += messages.length
|
||||
newest = messages[messages.length - 1] || newest
|
||||
cursor = {
|
||||
@@ -1181,6 +1070,12 @@ export class ChatSearchIndexService {
|
||||
}
|
||||
|
||||
db.prepare('DELETE FROM message_index_fts WHERE session_id = ?').run(sessionId)
|
||||
if (this.sqliteVectorAvailable) {
|
||||
db.prepare(`
|
||||
DELETE FROM message_embedding_vec
|
||||
WHERE vector_id IN (SELECT id FROM message_vector_index WHERE session_id = ?)
|
||||
`).run(sessionId)
|
||||
}
|
||||
db.prepare('DELETE FROM message_vector_index WHERE session_id = ?').run(sessionId)
|
||||
db.prepare('DELETE FROM message_index WHERE session_id = ?').run(sessionId)
|
||||
db.prepare('DELETE FROM session_index_state WHERE session_id = ?').run(sessionId)
|
||||
@@ -1193,7 +1088,7 @@ export class ChatSearchIndexService {
|
||||
let messages = firstPage.messages || []
|
||||
let hasMore = Boolean(firstPage.hasMore)
|
||||
if (messages.length > 0) {
|
||||
this.upsertMessages(db, sessionId, messages, { vectorize: vectorizeDuringIndexing })
|
||||
await this.upsertMessages(db, sessionId, messages, { vectorize: vectorizeDuringIndexing })
|
||||
scanned += messages.length
|
||||
newest = messages[messages.length - 1]
|
||||
await this.report({
|
||||
@@ -1220,7 +1115,7 @@ export class ChatSearchIndexService {
|
||||
|
||||
messages = result.messages || []
|
||||
if (messages.length === 0) break
|
||||
this.upsertMessages(db, sessionId, messages, { vectorize: vectorizeDuringIndexing })
|
||||
await this.upsertMessages(db, sessionId, messages, { vectorize: vectorizeDuringIndexing })
|
||||
scanned += messages.length
|
||||
hasMore = Boolean(result.hasMore)
|
||||
|
||||
@@ -1284,6 +1179,21 @@ export class ChatSearchIndexService {
|
||||
return this.getSessionVectorIndexState(sessionId)
|
||||
}
|
||||
|
||||
clearSemanticVectorIndex(vectorModel = this.getCurrentVectorModelId()): { success: boolean; deletedCount: number; vectorModel: string } {
|
||||
const db = this.getDb()
|
||||
const row = db.prepare('SELECT COUNT(*) AS count FROM message_vector_index WHERE vector_model = ?').get(vectorModel) as { count?: number }
|
||||
if (this.sqliteVectorAvailable) {
|
||||
db.prepare('DELETE FROM message_embedding_vec WHERE vector_model = ?').run(vectorModel)
|
||||
}
|
||||
db.prepare('DELETE FROM message_vector_index WHERE vector_model = ?').run(vectorModel)
|
||||
db.prepare('DELETE FROM session_vector_state WHERE vector_model = ?').run(vectorModel)
|
||||
return {
|
||||
success: true,
|
||||
deletedCount: Number(row?.count || 0),
|
||||
vectorModel
|
||||
}
|
||||
}
|
||||
|
||||
private async runPrepareSessionVectorIndex(
|
||||
sessionId: string,
|
||||
task: VectorTask,
|
||||
@@ -1301,6 +1211,35 @@ export class ChatSearchIndexService {
|
||||
}, onProgress)
|
||||
|
||||
try {
|
||||
if (!this.sqliteVectorAvailable) {
|
||||
throw new Error(`本地语义检索不可用:${this.sqliteVectorError || 'sqlite-vec 未加载'}`)
|
||||
}
|
||||
|
||||
const profile = this.getCurrentVectorProfile()
|
||||
const modelStatus = await localEmbeddingModelService.getModelStatus(profile.id)
|
||||
if (!modelStatus.exists) {
|
||||
await this.reportVectorProgress({
|
||||
sessionId,
|
||||
stage: 'downloading_model',
|
||||
status: 'running',
|
||||
processedCount: 0,
|
||||
totalCount: 0,
|
||||
message: `正在下载本地语义模型:${profile.displayName}`
|
||||
}, onProgress)
|
||||
await localEmbeddingModelService.downloadModel(profile.id, async (progress) => {
|
||||
await this.reportVectorProgress({
|
||||
sessionId,
|
||||
stage: 'downloading_model',
|
||||
status: 'running',
|
||||
processedCount: progress.loaded || 0,
|
||||
totalCount: progress.total || 0,
|
||||
message: progress.percent !== undefined
|
||||
? `正在下载 ${profile.displayName}:${progress.percent}%`
|
||||
: `正在下载 ${profile.displayName}`
|
||||
}, onProgress)
|
||||
})
|
||||
}
|
||||
|
||||
const searchState = await this.ensureSessionIndexed(sessionId, async (progress) => {
|
||||
if (task.cancelRequested) {
|
||||
throw new Error(VECTOR_INDEX_CANCELLED_ERROR)
|
||||
@@ -1326,7 +1265,7 @@ export class ChatSearchIndexService {
|
||||
status: 'completed',
|
||||
processedCount: currentState.vectorizedCount,
|
||||
totalCount: currentState.indexedCount,
|
||||
message: `本地向量索引已就绪,共 ${currentState.vectorizedCount} 条消息`
|
||||
message: `本地语义向量索引已就绪,共 ${currentState.vectorizedCount} 条消息`
|
||||
}, onProgress)
|
||||
return currentState
|
||||
}
|
||||
@@ -1386,11 +1325,11 @@ export class ChatSearchIndexService {
|
||||
WHERE m.session_id = ? AND v.message_id IS NULL
|
||||
ORDER BY m.id ASC
|
||||
LIMIT ?
|
||||
`).all(localVectorProvider.id, sessionId, VECTOR_BATCH_SIZE) as Array<Pick<MessageIndexRow, 'id' | 'session_id' | 'search_text'> & { indexed_at?: number }>
|
||||
`).all(this.getCurrentVectorModelId(), sessionId, VECTOR_BATCH_SIZE) as Array<Pick<MessageIndexRow, 'id' | 'session_id' | 'search_text'> & { indexed_at?: number }>
|
||||
|
||||
if (rows.length === 0) break
|
||||
|
||||
this.upsertVectorRows(db, rows)
|
||||
await this.upsertVectorRows(db, rows)
|
||||
currentState = this.getSessionVectorIndexState(sessionId)
|
||||
await this.reportVectorProgress({
|
||||
sessionId,
|
||||
@@ -1417,7 +1356,7 @@ export class ChatSearchIndexService {
|
||||
status: 'completed',
|
||||
processedCount: currentState.vectorizedCount,
|
||||
totalCount: currentState.indexedCount,
|
||||
message: `本地向量索引已完成,共 ${currentState.vectorizedCount} 条消息`
|
||||
message: `本地语义向量索引已完成,共 ${currentState.vectorizedCount} 条消息`
|
||||
}, onProgress)
|
||||
|
||||
return currentState
|
||||
@@ -1576,23 +1515,26 @@ export class ChatSearchIndexService {
|
||||
const db = this.getDb()
|
||||
const state = await this.ensureSessionIndexed(options.sessionId, options.onProgress)
|
||||
const vectorState = this.getSessionVectorIndexState(options.sessionId)
|
||||
const queryVector = localVectorProvider.buildVector(options.query)
|
||||
const profile = this.getCurrentVectorProfile()
|
||||
const vectorizedCount = vectorState.vectorizedCount
|
||||
|
||||
if (!vectorState.isVectorComplete || queryVector.length === 0) {
|
||||
if (!this.sqliteVectorAvailable || !vectorState.isVectorComplete || !normalizeSearchText(options.query)) {
|
||||
return {
|
||||
hits: [],
|
||||
indexedCount: state.indexedCount,
|
||||
vectorizedCount,
|
||||
truncated: false,
|
||||
model: localVectorProvider.id
|
||||
model: profile.id
|
||||
}
|
||||
}
|
||||
|
||||
const queryVector = await localEmbeddingModelService.embedText(options.query, profile.id)
|
||||
const queryEmbedding = float32ArrayToBuffer(queryVector)
|
||||
|
||||
await this.report({
|
||||
stage: 'searching_index',
|
||||
sessionId: options.sessionId,
|
||||
message: `正在进行本地向量检索:${options.query}`,
|
||||
message: `正在进行本地语义检索:${options.query}`,
|
||||
indexedCount: state.indexedCount
|
||||
}, options.onProgress)
|
||||
|
||||
@@ -1600,14 +1542,19 @@ export class ChatSearchIndexService {
|
||||
const endTime = toTimestampSeconds(options.endTimeMs)
|
||||
const senderUsername = normalizeSearchText(options.senderUsername)
|
||||
const direction = options.direction
|
||||
const scanLimit = MAX_VECTOR_SCAN_ROWS
|
||||
const scanLimit = Math.max(options.limit * VECTOR_SEARCH_OVERFETCH, options.limit + 20)
|
||||
const sqlFilters: string[] = [
|
||||
'v.session_id = @sessionId',
|
||||
'v.vector_model = @vectorModel'
|
||||
'vec.embedding MATCH @queryEmbedding',
|
||||
'vec.session_key = @sessionKey',
|
||||
'vec.session_id = @sessionId',
|
||||
'vec.vector_model = @vectorModel',
|
||||
'k = @scanLimit'
|
||||
]
|
||||
const params: Record<string, unknown> = {
|
||||
sessionId: options.sessionId,
|
||||
vectorModel: localVectorProvider.id,
|
||||
sessionKey: vectorSessionKey(options.sessionId),
|
||||
vectorModel: profile.id,
|
||||
queryEmbedding,
|
||||
scanLimit: scanLimit + 1
|
||||
}
|
||||
|
||||
@@ -1628,18 +1575,16 @@ export class ChatSearchIndexService {
|
||||
}
|
||||
|
||||
const rows = db.prepare(`
|
||||
SELECT m.*, v.vector_json
|
||||
FROM message_vector_index v
|
||||
SELECT m.*, vec.distance
|
||||
FROM message_embedding_vec vec
|
||||
JOIN message_vector_index v ON v.id = vec.vector_id
|
||||
JOIN message_index m ON m.id = v.message_id
|
||||
WHERE ${sqlFilters.join(' AND ')}
|
||||
AND v.feature_count > 0
|
||||
ORDER BY m.sort_seq DESC, m.create_time DESC, m.local_id DESC
|
||||
LIMIT @scanLimit
|
||||
ORDER BY vec.distance ASC
|
||||
`).all(params) as MessageVectorRow[]
|
||||
const queryWeights = localVectorProvider.toWeightMap(queryVector)
|
||||
const scored = rows
|
||||
.map((row) => {
|
||||
const vectorScore = localVectorProvider.dot(queryWeights, localVectorProvider.parseVector(row.vector_json))
|
||||
const vectorScore = Math.max(0, Math.min(1, 1 - Number(row.distance || 0)))
|
||||
return {
|
||||
row,
|
||||
vectorScore
|
||||
@@ -1662,7 +1607,7 @@ export class ChatSearchIndexService {
|
||||
indexedCount: state.indexedCount,
|
||||
vectorizedCount,
|
||||
truncated: rows.length > scanLimit,
|
||||
model: localVectorProvider.id
|
||||
model: profile.id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,596 @@
|
||||
import { createHash } from 'crypto'
|
||||
import { existsSync, mkdirSync, readdirSync, rmSync, statSync } from 'fs'
|
||||
import { dirname, join } from 'path'
|
||||
import { ConfigService } from '../config'
|
||||
|
||||
export type EmbeddingModelProfileId =
|
||||
| 'bge-large-zh-v1.5-int8'
|
||||
| 'bge-large-zh-v1.5-fp32'
|
||||
| 'bge-m3'
|
||||
|
||||
export type EmbeddingDevice = 'cpu' | 'dml'
|
||||
|
||||
export type EmbeddingDeviceStatus = {
|
||||
currentDevice: EmbeddingDevice
|
||||
effectiveDevice: EmbeddingDevice
|
||||
gpuAvailable: boolean
|
||||
provider: 'CPU' | 'DirectML'
|
||||
info: string
|
||||
}
|
||||
|
||||
export type EmbeddingModelStatus = {
|
||||
profileId: string
|
||||
displayName: string
|
||||
modelId: string
|
||||
dim: number
|
||||
dtype: string
|
||||
sizeLabel: string
|
||||
enabled: boolean
|
||||
exists: boolean
|
||||
modelDir: string
|
||||
sizeBytes: number
|
||||
}
|
||||
|
||||
export type EmbeddingDownloadProgress = {
|
||||
profileId: string
|
||||
displayName: string
|
||||
remoteHost?: string
|
||||
file?: string
|
||||
loaded?: number
|
||||
total?: number
|
||||
percent?: number
|
||||
status?: string
|
||||
}
|
||||
|
||||
export type EmbeddingModelProfile = {
|
||||
id: EmbeddingModelProfileId
|
||||
displayName: string
|
||||
description: string
|
||||
modelId: string
|
||||
remoteHosts: string[]
|
||||
remotePathTemplate: string
|
||||
revision: string
|
||||
dim: number
|
||||
maxTokens: number
|
||||
maxTextChars: number
|
||||
dtype: 'q8' | 'fp32'
|
||||
sizeLabel: string
|
||||
enabled: boolean
|
||||
}
|
||||
|
||||
const MODELSCOPE_HOST = 'https://www.modelscope.cn/'
|
||||
const MODELSCOPE_PATH_TEMPLATE = 'models/{model}/resolve/{revision}/'
|
||||
const MODELSCOPE_REVISION = 'master'
|
||||
|
||||
export const DEFAULT_EMBEDDING_MODEL_PROFILE: EmbeddingModelProfileId = 'bge-large-zh-v1.5-int8'
|
||||
|
||||
const EMBEDDING_MODEL_PROFILES: EmbeddingModelProfile[] = [
|
||||
{
|
||||
id: 'bge-large-zh-v1.5-int8',
|
||||
displayName: 'BGE Large 中文 · 推荐',
|
||||
description: '默认档位,1024 维中文语义向量,优先兼顾召回质量和本地 CPU 性能。',
|
||||
modelId: 'Xenova/bge-large-zh-v1.5',
|
||||
remoteHosts: [MODELSCOPE_HOST],
|
||||
remotePathTemplate: MODELSCOPE_PATH_TEMPLATE,
|
||||
revision: MODELSCOPE_REVISION,
|
||||
dim: 1024,
|
||||
maxTokens: 512,
|
||||
maxTextChars: 480,
|
||||
dtype: 'q8',
|
||||
sizeLabel: '约 330 MB',
|
||||
enabled: true
|
||||
},
|
||||
{
|
||||
id: 'bge-large-zh-v1.5-fp32',
|
||||
displayName: 'BGE Large 中文 · 高质量',
|
||||
description: '同模型 FP32 推理,精度更完整,下载和内存占用更高。',
|
||||
modelId: 'Xenova/bge-large-zh-v1.5',
|
||||
remoteHosts: [MODELSCOPE_HOST],
|
||||
remotePathTemplate: MODELSCOPE_PATH_TEMPLATE,
|
||||
revision: MODELSCOPE_REVISION,
|
||||
dim: 1024,
|
||||
maxTokens: 512,
|
||||
maxTextChars: 480,
|
||||
dtype: 'fp32',
|
||||
sizeLabel: '约 1.2 GB',
|
||||
enabled: true
|
||||
},
|
||||
{
|
||||
id: 'bge-m3',
|
||||
displayName: 'BGE-M3 · 多语言',
|
||||
description: '更强的多语言和长文本语义召回,资源占用更高。',
|
||||
modelId: 'Xenova/bge-m3',
|
||||
remoteHosts: [MODELSCOPE_HOST],
|
||||
remotePathTemplate: MODELSCOPE_PATH_TEMPLATE,
|
||||
revision: MODELSCOPE_REVISION,
|
||||
dim: 1024,
|
||||
maxTokens: 8192,
|
||||
maxTextChars: 2400,
|
||||
dtype: 'q8',
|
||||
sizeLabel: '约 600 MB',
|
||||
enabled: true
|
||||
}
|
||||
]
|
||||
|
||||
function safeProfileId(value: unknown): EmbeddingModelProfileId {
|
||||
const id = String(value || '').trim() as EmbeddingModelProfileId
|
||||
const profile = EMBEDDING_MODEL_PROFILES.find((item) => item.id === id && item.enabled)
|
||||
return profile?.id || DEFAULT_EMBEDDING_MODEL_PROFILE
|
||||
}
|
||||
|
||||
function safeEmbeddingDevice(value: unknown): EmbeddingDevice {
|
||||
return String(value || '').trim() === 'dml' ? 'dml' : 'cpu'
|
||||
}
|
||||
|
||||
function directorySize(dir: string): number {
|
||||
if (!existsSync(dir)) return 0
|
||||
|
||||
let total = 0
|
||||
for (const entry of readdirSync(dir, { withFileTypes: true })) {
|
||||
const path = join(dir, entry.name)
|
||||
if (entry.isDirectory()) {
|
||||
total += directorySize(path)
|
||||
} else if (entry.isFile()) {
|
||||
total += statSync(path).size
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
function hasModelFiles(dir: string): boolean {
|
||||
if (!existsSync(dir)) return false
|
||||
|
||||
let hasOnnx = false
|
||||
let hasTokenizer = false
|
||||
const visit = (current: string) => {
|
||||
for (const entry of readdirSync(current, { withFileTypes: true })) {
|
||||
const path = join(current, entry.name)
|
||||
if (entry.isDirectory()) {
|
||||
visit(path)
|
||||
continue
|
||||
}
|
||||
|
||||
if (entry.name.endsWith('.onnx')) hasOnnx = true
|
||||
if (entry.name === 'tokenizer.json' || entry.name === 'tokenizer_config.json') hasTokenizer = true
|
||||
}
|
||||
}
|
||||
|
||||
visit(dir)
|
||||
return hasOnnx && hasTokenizer
|
||||
}
|
||||
|
||||
function getElectronAppSafe(): any | null {
|
||||
try {
|
||||
const electronModule = require('electron')
|
||||
const electronApp = electronModule && typeof electronModule === 'object' ? electronModule.app : null
|
||||
return electronApp?.getPath ? electronApp : null
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
function getEffectiveCachePathFromConfig(): string {
|
||||
const configService = new ConfigService()
|
||||
try {
|
||||
const configured = String(configService.get('cachePath' as any) || '').trim()
|
||||
if (configured) return configured
|
||||
} finally {
|
||||
configService.close()
|
||||
}
|
||||
|
||||
const electronApp = getElectronAppSafe()
|
||||
if (electronApp?.getPath) {
|
||||
const documentsPath = electronApp.getPath('documents')
|
||||
if (process.env.VITE_DEV_SERVER_URL) {
|
||||
return join(documentsPath, 'CipherTalkData')
|
||||
}
|
||||
|
||||
const installDir = dirname(electronApp.getPath('exe'))
|
||||
const isOnCDrive = /^[cC]:/i.test(installDir) || installDir.startsWith('\\\\')
|
||||
return isOnCDrive ? join(documentsPath, 'CipherTalkData') : join(installDir, 'CipherTalkData')
|
||||
}
|
||||
|
||||
return join(process.cwd(), 'CipherTalkData')
|
||||
}
|
||||
|
||||
function getDirectMLDllPath(): string | null {
|
||||
if (process.platform !== 'win32') return null
|
||||
|
||||
try {
|
||||
const ortEntry = require.resolve('onnxruntime-node')
|
||||
const arch = process.arch === 'arm64' ? 'arm64' : 'x64'
|
||||
return join(dirname(ortEntry), '..', 'bin', 'napi-v6', 'win32', arch, 'DirectML.dll')
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
function limitEmbeddingText(text: string, maxChars: number): string {
|
||||
const value = String(text || '')
|
||||
const limit = Number.isFinite(maxChars) && maxChars > 0 ? Math.floor(maxChars) : 480
|
||||
if (value.length <= limit) return value
|
||||
|
||||
const headLength = Math.max(1, Math.floor(limit * 0.75))
|
||||
const tailLength = Math.max(1, limit - headLength)
|
||||
return `${value.slice(0, headLength)}\n${value.slice(-tailLength)}`
|
||||
}
|
||||
|
||||
function tensorToVectors(output: any, expectedCount: number): Float32Array[] {
|
||||
const data = output?.data
|
||||
const dims = Array.isArray(output?.dims) ? output.dims.map((item: unknown) => Number(item)) : []
|
||||
if (!data || typeof data.length !== 'number' || dims.length === 0) {
|
||||
throw new Error('Embedding 模型输出为空')
|
||||
}
|
||||
|
||||
const dim = Number(dims[dims.length - 1] || 0)
|
||||
const batch = dims.length >= 2 ? Number(dims[0] || expectedCount) : expectedCount
|
||||
if (!Number.isInteger(dim) || dim <= 0) {
|
||||
throw new Error('Embedding 模型输出维度无效')
|
||||
}
|
||||
|
||||
const vectors: Float32Array[] = []
|
||||
for (let index = 0; index < batch; index += 1) {
|
||||
const start = index * dim
|
||||
const end = start + dim
|
||||
if (end > data.length) break
|
||||
vectors.push(Float32Array.from(data.slice(start, end)))
|
||||
}
|
||||
|
||||
if (vectors.length !== expectedCount) {
|
||||
throw new Error(`Embedding 输出数量不匹配:${vectors.length}/${expectedCount}`)
|
||||
}
|
||||
|
||||
return vectors
|
||||
}
|
||||
|
||||
export function hashEmbeddingContent(value: string): string {
|
||||
return createHash('sha256').update(value || '').digest('hex')
|
||||
}
|
||||
|
||||
export function float32ArrayToBuffer(vector: Float32Array): Buffer {
|
||||
return Buffer.from(vector.buffer.slice(vector.byteOffset, vector.byteOffset + vector.byteLength))
|
||||
}
|
||||
|
||||
function meanPoolNormalize(output: any, attentionMask: any, expectedCount: number): Float32Array[] {
|
||||
const hidden = output?.last_hidden_state || output?.token_embeddings || output?.logits
|
||||
const hiddenData = hidden?.data
|
||||
const hiddenDims = Array.isArray(hidden?.dims) ? hidden.dims.map((item: unknown) => Number(item)) : []
|
||||
const maskData = attentionMask?.data
|
||||
if (!hiddenData || hiddenDims.length !== 3 || !maskData) {
|
||||
throw new Error('Embedding 模型输出为空')
|
||||
}
|
||||
|
||||
const [batchSize, seqLength, dim] = hiddenDims
|
||||
if (batchSize !== expectedCount || !Number.isInteger(seqLength) || !Number.isInteger(dim) || dim <= 0) {
|
||||
throw new Error(`Embedding 输出维度无效:${hiddenDims.join('x')}`)
|
||||
}
|
||||
|
||||
const vectors: Float32Array[] = []
|
||||
for (let batch = 0; batch < batchSize; batch += 1) {
|
||||
const vector = new Float32Array(dim)
|
||||
let tokenCount = 0
|
||||
for (let token = 0; token < seqLength; token += 1) {
|
||||
const mask = Number(maskData[batch * seqLength + token] || 0)
|
||||
if (mask <= 0) continue
|
||||
tokenCount += mask
|
||||
const offset = (batch * seqLength + token) * dim
|
||||
for (let index = 0; index < dim; index += 1) {
|
||||
vector[index] += Number(hiddenData[offset + index]) * mask
|
||||
}
|
||||
}
|
||||
|
||||
const divisor = tokenCount > 0 ? tokenCount : 1
|
||||
let norm = 0
|
||||
for (let index = 0; index < dim; index += 1) {
|
||||
vector[index] /= divisor
|
||||
norm += vector[index] * vector[index]
|
||||
}
|
||||
norm = Math.sqrt(norm) || 1
|
||||
for (let index = 0; index < dim; index += 1) {
|
||||
vector[index] /= norm
|
||||
}
|
||||
vectors.push(vector)
|
||||
}
|
||||
|
||||
return vectors
|
||||
}
|
||||
|
||||
export class LocalEmbeddingModelService {
|
||||
private pipelines = new Map<string, Promise<{ tokenizer: any; model: any }>>()
|
||||
private downloadTasks = new Map<string, Promise<EmbeddingModelStatus>>()
|
||||
private dmlFailureReason: string | null = null
|
||||
|
||||
listProfiles(): EmbeddingModelProfile[] {
|
||||
return EMBEDDING_MODEL_PROFILES.map((profile) => ({ ...profile }))
|
||||
}
|
||||
|
||||
getProfile(profileId?: string): EmbeddingModelProfile {
|
||||
const id = safeProfileId(profileId || this.getCurrentProfileId())
|
||||
return EMBEDDING_MODEL_PROFILES.find((profile) => profile.id === id)!
|
||||
}
|
||||
|
||||
getCurrentProfileId(): EmbeddingModelProfileId {
|
||||
const configService = new ConfigService()
|
||||
try {
|
||||
return safeProfileId(configService.get('aiEmbeddingModelProfile' as any))
|
||||
} finally {
|
||||
configService.close()
|
||||
}
|
||||
}
|
||||
|
||||
setCurrentProfileId(profileId: string): EmbeddingModelProfileId {
|
||||
const id = safeProfileId(profileId)
|
||||
const configService = new ConfigService()
|
||||
try {
|
||||
configService.set('aiEmbeddingModelProfile' as any, id)
|
||||
return id
|
||||
} finally {
|
||||
configService.close()
|
||||
}
|
||||
}
|
||||
|
||||
getCurrentDevice(): EmbeddingDevice {
|
||||
const configService = new ConfigService()
|
||||
try {
|
||||
return safeEmbeddingDevice(configService.get('aiEmbeddingDevice' as any))
|
||||
} finally {
|
||||
configService.close()
|
||||
}
|
||||
}
|
||||
|
||||
setCurrentDevice(device: string): EmbeddingDevice {
|
||||
const nextDevice = safeEmbeddingDevice(device)
|
||||
const configService = new ConfigService()
|
||||
try {
|
||||
configService.set('aiEmbeddingDevice' as any, nextDevice)
|
||||
this.dmlFailureReason = null
|
||||
this.clearPipelines()
|
||||
return nextDevice
|
||||
} finally {
|
||||
configService.close()
|
||||
}
|
||||
}
|
||||
|
||||
getDeviceStatus(): EmbeddingDeviceStatus {
|
||||
const currentDevice = this.getCurrentDevice()
|
||||
const directMLDll = getDirectMLDllPath()
|
||||
const directMLAvailable = process.platform === 'win32' && !!directMLDll && existsSync(directMLDll)
|
||||
|
||||
if (currentDevice === 'dml' && this.dmlFailureReason) {
|
||||
return {
|
||||
currentDevice,
|
||||
effectiveDevice: 'cpu',
|
||||
gpuAvailable: directMLAvailable,
|
||||
provider: 'CPU',
|
||||
info: `DirectML 本次运行失败,已自动回退 CPU:${this.dmlFailureReason}`
|
||||
}
|
||||
}
|
||||
|
||||
if (currentDevice === 'dml' && directMLAvailable) {
|
||||
return {
|
||||
currentDevice,
|
||||
effectiveDevice: 'dml',
|
||||
gpuAvailable: true,
|
||||
provider: 'DirectML',
|
||||
info: 'DirectML 组件已就绪,将优先使用 GPU;推理失败时自动回退 CPU'
|
||||
}
|
||||
}
|
||||
|
||||
if (currentDevice === 'dml') {
|
||||
return {
|
||||
currentDevice,
|
||||
effectiveDevice: 'cpu',
|
||||
gpuAvailable: false,
|
||||
provider: 'CPU',
|
||||
info: process.platform === 'win32'
|
||||
? '缺少 DirectML 组件,将使用 CPU'
|
||||
: '当前系统不支持 DirectML,将使用 CPU'
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
currentDevice,
|
||||
effectiveDevice: 'cpu',
|
||||
gpuAvailable: directMLAvailable,
|
||||
provider: 'CPU',
|
||||
info: directMLAvailable ? '当前使用 CPU,可切换到 DirectML GPU 实验模式' : '当前使用 CPU'
|
||||
}
|
||||
}
|
||||
|
||||
getModelsRoot(): string {
|
||||
return join(getEffectiveCachePathFromConfig(), 'models', 'embeddings')
|
||||
}
|
||||
|
||||
getProfileDir(profileId?: string): string {
|
||||
return join(this.getModelsRoot(), this.getProfile(profileId).id)
|
||||
}
|
||||
|
||||
async getModelStatus(profileId?: string): Promise<EmbeddingModelStatus> {
|
||||
const profile = this.getProfile(profileId)
|
||||
const modelDir = this.getProfileDir(profile.id)
|
||||
const exists = hasModelFiles(modelDir)
|
||||
return {
|
||||
profileId: profile.id,
|
||||
displayName: profile.displayName,
|
||||
modelId: profile.modelId,
|
||||
dim: profile.dim,
|
||||
dtype: profile.dtype,
|
||||
sizeLabel: profile.sizeLabel,
|
||||
enabled: profile.enabled,
|
||||
exists,
|
||||
modelDir,
|
||||
sizeBytes: directorySize(modelDir)
|
||||
}
|
||||
}
|
||||
|
||||
async downloadModel(
|
||||
profileId?: string,
|
||||
onProgress?: (progress: EmbeddingDownloadProgress) => void
|
||||
): Promise<EmbeddingModelStatus> {
|
||||
const profile = this.getProfile(profileId)
|
||||
const existing = this.downloadTasks.get(profile.id)
|
||||
if (existing) return existing
|
||||
|
||||
const task = (async () => {
|
||||
mkdirSync(this.getProfileDir(profile.id), { recursive: true })
|
||||
await this.downloadPipelineWithFallback(profile, onProgress)
|
||||
return this.getModelStatus(profile.id)
|
||||
})()
|
||||
|
||||
this.downloadTasks.set(profile.id, task)
|
||||
try {
|
||||
return await task
|
||||
} finally {
|
||||
this.downloadTasks.delete(profile.id)
|
||||
}
|
||||
}
|
||||
|
||||
async clearModel(profileId?: string): Promise<EmbeddingModelStatus> {
|
||||
const profile = this.getProfile(profileId)
|
||||
this.clearPipelines(profile.id)
|
||||
rmSync(this.getProfileDir(profile.id), { recursive: true, force: true })
|
||||
return this.getModelStatus(profile.id)
|
||||
}
|
||||
|
||||
async ensureModelReady(profileId?: string): Promise<EmbeddingModelStatus> {
|
||||
const status = await this.getModelStatus(profileId)
|
||||
if (!status.exists) {
|
||||
throw new Error(`本地语义模型未下载:${status.displayName}`)
|
||||
}
|
||||
return status
|
||||
}
|
||||
|
||||
async embedTexts(texts: string[], profileId?: string): Promise<Float32Array[]> {
|
||||
const profile = this.getProfile(profileId)
|
||||
const cleaned = texts.map((text) => limitEmbeddingText(String(text || ''), profile.maxTextChars))
|
||||
await this.ensureModelReady(profile.id)
|
||||
const deviceStatus = this.getDeviceStatus()
|
||||
|
||||
if (deviceStatus.effectiveDevice === 'dml') {
|
||||
try {
|
||||
return await this.runEmbedding(profile, cleaned, 'dml')
|
||||
} catch (error) {
|
||||
console.warn('[Embedding] DirectML 推理失败,回退 CPU:', error)
|
||||
this.dmlFailureReason = String(error instanceof Error ? error.message : error)
|
||||
this.clearPipelines(profile.id, 'dml')
|
||||
}
|
||||
}
|
||||
|
||||
return this.runEmbedding(profile, cleaned, 'cpu')
|
||||
}
|
||||
|
||||
async embedText(text: string, profileId?: string): Promise<Float32Array> {
|
||||
const [vector] = await this.embedTexts([text], profileId)
|
||||
return vector
|
||||
}
|
||||
|
||||
private async runEmbedding(
|
||||
profile: EmbeddingModelProfile,
|
||||
texts: string[],
|
||||
device: EmbeddingDevice
|
||||
): Promise<Float32Array[]> {
|
||||
const runtime = await this.getPipeline(profile, true, device)
|
||||
const modelInputs = runtime.tokenizer(texts, {
|
||||
padding: true,
|
||||
truncation: true,
|
||||
max_length: profile.maxTokens
|
||||
})
|
||||
const output = await runtime.model(modelInputs)
|
||||
return meanPoolNormalize(output, modelInputs.attention_mask, texts.length)
|
||||
}
|
||||
|
||||
private async getPipeline(
|
||||
profile: EmbeddingModelProfile,
|
||||
localOnly: boolean,
|
||||
device: EmbeddingDevice = 'cpu',
|
||||
remoteHost?: string,
|
||||
progressCallback?: (event: any) => void
|
||||
): Promise<{ tokenizer: any; model: any }> {
|
||||
const key = `${profile.id}:${device}:${localOnly ? 'local' : remoteHost || 'remote'}`
|
||||
const existing = this.pipelines.get(key)
|
||||
if (existing) return existing
|
||||
|
||||
const promise = (async () => {
|
||||
const transformers = await import('@huggingface/transformers')
|
||||
transformers.env.allowLocalModels = true
|
||||
transformers.env.allowRemoteModels = !localOnly
|
||||
transformers.env.cacheDir = this.getProfileDir(profile.id)
|
||||
if (remoteHost) {
|
||||
transformers.env.remoteHost = remoteHost
|
||||
transformers.env.remotePathTemplate = profile.remotePathTemplate
|
||||
}
|
||||
|
||||
const commonOptions = {
|
||||
cache_dir: this.getProfileDir(profile.id),
|
||||
local_files_only: localOnly,
|
||||
revision: profile.revision,
|
||||
progress_callback: progressCallback
|
||||
}
|
||||
const tokenizer = await transformers.AutoTokenizer.from_pretrained(profile.modelId, commonOptions as any)
|
||||
const model = await transformers.AutoModel.from_pretrained(profile.modelId, {
|
||||
...commonOptions,
|
||||
device,
|
||||
dtype: profile.dtype
|
||||
} as any)
|
||||
return { tokenizer, model }
|
||||
})()
|
||||
|
||||
this.pipelines.set(key, promise)
|
||||
try {
|
||||
return await promise
|
||||
} catch (error) {
|
||||
this.pipelines.delete(key)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
private clearPipelines(profileId?: string, device?: EmbeddingDevice): void {
|
||||
for (const key of Array.from(this.pipelines.keys())) {
|
||||
const matchesProfile = !profileId || key.startsWith(`${profileId}:`)
|
||||
const matchesDevice = !device || key.includes(`:${device}:`)
|
||||
if (matchesProfile && matchesDevice) {
|
||||
this.pipelines.delete(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async downloadPipelineWithFallback(
|
||||
profile: EmbeddingModelProfile,
|
||||
onProgress?: (progress: EmbeddingDownloadProgress) => void
|
||||
): Promise<void> {
|
||||
const errors: string[] = []
|
||||
|
||||
for (const remoteHost of profile.remoteHosts) {
|
||||
try {
|
||||
onProgress?.({
|
||||
profileId: profile.id,
|
||||
displayName: profile.displayName,
|
||||
remoteHost,
|
||||
status: 'initiate'
|
||||
})
|
||||
|
||||
await this.getPipeline(profile, false, 'cpu', remoteHost, (event) => {
|
||||
const loaded = Number(event?.loaded || 0)
|
||||
const total = Number(event?.total || 0)
|
||||
onProgress?.({
|
||||
profileId: profile.id,
|
||||
displayName: profile.displayName,
|
||||
remoteHost,
|
||||
file: String(event?.file || event?.name || ''),
|
||||
loaded: Number.isFinite(loaded) && loaded > 0 ? loaded : undefined,
|
||||
total: Number.isFinite(total) && total > 0 ? total : undefined,
|
||||
percent: total > 0 ? Math.min(100, Math.round((loaded / total) * 100)) : undefined,
|
||||
status: String(event?.status || '')
|
||||
})
|
||||
})
|
||||
return
|
||||
} catch (error) {
|
||||
errors.push(`${remoteHost}: ${String(error)}`)
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error(`语义模型下载失败。已尝试 ModelScope/魔塔社区:${profile.remoteHosts.join('、')}。请检查网络/代理或稍后重试。${errors.length ? ` 原始错误:${errors.join(' | ')}` : ''}`)
|
||||
}
|
||||
}
|
||||
|
||||
export const localEmbeddingModelService = new LocalEmbeddingModelService()
|
||||
Reference in New Issue
Block a user