feat: 支持远程下载分词词库,并支持繁体中文词库

This commit is contained in:
digua
2026-04-12 17:31:32 +08:00
committed by digua
parent 8c91ff1c5a
commit 17bb3d772e
17 changed files with 571 additions and 90 deletions
+23 -1
View File
@@ -1,12 +1,13 @@
/**
* NLP 功能 IPC 处理器
* 提供词频统计、分词等 NLP 功能
* 提供词频统计、分词等 NLP 功能,以及词库管理
*/
import { ipcMain } from 'electron'
import * as worker from '../worker/workerManager'
import type { IpcContext } from './types'
import type { WordFrequencyParams, WordFrequencyResult, SupportedLocale, PosTagInfo } from '../nlp'
import { getDictList, downloadDict, deleteDict, isDictDownloaded, type DictInfo } from '../nlp/dictManager'
/**
* 注册 NLP 相关 IPC 处理器
@@ -60,4 +61,25 @@ export function registerNlpHandlers(_ctx: IpcContext): void {
return []
}
})
// ==================== 词库管理 ====================
ipcMain.handle('nlp:getDictList', async (): Promise<DictInfo[]> => {
return getDictList()
})
ipcMain.handle('nlp:isDictDownloaded', async (_event, dictId: string): Promise<boolean> => {
return isDictDownloaded(dictId)
})
ipcMain.handle(
'nlp:downloadDict',
async (_event, dictId: string): Promise<{ success: boolean; error?: string }> => {
return downloadDict(dictId)
}
)
ipcMain.handle('nlp:deleteDict', async (_event, dictId: string): Promise<{ success: boolean; error?: string }> => {
return deleteDict(dictId)
})
}
+172
View File
@@ -0,0 +1,172 @@
/**
* NLP 词库管理器
* 负责自定义词库的下载、查询、删除
* 词库存储在 userData/nlp/ 目录下
*/
import * as fs from 'fs'
import * as path from 'path'
import { app } from 'electron'
import axios from 'axios'
const NLP_DIR_NAME = 'nlp'
const DICT_DOWNLOAD_URL_BASE = 'https://chatlab.fun/assets/nlp'
export interface DictInfo {
id: string
label: string
locale: string
downloaded: boolean
fileSize?: number
}
const AVAILABLE_DICTS: Omit<DictInfo, 'downloaded' | 'fileSize'>[] = [
{ id: 'zh-CN', label: '简体中文', locale: 'zh-CN' },
{ id: 'zh-TW', label: '繁體中文', locale: 'zh-TW' },
]
export function getNlpDir(): string {
const userDataPath = app.getPath('userData')
return path.join(userDataPath, 'data', NLP_DIR_NAME)
}
function ensureNlpDir(): void {
const dir = getNlpDir()
if (!fs.existsSync(dir)) {
fs.mkdirSync(dir, { recursive: true })
}
}
function getDictFilePath(dictId: string): string {
return path.join(getNlpDir(), `${dictId}.dict`)
}
function getDictDownloadUrl(dictId: string): string {
return `${DICT_DOWNLOAD_URL_BASE}/${dictId}.dict`
}
export function isDictDownloaded(dictId: string): boolean {
return fs.existsSync(getDictFilePath(dictId))
}
export function getDictList(): DictInfo[] {
return AVAILABLE_DICTS.map((d) => {
const filePath = getDictFilePath(d.id)
const downloaded = fs.existsSync(filePath)
let fileSize: number | undefined
if (downloaded) {
try {
fileSize = fs.statSync(filePath).size
} catch {
/* ignore */
}
}
return { ...d, downloaded, fileSize }
})
}
export function loadDictBuffer(dictId: string): Buffer | null {
const filePath = getDictFilePath(dictId)
if (!fs.existsSync(filePath)) return null
try {
return fs.readFileSync(filePath)
} catch (error) {
console.error(`[NLP DictManager] Failed to read dict file: ${filePath}`, error)
return null
}
}
export async function downloadDict(
dictId: string,
onProgress?: (percent: number) => void
): Promise<{ success: boolean; error?: string }> {
const dictDef = AVAILABLE_DICTS.find((d) => d.id === dictId)
if (!dictDef) {
return { success: false, error: `Unknown dict: ${dictId}` }
}
ensureNlpDir()
const url = getDictDownloadUrl(dictId)
const filePath = getDictFilePath(dictId)
const tmpPath = filePath + '.tmp'
try {
const response = await axios.get(url, {
responseType: 'arraybuffer',
timeout: 120_000,
onDownloadProgress: (progressEvent) => {
if (progressEvent.total && onProgress) {
onProgress(Math.round((progressEvent.loaded / progressEvent.total) * 100))
}
},
})
const buffer = Buffer.from(response.data)
// 词库文件至少应 > 1MB,且不应以 HTML 标签开头
const MIN_DICT_SIZE = 1_000_000
if (buffer.length < MIN_DICT_SIZE) {
const preview = buffer.subarray(0, 200).toString('utf-8')
console.error(`[NLP DictManager] Downloaded file too small (${buffer.length} bytes), preview: ${preview}`)
return { success: false, error: `Downloaded file is invalid (${buffer.length} bytes). The dictionary URL may not be available yet.` }
}
const head = buffer.subarray(0, 50).toString('utf-8').trim()
if (head.startsWith('<!') || head.startsWith('<html')) {
console.error(`[NLP DictManager] Downloaded file appears to be HTML, not a dict file`)
return { success: false, error: 'Downloaded file is HTML, not a dictionary file. The URL may not be deployed yet.' }
}
fs.writeFileSync(tmpPath, buffer)
if (fs.existsSync(filePath)) {
fs.unlinkSync(filePath)
}
fs.renameSync(tmpPath, filePath)
console.log(`[NLP DictManager] Dict downloaded: ${dictId} (${fs.statSync(filePath).size} bytes)`)
return { success: true }
} catch (error) {
if (fs.existsSync(tmpPath)) {
try {
fs.unlinkSync(tmpPath)
} catch {
/* ignore */
}
}
const msg = error instanceof Error ? error.message : String(error)
console.error(`[NLP DictManager] Download failed for ${dictId}:`, msg)
return { success: false, error: msg }
}
}
/**
* 应用启动时调用,自动后台下载简体中文词库(如未下载)
*/
export async function ensureDefaultDict(): Promise<void> {
if (isDictDownloaded('zh-CN')) return
console.log('[NLP DictManager] zh-CN dict not found, starting background download...')
const result = await downloadDict('zh-CN')
if (result.success) {
console.log('[NLP DictManager] zh-CN dict auto-downloaded successfully')
} else {
console.warn('[NLP DictManager] zh-CN dict auto-download failed:', result.error)
}
}
export function deleteDict(dictId: string): { success: boolean; error?: string } {
const filePath = getDictFilePath(dictId)
if (!fs.existsSync(filePath)) {
return { success: true }
}
try {
fs.unlinkSync(filePath)
console.log(`[NLP DictManager] Dict deleted: ${dictId}`)
return { success: true }
} catch (error) {
const msg = error instanceof Error ? error.message : String(error)
console.error(`[NLP DictManager] Delete failed for ${dictId}:`, msg)
return { success: false, error: msg }
}
}
+3 -21
View File
@@ -6,29 +6,11 @@
*
* 使用 jieba 处理中文(天然兼容中英混合文本),
* Intl.Segmenter 处理纯英文/日文。
*
* 复用 segmenter 模块的 jieba 实例池,默认使用 zh-CN 词库。
*/
interface JiebaInstance {
cut: (text: string, hmm?: boolean) => string[]
}
let jiebaInstance: JiebaInstance | null = null
function getJieba(): JiebaInstance {
if (!jiebaInstance) {
try {
// eslint-disable-next-line @typescript-eslint/no-require-imports
const { Jieba } = require('@node-rs/jieba')
// eslint-disable-next-line @typescript-eslint/no-require-imports
const { dict } = require('@node-rs/jieba/dict')
jiebaInstance = Jieba.withDict(dict)
} catch (error) {
console.error('[FTS] Failed to load jieba module:', error)
throw new Error('jieba 模块加载失败')
}
}
return jiebaInstance!
}
import { getJieba } from './segmenter'
/**
* 对文本进行 FTS 分词,返回空格分隔的 token 字符串。
+2
View File
@@ -5,3 +5,5 @@
export * from './types'
export * from './stopwords'
export * from './segmenter'
// dictManager 需要 electron app 模块,只能在主进程中直接导入
// import { ... } from './dictManager'
+85 -36
View File
@@ -1,42 +1,85 @@
/**
* 分词器模块
* 中文使用 @node-rs/jieba,其他语言使用 Intl.Segmenter
*
* 支持多词库:默认内置简体中文词库,可通过 dictType 加载繁体中文等自定义词库。
* 自定义词库文件存储在 nlpDir 目录下(由 Worker 初始化时传入)。
*/
import type { SupportedLocale, PosFilterMode, PosTagInfo } from './types'
import * as fs from 'fs'
import * as path from 'path'
import type { SupportedLocale, PosFilterMode, PosTagInfo, DictType } from './types'
import { isStopword } from './stopwords'
// Jieba 实例类型
export type { DictType }
interface JiebaInstance {
cut: (text: string, hmm?: boolean) => string[]
tag: (text: string) => Array<{ tag: string; word: string }>
}
// Jieba 实例(延迟初始化)
let jiebaInstance: JiebaInstance | null = null
let _nlpDir: string | null = null
const jiebaInstances = new Map<DictType, JiebaInstance>()
/**
* 获取 Jieba 实例(延迟加载)
* 由 Worker 初始化时调用,设置自定义词库目录路径
*/
function getJieba(): JiebaInstance {
if (!jiebaInstance) {
try {
// eslint-disable-next-line @typescript-eslint/no-require-imports
const { Jieba } = require('@node-rs/jieba')
// eslint-disable-next-line @typescript-eslint/no-require-imports
const { dict } = require('@node-rs/jieba/dict')
jiebaInstance = Jieba.withDict(dict)
console.log('[NLP] jieba module loaded')
} catch (error) {
console.error('[NLP] Failed to load jieba module:', error)
throw new Error('jieba 模块加载失败')
export function initNlpDir(nlpDir: string): void {
_nlpDir = nlpDir
}
/**
* 尝试从 nlpDir 加载词库文件,返回 Buffer 或 null
*/
function tryLoadDictFromDisk(dictId: string): Buffer | null {
if (!_nlpDir) return null
const dictPath = path.join(_nlpDir, `${dictId}.dict`)
if (!fs.existsSync(dictPath)) return null
try {
return fs.readFileSync(dictPath)
} catch {
return null
}
}
/**
* 获取 Jieba 实例(支持多词库)
*
* 所有词库均从 nlpDir 磁盘加载(由应用启动时自动下载)。
* default 和 zh-CN 共用同一实例。
*/
export function getJieba(dictType: DictType = 'default'): JiebaInstance {
const effectiveType = dictType === 'default' ? 'zh-CN' : dictType
const cached = jiebaInstances.get(effectiveType)
if (cached) return cached
try {
// eslint-disable-next-line @typescript-eslint/no-require-imports
const { Jieba } = require('@node-rs/jieba')
const diskDict = tryLoadDictFromDisk(effectiveType)
if (!diskDict) {
throw new Error(`Dict file not found for: ${effectiveType}. Please ensure the dictionary has been downloaded.`)
}
const instance: JiebaInstance = Jieba.withDict(diskDict)
console.log(`[NLP] jieba dict loaded: ${effectiveType} (${diskDict.length} bytes)`)
jiebaInstances.set(effectiveType, instance)
return instance
} catch (error) {
console.error(`[NLP] Failed to load jieba module (dict=${effectiveType}):`, error)
throw new Error(`jieba 模块加载失败 (${effectiveType})`)
}
const instance = jiebaInstance
if (!instance) {
throw new Error('jieba 模块未初始化')
}
return instance
}
/**
* 清除指定词库的缓存实例(词库更新后调用)
*/
export function clearJiebaInstance(dictType: DictType): void {
jiebaInstances.delete(dictType)
console.log(`[NLP] jieba instance cleared: ${dictType}`)
}
/**
@@ -147,6 +190,8 @@ interface ChineseSegmentOptions {
posFilterMode?: PosFilterMode
/** 自定义词性过滤列表 */
customPosTags?: string[]
/** 词库类型 */
dictType?: DictType
}
/**
@@ -156,12 +201,13 @@ interface ChineseSegmentOptions {
export function collectPosTagStats(
texts: string[],
minWordLength: number = 2,
enableStopwords: boolean = true
enableStopwords: boolean = true,
dictType: DictType = 'default'
): Map<string, number> {
const posStats = new Map<string, number>()
try {
const jieba = getJieba()
const jieba = getJieba(dictType)
for (const text of texts) {
const cleaned = cleanText(text)
@@ -191,36 +237,31 @@ export function collectPosTagStats(
* @param options 分词选项
*/
function segmentChinese(text: string, options: ChineseSegmentOptions = {}): string[] {
const { posFilterMode = 'meaningful', customPosTags } = options
const { posFilterMode = 'meaningful', customPosTags, dictType = 'default' } = options
const cleaned = cleanText(text)
if (!cleaned) return []
try {
const jieba = getJieba()
const jieba = getJieba(dictType)
// 全部模式:直接分词,不做词性过滤
if (posFilterMode === 'all') {
return jieba.cut(cleaned, false)
}
// 使用词性标注
const tagged = jieba.tag(cleaned)
// 根据模式过滤
let allowedTags: Set<string>
if (posFilterMode === 'custom' && customPosTags) {
allowedTags = new Set(customPosTags)
} else {
// meaningful 模式
allowedTags = MEANINGFUL_POS_TAGS
}
return tagged.filter((item) => allowedTags.has(item.tag)).map((item) => item.word)
} catch (error) {
console.error('[NLP] Chinese segmentation failed:', error)
// 降级:使用简单分词
try {
const jieba = getJieba()
const jieba = getJieba('default')
return jieba.cut(cleaned, false)
} catch {
return cleaned.split('')
@@ -277,6 +318,8 @@ export interface SegmentOptions {
customPosTags?: string[]
/** 是否启用停用词过滤 */
enableStopwords?: boolean
/** 词库类型(仅中文有效) */
dictType?: DictType
}
/**
@@ -287,7 +330,13 @@ export interface SegmentOptions {
* @returns 过滤后的分词结果
*/
export function segment(text: string, locale: SupportedLocale, options: SegmentOptions = {}): string[] {
const { minLength, posFilterMode = 'meaningful', customPosTags, enableStopwords = true } = options
const {
minLength,
posFilterMode = 'meaningful',
customPosTags,
enableStopwords = true,
dictType = 'default',
} = options
const isChinese = locale.startsWith('zh')
const isJapanese = locale === 'ja-JP'
const defaultMinLength = isChinese || isJapanese ? 2 : 3
@@ -296,7 +345,7 @@ export function segment(text: string, locale: SupportedLocale, options: SegmentO
let words: string[]
if (isChinese) {
words = segmentChinese(text, { posFilterMode, customPosTags })
words = segmentChinese(text, { posFilterMode, customPosTags, dictType })
} else if (isJapanese) {
words = segmentJapanese(text)
} else {
@@ -326,11 +375,11 @@ export function batchSegmentWithFrequency(
locale: SupportedLocale,
options: BatchSegmentOptions = {}
): Map<string, number> {
const { minLength, minCount = 2, topN = 100, posFilterMode, customPosTags, enableStopwords } = options
const { minLength, minCount = 2, topN = 100, posFilterMode, customPosTags, enableStopwords, dictType } = options
const wordFrequency = new Map<string, number>()
for (const text of texts) {
const words = segment(text, locale, { minLength, posFilterMode, customPosTags, enableStopwords })
const words = segment(text, locale, { minLength, posFilterMode, customPosTags, enableStopwords, dictType })
for (const word of words) {
wordFrequency.set(word, (wordFrequency.get(word) || 0) + 1)
}
+5
View File
@@ -48,6 +48,9 @@ export interface WordFrequencyResult {
/** 词性过滤模式 */
export type PosFilterMode = 'all' | 'meaningful' | 'custom'
/** 词库类型 */
export type DictType = 'default' | 'zh-CN' | 'zh-TW'
/** 词频统计参数 */
export interface WordFrequencyParams {
/** 会话 ID */
@@ -73,6 +76,8 @@ export interface WordFrequencyParams {
customPosTags?: string[]
/** 是否启用停用词过滤,默认 true */
enableStopwords?: boolean
/** 词库类型:default=内置简体中文, zh-TW=繁体中文 */
dictType?: DictType
}
/** 词性标签信息 */
+7
View File
@@ -72,10 +72,16 @@ import {
getPosTags,
} from './query'
import { streamImport, streamParseFileInfo, analyzeIncrementalImport, incrementalImport } from './import'
import { initNlpDir } from '../nlp/segmenter'
// 初始化数据库目录
initDbDir(workerData.dbDir, workerData.cacheDir)
// 初始化 NLP 词库目录
if (workerData.nlpDir) {
initNlpDir(workerData.nlpDir)
}
// ==================== 分析结果缓存 ====================
const ANALYSIS_CACHE_PREFIX = 'analysis:'
@@ -124,6 +130,7 @@ function buildAnalysisCacheKey(type: string, payload: any): string {
if (payload.topN) parts.push(`n${payload.topN}`)
if (payload.minLength) parts.push(`ml${payload.minLength}`)
if (payload.posTags) parts.push(`pt${JSON.stringify(payload.posTags)}`)
if (payload.dictType && payload.dictType !== 'default') parts.push(`dt${payload.dictType}`)
return parts.join(':')
}
+4 -12
View File
@@ -5,7 +5,7 @@
import { openDatabase, buildTimeFilter, type TimeFilter } from '../core'
import { segment, batchSegmentWithFrequency, getPosTagDefinitions, collectPosTagStats } from '../../nlp'
import type { SupportedLocale, WordFrequencyResult, WordFrequencyParams, PosTagInfo, PosTagStat } from '../../nlp'
import type { SupportedLocale, WordFrequencyResult, WordFrequencyParams, PosTagInfo, PosTagStat, DictType } from '../../nlp'
/**
* 获取词频统计
@@ -23,6 +23,7 @@ export function getWordFrequency(params: WordFrequencyParams): WordFrequencyResu
posFilterMode = 'meaningful',
customPosTags,
enableStopwords = true,
dictType = 'default',
} = params
const db = openDatabase(sessionId)
@@ -35,14 +36,12 @@ export function getWordFrequency(params: WordFrequencyParams): WordFrequencyResu
}
}
// 构建时间和成员过滤
const filter: TimeFilter = {
...timeFilter,
memberId,
}
const { clause, params: filterParams } = buildTimeFilter(filter, 'msg')
// 构建 WHERE 子句,排除系统消息
let whereClause = clause
if (whereClause.includes('WHERE')) {
whereClause +=
@@ -52,7 +51,6 @@ export function getWordFrequency(params: WordFrequencyParams): WordFrequencyResu
" WHERE COALESCE(m.account_name, '') != '系统消息' AND msg.type = 0 AND msg.content IS NOT NULL AND TRIM(msg.content) != ''"
}
// 查询消息内容
const messages = db
.prepare(
`
@@ -64,7 +62,6 @@ export function getWordFrequency(params: WordFrequencyParams): WordFrequencyResu
)
.all(...filterParams) as Array<{ content: string }>
// 如果没有消息,返回空结果
if (messages.length === 0) {
return {
words: [],
@@ -74,18 +71,14 @@ export function getWordFrequency(params: WordFrequencyParams): WordFrequencyResu
}
}
// 提取文本内容
const texts = messages.map((m) => m.content)
// 收集词性统计(用于显示每个词性有多少词,仅中文有效)
let posTagStats: PosTagStat[] | undefined
// 词性统计只对中文生效,这里先做类型兜底,避免异常 locale 直接触发 startsWith 报错。
if (typeof locale === 'string' && locale.startsWith('zh')) {
const posStatsMap = collectPosTagStats(texts, minWordLength ?? 2, enableStopwords)
const posStatsMap = collectPosTagStats(texts, minWordLength ?? 2, enableStopwords, dictType as DictType)
posTagStats = [...posStatsMap.entries()].map(([tag, count]) => ({ tag, count }))
}
// 批量分词并统计词频
const wordFrequency = batchSegmentWithFrequency(texts, locale as SupportedLocale, {
minLength: minWordLength,
minCount,
@@ -93,15 +86,14 @@ export function getWordFrequency(params: WordFrequencyParams): WordFrequencyResu
posFilterMode,
customPosTags,
enableStopwords,
dictType: dictType as DictType,
})
// 计算总词数(用于百分比)
let totalWords = 0
for (const count of wordFrequency.values()) {
totalWords += count
}
// 构建结果
const words = [...wordFrequency.entries()].map(([word, count]) => ({
word,
count,
+2
View File
@@ -10,6 +10,7 @@ import type { ParseProgress } from '../parser'
import type { StreamImportResult } from './import'
import { openDatabase } from '../database/core'
import { getDatabaseDir, getCacheDir, ensureDir } from '../paths'
import { getNlpDir } from '../nlp/dictManager'
// Worker 实例
let worker: Worker | null = null
@@ -70,6 +71,7 @@ export function initWorker(): void {
workerData: {
dbDir: getDbDir(),
cacheDir: getCacheDir(),
nlpDir: getNlpDir(),
},
})
+16 -9
View File
@@ -94,26 +94,33 @@ export interface ChatSessionItem {
// ==================== NLP API ====================
export const nlpApi = {
/**
* 获取词频统计(用于词云)
*/
getWordFrequency: (params: WordFrequencyParams): Promise<WordFrequencyResult> => {
return ipcRenderer.invoke('nlp:getWordFrequency', params)
},
/**
* 单文本分词
*/
segmentText: (text: string, locale: 'zh-CN' | 'en-US', minLength?: number): Promise<string[]> => {
return ipcRenderer.invoke('nlp:segmentText', text, locale, minLength)
},
/**
* 获取词性标签定义
*/
getPosTags: (): Promise<PosTagInfo[]> => {
return ipcRenderer.invoke('nlp:getPosTags')
},
getDictList: (): Promise<Array<{ id: string; label: string; locale: string; downloaded: boolean; fileSize?: number }>> => {
return ipcRenderer.invoke('nlp:getDictList')
},
isDictDownloaded: (dictId: string): Promise<boolean> => {
return ipcRenderer.invoke('nlp:isDictDownloaded', dictId)
},
downloadDict: (dictId: string): Promise<{ success: boolean; error?: string }> => {
return ipcRenderer.invoke('nlp:downloadDict', dictId)
},
deleteDict: (dictId: string): Promise<{ success: boolean; error?: string }> => {
return ipcRenderer.invoke('nlp:deleteDict', dictId)
},
}
// ==================== Network API ====================
+16
View File
@@ -923,6 +923,16 @@ interface WordFrequencyResult {
posTagStats?: PosTagStat[]
}
type DictType = 'default' | 'zh-CN' | 'zh-TW'
interface DictInfo {
id: string
label: string
locale: string
downloaded: boolean
fileSize?: number
}
interface WordFrequencyParams {
sessionId: string
locale: SupportedLocale
@@ -937,6 +947,8 @@ interface WordFrequencyParams {
customPosTags?: string[]
/** 是否启用停用词过滤,默认 true */
enableStopwords?: boolean
/** 词库类型:default=内置简体中文, zh-TW=繁体中文 */
dictType?: DictType
}
/** 词性标签信息 */
@@ -951,6 +963,10 @@ interface NlpApi {
getWordFrequency: (params: WordFrequencyParams) => Promise<WordFrequencyResult>
segmentText: (text: string, locale: SupportedLocale, minLength?: number) => Promise<string[]>
getPosTags: () => Promise<PosTagInfo[]>
getDictList: () => Promise<DictInfo[]>
isDictDownloaded: (dictId: string) => Promise<boolean>
downloadDict: (dictId: string) => Promise<{ success: boolean; error?: string }>
deleteDict: (dictId: string) => Promise<{ success: boolean; error?: string }>
}
// ChatLab API 服务类型