feat: integrate embedding model functionality and UI enhancements

- Added @huggingface/transformers and sqlite-vec dependencies to package.json.
- Updated electron-builder configuration to include sqlite-vec files.
- Enhanced AISummarySettings component with new styles and layout for embedding model selection and device configuration.
- Implemented embedding model loading, status checking, and download functionality in AISummarySettings.
- Added new types for embedding models and device status in ai.ts and electron.d.ts.
- Updated config service to manage embedding model profile and device settings.
- Modified AISummaryWindow to reflect changes in vector indexing messages and statuses.
This commit is contained in:
ILoveBingLu
2026-04-25 23:02:57 +08:00
parent 43a1608868
commit c1bf514b4a
14 changed files with 1741 additions and 325 deletions
+105
View File
@@ -4145,6 +4145,111 @@ function registerIpcHandlers() {
return { success: false, error: String(e) }
}
})
ipcMain.handle('ai:getEmbeddingModelProfiles', async () => {
try {
const { localEmbeddingModelService } = await import('./services/search/embeddingModelService')
return {
success: true,
result: localEmbeddingModelService.listProfiles(),
currentProfileId: localEmbeddingModelService.getCurrentProfileId()
}
} catch (e) {
console.error('[AI] 获取语义模型列表失败:', e)
return { success: false, error: String(e) }
}
})
ipcMain.handle('ai:setEmbeddingModelProfile', async (_, profileId: string) => {
try {
const { localEmbeddingModelService } = await import('./services/search/embeddingModelService')
const result = localEmbeddingModelService.setCurrentProfileId(profileId)
return { success: true, result }
} catch (e) {
console.error('[AI] 设置语义模型失败:', e)
return { success: false, error: String(e) }
}
})
ipcMain.handle('ai:getEmbeddingDeviceStatus', async () => {
try {
const { localEmbeddingModelService } = await import('./services/search/embeddingModelService')
return {
success: true,
result: localEmbeddingModelService.getDeviceStatus()
}
} catch (e) {
console.error('[AI] 获取语义向量计算模式失败:', e)
return { success: false, error: String(e) }
}
})
ipcMain.handle('ai:setEmbeddingDevice', async (_, device: string) => {
try {
const { localEmbeddingModelService } = await import('./services/search/embeddingModelService')
const result = localEmbeddingModelService.setCurrentDevice(device)
return {
success: true,
result,
status: localEmbeddingModelService.getDeviceStatus()
}
} catch (e) {
console.error('[AI] 设置语义向量计算模式失败:', e)
return { success: false, error: String(e) }
}
})
ipcMain.handle('ai:getEmbeddingModelStatus', async (_, profileId?: string) => {
try {
const { localEmbeddingModelService } = await import('./services/search/embeddingModelService')
return {
success: true,
result: await localEmbeddingModelService.getModelStatus(profileId)
}
} catch (e) {
console.error('[AI] 获取语义模型状态失败:', e)
return { success: false, error: String(e) }
}
})
ipcMain.handle('ai:downloadEmbeddingModel', async (event, profileId?: string) => {
try {
const { localEmbeddingModelService } = await import('./services/search/embeddingModelService')
const result = await localEmbeddingModelService.downloadModel(profileId, (progress) => {
event.sender.send('ai:embeddingModelDownloadProgress', progress)
})
return { success: true, result }
} catch (e) {
console.error('[AI] 下载语义模型失败:', e)
return { success: false, error: String(e) }
}
})
ipcMain.handle('ai:clearEmbeddingModel', async (_, profileId?: string) => {
try {
const { localEmbeddingModelService } = await import('./services/search/embeddingModelService')
return {
success: true,
result: await localEmbeddingModelService.clearModel(profileId)
}
} catch (e) {
console.error('[AI] 清理语义模型失败:', e)
return { success: false, error: String(e) }
}
})
ipcMain.handle('ai:clearSemanticVectorIndex', async (_, vectorModel?: string) => {
try {
const { chatSearchIndexService } = await import('./services/search/chatSearchIndexService')
return {
success: true,
result: chatSearchIndexService.clearSemanticVectorIndex(vectorModel)
}
} catch (e) {
console.error('[AI] 清理语义向量索引失败:', e)
return { success: false, error: String(e) }
}
})
}
// 主窗口引用
+23
View File
@@ -20,6 +20,17 @@ type SessionVectorIndexProgressEvent = {
vectorModel: string
}
type EmbeddingModelDownloadProgress = {
profileId: string
displayName: string
remoteHost?: string
file?: string
loaded?: number
total?: number
percent?: number
status?: string
}
function getMcpLaunchConfigSafe(): Promise<{
command: string
args: string[]
@@ -543,6 +554,14 @@ contextBridge.exposeInMainWorld('electronAPI', {
getSessionVectorIndexState: (sessionId: string) => ipcRenderer.invoke('ai:getSessionVectorIndexState', sessionId),
prepareSessionVectorIndex: (options: { sessionId: string }) => ipcRenderer.invoke('ai:prepareSessionVectorIndex', options),
cancelSessionVectorIndex: (sessionId: string) => ipcRenderer.invoke('ai:cancelSessionVectorIndex', sessionId),
getEmbeddingModelProfiles: () => ipcRenderer.invoke('ai:getEmbeddingModelProfiles'),
setEmbeddingModelProfile: (profileId: string) => ipcRenderer.invoke('ai:setEmbeddingModelProfile', profileId),
getEmbeddingDeviceStatus: () => ipcRenderer.invoke('ai:getEmbeddingDeviceStatus'),
setEmbeddingDevice: (device: 'cpu' | 'dml') => ipcRenderer.invoke('ai:setEmbeddingDevice', device),
getEmbeddingModelStatus: (profileId?: string) => ipcRenderer.invoke('ai:getEmbeddingModelStatus', profileId),
downloadEmbeddingModel: (profileId?: string) => ipcRenderer.invoke('ai:downloadEmbeddingModel', profileId),
clearEmbeddingModel: (profileId?: string) => ipcRenderer.invoke('ai:clearEmbeddingModel', profileId),
clearSemanticVectorIndex: (vectorModel?: string) => ipcRenderer.invoke('ai:clearSemanticVectorIndex', vectorModel),
onSummaryChunk: (callback: (chunk: string) => void) => {
ipcRenderer.on('ai:summaryChunk', (_, chunk) => callback(chunk))
return () => ipcRenderer.removeAllListeners('ai:summaryChunk')
@@ -558,6 +577,10 @@ contextBridge.exposeInMainWorld('electronAPI', {
onSessionVectorIndexProgress: (callback: (event: SessionVectorIndexProgressEvent) => void) => {
ipcRenderer.on('ai:sessionVectorIndexProgress', (_, event) => callback(event))
return () => ipcRenderer.removeAllListeners('ai:sessionVectorIndexProgress')
},
onEmbeddingModelDownloadProgress: (callback: (event: EmbeddingModelDownloadProgress) => void) => {
ipcRenderer.on('ai:embeddingModelDownloadProgress', (_, event) => callback(event))
return () => ipcRenderer.removeAllListeners('ai:embeddingModelDownloadProgress')
}
}
})
+4
View File
@@ -110,6 +110,8 @@ interface ConfigSchema {
aiEnableCache: boolean
aiEnableThinking: boolean // 是否显示思考过程
aiMessageLimit: number // 摘要提取的消息条数限制
aiEmbeddingModelProfile: string
aiEmbeddingDevice: 'cpu' | 'dml'
mcpEnabled: boolean
mcpExposeMediaPaths: boolean
mcpProxyPort: number
@@ -172,6 +174,8 @@ const defaults: ConfigSchema = {
aiEnableCache: true,
aiEnableThinking: true, // 默认显示思考过程
aiMessageLimit: 3000, // 默认3000条,用户可调至5000
aiEmbeddingModelProfile: 'bge-large-zh-v1.5-int8',
aiEmbeddingDevice: 'cpu',
mcpEnabled: false,
mcpExposeMediaPaths: true,
mcpProxyPort: 5032,
+219 -274
View File
@@ -3,6 +3,11 @@ import { existsSync, mkdirSync } from 'fs'
import { join } from 'path'
import { chatService, type Message } from '../chatService'
import { ConfigService } from '../config'
import {
float32ArrayToBuffer,
hashEmbeddingContent,
localEmbeddingModelService
} from './embeddingModelService'
export type ChatSearchIndexProgressStage =
| 'preparing_index'
@@ -56,6 +61,7 @@ export interface ChatSearchSessionResult {
export type ChatVectorIndexProgressStage =
| 'preparing'
| 'downloading_model'
| 'indexing_messages'
| 'vectorizing_messages'
| 'completed'
@@ -84,6 +90,9 @@ export interface ChatVectorIndexState {
isVectorComplete: boolean
isVectorRunning: boolean
vectorModel: string
vectorModelName?: string
vectorProviderAvailable?: boolean
vectorProviderError?: string
}
export interface ChatVectorSearchSessionResult {
@@ -112,17 +121,7 @@ type MessageIndexRow = {
}
type MessageVectorRow = MessageIndexRow & {
vector_json: string
}
type SparseVector = Array<[number, number]>
interface LocalVectorProvider {
id: string
buildVector(text: string): SparseVector
parseVector(value: string): SparseVector
toWeightMap(vector: SparseVector): Map<number, number>
dot(queryWeights: Map<number, number>, vector: SparseVector): number
distance: number
}
type SessionVectorStateRow = {
@@ -141,81 +140,18 @@ type VectorTask = {
}
const INDEX_DB_NAME = 'chat_search_index.db'
const INDEX_SCHEMA_VERSION = '1'
const INDEX_SCHEMA_VERSION = '3'
const INDEX_BATCH_SIZE = 800
const MAX_INDEX_TEXT_CHARS = 8000
const MAX_EXCERPT_RADIUS = 48
const MAX_INDEX_SEARCH_CANDIDATES = 240
const VECTOR_MODEL_ID = 'local-chargram-hash-v1'
const VECTOR_DIMENSIONS = 2048
const VECTOR_BATCH_SIZE = 800
const MAX_VECTOR_TEXT_CHARS = 2400
const MAX_VECTOR_SCAN_ROWS = 120000
const VECTOR_MIN_SCORE = 0.055
const VECTOR_BATCH_SIZE = 32
const VECTOR_SEARCH_OVERFETCH = 8
const VECTOR_MIN_SCORE = 0.45
// Vector hits are recall supplements, so keep them below high-confidence keyword hits.
const VECTOR_SCORE_BASE = 560
const VECTOR_SCORE_SCALE = 420
const VECTOR_INDEX_CANCELLED_ERROR = 'VECTOR_INDEX_CANCELLED'
const VECTOR_STOP_PHRASES = [
'有没有',
'是不是',
'是否',
'什么',
'哪个',
'哪些',
'什么时候',
'为什么',
'怎么',
'如何',
'帮我',
'看看',
'请问',
'问一下',
'聊天记录',
'聊天',
'消息',
'记录',
'内容',
'这个',
'那个',
'我们',
'他们',
'对方',
'最近',
'说过',
'提到',
'关于',
'一下'
]
const VECTOR_STOP_WORDS = new Set([
'the',
'and',
'for',
'with',
'that',
'this',
'what',
'when',
'where',
'which',
'why',
'how',
'have',
'has',
'是否',
'什么',
'哪个',
'哪些',
'怎么',
'如何',
'我们',
'他们',
'对方',
'消息',
'聊天',
'记录',
'内容'
])
function cursorKey(message: Pick<Message, 'localId' | 'createTime' | 'sortSeq'>): string {
return `${Number(message.localId || 0)}:${Number(message.createTime || 0)}:${Number(message.sortSeq || 0)}`
@@ -239,6 +175,15 @@ function compareIndexRowCursorAsc(
|| Number(a.local_id || 0) - Number(b.local_id || 0)
}
function vectorSessionKey(sessionId: string): number {
let hash = 2166136261
for (let index = 0; index < sessionId.length; index += 1) {
hash ^= sessionId.charCodeAt(index)
hash = Math.imul(hash, 16777619)
}
return hash >>> 0
}
function normalizeSearchText(value?: string): string {
return String(value || '')
.toLowerCase()
@@ -338,128 +283,6 @@ function buildSearchTokens(value: string): string {
return uniqueStrings(tokens).join(' ')
}
function normalizeVectorText(value: string): string {
let normalized = normalizeSearchText(value).slice(0, MAX_VECTOR_TEXT_CHARS)
for (const phrase of VECTOR_STOP_PHRASES) {
normalized = normalized.replace(new RegExp(phrase, 'gi'), ' ')
}
return normalized.replace(/\s+/g, ' ').trim()
}
function hashString(value: string): number {
let hash = 2166136261
for (let index = 0; index < value.length; index += 1) {
hash ^= value.charCodeAt(index)
hash = Math.imul(hash, 16777619)
}
return hash >>> 0
}
function addVectorFeature(weights: Map<number, number>, feature: string, weight: number): void {
if (!feature || VECTOR_STOP_WORDS.has(feature)) return
const hash = hashString(feature)
const dimension = hash % VECTOR_DIMENSIONS
const signedWeight = (hash & 0x80000000) ? -weight : weight
weights.set(dimension, (weights.get(dimension) || 0) + signedWeight)
}
function addChineseVectorFeatures(weights: Map<number, number>, segment: string): void {
if (!segment || VECTOR_STOP_WORDS.has(segment)) return
if (segment.length >= 2 && segment.length <= 12) {
addVectorFeature(weights, `zh:${segment}`, 1.35)
}
for (let size = 2; size <= 4; size += 1) {
if (segment.length < size) continue
const weight = size === 2 ? 0.9 : size === 3 ? 1.1 : 0.85
for (let index = 0; index <= segment.length - size; index += 1) {
const gram = segment.slice(index, index + size)
if (VECTOR_STOP_WORDS.has(gram)) continue
addVectorFeature(weights, `c${size}:${gram}`, weight)
}
}
}
function buildLocalSearchVector(value: string): SparseVector {
const normalized = normalizeVectorText(value)
if (!normalized) return []
const weights = new Map<number, number>()
const latinWords: string[] = []
for (const match of normalized.matchAll(/[\u3400-\u9fff]+/g)) {
addChineseVectorFeatures(weights, match[0])
}
for (const match of normalized.matchAll(/[a-z0-9_@.\-]{2,}/g)) {
const word = match[0]
if (VECTOR_STOP_WORDS.has(word)) continue
latinWords.push(word)
addVectorFeature(weights, `w:${word}`, 1.2)
}
for (let index = 0; index < latinWords.length - 1; index += 1) {
addVectorFeature(weights, `wb:${latinWords[index]} ${latinWords[index + 1]}`, 1.35)
}
let norm = 0
for (const weight of weights.values()) {
norm += weight * weight
}
if (norm <= 0) return []
const scale = Math.sqrt(norm)
return Array.from(weights.entries())
.map(([dimension, weight]) => [dimension, Number((weight / scale).toFixed(6))] as [number, number])
.filter(([, weight]) => Math.abs(weight) > 0.000001)
.sort((a, b) => a[0] - b[0])
}
function parseSparseVector(value: string): SparseVector {
try {
const parsed = JSON.parse(value)
if (!Array.isArray(parsed)) return []
const vector: SparseVector = []
for (const item of parsed) {
if (!Array.isArray(item) || item.length < 2) continue
const dimension = Number(item[0])
const weight = Number(item[1])
if (!Number.isInteger(dimension) || dimension < 0 || !Number.isFinite(weight)) continue
vector.push([dimension, weight])
}
return vector
} catch {
return []
}
}
function dotSparseVector(queryWeights: Map<number, number>, vector: SparseVector): number {
let score = 0
for (const [dimension, weight] of vector) {
const queryWeight = queryWeights.get(dimension)
if (queryWeight) score += queryWeight * weight
}
return score
}
function sparseVectorToMap(vector: SparseVector): Map<number, number> {
return new Map(vector)
}
const localVectorProvider: LocalVectorProvider = {
id: VECTOR_MODEL_ID,
buildVector: buildLocalSearchVector,
parseVector: parseSparseVector,
toWeightMap: sparseVectorToMap,
dot: dotSparseVector
}
function createVectorExcerpt(row: Pick<MessageIndexRow, 'parsed_content' | 'search_text'>, query: string): string {
const text = String(row.parsed_content || row.search_text || '')
if (!text) return ''
@@ -629,6 +452,8 @@ export class ChatSearchIndexService {
private db: Database.Database | null = null
private dbPath: string | null = null
private vectorTasks = new Map<string, VectorTask>()
private sqliteVectorAvailable = false
private sqliteVectorError = ''
private getCacheBasePath(): string {
const configService = new ConfigService()
@@ -662,10 +487,24 @@ export class ChatSearchIndexService {
const db = new Database(nextDbPath)
this.db = db
this.dbPath = nextDbPath
this.loadSqliteVectorExtension(db)
this.ensureSchema(db)
return db
}
private loadSqliteVectorExtension(db: Database.Database): void {
try {
const sqliteVec = require('sqlite-vec') as { load: (db: Database.Database) => void }
sqliteVec.load(db)
this.sqliteVectorAvailable = true
this.sqliteVectorError = ''
} catch (error) {
this.sqliteVectorAvailable = false
this.sqliteVectorError = String(error)
console.warn('[ChatSearchIndex] sqlite-vec 加载失败,语义向量检索将降级为关键词检索:', error)
}
}
private ensureSchema(db: Database.Database): void {
db.pragma('journal_mode = WAL')
db.pragma('synchronous = NORMAL')
@@ -720,12 +559,15 @@ export class ChatSearchIndexService {
);
CREATE TABLE IF NOT EXISTS message_vector_index (
message_id INTEGER PRIMARY KEY,
id INTEGER PRIMARY KEY AUTOINCREMENT,
message_id INTEGER NOT NULL,
session_id TEXT NOT NULL,
vector_model TEXT NOT NULL,
vector_json TEXT NOT NULL,
feature_count INTEGER NOT NULL DEFAULT 0,
indexed_at INTEGER NOT NULL
embedding_blob BLOB NOT NULL,
dim INTEGER NOT NULL,
content_hash TEXT NOT NULL,
indexed_at INTEGER NOT NULL,
UNIQUE(message_id, vector_model)
);
CREATE TABLE IF NOT EXISTS session_vector_state (
@@ -745,16 +587,32 @@ export class ChatSearchIndexService {
ON message_index(session_id, sender_username);
CREATE INDEX IF NOT EXISTS idx_message_vector_session_model
ON message_vector_index(session_id, vector_model);
CREATE INDEX IF NOT EXISTS idx_message_vector_message_model
ON message_vector_index(message_id, vector_model);
CREATE INDEX IF NOT EXISTS idx_session_vector_state_session
ON session_vector_state(session_id);
`)
if (this.sqliteVectorAvailable) {
const dim = localEmbeddingModelService.getProfile().dim
db.exec(`
CREATE VIRTUAL TABLE IF NOT EXISTS message_embedding_vec USING vec0(
vector_id INTEGER PRIMARY KEY,
session_key INTEGER PARTITION KEY,
session_id TEXT,
vector_model TEXT,
embedding FLOAT[${dim}]
);
`)
}
db.prepare('INSERT OR REPLACE INTO meta(key, value) VALUES (?, ?)').run('schema_version', INDEX_SCHEMA_VERSION)
}
private resetSchema(db: Database.Database): void {
db.exec(`
DROP TABLE IF EXISTS message_index_fts;
DROP TABLE IF EXISTS message_embedding_vec;
DROP TABLE IF EXISTS message_vector_index;
DROP TABLE IF EXISTS session_vector_state;
DROP TABLE IF EXISTS message_index;
@@ -783,26 +641,36 @@ export class ChatSearchIndexService {
return Number(row?.count || 0)
}
private getCurrentVectorProfile() {
return localEmbeddingModelService.getProfile()
}
private getCurrentVectorModelId(): string {
return this.getCurrentVectorProfile().id
}
private getVectorizedCount(db: Database.Database, sessionId: string): number {
const vectorModel = this.getCurrentVectorModelId()
const row = db.prepare(`
SELECT COUNT(*) AS count
FROM message_index m
JOIN message_vector_index v ON v.message_id = m.id
WHERE m.session_id = ? AND v.vector_model = ?
`).get(sessionId, localVectorProvider.id) as { count?: number }
`).get(sessionId, vectorModel) as { count?: number }
return Number(row?.count || 0)
}
private getVectorTaskKey(sessionId: string): string {
return `${sessionId}:${localVectorProvider.id}`
return `${sessionId}:${this.getCurrentVectorModelId()}`
}
private getVectorStateRow(db: Database.Database, sessionId: string): SessionVectorStateRow | null {
const vectorModel = this.getCurrentVectorModelId()
const row = db.prepare(`
SELECT *
FROM session_vector_state
WHERE session_id = ? AND vector_model = ?
`).get(sessionId, localVectorProvider.id) as SessionVectorStateRow | undefined
`).get(sessionId, vectorModel) as SessionVectorStateRow | undefined
return row || null
}
@@ -836,7 +704,7 @@ export class ChatSearchIndexService {
last_error = excluded.last_error
`).run(
input.sessionId,
localVectorProvider.id,
this.getCurrentVectorModelId(),
input.confirmedAt ?? null,
input.completedAt ?? null,
now,
@@ -847,11 +715,13 @@ export class ChatSearchIndexService {
getSessionVectorIndexState(sessionId: string): ChatVectorIndexState {
const db = this.getDb()
const profile = this.getCurrentVectorProfile()
const indexedCount = this.getIndexedCount(db, sessionId)
const vectorizedCount = this.getVectorizedCount(db, sessionId)
const isRunning = this.vectorTasks.has(this.getVectorTaskKey(sessionId))
const row = this.getVectorStateRow(db, sessionId)
const isComplete = Number(row?.is_complete || 0) === 1
const isComplete = this.sqliteVectorAvailable
&& Number(row?.is_complete || 0) === 1
&& vectorizedCount >= indexedCount
return {
@@ -861,7 +731,10 @@ export class ChatSearchIndexService {
pendingCount: Math.max(0, indexedCount - vectorizedCount),
isVectorComplete: isComplete,
isVectorRunning: isRunning,
vectorModel: localVectorProvider.id
vectorModel: profile.id,
vectorModelName: profile.displayName,
vectorProviderAvailable: this.sqliteVectorAvailable,
vectorProviderError: this.sqliteVectorError
}
}
@@ -873,56 +746,78 @@ export class ChatSearchIndexService {
progress: Omit<ChatVectorIndexProgress, 'vectorModel'>,
onProgress?: (progress: ChatVectorIndexProgress) => void | Promise<void>
): Promise<void> {
const profile = this.getCurrentVectorProfile()
await onProgress?.({
...progress,
vectorModel: localVectorProvider.id
vectorModel: profile.id
})
}
private upsertVectorRows(
private async upsertVectorRows(
db: Database.Database,
rows: Array<Pick<MessageIndexRow, 'id' | 'session_id' | 'search_text'> & { indexed_at?: number }>
): void {
): Promise<void> {
if (rows.length === 0) return
if (!this.sqliteVectorAvailable) {
throw new Error(`本地语义检索不可用:${this.sqliteVectorError || 'sqlite-vec 未加载'}`)
}
const profile = this.getCurrentVectorProfile()
const embeddings = await localEmbeddingModelService.embedTexts(rows.map((row) => row.search_text), profile.id)
const upsertVector = db.prepare(`
INSERT INTO message_vector_index (
message_id,
session_id,
vector_model,
vector_json,
feature_count,
embedding_blob,
dim,
content_hash,
indexed_at
) VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(message_id) DO UPDATE SET
) VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(message_id, vector_model) DO UPDATE SET
session_id = excluded.session_id,
vector_model = excluded.vector_model,
vector_json = excluded.vector_json,
feature_count = excluded.feature_count,
embedding_blob = excluded.embedding_blob,
dim = excluded.dim,
content_hash = excluded.content_hash,
indexed_at = excluded.indexed_at
`)
const selectVectorId = db.prepare(`
SELECT id FROM message_vector_index
WHERE message_id = ? AND vector_model = ?
`)
const upsertVec = db.prepare(`
INSERT OR REPLACE INTO message_embedding_vec(vector_id, session_key, session_id, vector_model, embedding)
VALUES (?, ?, ?, ?, ?)
`)
const run = db.transaction((items: Array<Pick<MessageIndexRow, 'id' | 'session_id' | 'search_text'> & { indexed_at?: number }>) => {
const now = Date.now()
for (const row of items) {
const vector = localVectorProvider.buildVector(row.search_text)
for (let index = 0; index < items.length; index += 1) {
const row = items[index]
const vector = embeddings[index]
upsertVector.run(
row.id,
row.session_id,
localVectorProvider.id,
JSON.stringify(vector),
profile.id,
float32ArrayToBuffer(vector),
vector.length,
hashEmbeddingContent(row.search_text),
row.indexed_at || now
)
const vectorRow = selectVectorId.get(row.id, profile.id) as { id?: number } | undefined
if (vectorRow?.id) {
upsertVec.run(vectorRow.id, vectorSessionKey(row.session_id), row.session_id, profile.id, float32ArrayToBuffer(vector))
}
}
})
run(rows)
}
private upsertMessages(db: Database.Database, sessionId: string, messages: Message[], options: {
private async upsertMessages(db: Database.Database, sessionId: string, messages: Message[], options: {
vectorize?: boolean
} = {}): void {
} = {}): Promise<void> {
if (messages.length === 0) return
const upsert = db.prepare(`
@@ -978,22 +873,7 @@ export class ChatSearchIndexService {
INSERT INTO message_index_fts(rowid, session_id, cursor_key, search_text, token_text)
VALUES (?, ?, ?, ?, ?)
`)
const upsertVector = db.prepare(`
INSERT INTO message_vector_index (
message_id,
session_id,
vector_model,
vector_json,
feature_count,
indexed_at
) VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(message_id) DO UPDATE SET
session_id = excluded.session_id,
vector_model = excluded.vector_model,
vector_json = excluded.vector_json,
feature_count = excluded.feature_count,
indexed_at = excluded.indexed_at
`)
const vectorRows: Array<Pick<MessageIndexRow, 'id' | 'session_id' | 'search_text'> & { indexed_at?: number }> = []
const run = db.transaction((items: Message[]) => {
const indexedAt = Date.now()
@@ -1024,20 +904,29 @@ export class ChatSearchIndexService {
deleteFts.run(row.id)
insertFts.run(row.id, sessionId, cursorKey(message), searchText, tokenText)
if (options.vectorize) {
const vector = localVectorProvider.buildVector(searchText)
upsertVector.run(
row.id,
sessionId,
localVectorProvider.id,
JSON.stringify(vector),
vector.length,
indexedAt
)
vectorRows.push({
id: row.id,
session_id: sessionId,
search_text: searchText,
indexed_at: indexedAt
})
}
}
})
run(messages)
if (vectorRows.length > 0) {
try {
await this.upsertVectorRows(db, vectorRows)
} catch (error) {
this.setSessionVectorState(db, {
sessionId,
completedAt: null,
isComplete: false,
lastError: String(error)
})
}
}
}
private updateSessionState(db: Database.Database, sessionId: string, newest: Message | null, isComplete: boolean): ChatSearchIndexState {
@@ -1149,7 +1038,7 @@ export class ChatSearchIndexService {
const messages = result.messages || []
if (messages.length === 0) break
this.upsertMessages(db, sessionId, messages, { vectorize: vectorizeDuringIndexing })
await this.upsertMessages(db, sessionId, messages, { vectorize: vectorizeDuringIndexing })
scanned += messages.length
newest = messages[messages.length - 1] || newest
cursor = {
@@ -1181,6 +1070,12 @@ export class ChatSearchIndexService {
}
db.prepare('DELETE FROM message_index_fts WHERE session_id = ?').run(sessionId)
if (this.sqliteVectorAvailable) {
db.prepare(`
DELETE FROM message_embedding_vec
WHERE vector_id IN (SELECT id FROM message_vector_index WHERE session_id = ?)
`).run(sessionId)
}
db.prepare('DELETE FROM message_vector_index WHERE session_id = ?').run(sessionId)
db.prepare('DELETE FROM message_index WHERE session_id = ?').run(sessionId)
db.prepare('DELETE FROM session_index_state WHERE session_id = ?').run(sessionId)
@@ -1193,7 +1088,7 @@ export class ChatSearchIndexService {
let messages = firstPage.messages || []
let hasMore = Boolean(firstPage.hasMore)
if (messages.length > 0) {
this.upsertMessages(db, sessionId, messages, { vectorize: vectorizeDuringIndexing })
await this.upsertMessages(db, sessionId, messages, { vectorize: vectorizeDuringIndexing })
scanned += messages.length
newest = messages[messages.length - 1]
await this.report({
@@ -1220,7 +1115,7 @@ export class ChatSearchIndexService {
messages = result.messages || []
if (messages.length === 0) break
this.upsertMessages(db, sessionId, messages, { vectorize: vectorizeDuringIndexing })
await this.upsertMessages(db, sessionId, messages, { vectorize: vectorizeDuringIndexing })
scanned += messages.length
hasMore = Boolean(result.hasMore)
@@ -1284,6 +1179,21 @@ export class ChatSearchIndexService {
return this.getSessionVectorIndexState(sessionId)
}
clearSemanticVectorIndex(vectorModel = this.getCurrentVectorModelId()): { success: boolean; deletedCount: number; vectorModel: string } {
const db = this.getDb()
const row = db.prepare('SELECT COUNT(*) AS count FROM message_vector_index WHERE vector_model = ?').get(vectorModel) as { count?: number }
if (this.sqliteVectorAvailable) {
db.prepare('DELETE FROM message_embedding_vec WHERE vector_model = ?').run(vectorModel)
}
db.prepare('DELETE FROM message_vector_index WHERE vector_model = ?').run(vectorModel)
db.prepare('DELETE FROM session_vector_state WHERE vector_model = ?').run(vectorModel)
return {
success: true,
deletedCount: Number(row?.count || 0),
vectorModel
}
}
private async runPrepareSessionVectorIndex(
sessionId: string,
task: VectorTask,
@@ -1301,6 +1211,35 @@ export class ChatSearchIndexService {
}, onProgress)
try {
if (!this.sqliteVectorAvailable) {
throw new Error(`本地语义检索不可用:${this.sqliteVectorError || 'sqlite-vec 未加载'}`)
}
const profile = this.getCurrentVectorProfile()
const modelStatus = await localEmbeddingModelService.getModelStatus(profile.id)
if (!modelStatus.exists) {
await this.reportVectorProgress({
sessionId,
stage: 'downloading_model',
status: 'running',
processedCount: 0,
totalCount: 0,
message: `正在下载本地语义模型:${profile.displayName}`
}, onProgress)
await localEmbeddingModelService.downloadModel(profile.id, async (progress) => {
await this.reportVectorProgress({
sessionId,
stage: 'downloading_model',
status: 'running',
processedCount: progress.loaded || 0,
totalCount: progress.total || 0,
message: progress.percent !== undefined
? `正在下载 ${profile.displayName}${progress.percent}%`
: `正在下载 ${profile.displayName}`
}, onProgress)
})
}
const searchState = await this.ensureSessionIndexed(sessionId, async (progress) => {
if (task.cancelRequested) {
throw new Error(VECTOR_INDEX_CANCELLED_ERROR)
@@ -1326,7 +1265,7 @@ export class ChatSearchIndexService {
status: 'completed',
processedCount: currentState.vectorizedCount,
totalCount: currentState.indexedCount,
message: `本地向量索引已就绪,共 ${currentState.vectorizedCount} 条消息`
message: `本地语义向量索引已就绪,共 ${currentState.vectorizedCount} 条消息`
}, onProgress)
return currentState
}
@@ -1386,11 +1325,11 @@ export class ChatSearchIndexService {
WHERE m.session_id = ? AND v.message_id IS NULL
ORDER BY m.id ASC
LIMIT ?
`).all(localVectorProvider.id, sessionId, VECTOR_BATCH_SIZE) as Array<Pick<MessageIndexRow, 'id' | 'session_id' | 'search_text'> & { indexed_at?: number }>
`).all(this.getCurrentVectorModelId(), sessionId, VECTOR_BATCH_SIZE) as Array<Pick<MessageIndexRow, 'id' | 'session_id' | 'search_text'> & { indexed_at?: number }>
if (rows.length === 0) break
this.upsertVectorRows(db, rows)
await this.upsertVectorRows(db, rows)
currentState = this.getSessionVectorIndexState(sessionId)
await this.reportVectorProgress({
sessionId,
@@ -1417,7 +1356,7 @@ export class ChatSearchIndexService {
status: 'completed',
processedCount: currentState.vectorizedCount,
totalCount: currentState.indexedCount,
message: `本地向量索引已完成,共 ${currentState.vectorizedCount} 条消息`
message: `本地语义向量索引已完成,共 ${currentState.vectorizedCount} 条消息`
}, onProgress)
return currentState
@@ -1576,23 +1515,26 @@ export class ChatSearchIndexService {
const db = this.getDb()
const state = await this.ensureSessionIndexed(options.sessionId, options.onProgress)
const vectorState = this.getSessionVectorIndexState(options.sessionId)
const queryVector = localVectorProvider.buildVector(options.query)
const profile = this.getCurrentVectorProfile()
const vectorizedCount = vectorState.vectorizedCount
if (!vectorState.isVectorComplete || queryVector.length === 0) {
if (!this.sqliteVectorAvailable || !vectorState.isVectorComplete || !normalizeSearchText(options.query)) {
return {
hits: [],
indexedCount: state.indexedCount,
vectorizedCount,
truncated: false,
model: localVectorProvider.id
model: profile.id
}
}
const queryVector = await localEmbeddingModelService.embedText(options.query, profile.id)
const queryEmbedding = float32ArrayToBuffer(queryVector)
await this.report({
stage: 'searching_index',
sessionId: options.sessionId,
message: `正在进行本地向量检索:${options.query}`,
message: `正在进行本地语义检索:${options.query}`,
indexedCount: state.indexedCount
}, options.onProgress)
@@ -1600,14 +1542,19 @@ export class ChatSearchIndexService {
const endTime = toTimestampSeconds(options.endTimeMs)
const senderUsername = normalizeSearchText(options.senderUsername)
const direction = options.direction
const scanLimit = MAX_VECTOR_SCAN_ROWS
const scanLimit = Math.max(options.limit * VECTOR_SEARCH_OVERFETCH, options.limit + 20)
const sqlFilters: string[] = [
'v.session_id = @sessionId',
'v.vector_model = @vectorModel'
'vec.embedding MATCH @queryEmbedding',
'vec.session_key = @sessionKey',
'vec.session_id = @sessionId',
'vec.vector_model = @vectorModel',
'k = @scanLimit'
]
const params: Record<string, unknown> = {
sessionId: options.sessionId,
vectorModel: localVectorProvider.id,
sessionKey: vectorSessionKey(options.sessionId),
vectorModel: profile.id,
queryEmbedding,
scanLimit: scanLimit + 1
}
@@ -1628,18 +1575,16 @@ export class ChatSearchIndexService {
}
const rows = db.prepare(`
SELECT m.*, v.vector_json
FROM message_vector_index v
SELECT m.*, vec.distance
FROM message_embedding_vec vec
JOIN message_vector_index v ON v.id = vec.vector_id
JOIN message_index m ON m.id = v.message_id
WHERE ${sqlFilters.join(' AND ')}
AND v.feature_count > 0
ORDER BY m.sort_seq DESC, m.create_time DESC, m.local_id DESC
LIMIT @scanLimit
ORDER BY vec.distance ASC
`).all(params) as MessageVectorRow[]
const queryWeights = localVectorProvider.toWeightMap(queryVector)
const scored = rows
.map((row) => {
const vectorScore = localVectorProvider.dot(queryWeights, localVectorProvider.parseVector(row.vector_json))
const vectorScore = Math.max(0, Math.min(1, 1 - Number(row.distance || 0)))
return {
row,
vectorScore
@@ -1662,7 +1607,7 @@ export class ChatSearchIndexService {
indexedCount: state.indexedCount,
vectorizedCount,
truncated: rows.length > scanLimit,
model: localVectorProvider.id
model: profile.id
}
}
}
@@ -0,0 +1,596 @@
import { createHash } from 'crypto'
import { existsSync, mkdirSync, readdirSync, rmSync, statSync } from 'fs'
import { dirname, join } from 'path'
import { ConfigService } from '../config'
export type EmbeddingModelProfileId =
| 'bge-large-zh-v1.5-int8'
| 'bge-large-zh-v1.5-fp32'
| 'bge-m3'
export type EmbeddingDevice = 'cpu' | 'dml'
export type EmbeddingDeviceStatus = {
currentDevice: EmbeddingDevice
effectiveDevice: EmbeddingDevice
gpuAvailable: boolean
provider: 'CPU' | 'DirectML'
info: string
}
export type EmbeddingModelStatus = {
profileId: string
displayName: string
modelId: string
dim: number
dtype: string
sizeLabel: string
enabled: boolean
exists: boolean
modelDir: string
sizeBytes: number
}
export type EmbeddingDownloadProgress = {
profileId: string
displayName: string
remoteHost?: string
file?: string
loaded?: number
total?: number
percent?: number
status?: string
}
export type EmbeddingModelProfile = {
id: EmbeddingModelProfileId
displayName: string
description: string
modelId: string
remoteHosts: string[]
remotePathTemplate: string
revision: string
dim: number
maxTokens: number
maxTextChars: number
dtype: 'q8' | 'fp32'
sizeLabel: string
enabled: boolean
}
const MODELSCOPE_HOST = 'https://www.modelscope.cn/'
const MODELSCOPE_PATH_TEMPLATE = 'models/{model}/resolve/{revision}/'
const MODELSCOPE_REVISION = 'master'
export const DEFAULT_EMBEDDING_MODEL_PROFILE: EmbeddingModelProfileId = 'bge-large-zh-v1.5-int8'
const EMBEDDING_MODEL_PROFILES: EmbeddingModelProfile[] = [
{
id: 'bge-large-zh-v1.5-int8',
displayName: 'BGE Large 中文 · 推荐',
description: '默认档位,1024 维中文语义向量,优先兼顾召回质量和本地 CPU 性能。',
modelId: 'Xenova/bge-large-zh-v1.5',
remoteHosts: [MODELSCOPE_HOST],
remotePathTemplate: MODELSCOPE_PATH_TEMPLATE,
revision: MODELSCOPE_REVISION,
dim: 1024,
maxTokens: 512,
maxTextChars: 480,
dtype: 'q8',
sizeLabel: '约 330 MB',
enabled: true
},
{
id: 'bge-large-zh-v1.5-fp32',
displayName: 'BGE Large 中文 · 高质量',
description: '同模型 FP32 推理,精度更完整,下载和内存占用更高。',
modelId: 'Xenova/bge-large-zh-v1.5',
remoteHosts: [MODELSCOPE_HOST],
remotePathTemplate: MODELSCOPE_PATH_TEMPLATE,
revision: MODELSCOPE_REVISION,
dim: 1024,
maxTokens: 512,
maxTextChars: 480,
dtype: 'fp32',
sizeLabel: '约 1.2 GB',
enabled: true
},
{
id: 'bge-m3',
displayName: 'BGE-M3 · 多语言',
description: '更强的多语言和长文本语义召回,资源占用更高。',
modelId: 'Xenova/bge-m3',
remoteHosts: [MODELSCOPE_HOST],
remotePathTemplate: MODELSCOPE_PATH_TEMPLATE,
revision: MODELSCOPE_REVISION,
dim: 1024,
maxTokens: 8192,
maxTextChars: 2400,
dtype: 'q8',
sizeLabel: '约 600 MB',
enabled: true
}
]
function safeProfileId(value: unknown): EmbeddingModelProfileId {
const id = String(value || '').trim() as EmbeddingModelProfileId
const profile = EMBEDDING_MODEL_PROFILES.find((item) => item.id === id && item.enabled)
return profile?.id || DEFAULT_EMBEDDING_MODEL_PROFILE
}
function safeEmbeddingDevice(value: unknown): EmbeddingDevice {
return String(value || '').trim() === 'dml' ? 'dml' : 'cpu'
}
function directorySize(dir: string): number {
if (!existsSync(dir)) return 0
let total = 0
for (const entry of readdirSync(dir, { withFileTypes: true })) {
const path = join(dir, entry.name)
if (entry.isDirectory()) {
total += directorySize(path)
} else if (entry.isFile()) {
total += statSync(path).size
}
}
return total
}
function hasModelFiles(dir: string): boolean {
if (!existsSync(dir)) return false
let hasOnnx = false
let hasTokenizer = false
const visit = (current: string) => {
for (const entry of readdirSync(current, { withFileTypes: true })) {
const path = join(current, entry.name)
if (entry.isDirectory()) {
visit(path)
continue
}
if (entry.name.endsWith('.onnx')) hasOnnx = true
if (entry.name === 'tokenizer.json' || entry.name === 'tokenizer_config.json') hasTokenizer = true
}
}
visit(dir)
return hasOnnx && hasTokenizer
}
function getElectronAppSafe(): any | null {
try {
const electronModule = require('electron')
const electronApp = electronModule && typeof electronModule === 'object' ? electronModule.app : null
return electronApp?.getPath ? electronApp : null
} catch {
return null
}
}
function getEffectiveCachePathFromConfig(): string {
const configService = new ConfigService()
try {
const configured = String(configService.get('cachePath' as any) || '').trim()
if (configured) return configured
} finally {
configService.close()
}
const electronApp = getElectronAppSafe()
if (electronApp?.getPath) {
const documentsPath = electronApp.getPath('documents')
if (process.env.VITE_DEV_SERVER_URL) {
return join(documentsPath, 'CipherTalkData')
}
const installDir = dirname(electronApp.getPath('exe'))
const isOnCDrive = /^[cC]:/i.test(installDir) || installDir.startsWith('\\\\')
return isOnCDrive ? join(documentsPath, 'CipherTalkData') : join(installDir, 'CipherTalkData')
}
return join(process.cwd(), 'CipherTalkData')
}
function getDirectMLDllPath(): string | null {
if (process.platform !== 'win32') return null
try {
const ortEntry = require.resolve('onnxruntime-node')
const arch = process.arch === 'arm64' ? 'arm64' : 'x64'
return join(dirname(ortEntry), '..', 'bin', 'napi-v6', 'win32', arch, 'DirectML.dll')
} catch {
return null
}
}
function limitEmbeddingText(text: string, maxChars: number): string {
const value = String(text || '')
const limit = Number.isFinite(maxChars) && maxChars > 0 ? Math.floor(maxChars) : 480
if (value.length <= limit) return value
const headLength = Math.max(1, Math.floor(limit * 0.75))
const tailLength = Math.max(1, limit - headLength)
return `${value.slice(0, headLength)}\n${value.slice(-tailLength)}`
}
function tensorToVectors(output: any, expectedCount: number): Float32Array[] {
const data = output?.data
const dims = Array.isArray(output?.dims) ? output.dims.map((item: unknown) => Number(item)) : []
if (!data || typeof data.length !== 'number' || dims.length === 0) {
throw new Error('Embedding 模型输出为空')
}
const dim = Number(dims[dims.length - 1] || 0)
const batch = dims.length >= 2 ? Number(dims[0] || expectedCount) : expectedCount
if (!Number.isInteger(dim) || dim <= 0) {
throw new Error('Embedding 模型输出维度无效')
}
const vectors: Float32Array[] = []
for (let index = 0; index < batch; index += 1) {
const start = index * dim
const end = start + dim
if (end > data.length) break
vectors.push(Float32Array.from(data.slice(start, end)))
}
if (vectors.length !== expectedCount) {
throw new Error(`Embedding 输出数量不匹配:${vectors.length}/${expectedCount}`)
}
return vectors
}
export function hashEmbeddingContent(value: string): string {
return createHash('sha256').update(value || '').digest('hex')
}
export function float32ArrayToBuffer(vector: Float32Array): Buffer {
return Buffer.from(vector.buffer.slice(vector.byteOffset, vector.byteOffset + vector.byteLength))
}
function meanPoolNormalize(output: any, attentionMask: any, expectedCount: number): Float32Array[] {
const hidden = output?.last_hidden_state || output?.token_embeddings || output?.logits
const hiddenData = hidden?.data
const hiddenDims = Array.isArray(hidden?.dims) ? hidden.dims.map((item: unknown) => Number(item)) : []
const maskData = attentionMask?.data
if (!hiddenData || hiddenDims.length !== 3 || !maskData) {
throw new Error('Embedding 模型输出为空')
}
const [batchSize, seqLength, dim] = hiddenDims
if (batchSize !== expectedCount || !Number.isInteger(seqLength) || !Number.isInteger(dim) || dim <= 0) {
throw new Error(`Embedding 输出维度无效:${hiddenDims.join('x')}`)
}
const vectors: Float32Array[] = []
for (let batch = 0; batch < batchSize; batch += 1) {
const vector = new Float32Array(dim)
let tokenCount = 0
for (let token = 0; token < seqLength; token += 1) {
const mask = Number(maskData[batch * seqLength + token] || 0)
if (mask <= 0) continue
tokenCount += mask
const offset = (batch * seqLength + token) * dim
for (let index = 0; index < dim; index += 1) {
vector[index] += Number(hiddenData[offset + index]) * mask
}
}
const divisor = tokenCount > 0 ? tokenCount : 1
let norm = 0
for (let index = 0; index < dim; index += 1) {
vector[index] /= divisor
norm += vector[index] * vector[index]
}
norm = Math.sqrt(norm) || 1
for (let index = 0; index < dim; index += 1) {
vector[index] /= norm
}
vectors.push(vector)
}
return vectors
}
export class LocalEmbeddingModelService {
private pipelines = new Map<string, Promise<{ tokenizer: any; model: any }>>()
private downloadTasks = new Map<string, Promise<EmbeddingModelStatus>>()
private dmlFailureReason: string | null = null
listProfiles(): EmbeddingModelProfile[] {
return EMBEDDING_MODEL_PROFILES.map((profile) => ({ ...profile }))
}
getProfile(profileId?: string): EmbeddingModelProfile {
const id = safeProfileId(profileId || this.getCurrentProfileId())
return EMBEDDING_MODEL_PROFILES.find((profile) => profile.id === id)!
}
getCurrentProfileId(): EmbeddingModelProfileId {
const configService = new ConfigService()
try {
return safeProfileId(configService.get('aiEmbeddingModelProfile' as any))
} finally {
configService.close()
}
}
setCurrentProfileId(profileId: string): EmbeddingModelProfileId {
const id = safeProfileId(profileId)
const configService = new ConfigService()
try {
configService.set('aiEmbeddingModelProfile' as any, id)
return id
} finally {
configService.close()
}
}
getCurrentDevice(): EmbeddingDevice {
const configService = new ConfigService()
try {
return safeEmbeddingDevice(configService.get('aiEmbeddingDevice' as any))
} finally {
configService.close()
}
}
setCurrentDevice(device: string): EmbeddingDevice {
const nextDevice = safeEmbeddingDevice(device)
const configService = new ConfigService()
try {
configService.set('aiEmbeddingDevice' as any, nextDevice)
this.dmlFailureReason = null
this.clearPipelines()
return nextDevice
} finally {
configService.close()
}
}
getDeviceStatus(): EmbeddingDeviceStatus {
const currentDevice = this.getCurrentDevice()
const directMLDll = getDirectMLDllPath()
const directMLAvailable = process.platform === 'win32' && !!directMLDll && existsSync(directMLDll)
if (currentDevice === 'dml' && this.dmlFailureReason) {
return {
currentDevice,
effectiveDevice: 'cpu',
gpuAvailable: directMLAvailable,
provider: 'CPU',
info: `DirectML 本次运行失败,已自动回退 CPU:${this.dmlFailureReason}`
}
}
if (currentDevice === 'dml' && directMLAvailable) {
return {
currentDevice,
effectiveDevice: 'dml',
gpuAvailable: true,
provider: 'DirectML',
info: 'DirectML 组件已就绪,将优先使用 GPU;推理失败时自动回退 CPU'
}
}
if (currentDevice === 'dml') {
return {
currentDevice,
effectiveDevice: 'cpu',
gpuAvailable: false,
provider: 'CPU',
info: process.platform === 'win32'
? '缺少 DirectML 组件,将使用 CPU'
: '当前系统不支持 DirectML,将使用 CPU'
}
}
return {
currentDevice,
effectiveDevice: 'cpu',
gpuAvailable: directMLAvailable,
provider: 'CPU',
info: directMLAvailable ? '当前使用 CPU,可切换到 DirectML GPU 实验模式' : '当前使用 CPU'
}
}
getModelsRoot(): string {
return join(getEffectiveCachePathFromConfig(), 'models', 'embeddings')
}
getProfileDir(profileId?: string): string {
return join(this.getModelsRoot(), this.getProfile(profileId).id)
}
async getModelStatus(profileId?: string): Promise<EmbeddingModelStatus> {
const profile = this.getProfile(profileId)
const modelDir = this.getProfileDir(profile.id)
const exists = hasModelFiles(modelDir)
return {
profileId: profile.id,
displayName: profile.displayName,
modelId: profile.modelId,
dim: profile.dim,
dtype: profile.dtype,
sizeLabel: profile.sizeLabel,
enabled: profile.enabled,
exists,
modelDir,
sizeBytes: directorySize(modelDir)
}
}
async downloadModel(
profileId?: string,
onProgress?: (progress: EmbeddingDownloadProgress) => void
): Promise<EmbeddingModelStatus> {
const profile = this.getProfile(profileId)
const existing = this.downloadTasks.get(profile.id)
if (existing) return existing
const task = (async () => {
mkdirSync(this.getProfileDir(profile.id), { recursive: true })
await this.downloadPipelineWithFallback(profile, onProgress)
return this.getModelStatus(profile.id)
})()
this.downloadTasks.set(profile.id, task)
try {
return await task
} finally {
this.downloadTasks.delete(profile.id)
}
}
async clearModel(profileId?: string): Promise<EmbeddingModelStatus> {
const profile = this.getProfile(profileId)
this.clearPipelines(profile.id)
rmSync(this.getProfileDir(profile.id), { recursive: true, force: true })
return this.getModelStatus(profile.id)
}
async ensureModelReady(profileId?: string): Promise<EmbeddingModelStatus> {
const status = await this.getModelStatus(profileId)
if (!status.exists) {
throw new Error(`本地语义模型未下载:${status.displayName}`)
}
return status
}
async embedTexts(texts: string[], profileId?: string): Promise<Float32Array[]> {
const profile = this.getProfile(profileId)
const cleaned = texts.map((text) => limitEmbeddingText(String(text || ''), profile.maxTextChars))
await this.ensureModelReady(profile.id)
const deviceStatus = this.getDeviceStatus()
if (deviceStatus.effectiveDevice === 'dml') {
try {
return await this.runEmbedding(profile, cleaned, 'dml')
} catch (error) {
console.warn('[Embedding] DirectML 推理失败,回退 CPU:', error)
this.dmlFailureReason = String(error instanceof Error ? error.message : error)
this.clearPipelines(profile.id, 'dml')
}
}
return this.runEmbedding(profile, cleaned, 'cpu')
}
async embedText(text: string, profileId?: string): Promise<Float32Array> {
const [vector] = await this.embedTexts([text], profileId)
return vector
}
private async runEmbedding(
profile: EmbeddingModelProfile,
texts: string[],
device: EmbeddingDevice
): Promise<Float32Array[]> {
const runtime = await this.getPipeline(profile, true, device)
const modelInputs = runtime.tokenizer(texts, {
padding: true,
truncation: true,
max_length: profile.maxTokens
})
const output = await runtime.model(modelInputs)
return meanPoolNormalize(output, modelInputs.attention_mask, texts.length)
}
private async getPipeline(
profile: EmbeddingModelProfile,
localOnly: boolean,
device: EmbeddingDevice = 'cpu',
remoteHost?: string,
progressCallback?: (event: any) => void
): Promise<{ tokenizer: any; model: any }> {
const key = `${profile.id}:${device}:${localOnly ? 'local' : remoteHost || 'remote'}`
const existing = this.pipelines.get(key)
if (existing) return existing
const promise = (async () => {
const transformers = await import('@huggingface/transformers')
transformers.env.allowLocalModels = true
transformers.env.allowRemoteModels = !localOnly
transformers.env.cacheDir = this.getProfileDir(profile.id)
if (remoteHost) {
transformers.env.remoteHost = remoteHost
transformers.env.remotePathTemplate = profile.remotePathTemplate
}
const commonOptions = {
cache_dir: this.getProfileDir(profile.id),
local_files_only: localOnly,
revision: profile.revision,
progress_callback: progressCallback
}
const tokenizer = await transformers.AutoTokenizer.from_pretrained(profile.modelId, commonOptions as any)
const model = await transformers.AutoModel.from_pretrained(profile.modelId, {
...commonOptions,
device,
dtype: profile.dtype
} as any)
return { tokenizer, model }
})()
this.pipelines.set(key, promise)
try {
return await promise
} catch (error) {
this.pipelines.delete(key)
throw error
}
}
private clearPipelines(profileId?: string, device?: EmbeddingDevice): void {
for (const key of Array.from(this.pipelines.keys())) {
const matchesProfile = !profileId || key.startsWith(`${profileId}:`)
const matchesDevice = !device || key.includes(`:${device}:`)
if (matchesProfile && matchesDevice) {
this.pipelines.delete(key)
}
}
}
private async downloadPipelineWithFallback(
profile: EmbeddingModelProfile,
onProgress?: (progress: EmbeddingDownloadProgress) => void
): Promise<void> {
const errors: string[] = []
for (const remoteHost of profile.remoteHosts) {
try {
onProgress?.({
profileId: profile.id,
displayName: profile.displayName,
remoteHost,
status: 'initiate'
})
await this.getPipeline(profile, false, 'cpu', remoteHost, (event) => {
const loaded = Number(event?.loaded || 0)
const total = Number(event?.total || 0)
onProgress?.({
profileId: profile.id,
displayName: profile.displayName,
remoteHost,
file: String(event?.file || event?.name || ''),
loaded: Number.isFinite(loaded) && loaded > 0 ? loaded : undefined,
total: Number.isFinite(total) && total > 0 ? total : undefined,
percent: total > 0 ? Math.min(100, Math.round((loaded / total) * 100)) : undefined,
status: String(event?.status || '')
})
})
return
} catch (error) {
errors.push(`${remoteHost}: ${String(error)}`)
}
}
throw new Error(`语义模型下载失败。已尝试 ModelScope/魔塔社区:${profile.remoteHosts.join('、')}。请检查网络/代理或稍后重试。${errors.length ? ` 原始错误:${errors.join(' | ')}` : ''}`)
}
}
export const localEmbeddingModelService = new LocalEmbeddingModelService()