mirror of
https://github.com/hellodigua/ChatLab.git
synced 2026-05-02 10:51:18 +08:00
feat: 聊天会话支持摘要功能
This commit is contained in:
456
electron/main/ai/summary/index.ts
Normal file
456
electron/main/ai/summary/index.ts
Normal file
@@ -0,0 +1,456 @@
|
||||
/**
|
||||
* 会话摘要生成服务
|
||||
*
|
||||
* 利用 LLM 为会话生成摘要
|
||||
* - 智能预处理:过滤无意义内容(纯表情、单字回复等)
|
||||
* - 根据消息数量智能调整摘要长度
|
||||
* - 超长会话采用 Map-Reduce 策略
|
||||
*/
|
||||
|
||||
import Database from 'better-sqlite3'
|
||||
import { chat } from '../llm'
|
||||
import { getDbPath, openDatabase } from '../../database/core'
|
||||
import { aiLogger } from '../logger'
|
||||
|
||||
/** 最小消息数阈值(少于此数量不生成摘要) */
|
||||
const MIN_MESSAGE_COUNT = 3
|
||||
|
||||
/** 单次 LLM 调用的最大内容字符数(约 2000 tokens,留安全余量) */
|
||||
const MAX_CONTENT_PER_CALL = 8000
|
||||
|
||||
/** 需要分段处理的阈值 */
|
||||
const SEGMENT_THRESHOLD = 8000
|
||||
|
||||
// ==================== 数据库操作函数(独立于 Worker) ====================
|
||||
|
||||
interface SessionMessagesResult {
|
||||
messageCount: number
|
||||
messages: Array<{
|
||||
senderName: string
|
||||
content: string | null
|
||||
}>
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取会话消息(主进程版本,使用 database/core)
|
||||
*/
|
||||
function getSessionMessagesForSummary(
|
||||
dbSessionId: string,
|
||||
chatSessionId: number,
|
||||
limit: number = 500
|
||||
): SessionMessagesResult | null {
|
||||
const db = openDatabase(dbSessionId, true)
|
||||
if (!db) {
|
||||
aiLogger.error('Summary', `数据库打开失败: ${dbSessionId}`)
|
||||
return null
|
||||
}
|
||||
|
||||
try {
|
||||
// 获取会话消息
|
||||
const messagesSql = `
|
||||
SELECT
|
||||
COALESCE(mb.group_nickname, mb.account_name, mb.platform_id) as senderName,
|
||||
m.content
|
||||
FROM message_context mc
|
||||
JOIN message m ON m.id = mc.message_id
|
||||
JOIN member mb ON mb.id = m.sender_id
|
||||
WHERE mc.session_id = ?
|
||||
ORDER BY m.ts ASC
|
||||
LIMIT ?
|
||||
`
|
||||
const messages = db.prepare(messagesSql).all(chatSessionId, limit) as Array<{
|
||||
senderName: string
|
||||
content: string | null
|
||||
}>
|
||||
|
||||
return {
|
||||
messageCount: messages.length,
|
||||
messages,
|
||||
}
|
||||
} catch (error) {
|
||||
aiLogger.error('Summary', `获取会话消息失败: ${error}`)
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存会话摘要(主进程版本)
|
||||
*/
|
||||
function saveSessionSummaryToDb(dbSessionId: string, chatSessionId: number, summary: string): void {
|
||||
const dbPath = getDbPath(dbSessionId)
|
||||
const db = new Database(dbPath)
|
||||
|
||||
try {
|
||||
db.prepare('UPDATE chat_session SET summary = ? WHERE id = ?').run(summary, chatSessionId)
|
||||
} finally {
|
||||
db.close()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取会话摘要(主进程版本)
|
||||
*/
|
||||
function getSessionSummaryFromDb(dbSessionId: string, chatSessionId: number): string | null {
|
||||
const db = openDatabase(dbSessionId, true)
|
||||
if (!db) {
|
||||
return null
|
||||
}
|
||||
|
||||
try {
|
||||
const result = db.prepare('SELECT summary FROM chat_session WHERE id = ?').get(chatSessionId) as
|
||||
| { summary: string | null }
|
||||
| undefined
|
||||
return result?.summary || null
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据消息数量计算摘要长度限制
|
||||
* - 3-10 条消息:50 字
|
||||
* - 11-30 条消息:80 字
|
||||
* - 31-100 条消息:120 字
|
||||
* - 100+ 条消息:200 字
|
||||
*/
|
||||
function getSummaryLengthLimit(messageCount: number): number {
|
||||
if (messageCount <= 10) return 50
|
||||
if (messageCount <= 30) return 80
|
||||
if (messageCount <= 100) return 120
|
||||
return 200
|
||||
}
|
||||
|
||||
/**
|
||||
* 判断消息是否有意义(用于过滤)
|
||||
*/
|
||||
function isValidMessage(content: string): boolean {
|
||||
const trimmed = content.trim()
|
||||
|
||||
// 过滤空内容
|
||||
if (!trimmed) return false
|
||||
|
||||
// 过滤单字/双字无意义回复
|
||||
if (trimmed.length <= 2) {
|
||||
// 允许一些有意义的短词
|
||||
const meaningfulShort = ['好的', '不是', '是的', '可以', '不行', '好吧', '明白', '知道', '同意']
|
||||
if (!meaningfulShort.includes(trimmed)) return false
|
||||
}
|
||||
|
||||
// 过滤纯表情消息
|
||||
const emojiOnlyPattern = /^[\p{Emoji}\s[\]()()]+$/u
|
||||
if (emojiOnlyPattern.test(trimmed)) return false
|
||||
|
||||
// 过滤占位符文本
|
||||
const placeholders = ['[图片]', '[语音]', '[视频]', '[文件]', '[表情]', '[动画表情]', '[位置]', '[名片]', '[红包]', '[转账]', '[撤回消息]']
|
||||
if (placeholders.some((p) => trimmed === p)) return false
|
||||
|
||||
// 过滤系统消息(入群、退群等)
|
||||
const systemPatterns = [
|
||||
/^.*邀请.*加入了群聊$/,
|
||||
/^.*退出了群聊$/,
|
||||
/^.*撤回了一条消息$/,
|
||||
/^你撤回了一条消息$/,
|
||||
]
|
||||
if (systemPatterns.some((p) => p.test(trimmed))) return false
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
/**
|
||||
* 预处理消息:过滤无意义内容
|
||||
*/
|
||||
function preprocessMessages(
|
||||
messages: Array<{ senderName: string; content: string | null }>
|
||||
): Array<{ senderName: string; content: string }> {
|
||||
return messages
|
||||
.filter((m) => m.content && isValidMessage(m.content))
|
||||
.map((m) => ({ senderName: m.senderName, content: m.content!.trim() }))
|
||||
}
|
||||
|
||||
/**
|
||||
* 格式化消息为文本
|
||||
*/
|
||||
function formatMessages(messages: Array<{ senderName: string; content: string }>): string {
|
||||
return messages.map((m) => `${m.senderName}: ${m.content}`).join('\n')
|
||||
}
|
||||
|
||||
/**
|
||||
* 将消息分成多个段落
|
||||
*/
|
||||
function splitIntoSegments(
|
||||
messages: Array<{ senderName: string; content: string }>,
|
||||
maxCharsPerSegment: number
|
||||
): Array<Array<{ senderName: string; content: string }>> {
|
||||
const segments: Array<Array<{ senderName: string; content: string }>> = []
|
||||
let currentSegment: Array<{ senderName: string; content: string }> = []
|
||||
let currentLength = 0
|
||||
|
||||
for (const msg of messages) {
|
||||
const msgLength = msg.senderName.length + msg.content.length + 3 // "name: content\n"
|
||||
|
||||
if (currentLength + msgLength > maxCharsPerSegment && currentSegment.length > 0) {
|
||||
segments.push(currentSegment)
|
||||
currentSegment = []
|
||||
currentLength = 0
|
||||
}
|
||||
|
||||
currentSegment.push(msg)
|
||||
currentLength += msgLength
|
||||
}
|
||||
|
||||
if (currentSegment.length > 0) {
|
||||
segments.push(currentSegment)
|
||||
}
|
||||
|
||||
return segments
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成摘要的 Prompt
|
||||
*/
|
||||
function buildSummaryPrompt(content: string, lengthLimit: number, locale: string): string {
|
||||
if (locale === 'zh-CN') {
|
||||
return `请用简洁的语言(${lengthLimit}字以内)总结以下对话的主要内容或话题。只输出摘要内容,不要添加任何前缀、解释或引号。
|
||||
|
||||
${content}`
|
||||
}
|
||||
return `Summarize the following conversation concisely (max ${lengthLimit} characters). Output only the summary, no prefix, explanation, or quotes.
|
||||
|
||||
${content}`
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成子摘要的 Prompt
|
||||
*/
|
||||
function buildSubSummaryPrompt(content: string, locale: string): string {
|
||||
if (locale === 'zh-CN') {
|
||||
return `请用一句话(不超过50字)概括以下对话片段的主要内容。只输出摘要内容,不要添加任何前缀、解释或引号。
|
||||
|
||||
${content}`
|
||||
}
|
||||
return `Summarize this conversation segment in one sentence (max 50 characters). Output only the summary, no prefix or quotes.
|
||||
|
||||
${content}`
|
||||
}
|
||||
|
||||
/**
|
||||
* 合并子摘要的 Prompt
|
||||
*/
|
||||
function buildMergePrompt(subSummaries: string[], lengthLimit: number, locale: string): string {
|
||||
const summaryList = subSummaries.map((s, i) => `${i + 1}. ${s}`).join('\n')
|
||||
if (locale === 'zh-CN') {
|
||||
return `以下是一段对话的多个片段摘要,请将它们合并成一个完整的总结(${lengthLimit}字以内)。只输出摘要内容,不要添加任何前缀、解释或引号。
|
||||
|
||||
${summaryList}`
|
||||
}
|
||||
return `Below are summaries of different parts of a conversation. Merge them into one cohesive summary (max ${lengthLimit} characters). Output only the summary, no prefix or quotes.
|
||||
|
||||
${summaryList}`
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成会话摘要
|
||||
*
|
||||
* @param dbSessionId 数据库会话ID(用于访问数据库)
|
||||
* @param chatSessionId 会话索引中的会话ID
|
||||
* @param locale 语言设置
|
||||
* @param forceRegenerate 是否强制重新生成(忽略缓存)
|
||||
* @returns 摘要内容或错误
|
||||
*/
|
||||
export async function generateSessionSummary(
|
||||
dbSessionId: string,
|
||||
chatSessionId: number,
|
||||
locale: string = 'zh-CN',
|
||||
forceRegenerate: boolean = false
|
||||
): Promise<{ success: boolean; summary?: string; error?: string }> {
|
||||
try {
|
||||
// 1. 检查是否已有摘要(除非强制重新生成)
|
||||
if (!forceRegenerate) {
|
||||
const existing = getSessionSummaryFromDb(dbSessionId, chatSessionId)
|
||||
if (existing) {
|
||||
return { success: true, summary: existing }
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 获取会话消息
|
||||
const sessionData = getSessionMessagesForSummary(dbSessionId, chatSessionId)
|
||||
if (!sessionData) {
|
||||
return { success: false, error: '会话不存在或数据库打开失败' }
|
||||
}
|
||||
|
||||
// 3. 检查消息数量
|
||||
if (sessionData.messageCount < MIN_MESSAGE_COUNT) {
|
||||
return {
|
||||
success: false,
|
||||
error:
|
||||
locale === 'zh-CN'
|
||||
? `消息数量少于${MIN_MESSAGE_COUNT}条,无需生成摘要`
|
||||
: `Message count less than ${MIN_MESSAGE_COUNT}, no need to generate summary`,
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 预处理:过滤无意义消息
|
||||
const validMessages = preprocessMessages(sessionData.messages)
|
||||
if (validMessages.length < MIN_MESSAGE_COUNT) {
|
||||
return {
|
||||
success: false,
|
||||
error:
|
||||
locale === 'zh-CN'
|
||||
? `有效消息数量少于${MIN_MESSAGE_COUNT}条,无需生成摘要`
|
||||
: `Valid message count less than ${MIN_MESSAGE_COUNT}, no need to generate summary`,
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 计算摘要长度限制
|
||||
const lengthLimit = getSummaryLengthLimit(validMessages.length)
|
||||
|
||||
// 6. 格式化内容
|
||||
const content = formatMessages(validMessages)
|
||||
|
||||
aiLogger.info(
|
||||
'Summary',
|
||||
`生成会话摘要: sessionId=${chatSessionId}, 原始消息=${sessionData.messageCount}, 有效消息=${validMessages.length}, 内容长度=${content.length}`
|
||||
)
|
||||
|
||||
let summary: string
|
||||
|
||||
// 7. 根据内容长度决定处理策略
|
||||
if (content.length <= SEGMENT_THRESHOLD) {
|
||||
// 短会话:直接生成摘要
|
||||
summary = await generateDirectSummary(content, lengthLimit, locale)
|
||||
} else {
|
||||
// 长会话:Map-Reduce 策略
|
||||
summary = await generateMapReduceSummary(validMessages, lengthLimit, locale)
|
||||
}
|
||||
|
||||
// 8. 后处理:移除引号
|
||||
if ((summary.startsWith('"') && summary.endsWith('"')) || (summary.startsWith('「') && summary.endsWith('」'))) {
|
||||
summary = summary.slice(1, -1)
|
||||
}
|
||||
|
||||
// 如果摘要超过限制的 1.5 倍,进行截断
|
||||
const hardLimit = Math.floor(lengthLimit * 1.5)
|
||||
if (summary.length > hardLimit) {
|
||||
summary = summary.slice(0, hardLimit - 3) + '...'
|
||||
}
|
||||
|
||||
// 9. 保存到数据库
|
||||
saveSessionSummaryToDb(dbSessionId, chatSessionId, summary)
|
||||
|
||||
aiLogger.info('Summary', `摘要生成成功: "${summary.slice(0, 50)}..."`)
|
||||
|
||||
return { success: true, summary }
|
||||
} catch (error) {
|
||||
aiLogger.error('Summary', '摘要生成失败', error)
|
||||
return {
|
||||
success: false,
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接生成摘要(适用于短会话)
|
||||
*/
|
||||
async function generateDirectSummary(content: string, lengthLimit: number, locale: string): Promise<string> {
|
||||
const response = await chat(
|
||||
[
|
||||
{ role: 'system', content: '你是一个对话摘要专家,擅长用简洁的语言总结对话内容。' },
|
||||
{ role: 'user', content: buildSummaryPrompt(content, lengthLimit, locale) },
|
||||
],
|
||||
{
|
||||
temperature: 0.3,
|
||||
maxTokens: 300,
|
||||
}
|
||||
)
|
||||
return response.content.trim()
|
||||
}
|
||||
|
||||
/**
|
||||
* Map-Reduce 策略生成摘要(适用于长会话)
|
||||
*/
|
||||
async function generateMapReduceSummary(
|
||||
messages: Array<{ senderName: string; content: string }>,
|
||||
lengthLimit: number,
|
||||
locale: string
|
||||
): Promise<string> {
|
||||
// 1. Map:分段生成子摘要
|
||||
const segments = splitIntoSegments(messages, MAX_CONTENT_PER_CALL)
|
||||
aiLogger.info('Summary', `长会话分段处理: ${segments.length} 个段落`)
|
||||
|
||||
const subSummaries: string[] = []
|
||||
|
||||
for (let i = 0; i < segments.length; i++) {
|
||||
const segmentContent = formatMessages(segments[i])
|
||||
const response = await chat(
|
||||
[
|
||||
{ role: 'system', content: '你是一个对话摘要专家,擅长用简洁的语言总结对话内容。' },
|
||||
{ role: 'user', content: buildSubSummaryPrompt(segmentContent, locale) },
|
||||
],
|
||||
{
|
||||
temperature: 0.3,
|
||||
maxTokens: 100,
|
||||
}
|
||||
)
|
||||
subSummaries.push(response.content.trim())
|
||||
}
|
||||
|
||||
// 2. Reduce:合并子摘要
|
||||
if (subSummaries.length === 1) {
|
||||
return subSummaries[0]
|
||||
}
|
||||
|
||||
const mergeResponse = await chat(
|
||||
[
|
||||
{ role: 'system', content: '你是一个对话摘要专家,擅长将多个摘要合并成一个连贯的总结。' },
|
||||
{ role: 'user', content: buildMergePrompt(subSummaries, lengthLimit, locale) },
|
||||
],
|
||||
{
|
||||
temperature: 0.3,
|
||||
maxTokens: 300,
|
||||
}
|
||||
)
|
||||
|
||||
return mergeResponse.content.trim()
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量生成会话摘要
|
||||
*
|
||||
* @param dbSessionId 数据库会话ID
|
||||
* @param chatSessionIds 会话ID列表
|
||||
* @param locale 语言设置
|
||||
* @param onProgress 进度回调
|
||||
* @returns 生成结果
|
||||
*/
|
||||
export async function generateSessionSummaries(
|
||||
dbSessionId: string,
|
||||
chatSessionIds: number[],
|
||||
locale: string = 'zh-CN',
|
||||
onProgress?: (current: number, total: number) => void
|
||||
): Promise<{ success: number; failed: number; skipped: number }> {
|
||||
let success = 0
|
||||
let failed = 0
|
||||
let skipped = 0
|
||||
|
||||
for (let i = 0; i < chatSessionIds.length; i++) {
|
||||
const chatSessionId = chatSessionIds[i]
|
||||
|
||||
const result = await generateSessionSummary(dbSessionId, chatSessionId, locale, false)
|
||||
|
||||
if (result.success) {
|
||||
success++
|
||||
} else if (result.error?.includes('少于') || result.error?.includes('less than')) {
|
||||
skipped++
|
||||
} else {
|
||||
failed++
|
||||
}
|
||||
|
||||
if (onProgress) {
|
||||
onProgress(i + 1, chatSessionIds.length)
|
||||
}
|
||||
}
|
||||
|
||||
return { success, failed, skipped }
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user