mirror of
https://github.com/hellodigua/ChatLab.git
synced 2026-05-23 06:51:10 +08:00
feat: 支持远程下载分词词库,并支持繁体中文词库
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
|
||||
import { openDatabase, buildTimeFilter, type TimeFilter } from '../core'
|
||||
import { segment, batchSegmentWithFrequency, getPosTagDefinitions, collectPosTagStats } from '../../nlp'
|
||||
import type { SupportedLocale, WordFrequencyResult, WordFrequencyParams, PosTagInfo, PosTagStat } from '../../nlp'
|
||||
import type { SupportedLocale, WordFrequencyResult, WordFrequencyParams, PosTagInfo, PosTagStat, DictType } from '../../nlp'
|
||||
|
||||
/**
|
||||
* 获取词频统计
|
||||
@@ -23,6 +23,7 @@ export function getWordFrequency(params: WordFrequencyParams): WordFrequencyResu
|
||||
posFilterMode = 'meaningful',
|
||||
customPosTags,
|
||||
enableStopwords = true,
|
||||
dictType = 'default',
|
||||
} = params
|
||||
|
||||
const db = openDatabase(sessionId)
|
||||
@@ -35,14 +36,12 @@ export function getWordFrequency(params: WordFrequencyParams): WordFrequencyResu
|
||||
}
|
||||
}
|
||||
|
||||
// 构建时间和成员过滤
|
||||
const filter: TimeFilter = {
|
||||
...timeFilter,
|
||||
memberId,
|
||||
}
|
||||
const { clause, params: filterParams } = buildTimeFilter(filter, 'msg')
|
||||
|
||||
// 构建 WHERE 子句,排除系统消息
|
||||
let whereClause = clause
|
||||
if (whereClause.includes('WHERE')) {
|
||||
whereClause +=
|
||||
@@ -52,7 +51,6 @@ export function getWordFrequency(params: WordFrequencyParams): WordFrequencyResu
|
||||
" WHERE COALESCE(m.account_name, '') != '系统消息' AND msg.type = 0 AND msg.content IS NOT NULL AND TRIM(msg.content) != ''"
|
||||
}
|
||||
|
||||
// 查询消息内容
|
||||
const messages = db
|
||||
.prepare(
|
||||
`
|
||||
@@ -64,7 +62,6 @@ export function getWordFrequency(params: WordFrequencyParams): WordFrequencyResu
|
||||
)
|
||||
.all(...filterParams) as Array<{ content: string }>
|
||||
|
||||
// 如果没有消息,返回空结果
|
||||
if (messages.length === 0) {
|
||||
return {
|
||||
words: [],
|
||||
@@ -74,18 +71,14 @@ export function getWordFrequency(params: WordFrequencyParams): WordFrequencyResu
|
||||
}
|
||||
}
|
||||
|
||||
// 提取文本内容
|
||||
const texts = messages.map((m) => m.content)
|
||||
|
||||
// 收集词性统计(用于显示每个词性有多少词,仅中文有效)
|
||||
let posTagStats: PosTagStat[] | undefined
|
||||
// 词性统计只对中文生效,这里先做类型兜底,避免异常 locale 直接触发 startsWith 报错。
|
||||
if (typeof locale === 'string' && locale.startsWith('zh')) {
|
||||
const posStatsMap = collectPosTagStats(texts, minWordLength ?? 2, enableStopwords)
|
||||
const posStatsMap = collectPosTagStats(texts, minWordLength ?? 2, enableStopwords, dictType as DictType)
|
||||
posTagStats = [...posStatsMap.entries()].map(([tag, count]) => ({ tag, count }))
|
||||
}
|
||||
|
||||
// 批量分词并统计词频
|
||||
const wordFrequency = batchSegmentWithFrequency(texts, locale as SupportedLocale, {
|
||||
minLength: minWordLength,
|
||||
minCount,
|
||||
@@ -93,15 +86,14 @@ export function getWordFrequency(params: WordFrequencyParams): WordFrequencyResu
|
||||
posFilterMode,
|
||||
customPosTags,
|
||||
enableStopwords,
|
||||
dictType: dictType as DictType,
|
||||
})
|
||||
|
||||
// 计算总词数(用于百分比)
|
||||
let totalWords = 0
|
||||
for (const count of wordFrequency.values()) {
|
||||
totalWords += count
|
||||
}
|
||||
|
||||
// 构建结果
|
||||
const words = [...wordFrequency.entries()].map(([word, count]) => ({
|
||||
word,
|
||||
count,
|
||||
|
||||
Reference in New Issue
Block a user