feat: 更新在线嵌入模型维度和配置,优化并发处理逻辑

This commit is contained in:
ILoveBingLu
2026-04-28 22:58:29 +08:00
parent 68998fa45d
commit de2ccee1a4
4 changed files with 150 additions and 24 deletions
@@ -63,7 +63,7 @@ export class EmbeddingRuntimeService {
getCurrentBatchSize(defaultBatchSize: number): number {
if (this.getMode() !== 'online') return defaultBatchSize
return Math.max(1, Math.min(defaultBatchSize, onlineEmbeddingService.getCurrentBatchSize()))
return Math.max(1, onlineEmbeddingService.getCurrentBatchSize() * onlineEmbeddingService.getCurrentConcurrency())
}
ensureReady(): void {
@@ -4,7 +4,7 @@ import type {
OnlineEmbeddingProviderInfo
} from './onlineEmbeddingTypes'
export const ONLINE_EMBEDDING_COMMON_DIMS = [2048, 1536, 1024, 768, 512, 256, 128, 64]
export const ONLINE_EMBEDDING_COMMON_DIMS = [4096, 2560, 2048, 1536, 1024, 768, 512, 256, 128, 64]
const PROVIDERS: OnlineEmbeddingProviderInfo[] = [
{
@@ -50,27 +50,72 @@ const PROVIDERS: OnlineEmbeddingProviderInfo[] = [
defaultBaseURL: 'https://api.siliconflow.cn/v1',
website: 'https://docs.siliconflow.cn/',
models: [
{
id: 'Qwen/Qwen3-VL-Embedding-8B',
displayName: 'Qwen3 VL Embedding 8B(收费)',
supportedDims: [4096, 2560, 2048, 1536, 1024, 768, 512, 256, 128, 64],
defaultDim: 4096,
maxBatchSize: 10,
maxTokens: 32768,
supportsDimensions: true
},
{
id: 'Qwen/Qwen3-Embedding-8B',
displayName: 'Qwen3 Embedding 8B(收费)',
supportedDims: [4096, 2560, 2048, 1536, 1024, 768, 512, 256, 128, 64],
defaultDim: 4096,
maxBatchSize: 10,
maxTokens: 32768,
supportsDimensions: true
},
{
id: 'Qwen/Qwen3-Embedding-4B',
displayName: 'Qwen3 Embedding 4B(收费)',
supportedDims: [2560, 2048, 1536, 1024, 768, 512, 256, 128, 64],
defaultDim: 2560,
maxBatchSize: 10,
maxTokens: 32768,
supportsDimensions: true
},
{
id: 'Qwen/Qwen3-Embedding-0.6B',
displayName: 'Qwen3 Embedding 0.6B',
displayName: 'Qwen3 Embedding 0.6B(收费)',
supportedDims: [1024, 768, 512, 256, 128, 64],
defaultDim: 1024,
maxBatchSize: 10,
maxTokens: 8192,
maxTokens: 32768,
supportsDimensions: true
},
{
id: 'BAAI/bge-m3',
displayName: 'BAAI bge-m3',
displayName: 'BAAI bge-m3(免费)',
supportedDims: [1024],
defaultDim: 1024,
maxBatchSize: 10,
maxTokens: 8192,
supportsDimensions: false
},
{
id: 'netease-youdao/bce-embedding-base_v1',
displayName: 'netease-youdao bce-embedding-base_v1(免费)',
supportedDims: [768],
defaultDim: 768,
maxBatchSize: 10,
maxTokens: 512,
supportsDimensions: false
},
{
id: 'BAAI/bge-large-zh-v1.5',
displayName: 'BAAI bge-large-zh-v1.5',
displayName: 'BAAI bge-large-zh-v1.5(免费)',
supportedDims: [1024],
defaultDim: 1024,
maxBatchSize: 10,
maxTokens: 512,
supportsDimensions: false
},
{
id: 'BAAI/bge-large-en-v1.5',
displayName: 'BAAI bge-large-en-v1.5(免费)',
supportedDims: [1024],
defaultDim: 1024,
maxBatchSize: 10,
@@ -79,7 +124,7 @@ const PROVIDERS: OnlineEmbeddingProviderInfo[] = [
},
{
id: 'Pro/BAAI/bge-m3',
displayName: 'Pro BAAI bge-m3',
displayName: 'Pro BAAI bge-m3(免费)',
supportedDims: [1024],
defaultDim: 1024,
maxBatchSize: 10,
@@ -18,6 +18,14 @@ import {
ONLINE_EMBEDDING_COMMON_DIMS
} from './onlineEmbeddingRegistry'
const ONLINE_EMBEDDING_CONCURRENCY = 6
const ONLINE_EMBEDDING_MIN_CHARS_ON_413 = 512
const ONLINE_EMBEDDING_413_SHRINK_RATIO = 0.5
type EmbeddingRequestError = Error & {
status?: number
}
function normalizeVector(vector: Float32Array): Float32Array {
let norm = 0
for (let index = 0; index < vector.length; index += 1) norm += vector[index] * vector[index]
@@ -48,7 +56,8 @@ function sleep(ms: number): Promise<void> {
function getErrorStatus(error: unknown): number {
if (typeof error === 'object' && error) {
const record = error as Record<string, unknown>
return Number(record.status || record.statusCode || record.code || 0)
const status = Number(record.status || record.statusCode || record.code || 0)
return Number.isFinite(status) ? status : 0
}
return 0
}
@@ -56,7 +65,14 @@ function getErrorStatus(error: unknown): number {
function normalizeErrorMessage(error: unknown): string {
const status = getErrorStatus(error)
const message = error instanceof Error ? error.message : String(error || '在线向量请求失败')
return status ? `${status}: ${message}` : message
return status && !message.startsWith(`${status}:`) ? `${status}: ${message}` : message
}
function createEmbeddingRequestError(error: unknown, fallbackMessage?: string): EmbeddingRequestError {
const status = getErrorStatus(error)
const wrapped = new Error(fallbackMessage || normalizeErrorMessage(error)) as EmbeddingRequestError
if (status) wrapped.status = status
return wrapped
}
function limitEmbeddingText(text: string, maxChars: number): string {
@@ -67,6 +83,27 @@ function limitEmbeddingText(text: string, maxChars: number): string {
return `${value.slice(0, head)}\n${value.slice(-(limit - head))}`
}
async function mapWithConcurrency<T, R>(
items: T[],
concurrency: number,
worker: (item: T, index: number) => Promise<R>
): Promise<R[]> {
const results = new Array<R>(items.length)
let nextIndex = 0
const workerCount = Math.max(1, Math.min(Math.floor(concurrency), items.length))
await Promise.all(Array.from({ length: workerCount }, async () => {
while (true) {
const index = nextIndex
nextIndex += 1
if (index >= items.length) break
results[index] = await worker(items[index], index)
}
}))
return results
}
export class OnlineEmbeddingService {
listProviders(): OnlineEmbeddingProviderInfo[] {
return listOnlineEmbeddingProviders()
@@ -226,6 +263,10 @@ export class OnlineEmbeddingService {
return Math.max(1, Math.min(10, this.getModelInfo(config.providerId, config.model)?.maxBatchSize || 10))
}
getCurrentConcurrency(): number {
return ONLINE_EMBEDDING_CONCURRENCY
}
getCurrentProfile() {
const config = this.getCurrentConfig()
const provider = this.getProvider(config?.providerId)
@@ -348,16 +389,58 @@ export class OnlineEmbeddingService {
const model = this.getModelInfo(config.providerId, config.model)
const batchSize = Math.max(1, Math.min(model?.maxBatchSize || 10, texts.length))
const maxChars = model?.maxTokens ? Math.max(1000, model.maxTokens * 2) : 8000
const vectors: Float32Array[] = []
const batches: string[][] = []
for (let index = 0; index < texts.length; index += batchSize) {
const batch = texts.slice(index, index + batchSize)
const cleaned = batch.map((text) => limitEmbeddingText(String(text || ''), maxChars))
const batchVectors = await this.requestEmbeddings(config, cleaned)
vectors.push(...batchVectors)
batches.push(texts.slice(index, index + batchSize).map((text) => String(text || '')))
}
return vectors
const batchVectors = await mapWithConcurrency(
batches,
this.getCurrentConcurrency(),
(batch) => this.requestEmbeddingsWithPayloadRecovery(config, batch, maxChars)
)
return batchVectors.flat()
}
private async requestEmbeddingsWithPayloadRecovery(
config: OnlineEmbeddingConfig,
texts: string[],
maxChars: number
): Promise<Float32Array[]> {
const safeMaxChars = Math.max(1, Math.floor(maxChars))
const cleaned = texts.map((text) => limitEmbeddingText(text, safeMaxChars))
try {
return await this.requestEmbeddings(config, cleaned)
} catch (error) {
if (getErrorStatus(error) !== 413) {
throw error
}
if (texts.length > 1) {
const midpoint = Math.max(1, Math.floor(texts.length / 2))
const left = await this.requestEmbeddingsWithPayloadRecovery(config, texts.slice(0, midpoint), safeMaxChars)
const right = await this.requestEmbeddingsWithPayloadRecovery(config, texts.slice(midpoint), safeMaxChars)
return [...left, ...right]
}
if (safeMaxChars > ONLINE_EMBEDDING_MIN_CHARS_ON_413) {
const nextMaxChars = Math.max(
ONLINE_EMBEDDING_MIN_CHARS_ON_413,
Math.floor(safeMaxChars * ONLINE_EMBEDDING_413_SHRINK_RATIO)
)
if (nextMaxChars < safeMaxChars) {
return this.requestEmbeddingsWithPayloadRecovery(config, texts, nextMaxChars)
}
}
throw createEmbeddingRequestError(
error,
`在线向量服务拒绝单条输入大小,已降到 ${safeMaxChars} 字符仍失败`
)
}
}
private async requestEmbeddings(config: OnlineEmbeddingConfig, texts: string[]): Promise<Float32Array[]> {
@@ -411,7 +494,7 @@ export class OnlineEmbeddingService {
}
}
throw new Error(normalizeErrorMessage(lastError))
throw createEmbeddingRequestError(lastError)
}
}
+6 -8
View File
@@ -126,7 +126,7 @@ const DEEPSEEK_LEGACY_MODEL_MAP: Record<string, string> = {
'deepseek-reasoner': 'deepseek-v4-flash'
}
const ONLINE_EMBEDDING_FALLBACK_DIMS = [2048, 1536, 1024, 768, 512, 256, 128, 64]
const ONLINE_EMBEDDING_FALLBACK_DIMS = [4096, 2560, 2048, 1536, 1024, 768, 512, 256, 128, 64]
function normalizeProviderModel(providerId: string, modelName: string) {
if (providerId !== 'deepseek') {
@@ -247,11 +247,13 @@ function AISummarySettings({
useEffect(() => {
if (onlineEmbeddingProviders.length === 0) return
const selected = onlineEmbeddingConfigs.find((item) => item.id === currentOnlineEmbeddingConfigId) || onlineEmbeddingConfigs[0] || null
const selected = currentOnlineEmbeddingConfigId
? onlineEmbeddingConfigs.find((item) => item.id === currentOnlineEmbeddingConfigId) || null
: (!onlineEmbeddingModel ? onlineEmbeddingConfigs[0] || null : null)
if (selected || !onlineEmbeddingModel) {
applyOnlineEmbeddingConfig(selected, onlineEmbeddingProviders)
}
}, [onlineEmbeddingProviders.length])
}, [onlineEmbeddingProviders, onlineEmbeddingConfigs, currentOnlineEmbeddingConfigId])
useEffect(() => {
const normalizedModel = normalizeProviderModel(provider, model)
@@ -383,9 +385,6 @@ function AISummarySettings({
const result = await window.electronAPI.ai.getOnlineEmbeddingProviders()
if (result.success && result.result) {
setOnlineEmbeddingProviders(result.result)
if (!onlineEmbeddingBaseURL && result.result[0]) {
applyOnlineEmbeddingConfig(null, result.result)
}
}
} catch (e) {
console.error('加载在线向量厂商失败:', e)
@@ -397,8 +396,7 @@ function AISummarySettings({
const result = await window.electronAPI.ai.listOnlineEmbeddingConfigs()
if (result.success && result.result) {
setOnlineEmbeddingConfigs(result.result)
const selected = result.result.find((item) => item.id === result.currentConfigId) || result.result[0] || null
applyOnlineEmbeddingConfig(selected)
setCurrentOnlineEmbeddingConfigId(result.currentConfigId || result.result[0]?.id || '')
}
} catch (e) {
console.error('加载在线向量配置失败:', e)