mirror of
https://github.com/ILoveBingLu/CipherTalk.git
synced 2026-05-12 23:28:30 +08:00
feat: 更新在线嵌入模型维度和配置,优化并发处理逻辑
This commit is contained in:
@@ -63,7 +63,7 @@ export class EmbeddingRuntimeService {
|
||||
|
||||
getCurrentBatchSize(defaultBatchSize: number): number {
|
||||
if (this.getMode() !== 'online') return defaultBatchSize
|
||||
return Math.max(1, Math.min(defaultBatchSize, onlineEmbeddingService.getCurrentBatchSize()))
|
||||
return Math.max(1, onlineEmbeddingService.getCurrentBatchSize() * onlineEmbeddingService.getCurrentConcurrency())
|
||||
}
|
||||
|
||||
ensureReady(): void {
|
||||
|
||||
@@ -4,7 +4,7 @@ import type {
|
||||
OnlineEmbeddingProviderInfo
|
||||
} from './onlineEmbeddingTypes'
|
||||
|
||||
export const ONLINE_EMBEDDING_COMMON_DIMS = [2048, 1536, 1024, 768, 512, 256, 128, 64]
|
||||
export const ONLINE_EMBEDDING_COMMON_DIMS = [4096, 2560, 2048, 1536, 1024, 768, 512, 256, 128, 64]
|
||||
|
||||
const PROVIDERS: OnlineEmbeddingProviderInfo[] = [
|
||||
{
|
||||
@@ -50,27 +50,72 @@ const PROVIDERS: OnlineEmbeddingProviderInfo[] = [
|
||||
defaultBaseURL: 'https://api.siliconflow.cn/v1',
|
||||
website: 'https://docs.siliconflow.cn/',
|
||||
models: [
|
||||
{
|
||||
id: 'Qwen/Qwen3-VL-Embedding-8B',
|
||||
displayName: 'Qwen3 VL Embedding 8B(收费)',
|
||||
supportedDims: [4096, 2560, 2048, 1536, 1024, 768, 512, 256, 128, 64],
|
||||
defaultDim: 4096,
|
||||
maxBatchSize: 10,
|
||||
maxTokens: 32768,
|
||||
supportsDimensions: true
|
||||
},
|
||||
{
|
||||
id: 'Qwen/Qwen3-Embedding-8B',
|
||||
displayName: 'Qwen3 Embedding 8B(收费)',
|
||||
supportedDims: [4096, 2560, 2048, 1536, 1024, 768, 512, 256, 128, 64],
|
||||
defaultDim: 4096,
|
||||
maxBatchSize: 10,
|
||||
maxTokens: 32768,
|
||||
supportsDimensions: true
|
||||
},
|
||||
{
|
||||
id: 'Qwen/Qwen3-Embedding-4B',
|
||||
displayName: 'Qwen3 Embedding 4B(收费)',
|
||||
supportedDims: [2560, 2048, 1536, 1024, 768, 512, 256, 128, 64],
|
||||
defaultDim: 2560,
|
||||
maxBatchSize: 10,
|
||||
maxTokens: 32768,
|
||||
supportsDimensions: true
|
||||
},
|
||||
{
|
||||
id: 'Qwen/Qwen3-Embedding-0.6B',
|
||||
displayName: 'Qwen3 Embedding 0.6B',
|
||||
displayName: 'Qwen3 Embedding 0.6B(收费)',
|
||||
supportedDims: [1024, 768, 512, 256, 128, 64],
|
||||
defaultDim: 1024,
|
||||
maxBatchSize: 10,
|
||||
maxTokens: 8192,
|
||||
maxTokens: 32768,
|
||||
supportsDimensions: true
|
||||
},
|
||||
{
|
||||
id: 'BAAI/bge-m3',
|
||||
displayName: 'BAAI bge-m3',
|
||||
displayName: 'BAAI bge-m3(免费)',
|
||||
supportedDims: [1024],
|
||||
defaultDim: 1024,
|
||||
maxBatchSize: 10,
|
||||
maxTokens: 8192,
|
||||
supportsDimensions: false
|
||||
},
|
||||
{
|
||||
id: 'netease-youdao/bce-embedding-base_v1',
|
||||
displayName: 'netease-youdao bce-embedding-base_v1(免费)',
|
||||
supportedDims: [768],
|
||||
defaultDim: 768,
|
||||
maxBatchSize: 10,
|
||||
maxTokens: 512,
|
||||
supportsDimensions: false
|
||||
},
|
||||
{
|
||||
id: 'BAAI/bge-large-zh-v1.5',
|
||||
displayName: 'BAAI bge-large-zh-v1.5',
|
||||
displayName: 'BAAI bge-large-zh-v1.5(免费)',
|
||||
supportedDims: [1024],
|
||||
defaultDim: 1024,
|
||||
maxBatchSize: 10,
|
||||
maxTokens: 512,
|
||||
supportsDimensions: false
|
||||
},
|
||||
{
|
||||
id: 'BAAI/bge-large-en-v1.5',
|
||||
displayName: 'BAAI bge-large-en-v1.5(免费)',
|
||||
supportedDims: [1024],
|
||||
defaultDim: 1024,
|
||||
maxBatchSize: 10,
|
||||
@@ -79,7 +124,7 @@ const PROVIDERS: OnlineEmbeddingProviderInfo[] = [
|
||||
},
|
||||
{
|
||||
id: 'Pro/BAAI/bge-m3',
|
||||
displayName: 'Pro BAAI bge-m3',
|
||||
displayName: 'Pro BAAI bge-m3(免费)',
|
||||
supportedDims: [1024],
|
||||
defaultDim: 1024,
|
||||
maxBatchSize: 10,
|
||||
|
||||
@@ -18,6 +18,14 @@ import {
|
||||
ONLINE_EMBEDDING_COMMON_DIMS
|
||||
} from './onlineEmbeddingRegistry'
|
||||
|
||||
const ONLINE_EMBEDDING_CONCURRENCY = 6
|
||||
const ONLINE_EMBEDDING_MIN_CHARS_ON_413 = 512
|
||||
const ONLINE_EMBEDDING_413_SHRINK_RATIO = 0.5
|
||||
|
||||
type EmbeddingRequestError = Error & {
|
||||
status?: number
|
||||
}
|
||||
|
||||
function normalizeVector(vector: Float32Array): Float32Array {
|
||||
let norm = 0
|
||||
for (let index = 0; index < vector.length; index += 1) norm += vector[index] * vector[index]
|
||||
@@ -48,7 +56,8 @@ function sleep(ms: number): Promise<void> {
|
||||
function getErrorStatus(error: unknown): number {
|
||||
if (typeof error === 'object' && error) {
|
||||
const record = error as Record<string, unknown>
|
||||
return Number(record.status || record.statusCode || record.code || 0)
|
||||
const status = Number(record.status || record.statusCode || record.code || 0)
|
||||
return Number.isFinite(status) ? status : 0
|
||||
}
|
||||
return 0
|
||||
}
|
||||
@@ -56,7 +65,14 @@ function getErrorStatus(error: unknown): number {
|
||||
function normalizeErrorMessage(error: unknown): string {
|
||||
const status = getErrorStatus(error)
|
||||
const message = error instanceof Error ? error.message : String(error || '在线向量请求失败')
|
||||
return status ? `${status}: ${message}` : message
|
||||
return status && !message.startsWith(`${status}:`) ? `${status}: ${message}` : message
|
||||
}
|
||||
|
||||
function createEmbeddingRequestError(error: unknown, fallbackMessage?: string): EmbeddingRequestError {
|
||||
const status = getErrorStatus(error)
|
||||
const wrapped = new Error(fallbackMessage || normalizeErrorMessage(error)) as EmbeddingRequestError
|
||||
if (status) wrapped.status = status
|
||||
return wrapped
|
||||
}
|
||||
|
||||
function limitEmbeddingText(text: string, maxChars: number): string {
|
||||
@@ -67,6 +83,27 @@ function limitEmbeddingText(text: string, maxChars: number): string {
|
||||
return `${value.slice(0, head)}\n${value.slice(-(limit - head))}`
|
||||
}
|
||||
|
||||
async function mapWithConcurrency<T, R>(
|
||||
items: T[],
|
||||
concurrency: number,
|
||||
worker: (item: T, index: number) => Promise<R>
|
||||
): Promise<R[]> {
|
||||
const results = new Array<R>(items.length)
|
||||
let nextIndex = 0
|
||||
const workerCount = Math.max(1, Math.min(Math.floor(concurrency), items.length))
|
||||
|
||||
await Promise.all(Array.from({ length: workerCount }, async () => {
|
||||
while (true) {
|
||||
const index = nextIndex
|
||||
nextIndex += 1
|
||||
if (index >= items.length) break
|
||||
results[index] = await worker(items[index], index)
|
||||
}
|
||||
}))
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
export class OnlineEmbeddingService {
|
||||
listProviders(): OnlineEmbeddingProviderInfo[] {
|
||||
return listOnlineEmbeddingProviders()
|
||||
@@ -226,6 +263,10 @@ export class OnlineEmbeddingService {
|
||||
return Math.max(1, Math.min(10, this.getModelInfo(config.providerId, config.model)?.maxBatchSize || 10))
|
||||
}
|
||||
|
||||
getCurrentConcurrency(): number {
|
||||
return ONLINE_EMBEDDING_CONCURRENCY
|
||||
}
|
||||
|
||||
getCurrentProfile() {
|
||||
const config = this.getCurrentConfig()
|
||||
const provider = this.getProvider(config?.providerId)
|
||||
@@ -348,16 +389,58 @@ export class OnlineEmbeddingService {
|
||||
const model = this.getModelInfo(config.providerId, config.model)
|
||||
const batchSize = Math.max(1, Math.min(model?.maxBatchSize || 10, texts.length))
|
||||
const maxChars = model?.maxTokens ? Math.max(1000, model.maxTokens * 2) : 8000
|
||||
const vectors: Float32Array[] = []
|
||||
const batches: string[][] = []
|
||||
|
||||
for (let index = 0; index < texts.length; index += batchSize) {
|
||||
const batch = texts.slice(index, index + batchSize)
|
||||
const cleaned = batch.map((text) => limitEmbeddingText(String(text || ''), maxChars))
|
||||
const batchVectors = await this.requestEmbeddings(config, cleaned)
|
||||
vectors.push(...batchVectors)
|
||||
batches.push(texts.slice(index, index + batchSize).map((text) => String(text || '')))
|
||||
}
|
||||
|
||||
return vectors
|
||||
const batchVectors = await mapWithConcurrency(
|
||||
batches,
|
||||
this.getCurrentConcurrency(),
|
||||
(batch) => this.requestEmbeddingsWithPayloadRecovery(config, batch, maxChars)
|
||||
)
|
||||
|
||||
return batchVectors.flat()
|
||||
}
|
||||
|
||||
private async requestEmbeddingsWithPayloadRecovery(
|
||||
config: OnlineEmbeddingConfig,
|
||||
texts: string[],
|
||||
maxChars: number
|
||||
): Promise<Float32Array[]> {
|
||||
const safeMaxChars = Math.max(1, Math.floor(maxChars))
|
||||
const cleaned = texts.map((text) => limitEmbeddingText(text, safeMaxChars))
|
||||
|
||||
try {
|
||||
return await this.requestEmbeddings(config, cleaned)
|
||||
} catch (error) {
|
||||
if (getErrorStatus(error) !== 413) {
|
||||
throw error
|
||||
}
|
||||
|
||||
if (texts.length > 1) {
|
||||
const midpoint = Math.max(1, Math.floor(texts.length / 2))
|
||||
const left = await this.requestEmbeddingsWithPayloadRecovery(config, texts.slice(0, midpoint), safeMaxChars)
|
||||
const right = await this.requestEmbeddingsWithPayloadRecovery(config, texts.slice(midpoint), safeMaxChars)
|
||||
return [...left, ...right]
|
||||
}
|
||||
|
||||
if (safeMaxChars > ONLINE_EMBEDDING_MIN_CHARS_ON_413) {
|
||||
const nextMaxChars = Math.max(
|
||||
ONLINE_EMBEDDING_MIN_CHARS_ON_413,
|
||||
Math.floor(safeMaxChars * ONLINE_EMBEDDING_413_SHRINK_RATIO)
|
||||
)
|
||||
if (nextMaxChars < safeMaxChars) {
|
||||
return this.requestEmbeddingsWithPayloadRecovery(config, texts, nextMaxChars)
|
||||
}
|
||||
}
|
||||
|
||||
throw createEmbeddingRequestError(
|
||||
error,
|
||||
`在线向量服务拒绝单条输入大小,已降到 ${safeMaxChars} 字符仍失败`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private async requestEmbeddings(config: OnlineEmbeddingConfig, texts: string[]): Promise<Float32Array[]> {
|
||||
@@ -411,7 +494,7 @@ export class OnlineEmbeddingService {
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error(normalizeErrorMessage(lastError))
|
||||
throw createEmbeddingRequestError(lastError)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ const DEEPSEEK_LEGACY_MODEL_MAP: Record<string, string> = {
|
||||
'deepseek-reasoner': 'deepseek-v4-flash'
|
||||
}
|
||||
|
||||
const ONLINE_EMBEDDING_FALLBACK_DIMS = [2048, 1536, 1024, 768, 512, 256, 128, 64]
|
||||
const ONLINE_EMBEDDING_FALLBACK_DIMS = [4096, 2560, 2048, 1536, 1024, 768, 512, 256, 128, 64]
|
||||
|
||||
function normalizeProviderModel(providerId: string, modelName: string) {
|
||||
if (providerId !== 'deepseek') {
|
||||
@@ -247,11 +247,13 @@ function AISummarySettings({
|
||||
|
||||
useEffect(() => {
|
||||
if (onlineEmbeddingProviders.length === 0) return
|
||||
const selected = onlineEmbeddingConfigs.find((item) => item.id === currentOnlineEmbeddingConfigId) || onlineEmbeddingConfigs[0] || null
|
||||
const selected = currentOnlineEmbeddingConfigId
|
||||
? onlineEmbeddingConfigs.find((item) => item.id === currentOnlineEmbeddingConfigId) || null
|
||||
: (!onlineEmbeddingModel ? onlineEmbeddingConfigs[0] || null : null)
|
||||
if (selected || !onlineEmbeddingModel) {
|
||||
applyOnlineEmbeddingConfig(selected, onlineEmbeddingProviders)
|
||||
}
|
||||
}, [onlineEmbeddingProviders.length])
|
||||
}, [onlineEmbeddingProviders, onlineEmbeddingConfigs, currentOnlineEmbeddingConfigId])
|
||||
|
||||
useEffect(() => {
|
||||
const normalizedModel = normalizeProviderModel(provider, model)
|
||||
@@ -383,9 +385,6 @@ function AISummarySettings({
|
||||
const result = await window.electronAPI.ai.getOnlineEmbeddingProviders()
|
||||
if (result.success && result.result) {
|
||||
setOnlineEmbeddingProviders(result.result)
|
||||
if (!onlineEmbeddingBaseURL && result.result[0]) {
|
||||
applyOnlineEmbeddingConfig(null, result.result)
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('加载在线向量厂商失败:', e)
|
||||
@@ -397,8 +396,7 @@ function AISummarySettings({
|
||||
const result = await window.electronAPI.ai.listOnlineEmbeddingConfigs()
|
||||
if (result.success && result.result) {
|
||||
setOnlineEmbeddingConfigs(result.result)
|
||||
const selected = result.result.find((item) => item.id === result.currentConfigId) || result.result[0] || null
|
||||
applyOnlineEmbeddingConfig(selected)
|
||||
setCurrentOnlineEmbeddingConfigId(result.currentConfigId || result.result[0]?.id || '')
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('加载在线向量配置失败:', e)
|
||||
|
||||
Reference in New Issue
Block a user