mirror of
https://github.com/hicccc77/WeFlow.git
synced 2026-03-22 07:32:29 +08:00
487 lines
15 KiB
TypeScript
487 lines
15 KiB
TypeScript
import { app } from 'electron'
|
||
import { existsSync, mkdirSync, statSync, unlinkSync, createWriteStream, openSync, writeSync, closeSync } from 'fs'
|
||
import { join } from 'path'
|
||
import * as https from 'https'
|
||
import * as http from 'http'
|
||
import { ConfigService } from './config'
|
||
|
||
// Sherpa-onnx 类型定义
|
||
type OfflineRecognizer = any
|
||
type OfflineStream = any
|
||
|
||
type ModelInfo = {
|
||
name: string
|
||
files: {
|
||
model: string
|
||
tokens: string
|
||
}
|
||
sizeBytes: number
|
||
sizeLabel: string
|
||
}
|
||
|
||
type DownloadProgress = {
|
||
modelName: string
|
||
downloadedBytes: number
|
||
totalBytes?: number
|
||
percent?: number
|
||
speed?: number
|
||
}
|
||
|
||
const SENSEVOICE_MODEL: ModelInfo = {
|
||
name: 'SenseVoiceSmall',
|
||
files: {
|
||
model: 'model.int8.onnx',
|
||
tokens: 'tokens.txt'
|
||
},
|
||
sizeBytes: 245_000_000,
|
||
sizeLabel: '245 MB'
|
||
}
|
||
|
||
const MODEL_DOWNLOAD_URLS = {
|
||
model: 'https://modelscope.cn/models/pengzhendong/sherpa-onnx-sense-voice-zh-en-ja-ko-yue/resolve/master/model.int8.onnx',
|
||
tokens: 'https://modelscope.cn/models/pengzhendong/sherpa-onnx-sense-voice-zh-en-ja-ko-yue/resolve/master/tokens.txt'
|
||
}
|
||
|
||
export class VoiceTranscribeService {
|
||
private configService = new ConfigService()
|
||
private downloadTasks = new Map<string, Promise<{ success: boolean; path?: string; error?: string }>>()
|
||
private recognizer: OfflineRecognizer | null = null
|
||
private isInitializing = false
|
||
|
||
private resolveModelDir(): string {
|
||
const configured = this.configService.get('whisperModelDir') as string | undefined
|
||
if (configured) return configured
|
||
return join(app.getPath('documents'), 'WeFlow', 'models', 'sensevoice')
|
||
}
|
||
|
||
private resolveModelPath(fileName: string): string {
|
||
return join(this.resolveModelDir(), fileName)
|
||
}
|
||
|
||
/**
|
||
* 检查模型状态
|
||
*/
|
||
async getModelStatus(): Promise<{
|
||
success: boolean
|
||
exists?: boolean
|
||
modelPath?: string
|
||
tokensPath?: string
|
||
sizeBytes?: number
|
||
error?: string
|
||
}> {
|
||
try {
|
||
const modelPath = this.resolveModelPath(SENSEVOICE_MODEL.files.model)
|
||
const tokensPath = this.resolveModelPath(SENSEVOICE_MODEL.files.tokens)
|
||
const modelExists = existsSync(modelPath)
|
||
const tokensExists = existsSync(tokensPath)
|
||
const exists = modelExists && tokensExists
|
||
|
||
if (!exists) {
|
||
return { success: true, exists: false, modelPath, tokensPath }
|
||
}
|
||
|
||
const modelSize = statSync(modelPath).size
|
||
const tokensSize = statSync(tokensPath).size
|
||
const totalSize = modelSize + tokensSize
|
||
|
||
return {
|
||
success: true,
|
||
exists: true,
|
||
modelPath,
|
||
tokensPath,
|
||
sizeBytes: totalSize
|
||
}
|
||
} catch (error) {
|
||
return { success: false, error: String(error) }
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 下载模型文件
|
||
*/
|
||
async downloadModel(
|
||
onProgress?: (progress: DownloadProgress) => void
|
||
): Promise<{ success: boolean; modelPath?: string; tokensPath?: string; error?: string }> {
|
||
const cacheKey = 'sensevoice'
|
||
const pending = this.downloadTasks.get(cacheKey)
|
||
if (pending) return pending
|
||
|
||
const task = (async () => {
|
||
try {
|
||
const modelDir = this.resolveModelDir()
|
||
if (!existsSync(modelDir)) {
|
||
mkdirSync(modelDir, { recursive: true })
|
||
}
|
||
|
||
const modelPath = this.resolveModelPath(SENSEVOICE_MODEL.files.model)
|
||
const tokensPath = this.resolveModelPath(SENSEVOICE_MODEL.files.tokens)
|
||
|
||
// 初始进度
|
||
onProgress?.({
|
||
modelName: SENSEVOICE_MODEL.name,
|
||
downloadedBytes: 0,
|
||
totalBytes: SENSEVOICE_MODEL.sizeBytes,
|
||
percent: 0
|
||
})
|
||
|
||
// 下载模型文件 (80% 权重)
|
||
console.info('[VoiceTranscribe] 开始下载模型文件...')
|
||
await this.downloadToFile(
|
||
MODEL_DOWNLOAD_URLS.model,
|
||
modelPath,
|
||
'model',
|
||
(downloaded, total, speed) => {
|
||
const percent = total ? (downloaded / total) * 80 : 0
|
||
onProgress?.({
|
||
modelName: SENSEVOICE_MODEL.name,
|
||
downloadedBytes: downloaded,
|
||
totalBytes: SENSEVOICE_MODEL.sizeBytes,
|
||
percent,
|
||
speed
|
||
})
|
||
}
|
||
)
|
||
|
||
// 下载 tokens 文件 (20% 权重)
|
||
console.info('[VoiceTranscribe] 开始下载 tokens 文件...')
|
||
await this.downloadToFile(
|
||
MODEL_DOWNLOAD_URLS.tokens,
|
||
tokensPath,
|
||
'tokens',
|
||
(downloaded, total, speed) => {
|
||
const modelSize = existsSync(modelPath) ? statSync(modelPath).size : 0
|
||
const percent = total ? 80 + (downloaded / total) * 20 : 80
|
||
onProgress?.({
|
||
modelName: SENSEVOICE_MODEL.name,
|
||
downloadedBytes: modelSize + downloaded,
|
||
totalBytes: SENSEVOICE_MODEL.sizeBytes,
|
||
percent,
|
||
speed
|
||
})
|
||
}
|
||
)
|
||
|
||
console.info('[VoiceTranscribe] 模型下载完成')
|
||
return { success: true, modelPath, tokensPath }
|
||
} catch (error) {
|
||
const modelPath = this.resolveModelPath(SENSEVOICE_MODEL.files.model)
|
||
const tokensPath = this.resolveModelPath(SENSEVOICE_MODEL.files.tokens)
|
||
try {
|
||
if (existsSync(modelPath)) unlinkSync(modelPath)
|
||
if (existsSync(tokensPath)) unlinkSync(tokensPath)
|
||
} catch { }
|
||
return { success: false, error: String(error) }
|
||
} finally {
|
||
this.downloadTasks.delete(cacheKey)
|
||
}
|
||
})()
|
||
|
||
this.downloadTasks.set(cacheKey, task)
|
||
return task
|
||
}
|
||
|
||
/**
|
||
* 转写 WAV 音频数据
|
||
*/
|
||
async transcribeWavBuffer(
|
||
wavData: Buffer,
|
||
onPartial?: (text: string) => void,
|
||
languages?: string[]
|
||
): Promise<{ success: boolean; transcript?: string; error?: string }> {
|
||
return new Promise((resolve) => {
|
||
try {
|
||
const modelPath = this.resolveModelPath(SENSEVOICE_MODEL.files.model)
|
||
const tokensPath = this.resolveModelPath(SENSEVOICE_MODEL.files.tokens)
|
||
|
||
if (!existsSync(modelPath) || !existsSync(tokensPath)) {
|
||
resolve({ success: false, error: '模型文件不存在,请先下载模型' })
|
||
return
|
||
}
|
||
|
||
let supportedLanguages = languages
|
||
if (!supportedLanguages || supportedLanguages.length === 0) {
|
||
supportedLanguages = this.configService.get('transcribeLanguages')
|
||
if (!supportedLanguages || supportedLanguages.length === 0) {
|
||
supportedLanguages = ['zh', 'yue']
|
||
}
|
||
}
|
||
|
||
const { Worker } = require('worker_threads')
|
||
const workerPath = join(__dirname, 'transcribeWorker.js')
|
||
|
||
const worker = new Worker(workerPath, {
|
||
workerData: {
|
||
modelPath,
|
||
tokensPath,
|
||
wavData,
|
||
sampleRate: 16000,
|
||
languages: supportedLanguages
|
||
}
|
||
})
|
||
|
||
let finalTranscript = ''
|
||
|
||
worker.on('message', (msg: any) => {
|
||
if (msg.type === 'partial') {
|
||
onPartial?.(msg.text)
|
||
} else if (msg.type === 'final') {
|
||
finalTranscript = msg.text
|
||
resolve({ success: true, transcript: finalTranscript })
|
||
worker.terminate()
|
||
} else if (msg.type === 'error') {
|
||
console.error('[VoiceTranscribe] Worker 错误:', msg.error)
|
||
resolve({ success: false, error: msg.error })
|
||
worker.terminate()
|
||
}
|
||
})
|
||
|
||
worker.on('error', (err: Error) => resolve({ success: false, error: String(err) }))
|
||
worker.on('exit', (code: number) => {
|
||
if (code !== 0) resolve({ success: false, error: `Worker exited with code ${code}` })
|
||
})
|
||
|
||
} catch (error) {
|
||
resolve({ success: false, error: String(error) })
|
||
}
|
||
})
|
||
}
|
||
|
||
/**
|
||
* 下载文件 (支持多线程)
|
||
*/
|
||
private async downloadToFile(
|
||
url: string,
|
||
targetPath: string,
|
||
fileName: string,
|
||
onProgress?: (downloaded: number, total?: number, speed?: number) => void
|
||
): Promise<void> {
|
||
if (existsSync(targetPath)) {
|
||
unlinkSync(targetPath)
|
||
}
|
||
|
||
console.info(`[VoiceTranscribe] 准备下载 ${fileName}: ${url}`)
|
||
|
||
// 1. 探测支持情况
|
||
let probeResult
|
||
try {
|
||
probeResult = await this.probeUrl(url)
|
||
} catch (err) {
|
||
console.warn(`[VoiceTranscribe] ${fileName} 探测失败,使用单线程`, err)
|
||
return this.downloadSingleThread(url, targetPath, fileName, onProgress)
|
||
}
|
||
|
||
const { totalSize, acceptRanges, finalUrl } = probeResult
|
||
|
||
// 如果文件太小 (< 2MB) 或者不支持 Range,使用单线程
|
||
if (totalSize < 2 * 1024 * 1024 || !acceptRanges) {
|
||
return this.downloadSingleThread(finalUrl, targetPath, fileName, onProgress)
|
||
}
|
||
|
||
console.info(`[VoiceTranscribe] ${fileName} 开始多线程下载 (4 线程), 大小: ${(totalSize / 1024 / 1024).toFixed(2)} MB`)
|
||
|
||
const threadCount = 4
|
||
const chunkSize = Math.ceil(totalSize / threadCount)
|
||
const fd = openSync(targetPath, 'w')
|
||
|
||
let downloadedTotal = 0
|
||
let lastDownloaded = 0
|
||
let lastTime = Date.now()
|
||
let speed = 0
|
||
|
||
const speedInterval = setInterval(() => {
|
||
const now = Date.now()
|
||
const duration = (now - lastTime) / 1000
|
||
if (duration > 0) {
|
||
speed = (downloadedTotal - lastDownloaded) / duration
|
||
lastDownloaded = downloadedTotal
|
||
lastTime = now
|
||
onProgress?.(downloadedTotal, totalSize, speed)
|
||
}
|
||
}, 1000)
|
||
|
||
try {
|
||
const promises = []
|
||
for (let i = 0; i < threadCount; i++) {
|
||
const start = i * chunkSize
|
||
const end = i === threadCount - 1 ? totalSize - 1 : (i + 1) * chunkSize - 1
|
||
|
||
promises.push(this.downloadChunk(finalUrl, fd, start, end, (bytes) => {
|
||
downloadedTotal += bytes
|
||
}))
|
||
}
|
||
|
||
await Promise.all(promises)
|
||
// Final progress update
|
||
onProgress?.(totalSize, totalSize, 0)
|
||
console.info(`[VoiceTranscribe] ${fileName} 多线程下载完成`)
|
||
} catch (err) {
|
||
console.error(`[VoiceTranscribe] ${fileName} 多线程下载失败:`, err)
|
||
throw err
|
||
} finally {
|
||
clearInterval(speedInterval)
|
||
closeSync(fd)
|
||
}
|
||
}
|
||
|
||
private async probeUrl(url: string, remainingRedirects = 5): Promise<{ totalSize: number, acceptRanges: boolean, finalUrl: string }> {
|
||
return new Promise((resolve, reject) => {
|
||
const protocol = url.startsWith('https') ? https : http
|
||
const options = {
|
||
method: 'GET',
|
||
headers: {
|
||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
|
||
'Referer': 'https://modelscope.cn/',
|
||
'Range': 'bytes=0-0'
|
||
}
|
||
}
|
||
|
||
const req = protocol.get(url, options, (res) => {
|
||
if ([301, 302, 303, 307, 308].includes(res.statusCode || 0)) {
|
||
const location = res.headers.location
|
||
if (location && remainingRedirects > 0) {
|
||
const nextUrl = new URL(location, url).href
|
||
this.probeUrl(nextUrl, remainingRedirects - 1).then(resolve).catch(reject)
|
||
return
|
||
}
|
||
}
|
||
|
||
if (res.statusCode !== 206 && res.statusCode !== 200) {
|
||
reject(new Error(`Probe failed: HTTP ${res.statusCode}`))
|
||
return
|
||
}
|
||
|
||
const contentRange = res.headers['content-range']
|
||
let totalSize = 0
|
||
if (contentRange) {
|
||
const parts = contentRange.split('/')
|
||
totalSize = parseInt(parts[parts.length - 1], 10)
|
||
} else {
|
||
totalSize = parseInt(res.headers['content-length'] || '0', 10)
|
||
}
|
||
|
||
const acceptRanges = res.headers['accept-ranges'] === 'bytes' || !!contentRange
|
||
resolve({ totalSize, acceptRanges, finalUrl: url })
|
||
res.destroy()
|
||
})
|
||
req.on('error', reject)
|
||
})
|
||
}
|
||
|
||
private async downloadChunk(url: string, fd: number, start: number, end: number, onData: (bytes: number) => void): Promise<void> {
|
||
return new Promise((resolve, reject) => {
|
||
const protocol = url.startsWith('https') ? https : http
|
||
const options = {
|
||
headers: {
|
||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
|
||
'Referer': 'https://modelscope.cn/',
|
||
'Range': `bytes=${start}-${end}`
|
||
}
|
||
}
|
||
|
||
const req = protocol.get(url, options, (res) => {
|
||
if (res.statusCode !== 206) {
|
||
reject(new Error(`Chunk download failed: HTTP ${res.statusCode}`))
|
||
return
|
||
}
|
||
|
||
let currentOffset = start
|
||
res.on('data', (chunk: Buffer) => {
|
||
try {
|
||
writeSync(fd, chunk, 0, chunk.length, currentOffset)
|
||
currentOffset += chunk.length
|
||
onData(chunk.length)
|
||
} catch (err) {
|
||
reject(err)
|
||
res.destroy()
|
||
}
|
||
})
|
||
|
||
res.on('end', () => resolve())
|
||
res.on('error', reject)
|
||
})
|
||
req.on('error', reject)
|
||
})
|
||
}
|
||
|
||
private async downloadSingleThread(url: string, targetPath: string, fileName: string, onProgress?: (downloaded: number, total?: number, speed?: number) => void, remainingRedirects = 5): Promise<void> {
|
||
return new Promise((resolve, reject) => {
|
||
const protocol = url.startsWith('https') ? https : http
|
||
const options = {
|
||
headers: {
|
||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
|
||
'Referer': 'https://modelscope.cn/'
|
||
}
|
||
}
|
||
|
||
const request = protocol.get(url, options, (response) => {
|
||
if ([301, 302, 303, 307, 308].includes(response.statusCode || 0)) {
|
||
const location = response.headers.location
|
||
if (location && remainingRedirects > 0) {
|
||
const nextUrl = new URL(location, url).href
|
||
this.downloadSingleThread(nextUrl, targetPath, fileName, onProgress, remainingRedirects - 1).then(resolve).catch(reject)
|
||
return
|
||
}
|
||
}
|
||
|
||
if (response.statusCode !== 200) {
|
||
reject(new Error(`Fallback download failed: HTTP ${response.statusCode}`))
|
||
return
|
||
}
|
||
|
||
const totalBytes = Number(response.headers['content-length'] || 0) || undefined
|
||
let downloadedBytes = 0
|
||
let lastDownloaded = 0
|
||
let lastTime = Date.now()
|
||
let speed = 0
|
||
|
||
const speedInterval = setInterval(() => {
|
||
const now = Date.now()
|
||
const duration = (now - lastTime) / 1000
|
||
if (duration > 0) {
|
||
speed = (downloadedBytes - lastDownloaded) / duration
|
||
lastDownloaded = downloadedBytes
|
||
lastTime = now
|
||
onProgress?.(downloadedBytes, totalBytes, speed)
|
||
}
|
||
}, 1000)
|
||
|
||
const writer = createWriteStream(targetPath)
|
||
response.on('data', (chunk) => {
|
||
downloadedBytes += chunk.length
|
||
})
|
||
|
||
writer.on('finish', () => {
|
||
clearInterval(speedInterval)
|
||
writer.close()
|
||
resolve()
|
||
})
|
||
|
||
writer.on('error', (err) => {
|
||
clearInterval(speedInterval)
|
||
// 确保在错误情况下也关闭文件句柄
|
||
writer.destroy()
|
||
reject(err)
|
||
})
|
||
|
||
response.on('error', (err) => {
|
||
clearInterval(speedInterval)
|
||
// 确保在响应错误时也关闭文件句柄
|
||
writer.destroy()
|
||
reject(err)
|
||
})
|
||
|
||
response.pipe(writer)
|
||
})
|
||
request.on('error', reject)
|
||
})
|
||
}
|
||
|
||
dispose() {
|
||
if (this.recognizer) {
|
||
this.recognizer = null
|
||
}
|
||
}
|
||
}
|
||
|
||
export const voiceTranscribeService = new VoiceTranscribeService()
|