Files
ChatLab/electron/main/ai/summary/index.ts
2026-01-25 18:54:27 +08:00

457 lines
13 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* 会话摘要生成服务
*
* 利用 LLM 为会话生成摘要
* - 智能预处理:过滤无意义内容(纯表情、单字回复等)
* - 根据消息数量智能调整摘要长度
* - 超长会话采用 Map-Reduce 策略
*/
import Database from 'better-sqlite3'
import { chat } from '../llm'
import { getDbPath, openDatabase } from '../../database/core'
import { aiLogger } from '../logger'
/** 最小消息数阈值(少于此数量不生成摘要) */
const MIN_MESSAGE_COUNT = 3
/** 单次 LLM 调用的最大内容字符数(约 2000 tokens留安全余量 */
const MAX_CONTENT_PER_CALL = 8000
/** 需要分段处理的阈值 */
const SEGMENT_THRESHOLD = 8000
// ==================== 数据库操作函数(独立于 Worker ====================
interface SessionMessagesResult {
messageCount: number
messages: Array<{
senderName: string
content: string | null
}>
}
/**
* 获取会话消息(主进程版本,使用 database/core
*/
function getSessionMessagesForSummary(
dbSessionId: string,
chatSessionId: number,
limit: number = 500
): SessionMessagesResult | null {
const db = openDatabase(dbSessionId, true)
if (!db) {
aiLogger.error('Summary', `数据库打开失败: ${dbSessionId}`)
return null
}
try {
// 获取会话消息
const messagesSql = `
SELECT
COALESCE(mb.group_nickname, mb.account_name, mb.platform_id) as senderName,
m.content
FROM message_context mc
JOIN message m ON m.id = mc.message_id
JOIN member mb ON mb.id = m.sender_id
WHERE mc.session_id = ?
ORDER BY m.ts ASC
LIMIT ?
`
const messages = db.prepare(messagesSql).all(chatSessionId, limit) as Array<{
senderName: string
content: string | null
}>
return {
messageCount: messages.length,
messages,
}
} catch (error) {
aiLogger.error('Summary', `获取会话消息失败: ${error}`)
return null
}
}
/**
* 保存会话摘要(主进程版本)
*/
function saveSessionSummaryToDb(dbSessionId: string, chatSessionId: number, summary: string): void {
const dbPath = getDbPath(dbSessionId)
const db = new Database(dbPath)
try {
db.prepare('UPDATE chat_session SET summary = ? WHERE id = ?').run(summary, chatSessionId)
} finally {
db.close()
}
}
/**
* 获取会话摘要(主进程版本)
*/
function getSessionSummaryFromDb(dbSessionId: string, chatSessionId: number): string | null {
const db = openDatabase(dbSessionId, true)
if (!db) {
return null
}
try {
const result = db.prepare('SELECT summary FROM chat_session WHERE id = ?').get(chatSessionId) as
| { summary: string | null }
| undefined
return result?.summary || null
} catch {
return null
}
}
/**
* 根据消息数量计算摘要长度限制
* - 3-10 条消息50 字
* - 11-30 条消息80 字
* - 31-100 条消息120 字
* - 100+ 条消息200 字
*/
function getSummaryLengthLimit(messageCount: number): number {
if (messageCount <= 10) return 50
if (messageCount <= 30) return 80
if (messageCount <= 100) return 120
return 200
}
/**
* 判断消息是否有意义(用于过滤)
*/
function isValidMessage(content: string): boolean {
const trimmed = content.trim()
// 过滤空内容
if (!trimmed) return false
// 过滤单字/双字无意义回复
if (trimmed.length <= 2) {
// 允许一些有意义的短词
const meaningfulShort = ['好的', '不是', '是的', '可以', '不行', '好吧', '明白', '知道', '同意']
if (!meaningfulShort.includes(trimmed)) return false
}
// 过滤纯表情消息
const emojiOnlyPattern = /^[\p{Emoji}\s[\]()]+$/u
if (emojiOnlyPattern.test(trimmed)) return false
// 过滤占位符文本
const placeholders = ['[图片]', '[语音]', '[视频]', '[文件]', '[表情]', '[动画表情]', '[位置]', '[名片]', '[红包]', '[转账]', '[撤回消息]']
if (placeholders.some((p) => trimmed === p)) return false
// 过滤系统消息(入群、退群等)
const systemPatterns = [
/^.*邀请.*加入了群聊$/,
/^.*退出了群聊$/,
/^.*撤回了一条消息$/,
/^你撤回了一条消息$/,
]
if (systemPatterns.some((p) => p.test(trimmed))) return false
return true
}
/**
* 预处理消息:过滤无意义内容
*/
function preprocessMessages(
messages: Array<{ senderName: string; content: string | null }>
): Array<{ senderName: string; content: string }> {
return messages
.filter((m) => m.content && isValidMessage(m.content))
.map((m) => ({ senderName: m.senderName, content: m.content!.trim() }))
}
/**
* 格式化消息为文本
*/
function formatMessages(messages: Array<{ senderName: string; content: string }>): string {
return messages.map((m) => `${m.senderName}: ${m.content}`).join('\n')
}
/**
* 将消息分成多个段落
*/
function splitIntoSegments(
messages: Array<{ senderName: string; content: string }>,
maxCharsPerSegment: number
): Array<Array<{ senderName: string; content: string }>> {
const segments: Array<Array<{ senderName: string; content: string }>> = []
let currentSegment: Array<{ senderName: string; content: string }> = []
let currentLength = 0
for (const msg of messages) {
const msgLength = msg.senderName.length + msg.content.length + 3 // "name: content\n"
if (currentLength + msgLength > maxCharsPerSegment && currentSegment.length > 0) {
segments.push(currentSegment)
currentSegment = []
currentLength = 0
}
currentSegment.push(msg)
currentLength += msgLength
}
if (currentSegment.length > 0) {
segments.push(currentSegment)
}
return segments
}
/**
* 生成摘要的 Prompt
*/
function buildSummaryPrompt(content: string, lengthLimit: number, locale: string): string {
if (locale === 'zh-CN') {
return `请用简洁的语言(${lengthLimit}字以内)总结以下对话的主要内容或话题。只输出摘要内容,不要添加任何前缀、解释或引号。
${content}`
}
return `Summarize the following conversation concisely (max ${lengthLimit} characters). Output only the summary, no prefix, explanation, or quotes.
${content}`
}
/**
* 生成子摘要的 Prompt
*/
function buildSubSummaryPrompt(content: string, locale: string): string {
if (locale === 'zh-CN') {
return `请用一句话不超过50字概括以下对话片段的主要内容。只输出摘要内容不要添加任何前缀、解释或引号。
${content}`
}
return `Summarize this conversation segment in one sentence (max 50 characters). Output only the summary, no prefix or quotes.
${content}`
}
/**
* 合并子摘要的 Prompt
*/
function buildMergePrompt(subSummaries: string[], lengthLimit: number, locale: string): string {
const summaryList = subSummaries.map((s, i) => `${i + 1}. ${s}`).join('\n')
if (locale === 'zh-CN') {
return `以下是一段对话的多个片段摘要,请将它们合并成一个完整的总结(${lengthLimit}字以内)。只输出摘要内容,不要添加任何前缀、解释或引号。
${summaryList}`
}
return `Below are summaries of different parts of a conversation. Merge them into one cohesive summary (max ${lengthLimit} characters). Output only the summary, no prefix or quotes.
${summaryList}`
}
/**
* 生成会话摘要
*
* @param dbSessionId 数据库会话ID用于访问数据库
* @param chatSessionId 会话索引中的会话ID
* @param locale 语言设置
* @param forceRegenerate 是否强制重新生成(忽略缓存)
* @returns 摘要内容或错误
*/
export async function generateSessionSummary(
dbSessionId: string,
chatSessionId: number,
locale: string = 'zh-CN',
forceRegenerate: boolean = false
): Promise<{ success: boolean; summary?: string; error?: string }> {
try {
// 1. 检查是否已有摘要(除非强制重新生成)
if (!forceRegenerate) {
const existing = getSessionSummaryFromDb(dbSessionId, chatSessionId)
if (existing) {
return { success: true, summary: existing }
}
}
// 2. 获取会话消息
const sessionData = getSessionMessagesForSummary(dbSessionId, chatSessionId)
if (!sessionData) {
return { success: false, error: '会话不存在或数据库打开失败' }
}
// 3. 检查消息数量
if (sessionData.messageCount < MIN_MESSAGE_COUNT) {
return {
success: false,
error:
locale === 'zh-CN'
? `消息数量少于${MIN_MESSAGE_COUNT}条,无需生成摘要`
: `Message count less than ${MIN_MESSAGE_COUNT}, no need to generate summary`,
}
}
// 4. 预处理:过滤无意义消息
const validMessages = preprocessMessages(sessionData.messages)
if (validMessages.length < MIN_MESSAGE_COUNT) {
return {
success: false,
error:
locale === 'zh-CN'
? `有效消息数量少于${MIN_MESSAGE_COUNT}条,无需生成摘要`
: `Valid message count less than ${MIN_MESSAGE_COUNT}, no need to generate summary`,
}
}
// 5. 计算摘要长度限制
const lengthLimit = getSummaryLengthLimit(validMessages.length)
// 6. 格式化内容
const content = formatMessages(validMessages)
aiLogger.info(
'Summary',
`生成会话摘要: sessionId=${chatSessionId}, 原始消息=${sessionData.messageCount}, 有效消息=${validMessages.length}, 内容长度=${content.length}`
)
let summary: string
// 7. 根据内容长度决定处理策略
if (content.length <= SEGMENT_THRESHOLD) {
// 短会话:直接生成摘要
summary = await generateDirectSummary(content, lengthLimit, locale)
} else {
// 长会话Map-Reduce 策略
summary = await generateMapReduceSummary(validMessages, lengthLimit, locale)
}
// 8. 后处理:移除引号
if ((summary.startsWith('"') && summary.endsWith('"')) || (summary.startsWith('「') && summary.endsWith('」'))) {
summary = summary.slice(1, -1)
}
// 如果摘要超过限制的 1.5 倍,进行截断
const hardLimit = Math.floor(lengthLimit * 1.5)
if (summary.length > hardLimit) {
summary = summary.slice(0, hardLimit - 3) + '...'
}
// 9. 保存到数据库
saveSessionSummaryToDb(dbSessionId, chatSessionId, summary)
aiLogger.info('Summary', `摘要生成成功: "${summary.slice(0, 50)}..."`)
return { success: true, summary }
} catch (error) {
aiLogger.error('Summary', '摘要生成失败', error)
return {
success: false,
error: error instanceof Error ? error.message : String(error),
}
}
}
/**
* 直接生成摘要(适用于短会话)
*/
async function generateDirectSummary(content: string, lengthLimit: number, locale: string): Promise<string> {
const response = await chat(
[
{ role: 'system', content: '你是一个对话摘要专家,擅长用简洁的语言总结对话内容。' },
{ role: 'user', content: buildSummaryPrompt(content, lengthLimit, locale) },
],
{
temperature: 0.3,
maxTokens: 300,
}
)
return response.content.trim()
}
/**
* Map-Reduce 策略生成摘要(适用于长会话)
*/
async function generateMapReduceSummary(
messages: Array<{ senderName: string; content: string }>,
lengthLimit: number,
locale: string
): Promise<string> {
// 1. Map分段生成子摘要
const segments = splitIntoSegments(messages, MAX_CONTENT_PER_CALL)
aiLogger.info('Summary', `长会话分段处理: ${segments.length} 个段落`)
const subSummaries: string[] = []
for (let i = 0; i < segments.length; i++) {
const segmentContent = formatMessages(segments[i])
const response = await chat(
[
{ role: 'system', content: '你是一个对话摘要专家,擅长用简洁的语言总结对话内容。' },
{ role: 'user', content: buildSubSummaryPrompt(segmentContent, locale) },
],
{
temperature: 0.3,
maxTokens: 100,
}
)
subSummaries.push(response.content.trim())
}
// 2. Reduce合并子摘要
if (subSummaries.length === 1) {
return subSummaries[0]
}
const mergeResponse = await chat(
[
{ role: 'system', content: '你是一个对话摘要专家,擅长将多个摘要合并成一个连贯的总结。' },
{ role: 'user', content: buildMergePrompt(subSummaries, lengthLimit, locale) },
],
{
temperature: 0.3,
maxTokens: 300,
}
)
return mergeResponse.content.trim()
}
/**
* 批量生成会话摘要
*
* @param dbSessionId 数据库会话ID
* @param chatSessionIds 会话ID列表
* @param locale 语言设置
* @param onProgress 进度回调
* @returns 生成结果
*/
export async function generateSessionSummaries(
dbSessionId: string,
chatSessionIds: number[],
locale: string = 'zh-CN',
onProgress?: (current: number, total: number) => void
): Promise<{ success: number; failed: number; skipped: number }> {
let success = 0
let failed = 0
let skipped = 0
for (let i = 0; i < chatSessionIds.length; i++) {
const chatSessionId = chatSessionIds[i]
const result = await generateSessionSummary(dbSessionId, chatSessionId, locale, false)
if (result.success) {
success++
} else if (result.error?.includes('少于') || result.error?.includes('less than')) {
skipped++
} else {
failed++
}
if (onProgress) {
onProgress(i + 1, chatSessionIds.length)
}
}
return { success, failed, skipped }
}