From 206beb02e634658495f7b049c62a35b60c0680ea Mon Sep 17 00:00:00 2001 From: digua Date: Mon, 2 Feb 2026 23:51:24 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84session=E6=9F=A5?= =?UTF-8?q?=E8=AF=A2=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- electron/main/worker/query/index.ts | 8 + electron/main/worker/query/session.ts | 1001 ----------------- electron/main/worker/query/session/aiTools.ts | 205 ++++ electron/main/worker/query/session/core.ts | 39 + electron/main/worker/query/session/export.ts | 381 +++++++ electron/main/worker/query/session/filter.ts | 447 ++++++++ electron/main/worker/query/session/index.ts | 41 + .../main/worker/query/session/sessionIndex.ts | 349 ++++++ electron/main/worker/query/session/types.ts | 163 +++ 9 files changed, 1633 insertions(+), 1001 deletions(-) delete mode 100644 electron/main/worker/query/session.ts create mode 100644 electron/main/worker/query/session/aiTools.ts create mode 100644 electron/main/worker/query/session/core.ts create mode 100644 electron/main/worker/query/session/export.ts create mode 100644 electron/main/worker/query/session/filter.ts create mode 100644 electron/main/worker/query/session/index.ts create mode 100644 electron/main/worker/query/session/sessionIndex.ts create mode 100644 electron/main/worker/query/session/types.ts diff --git a/electron/main/worker/query/index.ts b/electron/main/worker/query/index.ts index 812d018..f586d72 100644 --- a/electron/main/worker/query/index.ts +++ b/electron/main/worker/query/index.ts @@ -68,12 +68,16 @@ export { getSessionStats, updateSessionGapThreshold, getSessions, + saveSessionSummary, + getSessionSummary, searchSessions, getSessionMessages, DEFAULT_SESSION_GAP_THRESHOLD, // 自定义筛选 filterMessagesWithContext, getMultipleSessionsMessages, + // 导出功能 + exportFilterResultToFile, } from './session' export type { ChatSessionItem, @@ -82,6 +86,10 @@ export type { ContextBlock, FilterResult, FilterMessage, + PaginationInfo, + FilterResultWithPagination, + ExportFilterParams, + ExportProgress, } from './session' // NLP 查询 diff --git a/electron/main/worker/query/session.ts b/electron/main/worker/query/session.ts deleted file mode 100644 index bc11010..0000000 --- a/electron/main/worker/query/session.ts +++ /dev/null @@ -1,1001 +0,0 @@ -/** - * 会话索引模块 - * 提供基于时间间隔的会话切分算法 - */ - -import Database from 'better-sqlite3' -import { getDbPath, closeDatabase } from '../core' - -/** 默认会话切分阈值:30分钟(秒) */ -export const DEFAULT_SESSION_GAP_THRESHOLD = 1800 - -/** - * 打开数据库(可写模式,不使用缓存) - * 会话索引需要写入数据 - */ -function openWritableDatabase(sessionId: string): Database.Database | null { - const dbPath = getDbPath(sessionId) - try { - const db = new Database(dbPath) - db.pragma('journal_mode = WAL') - return db - } catch { - return null - } -} - -/** - * 打开数据库(只读模式,不使用缓存) - */ -function openReadonlyDatabase(sessionId: string): Database.Database | null { - const dbPath = getDbPath(sessionId) - try { - const db = new Database(dbPath, { readonly: true }) - db.pragma('journal_mode = WAL') - return db - } catch { - return null - } -} - -/** - * 生成会话索引 - * 使用 Gap-based 算法,根据消息时间间隔自动切分会话 - * - * @param sessionId 数据库会话ID - * @param gapThreshold 时间间隔阈值(秒),默认 1800(30分钟) - * @param onProgress 进度回调 - * @returns 生成的会话数量 - */ -export function generateSessions( - sessionId: string, - gapThreshold: number = DEFAULT_SESSION_GAP_THRESHOLD, - onProgress?: (current: number, total: number) => void -): number { - // 先关闭缓存的只读连接 - closeDatabase(sessionId) - - const db = openWritableDatabase(sessionId) - if (!db) { - throw new Error(`无法打开数据库: ${sessionId}`) - } - - try { - // 获取消息总数 - const countResult = db.prepare('SELECT COUNT(*) as count FROM message').get() as { count: number } - const totalMessages = countResult.count - - if (totalMessages === 0) { - return 0 - } - - // 清空已有的会话数据 - clearSessionsInternal(db) - - // 使用窗口函数计算会话边界 - // 步骤1:为每条消息计算与前一条的时间差,标记新会话起点 - const sessionMarkSQL = ` - WITH message_ordered AS ( - SELECT - id, - ts, - LAG(ts) OVER (ORDER BY ts, id) AS prev_ts - FROM message - ), - session_marks AS ( - SELECT - id, - ts, - CASE - WHEN prev_ts IS NULL OR (ts - prev_ts) > ? THEN 1 - ELSE 0 - END AS is_new_session - FROM message_ordered - ), - session_ids AS ( - SELECT - id, - ts, - SUM(is_new_session) OVER (ORDER BY ts, id) AS session_num - FROM session_marks - ) - SELECT id, ts, session_num FROM session_ids - ` - - const messages = db.prepare(sessionMarkSQL).all(gapThreshold) as Array<{ - id: number - ts: number - session_num: number - }> - - if (messages.length === 0) { - return 0 - } - - // 步骤2:计算每个会话的统计信息 - const sessionMap = new Map() - - for (const msg of messages) { - const session = sessionMap.get(msg.session_num) - if (!session) { - sessionMap.set(msg.session_num, { - startTs: msg.ts, - endTs: msg.ts, - messageIds: [msg.id], - }) - } else { - session.endTs = msg.ts - session.messageIds.push(msg.id) - } - } - - // 步骤3:批量写入 chat_session 和 message_context 表 - const insertSession = db.prepare(` - INSERT INTO chat_session (start_ts, end_ts, message_count, is_manual, summary) - VALUES (?, ?, ?, 0, NULL) - `) - - const insertContext = db.prepare(` - INSERT INTO message_context (message_id, session_id, topic_id) - VALUES (?, ?, NULL) - `) - - // 开始事务 - const transaction = db.transaction(() => { - let processedCount = 0 - const totalSessions = sessionMap.size - - for (const [, sessionData] of sessionMap) { - // 插入会话记录 - const result = insertSession.run(sessionData.startTs, sessionData.endTs, sessionData.messageIds.length) - const newSessionId = result.lastInsertRowid as number - - // 批量插入消息上下文 - for (const messageId of sessionData.messageIds) { - insertContext.run(messageId, newSessionId) - } - - processedCount++ - if (onProgress && processedCount % 100 === 0) { - onProgress(processedCount, totalSessions) - } - } - - return totalSessions - }) - - const sessionCount = transaction() - - // 最终进度回调 - if (onProgress) { - onProgress(sessionCount, sessionCount) - } - - return sessionCount - } finally { - db.close() - } -} - -/** - * 清空会话索引数据 - * @param sessionId 数据库会话ID - */ -export function clearSessions(sessionId: string): void { - // 先关闭缓存的只读连接 - closeDatabase(sessionId) - - const db = openWritableDatabase(sessionId) - if (!db) { - throw new Error(`无法打开数据库: ${sessionId}`) - } - - try { - clearSessionsInternal(db) - } finally { - db.close() - } -} - -/** - * 内部清空会话数据函数 - */ -function clearSessionsInternal(db: Database.Database): void { - db.exec('DELETE FROM message_context') - db.exec('DELETE FROM chat_session') -} - -/** - * 检查是否已生成会话索引 - * @param sessionId 数据库会话ID - * @returns 是否有会话索引 - */ -export function hasSessionIndex(sessionId: string): boolean { - const db = openReadonlyDatabase(sessionId) - if (!db) { - return false - } - - try { - // 检查 chat_session 表是否存在且有数据 - const result = db.prepare('SELECT COUNT(*) as count FROM chat_session').get() as { count: number } - return result.count > 0 - } catch { - // 表可能不存在 - return false - } finally { - db.close() - } -} - -/** - * 获取会话索引统计信息 - * @param sessionId 数据库会话ID - */ -export function getSessionStats(sessionId: string): { - sessionCount: number - hasIndex: boolean - gapThreshold: number -} { - const db = openReadonlyDatabase(sessionId) - if (!db) { - return { sessionCount: 0, hasIndex: false, gapThreshold: DEFAULT_SESSION_GAP_THRESHOLD } - } - - try { - // 获取会话数量 - let sessionCount = 0 - try { - const countResult = db.prepare('SELECT COUNT(*) as count FROM chat_session').get() as { count: number } - sessionCount = countResult.count - } catch { - // 表可能不存在 - } - - // 获取配置的阈值 - let gapThreshold = DEFAULT_SESSION_GAP_THRESHOLD - try { - const metaResult = db.prepare('SELECT session_gap_threshold FROM meta LIMIT 1').get() as - | { - session_gap_threshold: number | null - } - | undefined - if (metaResult?.session_gap_threshold) { - gapThreshold = metaResult.session_gap_threshold - } - } catch { - // 字段可能不存在 - } - - return { - sessionCount, - hasIndex: sessionCount > 0, - gapThreshold, - } - } finally { - db.close() - } -} - -/** - * 更新单个聊天的会话切分阈值 - * @param sessionId 数据库会话ID - * @param gapThreshold 阈值(秒),null 表示使用全局配置 - */ -export function updateSessionGapThreshold(sessionId: string, gapThreshold: number | null): void { - // 先关闭缓存的只读连接 - closeDatabase(sessionId) - - const db = openWritableDatabase(sessionId) - if (!db) { - throw new Error(`无法打开数据库: ${sessionId}`) - } - - try { - db.prepare('UPDATE meta SET session_gap_threshold = ?').run(gapThreshold) - } finally { - db.close() - } -} - -/** - * 会话列表项类型 - */ -export interface ChatSessionItem { - id: number - startTs: number - endTs: number - messageCount: number - firstMessageId: number - /** 会话摘要(如果有) */ - summary?: string | null -} - -/** - * 获取会话列表(用于时间线导航) - * @param sessionId 数据库会话ID - * @returns 会话列表,按时间排序 - */ -export function getSessions(sessionId: string): ChatSessionItem[] { - const db = openReadonlyDatabase(sessionId) - if (!db) { - return [] - } - - try { - // 查询会话列表,同时获取每个会话的首条消息 ID 和摘要 - const sql = ` - SELECT - cs.id, - cs.start_ts as startTs, - cs.end_ts as endTs, - cs.message_count as messageCount, - cs.summary, - (SELECT mc.message_id FROM message_context mc WHERE mc.session_id = cs.id ORDER BY mc.message_id LIMIT 1) as firstMessageId - FROM chat_session cs - ORDER BY cs.start_ts ASC - ` - const sessions = db.prepare(sql).all() as ChatSessionItem[] - return sessions - } catch { - return [] - } finally { - db.close() - } -} - -// ==================== 会话摘要相关函数 ==================== - -/** - * 保存会话摘要 - * @param sessionId 数据库会话ID - * @param chatSessionId 会话索引中的会话ID - * @param summary 摘要内容 - */ -export function saveSessionSummary(sessionId: string, chatSessionId: number, summary: string): void { - // 先关闭缓存的只读连接 - closeDatabase(sessionId) - - const db = openWritableDatabase(sessionId) - if (!db) { - throw new Error(`无法打开数据库: ${sessionId}`) - } - - try { - db.prepare('UPDATE chat_session SET summary = ? WHERE id = ?').run(summary, chatSessionId) - } finally { - db.close() - } -} - -/** - * 获取会话摘要 - * @param sessionId 数据库会话ID - * @param chatSessionId 会话索引中的会话ID - * @returns 摘要内容 - */ -export function getSessionSummary(sessionId: string, chatSessionId: number): string | null { - const db = openReadonlyDatabase(sessionId) - if (!db) { - return null - } - - try { - const result = db.prepare('SELECT summary FROM chat_session WHERE id = ?').get(chatSessionId) as - | { summary: string | null } - | undefined - return result?.summary || null - } catch { - return null - } finally { - db.close() - } -} - -// ==================== AI 工具专用查询函数 ==================== - -/** - * 会话搜索结果项类型(用于 AI 工具) - */ -export interface SessionSearchResultItem { - id: number - startTs: number - endTs: number - messageCount: number - /** 是否为完整会话(消息数 <= 预览条数) */ - isComplete: boolean - /** 预览消息列表 */ - previewMessages: Array<{ - id: number - senderName: string - content: string | null - timestamp: number - }> -} - -/** - * 搜索会话(用于 AI 工具) - * 支持按关键词和时间范围筛选会话 - * - * @param sessionId 数据库会话ID - * @param keywords 关键词列表(可选,OR 逻辑匹配) - * @param timeFilter 时间过滤器(可选) - * @param limit 返回数量限制,默认 20 - * @param previewCount 预览消息数量,默认 5 - * @returns 匹配的会话列表 - */ -export function searchSessions( - sessionId: string, - keywords?: string[], - timeFilter?: { startTs: number; endTs: number }, - limit: number = 20, - previewCount: number = 5 -): SessionSearchResultItem[] { - const db = openReadonlyDatabase(sessionId) - if (!db) { - return [] - } - - try { - // 1. 构建会话查询 SQL - let sessionSql = ` - SELECT - cs.id, - cs.start_ts as startTs, - cs.end_ts as endTs, - cs.message_count as messageCount - FROM chat_session cs - WHERE 1=1 - ` - const params: unknown[] = [] - - // 时间范围过滤 - if (timeFilter) { - sessionSql += ` AND cs.start_ts >= ? AND cs.end_ts <= ?` - params.push(timeFilter.startTs, timeFilter.endTs) - } - - // 关键词过滤:只返回包含关键词的会话 - if (keywords && keywords.length > 0) { - const keywordConditions = keywords.map(() => `m.content LIKE ?`).join(' OR ') - sessionSql += ` - AND cs.id IN ( - SELECT DISTINCT mc.session_id - FROM message_context mc - JOIN message m ON m.id = mc.message_id - WHERE (${keywordConditions}) - ) - ` - for (const kw of keywords) { - params.push(`%${kw}%`) - } - } - - sessionSql += ` ORDER BY cs.start_ts DESC LIMIT ?` - params.push(limit) - - const sessions = db.prepare(sessionSql).all(...params) as Array<{ - id: number - startTs: number - endTs: number - messageCount: number - }> - - // 2. 为每个会话获取预览消息 - const previewSql = ` - SELECT - m.id, - COALESCE(mb.group_nickname, mb.account_name, mb.platform_id) as senderName, - m.content, - m.ts as timestamp - FROM message_context mc - JOIN message m ON m.id = mc.message_id - JOIN member mb ON mb.id = m.sender_id - WHERE mc.session_id = ? - ORDER BY m.ts ASC - LIMIT ? - ` - - const results: SessionSearchResultItem[] = [] - for (const session of sessions) { - const previewMessages = db.prepare(previewSql).all(session.id, previewCount) as Array<{ - id: number - senderName: string - content: string | null - timestamp: number - }> - - results.push({ - id: session.id, - startTs: session.startTs, - endTs: session.endTs, - messageCount: session.messageCount, - isComplete: session.messageCount <= previewCount, - previewMessages, - }) - } - - return results - } catch (error) { - console.error('searchSessions error:', error) - return [] - } finally { - db.close() - } -} - -/** - * 会话消息结果类型(用于 AI 工具) - */ -export interface SessionMessagesResult { - sessionId: number - startTs: number - endTs: number - messageCount: number - returnedCount: number - /** 参与者列表 */ - participants: string[] - /** 消息列表 */ - messages: Array<{ - id: number - senderName: string - content: string | null - timestamp: number - }> -} - -/** - * 获取会话的完整消息(用于 AI 工具) - * - * @param sessionId 数据库会话ID - * @param chatSessionId 会话索引中的会话ID - * @param limit 返回数量限制,默认 500 - * @returns 会话的完整消息 - */ -export function getSessionMessages( - sessionId: string, - chatSessionId: number, - limit: number = 500 -): SessionMessagesResult | null { - const db = openReadonlyDatabase(sessionId) - if (!db) { - return null - } - - try { - // 1. 获取会话基本信息 - const sessionSql = ` - SELECT - id, - start_ts as startTs, - end_ts as endTs, - message_count as messageCount - FROM chat_session - WHERE id = ? - ` - const session = db.prepare(sessionSql).get(chatSessionId) as - | { - id: number - startTs: number - endTs: number - messageCount: number - } - | undefined - - if (!session) { - db.close() - return null - } - - // 2. 获取会话消息 - const messagesSql = ` - SELECT - m.id, - COALESCE(mb.group_nickname, mb.account_name, mb.platform_id) as senderName, - m.content, - m.ts as timestamp - FROM message_context mc - JOIN message m ON m.id = mc.message_id - JOIN member mb ON mb.id = m.sender_id - WHERE mc.session_id = ? - ORDER BY m.ts ASC - LIMIT ? - ` - const messages = db.prepare(messagesSql).all(chatSessionId, limit) as Array<{ - id: number - senderName: string - content: string | null - timestamp: number - }> - - // 3. 统计参与者 - const participantsSet = new Set() - for (const msg of messages) { - participantsSet.add(msg.senderName) - } - - return { - sessionId: session.id, - startTs: session.startTs, - endTs: session.endTs, - messageCount: session.messageCount, - returnedCount: messages.length, - participants: Array.from(participantsSet), - messages, - } - } catch (error) { - console.error('getSessionMessages error:', error) - return null - } finally { - db.close() - } -} - -// ==================== 自定义筛选专用函数 ==================== - -/** - * 自定义筛选消息类型(完整信息,兼容 MessageList 组件) - */ -export interface FilterMessage { - id: number - senderName: string - senderPlatformId: string - senderAliases: string[] - senderAvatar: string | null - content: string - timestamp: number - type: number - replyToMessageId: string | null - replyToContent: string | null - replyToSenderName: string | null - /** 是否为命中的消息(关键词匹配) */ - isHit: boolean -} - -/** - * 上下文块类型(用于自定义筛选) - */ -export interface ContextBlock { - /** 块的时间范围 */ - startTs: number - endTs: number - /** 消息列表 */ - messages: FilterMessage[] - /** 命中的消息数量 */ - hitCount: number -} - -/** - * 筛选结果类型 - */ -export interface FilterResult { - /** 上下文块列表 */ - blocks: ContextBlock[] - /** 统计信息 */ - stats: { - /** 总消息数 */ - totalMessages: number - /** 命中的消息数 */ - hitMessages: number - /** 总字符数 */ - totalChars: number - } -} - -/** - * 按条件筛选消息并扩充上下文 - * - * 核心算法: - * 1. 先搜索匹配条件的消息,获取消息ID列表 - * 2. 为每个命中消息向前后各扩展 contextSize 条消息 - * 3. 合并重叠/相邻的消息范围 - * 4. 按合并后的范围分块返回消息 - * - * @param sessionId 数据库会话ID - * @param keywords 关键词列表(可选,OR 逻辑) - * @param timeFilter 时间过滤器(可选) - * @param senderIds 发送者ID列表(可选) - * @param contextSize 上下文扩展数量(前后各多少条) - * @returns 筛选结果 - */ -export function filterMessagesWithContext( - sessionId: string, - keywords?: string[], - timeFilter?: { startTs: number; endTs: number }, - senderIds?: number[], - contextSize: number = 10 -): FilterResult { - const db = openReadonlyDatabase(sessionId) - if (!db) { - return { blocks: [], stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 } } - } - - try { - // 1. 构建基础消息查询(完整信息),按时间排序 - // 使用 LEFT JOIN 获取回复消息的信息 - const allMessagesSql = ` - SELECT - msg.id, - msg.ts, - COALESCE(m.group_nickname, m.account_name, m.platform_id) as senderName, - m.platform_id as senderPlatformId, - COALESCE(m.aliases, '[]') as senderAliasesJson, - m.avatar as senderAvatar, - msg.content, - msg.type, - msg.reply_to_message_id as replyToMessageId, - reply_msg.content as replyToContent, - COALESCE(reply_m.group_nickname, reply_m.account_name, reply_m.platform_id) as replyToSenderName, - msg.sender_id as senderId - FROM message msg - JOIN member m ON msg.sender_id = m.id - LEFT JOIN message reply_msg ON msg.reply_to_message_id = reply_msg.platform_message_id - LEFT JOIN member reply_m ON reply_msg.sender_id = reply_m.id - ${timeFilter ? 'WHERE msg.ts >= ? AND msg.ts <= ?' : ''} - ORDER BY msg.ts ASC, msg.id ASC - ` - - const params: unknown[] = [] - if (timeFilter) { - params.push(timeFilter.startTs, timeFilter.endTs) - } - - const allMessages = db.prepare(allMessagesSql).all(...params) as Array<{ - id: number - ts: number - senderName: string - senderPlatformId: string - senderAliasesJson: string - senderAvatar: string | null - content: string | null - type: number - replyToMessageId: string | null - replyToContent: string | null - replyToSenderName: string | null - senderId: number - }> - - if (allMessages.length === 0) { - return { blocks: [], stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 } } - } - - // 2. 标记命中的消息 - const hitIndexes: number[] = [] - for (let i = 0; i < allMessages.length; i++) { - const msg = allMessages[i] - let isHit = true - - // 关键词匹配(OR 逻辑) - if (keywords && keywords.length > 0) { - const content = (msg.content || '').toLowerCase() - isHit = keywords.some((kw) => content.includes(kw.toLowerCase())) - } - - // 发送者匹配 - if (isHit && senderIds && senderIds.length > 0) { - isHit = senderIds.includes(msg.senderId) - } - - if (isHit) { - hitIndexes.push(i) - } - } - - if (hitIndexes.length === 0) { - return { blocks: [], stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 } } - } - - // 3. 扩展上下文并合并重叠范围 - const ranges: Array<{ start: number; end: number; hitIndexes: number[] }> = [] - - for (const hitIndex of hitIndexes) { - const start = Math.max(0, hitIndex - contextSize) - const end = Math.min(allMessages.length - 1, hitIndex + contextSize) - - // 检查是否能与前一个范围合并 - if (ranges.length > 0) { - const lastRange = ranges[ranges.length - 1] - // 如果当前范围的 start <= 上一个范围的 end + 1,则合并 - if (start <= lastRange.end + 1) { - lastRange.end = Math.max(lastRange.end, end) - lastRange.hitIndexes.push(hitIndex) - continue - } - } - - ranges.push({ start, end, hitIndexes: [hitIndex] }) - } - - // 4. 按范围构建上下文块 - const blocks: ContextBlock[] = [] - let totalMessages = 0 - let totalChars = 0 - - for (const range of ranges) { - const hitIndexSet = new Set(range.hitIndexes) - const blockMessages: FilterMessage[] = [] - - for (let i = range.start; i <= range.end; i++) { - const msg = allMessages[i] - const isHit = hitIndexSet.has(i) - - // 解析别名 JSON - let senderAliases: string[] = [] - try { - senderAliases = JSON.parse(msg.senderAliasesJson || '[]') - } catch { - senderAliases = [] - } - - blockMessages.push({ - id: msg.id, - senderName: msg.senderName, - senderPlatformId: msg.senderPlatformId, - senderAliases, - senderAvatar: msg.senderAvatar, - content: msg.content || '', - timestamp: msg.ts, - type: msg.type, - replyToMessageId: msg.replyToMessageId, - replyToContent: msg.replyToContent, - replyToSenderName: msg.replyToSenderName, - isHit, - }) - totalChars += (msg.content || '').length - } - - blocks.push({ - startTs: allMessages[range.start].ts, - endTs: allMessages[range.end].ts, - messages: blockMessages, - hitCount: range.hitIndexes.length, - }) - - totalMessages += blockMessages.length - } - - return { - blocks, - stats: { - totalMessages, - hitMessages: hitIndexes.length, - totalChars, - }, - } - } catch (error) { - console.error('filterMessagesWithContext error:', error) - return { blocks: [], stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 } } - } finally { - db.close() - } -} - -/** - * 获取多个会话的完整消息(用于会话筛选模式) - * - * @param sessionId 数据库会话ID - * @param chatSessionIds 要获取的会话ID列表 - * @returns 合并后的上下文块和统计 - */ -export function getMultipleSessionsMessages(sessionId: string, chatSessionIds: number[]): FilterResult { - const db = openReadonlyDatabase(sessionId) - if (!db) { - return { blocks: [], stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 } } - } - - try { - if (chatSessionIds.length === 0) { - return { blocks: [], stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 } } - } - - const blocks: ContextBlock[] = [] - let totalMessages = 0 - let totalChars = 0 - - // 先获取会话信息,按时间排序 - const sessionsSql = ` - SELECT id, start_ts as startTs, end_ts as endTs, message_count as messageCount - FROM chat_session - WHERE id IN (${chatSessionIds.map(() => '?').join(',')}) - ORDER BY start_ts ASC - ` - const sessions = db.prepare(sessionsSql).all(...chatSessionIds) as Array<{ - id: number - startTs: number - endTs: number - messageCount: number - }> - - // 为每个会话获取消息(完整信息) - // 使用 LEFT JOIN 获取回复消息的信息 - const messagesSql = ` - SELECT - msg.id, - COALESCE(m.group_nickname, m.account_name, m.platform_id) as senderName, - m.platform_id as senderPlatformId, - COALESCE(m.aliases, '[]') as senderAliasesJson, - m.avatar as senderAvatar, - msg.content, - msg.type, - msg.reply_to_message_id as replyToMessageId, - reply_msg.content as replyToContent, - COALESCE(reply_m.group_nickname, reply_m.account_name, reply_m.platform_id) as replyToSenderName, - msg.ts as timestamp - FROM message_context mc - JOIN message msg ON msg.id = mc.message_id - JOIN member m ON msg.sender_id = m.id - LEFT JOIN message reply_msg ON msg.reply_to_message_id = reply_msg.platform_message_id - LEFT JOIN member reply_m ON reply_msg.sender_id = reply_m.id - WHERE mc.session_id = ? - ORDER BY msg.ts ASC - ` - - for (const session of sessions) { - const messages = db.prepare(messagesSql).all(session.id) as Array<{ - id: number - senderName: string - senderPlatformId: string - senderAliasesJson: string - senderAvatar: string | null - content: string | null - type: number - replyToMessageId: string | null - replyToContent: string | null - replyToSenderName: string | null - timestamp: number - }> - - const blockMessages: FilterMessage[] = messages.map((msg) => { - // 解析别名 JSON - let senderAliases: string[] = [] - try { - senderAliases = JSON.parse(msg.senderAliasesJson || '[]') - } catch { - senderAliases = [] - } - - return { - id: msg.id, - senderName: msg.senderName, - senderPlatformId: msg.senderPlatformId, - senderAliases, - senderAvatar: msg.senderAvatar, - content: msg.content || '', - timestamp: msg.timestamp, - type: msg.type, - replyToMessageId: msg.replyToMessageId, - replyToContent: msg.replyToContent, - replyToSenderName: msg.replyToSenderName, - isHit: false, // 会话模式下没有命中高亮 - } - }) - - for (const msg of messages) { - totalChars += (msg.content || '').length - } - - blocks.push({ - startTs: session.startTs, - endTs: session.endTs, - messages: blockMessages, - hitCount: 0, - }) - - totalMessages += messages.length - } - - return { - blocks, - stats: { - totalMessages, - hitMessages: 0, // 会话模式没有命中概念 - totalChars, - }, - } - } catch (error) { - console.error('getMultipleSessionsMessages error:', error) - return { blocks: [], stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 } } - } finally { - db.close() - } -} diff --git a/electron/main/worker/query/session/aiTools.ts b/electron/main/worker/query/session/aiTools.ts new file mode 100644 index 0000000..af1a6b1 --- /dev/null +++ b/electron/main/worker/query/session/aiTools.ts @@ -0,0 +1,205 @@ +/** + * AI 工具专用查询模块 + * 提供搜索会话和获取会话消息等功能,供 AI 工具使用 + */ + +import { openReadonlyDatabase } from './core' +import type { SessionSearchResultItem, SessionMessagesResult } from './types' + +/** + * 搜索会话(用于 AI 工具) + * 支持按关键词和时间范围筛选会话 + * + * @param sessionId 数据库会话ID + * @param keywords 关键词列表(可选,OR 逻辑匹配) + * @param timeFilter 时间过滤器(可选) + * @param limit 返回数量限制,默认 20 + * @param previewCount 预览消息数量,默认 5 + * @returns 匹配的会话列表 + */ +export function searchSessions( + sessionId: string, + keywords?: string[], + timeFilter?: { startTs: number; endTs: number }, + limit: number = 20, + previewCount: number = 5 +): SessionSearchResultItem[] { + const db = openReadonlyDatabase(sessionId) + if (!db) { + return [] + } + + try { + // 1. 构建会话查询 SQL + let sessionSql = ` + SELECT + cs.id, + cs.start_ts as startTs, + cs.end_ts as endTs, + cs.message_count as messageCount + FROM chat_session cs + WHERE 1=1 + ` + const params: unknown[] = [] + + // 时间范围过滤 + if (timeFilter) { + sessionSql += ` AND cs.start_ts >= ? AND cs.end_ts <= ?` + params.push(timeFilter.startTs, timeFilter.endTs) + } + + // 关键词过滤:只返回包含关键词的会话 + if (keywords && keywords.length > 0) { + const keywordConditions = keywords.map(() => `m.content LIKE ?`).join(' OR ') + sessionSql += ` + AND cs.id IN ( + SELECT DISTINCT mc.session_id + FROM message_context mc + JOIN message m ON m.id = mc.message_id + WHERE (${keywordConditions}) + ) + ` + for (const kw of keywords) { + params.push(`%${kw}%`) + } + } + + sessionSql += ` ORDER BY cs.start_ts DESC LIMIT ?` + params.push(limit) + + const sessions = db.prepare(sessionSql).all(...params) as Array<{ + id: number + startTs: number + endTs: number + messageCount: number + }> + + // 2. 为每个会话获取预览消息 + const previewSql = ` + SELECT + m.id, + COALESCE(mb.group_nickname, mb.account_name, mb.platform_id) as senderName, + m.content, + m.ts as timestamp + FROM message_context mc + JOIN message m ON m.id = mc.message_id + JOIN member mb ON mb.id = m.sender_id + WHERE mc.session_id = ? + ORDER BY m.ts ASC + LIMIT ? + ` + + const results: SessionSearchResultItem[] = [] + for (const session of sessions) { + const previewMessages = db.prepare(previewSql).all(session.id, previewCount) as Array<{ + id: number + senderName: string + content: string | null + timestamp: number + }> + + results.push({ + id: session.id, + startTs: session.startTs, + endTs: session.endTs, + messageCount: session.messageCount, + isComplete: session.messageCount <= previewCount, + previewMessages, + }) + } + + return results + } catch (error) { + console.error('searchSessions error:', error) + return [] + } finally { + db.close() + } +} + +/** + * 获取会话的完整消息(用于 AI 工具) + * + * @param sessionId 数据库会话ID + * @param chatSessionId 会话索引中的会话ID + * @param limit 返回数量限制,默认 500 + * @returns 会话的完整消息 + */ +export function getSessionMessages( + sessionId: string, + chatSessionId: number, + limit: number = 500 +): SessionMessagesResult | null { + const db = openReadonlyDatabase(sessionId) + if (!db) { + return null + } + + try { + // 1. 获取会话基本信息 + const sessionSql = ` + SELECT + id, + start_ts as startTs, + end_ts as endTs, + message_count as messageCount + FROM chat_session + WHERE id = ? + ` + const session = db.prepare(sessionSql).get(chatSessionId) as + | { + id: number + startTs: number + endTs: number + messageCount: number + } + | undefined + + if (!session) { + db.close() + return null + } + + // 2. 获取会话消息 + const messagesSql = ` + SELECT + m.id, + COALESCE(mb.group_nickname, mb.account_name, mb.platform_id) as senderName, + m.content, + m.ts as timestamp + FROM message_context mc + JOIN message m ON m.id = mc.message_id + JOIN member mb ON mb.id = m.sender_id + WHERE mc.session_id = ? + ORDER BY m.ts ASC + LIMIT ? + ` + const messages = db.prepare(messagesSql).all(chatSessionId, limit) as Array<{ + id: number + senderName: string + content: string | null + timestamp: number + }> + + // 3. 统计参与者 + const participantsSet = new Set() + for (const msg of messages) { + participantsSet.add(msg.senderName) + } + + return { + sessionId: session.id, + startTs: session.startTs, + endTs: session.endTs, + messageCount: session.messageCount, + returnedCount: messages.length, + participants: Array.from(participantsSet), + messages, + } + } catch (error) { + console.error('getSessionMessages error:', error) + return null + } finally { + db.close() + } +} diff --git a/electron/main/worker/query/session/core.ts b/electron/main/worker/query/session/core.ts new file mode 100644 index 0000000..53cf170 --- /dev/null +++ b/electron/main/worker/query/session/core.ts @@ -0,0 +1,39 @@ +/** + * 会话模块核心工具函数 + * 提供数据库连接等共享功能 + */ + +import Database from 'better-sqlite3' +import { getDbPath, closeDatabase } from '../../core' + +// 重新导出 closeDatabase 供其他模块使用 +export { closeDatabase } + +/** + * 打开数据库(可写模式,不使用缓存) + * 会话索引需要写入数据 + */ +export function openWritableDatabase(sessionId: string): Database.Database | null { + const dbPath = getDbPath(sessionId) + try { + const db = new Database(dbPath) + db.pragma('journal_mode = WAL') + return db + } catch { + return null + } +} + +/** + * 打开数据库(只读模式,不使用缓存) + */ +export function openReadonlyDatabase(sessionId: string): Database.Database | null { + const dbPath = getDbPath(sessionId) + try { + const db = new Database(dbPath, { readonly: true }) + db.pragma('journal_mode = WAL') + return db + } catch { + return null + } +} diff --git a/electron/main/worker/query/session/export.ts b/electron/main/worker/query/session/export.ts new file mode 100644 index 0000000..088f2a6 --- /dev/null +++ b/electron/main/worker/query/session/export.ts @@ -0,0 +1,381 @@ +/** + * 导出功能模块 + * 提供将筛选结果导出为 Markdown 文件的功能 + */ + +import * as fs from 'fs' +import * as path from 'path' +import { parentPort } from 'worker_threads' +import { openReadonlyDatabase } from './core' +import type { ExportFilterParams, ExportProgress } from './types' + +/** + * 发送导出进度到主进程 + */ +function sendExportProgress(requestId: string, progress: ExportProgress): void { + parentPort?.postMessage({ + id: requestId, + type: 'progress', + payload: progress, + }) +} + +/** + * 导出筛选结果到 Markdown 文件(后端生成,支持大数据量) + * 使用流式写入,避免内存溢出 + * + * @param params 导出参数 + * @param requestId 请求 ID(用于发送进度) + * @returns 生成的文件路径 + */ +export function exportFilterResultToFile( + params: ExportFilterParams, + requestId?: string +): { success: boolean; filePath?: string; error?: string } { + const db = openReadonlyDatabase(params.sessionId) + if (!db) { + return { success: false, error: '无法打开数据库' } + } + + try { + const timestamp = Date.now() + const fileName = `${params.sessionName}_筛选结果_${timestamp}.md` + const filePath = path.join(params.outputDir, fileName) + + // 创建写入流 + const writeStream = fs.createWriteStream(filePath, { encoding: 'utf8' }) + + // 写入头部 + writeStream.write(`# ${params.sessionName} - 聊天记录筛选结果\n\n`) + writeStream.write(`> 导出时间: ${new Date().toLocaleString()}\n\n`) + + // 写入筛选条件摘要 + writeStream.write(`## 筛选条件\n\n`) + if (params.filterMode === 'condition') { + if (params.keywords && params.keywords.length > 0) { + writeStream.write(`- 关键词: ${params.keywords.join(', ')}\n`) + } + if (params.timeFilter) { + const start = new Date(params.timeFilter.startTs * 1000).toLocaleString() + const end = new Date(params.timeFilter.endTs * 1000).toLocaleString() + writeStream.write(`- 时间范围: ${start} ~ ${end}\n`) + } + writeStream.write(`- 上下文扩展: ±${params.contextSize || 10} 条消息\n`) + } else { + writeStream.write(`- 模式: 会话筛选\n`) + writeStream.write(`- 选中会话数: ${params.chatSessionIds?.length || 0}\n`) + } + writeStream.write('\n') + + let totalMessages = 0 + let totalHits = 0 + let totalChars = 0 + let blockIndex = 0 + + if (params.filterMode === 'condition') { + // 条件筛选模式:流式处理 + const contextSize = params.contextSize || 10 + + // 第一阶段:获取命中消息的索引 + const lightweightSql = ` + SELECT + id, + ts, + sender_id as senderId, + content + FROM message + ${params.timeFilter ? 'WHERE ts >= ? AND ts <= ?' : ''} + ORDER BY ts ASC, id ASC + ` + const sqlParams: unknown[] = [] + if (params.timeFilter) { + sqlParams.push(params.timeFilter.startTs, params.timeFilter.endTs) + } + + const hitIndexes: number[] = [] + let msgIndex = 0 + const stmt = db.prepare(lightweightSql) + + for (const row of stmt.iterate(...sqlParams) as Iterable<{ + id: number + ts: number + senderId: number + content: string | null + }>) { + let isHit = true + + if (params.keywords && params.keywords.length > 0) { + const content = (row.content || '').toLowerCase() + isHit = params.keywords.some((kw) => content.includes(kw.toLowerCase())) + } + + if (isHit && params.senderIds && params.senderIds.length > 0) { + isHit = params.senderIds.includes(row.senderId) + } + + if (isHit) { + hitIndexes.push(msgIndex) + } + msgIndex++ + } + + totalHits = hitIndexes.length + + // 发送准备阶段进度 + if (requestId) { + sendExportProgress(requestId, { + stage: 'preparing', + currentBlock: 0, + totalBlocks: 0, + percentage: 10, + message: `正在分析数据,找到 ${totalHits} 条匹配消息...`, + }) + } + + if (hitIndexes.length === 0) { + writeStream.write(`## 统计信息\n\n`) + writeStream.write(`- 无匹配结果\n`) + writeStream.end() + if (requestId) { + sendExportProgress(requestId, { + stage: 'done', + currentBlock: 0, + totalBlocks: 0, + percentage: 100, + message: '导出完成(无匹配结果)', + }) + } + return { success: true, filePath } + } + + // 计算上下文范围并合并 + const ranges: Array<{ start: number; end: number; hitIndexes: number[] }> = [] + const totalMsgCount = msgIndex + + for (const hitIdx of hitIndexes) { + const start = Math.max(0, hitIdx - contextSize) + const end = Math.min(totalMsgCount - 1, hitIdx + contextSize) + + if (ranges.length > 0) { + const lastRange = ranges[ranges.length - 1] + if (start <= lastRange.end + 1) { + lastRange.end = Math.max(lastRange.end, end) + lastRange.hitIndexes.push(hitIdx) + continue + } + } + ranges.push({ start, end, hitIndexes: [hitIdx] }) + } + + const totalBlocks = ranges.length + + // 发送开始导出进度 + if (requestId) { + sendExportProgress(requestId, { + stage: 'exporting', + currentBlock: 0, + totalBlocks, + percentage: 15, + message: `开始导出 ${totalBlocks} 个对话块...`, + }) + } + + // 写入统计信息 + writeStream.write(`## 统计信息\n\n`) + writeStream.write(`- 对话块数: ${totalBlocks}\n`) + writeStream.write(`- 命中消息: ${totalHits}\n\n`) + + // 第二阶段:流式写入每个块的内容 + writeStream.write(`## 对话内容\n\n`) + + for (const range of ranges) { + blockIndex++ + + // 发送导出进度(每个块) + if (requestId) { + const percentage = Math.round(15 + ((blockIndex - 1) / totalBlocks) * 80) + sendExportProgress(requestId, { + stage: 'exporting', + currentBlock: blockIndex, + totalBlocks, + percentage, + message: `正在导出对话块 ${blockIndex}/${totalBlocks}...`, + }) + } + + const blockSql = ` + SELECT + msg.id, + msg.ts, + COALESCE(m.group_nickname, m.account_name, m.platform_id) as senderName, + msg.content + FROM message msg + JOIN member m ON msg.sender_id = m.id + ${params.timeFilter ? 'WHERE msg.ts >= ? AND msg.ts <= ?' : ''} + ORDER BY msg.ts ASC, msg.id ASC + LIMIT ? OFFSET ? + ` + const blockParams: unknown[] = [] + if (params.timeFilter) { + blockParams.push(params.timeFilter.startTs, params.timeFilter.endTs) + } + blockParams.push(range.end - range.start + 1, range.start) + + const messages = db.prepare(blockSql).all(...blockParams) as Array<{ + id: number + ts: number + senderName: string + content: string | null + }> + + if (messages.length === 0) continue + + const hitIndexSet = new Set(range.hitIndexes.map((idx) => idx - range.start)) + + const startTime = new Date(messages[0].ts * 1000).toLocaleString() + const endTime = new Date(messages[messages.length - 1].ts * 1000).toLocaleString() + writeStream.write(`### 对话块 ${blockIndex} (${startTime} ~ ${endTime})\n\n`) + + for (let i = 0; i < messages.length; i++) { + const msg = messages[i] + const time = new Date(msg.ts * 1000).toLocaleTimeString() + const hitMark = hitIndexSet.has(i) ? ' ⭐' : '' + const content = msg.content || '[非文本消息]' + writeStream.write(`${time} ${msg.senderName}${hitMark}: ${content}\n`) + totalMessages++ + totalChars += (msg.content || '').length + } + writeStream.write('\n') + } + } else { + // 会话筛选模式 + if (!params.chatSessionIds || params.chatSessionIds.length === 0) { + writeStream.write(`## 统计信息\n\n`) + writeStream.write(`- 未选择会话\n`) + writeStream.end() + if (requestId) { + sendExportProgress(requestId, { + stage: 'done', + currentBlock: 0, + totalBlocks: 0, + percentage: 100, + message: '导出完成(未选择会话)', + }) + } + return { success: true, filePath } + } + + // 发送准备阶段进度 + if (requestId) { + sendExportProgress(requestId, { + stage: 'preparing', + currentBlock: 0, + totalBlocks: params.chatSessionIds.length, + percentage: 10, + message: `正在准备导出 ${params.chatSessionIds.length} 个会话...`, + }) + } + + // 获取会话信息 + const sessionsSql = ` + SELECT id, start_ts as startTs, end_ts as endTs + FROM chat_session + WHERE id IN (${params.chatSessionIds.map(() => '?').join(',')}) + ORDER BY start_ts ASC + ` + const sessions = db.prepare(sessionsSql).all(...params.chatSessionIds) as Array<{ + id: number + startTs: number + endTs: number + }> + + const totalBlocks = sessions.length + + writeStream.write(`## 统计信息\n\n`) + writeStream.write(`- 对话块数: ${totalBlocks}\n\n`) + + writeStream.write(`## 对话内容\n\n`) + + const messagesSql = ` + SELECT + msg.id, + COALESCE(m.group_nickname, m.account_name, m.platform_id) as senderName, + msg.content, + msg.ts as timestamp + FROM message_context mc + JOIN message msg ON msg.id = mc.message_id + JOIN member m ON msg.sender_id = m.id + WHERE mc.session_id = ? + ORDER BY msg.ts ASC + ` + + for (const session of sessions) { + blockIndex++ + + // 发送导出进度(每个会话) + if (requestId) { + const percentage = Math.round(15 + ((blockIndex - 1) / totalBlocks) * 80) + sendExportProgress(requestId, { + stage: 'exporting', + currentBlock: blockIndex, + totalBlocks, + percentage, + message: `正在导出会话 ${blockIndex}/${totalBlocks}...`, + }) + } + + const messages = db.prepare(messagesSql).all(session.id) as Array<{ + id: number + senderName: string + content: string | null + timestamp: number + }> + + if (messages.length === 0) continue + + const startTime = new Date(session.startTs * 1000).toLocaleString() + const endTime = new Date(session.endTs * 1000).toLocaleString() + writeStream.write(`### 对话块 ${blockIndex} (${startTime} ~ ${endTime})\n\n`) + + for (const msg of messages) { + const time = new Date(msg.timestamp * 1000).toLocaleTimeString() + const content = msg.content || '[非文本消息]' + writeStream.write(`${time} ${msg.senderName}: ${content}\n`) + totalMessages++ + totalChars += (msg.content || '').length + } + writeStream.write('\n') + } + } + + writeStream.end() + + // 发送完成进度 + if (requestId) { + sendExportProgress(requestId, { + stage: 'done', + currentBlock: blockIndex, + totalBlocks: blockIndex, + percentage: 100, + message: `导出完成,共 ${blockIndex} 个对话块`, + }) + } + + return { success: true, filePath } + } catch (error) { + console.error('exportFilterResultToFile error:', error) + // 发送错误进度 + if (requestId) { + sendExportProgress(requestId, { + stage: 'error', + currentBlock: 0, + totalBlocks: 0, + percentage: 0, + message: `导出失败: ${String(error)}`, + }) + } + return { success: false, error: String(error) } + } finally { + db.close() + } +} diff --git a/electron/main/worker/query/session/filter.ts b/electron/main/worker/query/session/filter.ts new file mode 100644 index 0000000..a4fc466 --- /dev/null +++ b/electron/main/worker/query/session/filter.ts @@ -0,0 +1,447 @@ +/** + * 自定义筛选模块 + * 提供按条件筛选消息和获取多会话消息等功能 + */ + +import { openReadonlyDatabase } from './core' +import type { + FilterMessage, + ContextBlock, + FilterResultWithPagination, +} from './types' + +/** + * 按条件筛选消息并扩充上下文(支持分页) + * + * 两阶段查询架构: + * 1. 第一阶段:轻量级查询获取消息 ID、序号和匹配信息(不加载完整内容) + * 2. 第二阶段:计算上下文范围、合并、分页后只获取当前页的完整消息 + * + * @param sessionId 数据库会话ID + * @param keywords 关键词列表(可选,OR 逻辑) + * @param timeFilter 时间过滤器(可选) + * @param senderIds 发送者ID列表(可选) + * @param contextSize 上下文扩展数量(前后各多少条) + * @param page 页码(从 1 开始,默认 1) + * @param pageSize 每页块数(默认 50) + * @returns 筛选结果(带分页信息) + */ +export function filterMessagesWithContext( + sessionId: string, + keywords?: string[], + timeFilter?: { startTs: number; endTs: number }, + senderIds?: number[], + contextSize: number = 10, + page: number = 1, + pageSize: number = 50 +): FilterResultWithPagination { + const db = openReadonlyDatabase(sessionId) + if (!db) { + return { + blocks: [], + stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 }, + pagination: { page, pageSize, totalBlocks: 0, totalHits: 0, hasMore: false }, + } + } + + try { + // ==================== 第一阶段:轻量级查询 ==================== + // 只获取消息的 ID、时间戳、发送者ID、内容(用于匹配) + // 使用 ROW_NUMBER() 计算全局序号,避免一次性加载所有完整数据 + + const lightweightSql = ` + SELECT + id, + ts, + sender_id as senderId, + content + FROM message + ${timeFilter ? 'WHERE ts >= ? AND ts <= ?' : ''} + ORDER BY ts ASC, id ASC + ` + + const params: unknown[] = [] + if (timeFilter) { + params.push(timeFilter.startTs, timeFilter.endTs) + } + + // 使用 iterate() 流式处理,避免一次性加载所有数据到内存 + const stmt = db.prepare(lightweightSql) + const hitIndexes: number[] = [] + let totalMessageCount = 0 + let estimatedTotalChars = 0 + + // 流式遍历消息,标记命中的索引 + for (const row of stmt.iterate(...params) as Iterable<{ + id: number + ts: number + senderId: number + content: string | null + }>) { + let isHit = true + + // 关键词匹配(OR 逻辑) + if (keywords && keywords.length > 0) { + const content = (row.content || '').toLowerCase() + isHit = keywords.some((kw) => content.includes(kw.toLowerCase())) + } + + // 发送者匹配 + if (isHit && senderIds && senderIds.length > 0) { + isHit = senderIds.includes(row.senderId) + } + + if (isHit) { + hitIndexes.push(totalMessageCount) + } + + totalMessageCount++ + } + + if (hitIndexes.length === 0) { + return { + blocks: [], + stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 }, + pagination: { page, pageSize, totalBlocks: 0, totalHits: 0, hasMore: false }, + } + } + + // ==================== 计算上下文范围并合并 ==================== + const ranges: Array<{ start: number; end: number; hitIndexes: number[] }> = [] + + for (const hitIndex of hitIndexes) { + const start = Math.max(0, hitIndex - contextSize) + const end = Math.min(totalMessageCount - 1, hitIndex + contextSize) + + // 检查是否能与前一个范围合并 + if (ranges.length > 0) { + const lastRange = ranges[ranges.length - 1] + if (start <= lastRange.end + 1) { + lastRange.end = Math.max(lastRange.end, end) + lastRange.hitIndexes.push(hitIndex) + continue + } + } + + ranges.push({ start, end, hitIndexes: [hitIndex] }) + } + + const totalBlocks = ranges.length + const totalHits = hitIndexes.length + + // ==================== 分页处理 ==================== + const startIndex = (page - 1) * pageSize + const endIndex = Math.min(startIndex + pageSize, totalBlocks) + const pageRanges = ranges.slice(startIndex, endIndex) + const hasMore = endIndex < totalBlocks + + if (pageRanges.length === 0) { + return { + blocks: [], + stats: { totalMessages: 0, hitMessages: totalHits, totalChars: 0 }, + pagination: { page, pageSize, totalBlocks, totalHits, hasMore: false }, + } + } + + // ==================== 第二阶段:获取当前页的完整消息 ==================== + // 只为当前页的范围获取完整消息数据 + + const blocks: ContextBlock[] = [] + let totalMessages = 0 + let totalChars = 0 + + for (const range of pageRanges) { + // 使用 LIMIT OFFSET 获取指定范围的消息 + const blockSql = ` + SELECT + msg.id, + msg.ts, + COALESCE(m.group_nickname, m.account_name, m.platform_id) as senderName, + m.platform_id as senderPlatformId, + COALESCE(m.aliases, '[]') as senderAliasesJson, + m.avatar as senderAvatar, + msg.content, + msg.type, + msg.reply_to_message_id as replyToMessageId, + reply_msg.content as replyToContent, + COALESCE(reply_m.group_nickname, reply_m.account_name, reply_m.platform_id) as replyToSenderName + FROM message msg + JOIN member m ON msg.sender_id = m.id + LEFT JOIN message reply_msg ON msg.reply_to_message_id = reply_msg.platform_message_id + LEFT JOIN member reply_m ON reply_msg.sender_id = reply_m.id + ${timeFilter ? 'WHERE msg.ts >= ? AND msg.ts <= ?' : ''} + ORDER BY msg.ts ASC, msg.id ASC + LIMIT ? OFFSET ? + ` + + const blockParams: unknown[] = [] + if (timeFilter) { + blockParams.push(timeFilter.startTs, timeFilter.endTs) + } + blockParams.push(range.end - range.start + 1, range.start) + + const messages = db.prepare(blockSql).all(...blockParams) as Array<{ + id: number + ts: number + senderName: string + senderPlatformId: string + senderAliasesJson: string + senderAvatar: string | null + content: string | null + type: number + replyToMessageId: string | null + replyToContent: string | null + replyToSenderName: string | null + }> + + // 构建 hitIndexSet(相对于 range.start 的偏移) + const hitIndexSet = new Set(range.hitIndexes.map((idx) => idx - range.start)) + + const blockMessages: FilterMessage[] = [] + for (let i = 0; i < messages.length; i++) { + const msg = messages[i] + const isHit = hitIndexSet.has(i) + + // 解析别名 JSON + let senderAliases: string[] = [] + try { + senderAliases = JSON.parse(msg.senderAliasesJson || '[]') + } catch { + senderAliases = [] + } + + blockMessages.push({ + id: msg.id, + senderName: msg.senderName, + senderPlatformId: msg.senderPlatformId, + senderAliases, + senderAvatar: msg.senderAvatar, + content: msg.content || '', + timestamp: msg.ts, + type: msg.type, + replyToMessageId: msg.replyToMessageId, + replyToContent: msg.replyToContent, + replyToSenderName: msg.replyToSenderName, + isHit, + }) + totalChars += (msg.content || '').length + } + + if (blockMessages.length > 0) { + blocks.push({ + startTs: blockMessages[0].timestamp, + endTs: blockMessages[blockMessages.length - 1].timestamp, + messages: blockMessages, + hitCount: range.hitIndexes.length, + }) + totalMessages += blockMessages.length + } + } + + // 如果是第一页,需要估算总字符数(用于统计显示) + // 由于我们不再一次性加载所有数据,这里使用采样估算 + if (page === 1 && totalBlocks > pageSize) { + // 估算:当前页的平均字符数 × 总块数 + const avgCharsPerBlock = totalChars / blocks.length + estimatedTotalChars = Math.round(avgCharsPerBlock * totalBlocks) + } else if (page === 1) { + estimatedTotalChars = totalChars + } + + return { + blocks, + stats: { + totalMessages: page === 1 ? totalMessages : 0, // 只有第一页返回准确的消息数 + hitMessages: totalHits, + totalChars: page === 1 ? (totalBlocks > pageSize ? estimatedTotalChars : totalChars) : 0, + }, + pagination: { + page, + pageSize, + totalBlocks, + totalHits, + hasMore, + }, + } + } catch (error) { + console.error('filterMessagesWithContext error:', error) + return { + blocks: [], + stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 }, + pagination: { page, pageSize, totalBlocks: 0, totalHits: 0, hasMore: false }, + } + } finally { + db.close() + } +} + +/** + * 获取多个会话的完整消息(用于会话筛选模式,支持分页) + * + * @param sessionId 数据库会话ID + * @param chatSessionIds 要获取的会话ID列表 + * @param page 页码(从 1 开始,默认 1) + * @param pageSize 每页块数(默认 50) + * @returns 合并后的上下文块和统计(带分页信息) + */ +export function getMultipleSessionsMessages( + sessionId: string, + chatSessionIds: number[], + page: number = 1, + pageSize: number = 50 +): FilterResultWithPagination { + const db = openReadonlyDatabase(sessionId) + if (!db) { + return { + blocks: [], + stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 }, + pagination: { page, pageSize, totalBlocks: 0, totalHits: 0, hasMore: false }, + } + } + + try { + if (chatSessionIds.length === 0) { + return { + blocks: [], + stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 }, + pagination: { page, pageSize, totalBlocks: 0, totalHits: 0, hasMore: false }, + } + } + + // 先获取会话信息,按时间排序 + const sessionsSql = ` + SELECT id, start_ts as startTs, end_ts as endTs, message_count as messageCount + FROM chat_session + WHERE id IN (${chatSessionIds.map(() => '?').join(',')}) + ORDER BY start_ts ASC + ` + const allSessions = db.prepare(sessionsSql).all(...chatSessionIds) as Array<{ + id: number + startTs: number + endTs: number + messageCount: number + }> + + const totalBlocks = allSessions.length + + // 分页处理 + const startIndex = (page - 1) * pageSize + const endIndex = Math.min(startIndex + pageSize, totalBlocks) + const pageSessions = allSessions.slice(startIndex, endIndex) + const hasMore = endIndex < totalBlocks + + if (pageSessions.length === 0) { + return { + blocks: [], + stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 }, + pagination: { page, pageSize, totalBlocks, totalHits: 0, hasMore: false }, + } + } + + const blocks: ContextBlock[] = [] + let totalMessages = 0 + let totalChars = 0 + + // 为当前页的会话获取消息(完整信息) + const messagesSql = ` + SELECT + msg.id, + COALESCE(m.group_nickname, m.account_name, m.platform_id) as senderName, + m.platform_id as senderPlatformId, + COALESCE(m.aliases, '[]') as senderAliasesJson, + m.avatar as senderAvatar, + msg.content, + msg.type, + msg.reply_to_message_id as replyToMessageId, + reply_msg.content as replyToContent, + COALESCE(reply_m.group_nickname, reply_m.account_name, reply_m.platform_id) as replyToSenderName, + msg.ts as timestamp + FROM message_context mc + JOIN message msg ON msg.id = mc.message_id + JOIN member m ON msg.sender_id = m.id + LEFT JOIN message reply_msg ON msg.reply_to_message_id = reply_msg.platform_message_id + LEFT JOIN member reply_m ON reply_msg.sender_id = reply_m.id + WHERE mc.session_id = ? + ORDER BY msg.ts ASC + ` + + for (const session of pageSessions) { + const messages = db.prepare(messagesSql).all(session.id) as Array<{ + id: number + senderName: string + senderPlatformId: string + senderAliasesJson: string + senderAvatar: string | null + content: string | null + type: number + replyToMessageId: string | null + replyToContent: string | null + replyToSenderName: string | null + timestamp: number + }> + + const blockMessages: FilterMessage[] = messages.map((msg) => { + // 解析别名 JSON + let senderAliases: string[] = [] + try { + senderAliases = JSON.parse(msg.senderAliasesJson || '[]') + } catch { + senderAliases = [] + } + + return { + id: msg.id, + senderName: msg.senderName, + senderPlatformId: msg.senderPlatformId, + senderAliases, + senderAvatar: msg.senderAvatar, + content: msg.content || '', + timestamp: msg.timestamp, + type: msg.type, + replyToMessageId: msg.replyToMessageId, + replyToContent: msg.replyToContent, + replyToSenderName: msg.replyToSenderName, + isHit: false, // 会话模式下没有命中高亮 + } + }) + + for (const msg of messages) { + totalChars += (msg.content || '').length + } + + blocks.push({ + startTs: session.startTs, + endTs: session.endTs, + messages: blockMessages, + hitCount: 0, + }) + + totalMessages += messages.length + } + + return { + blocks, + stats: { + totalMessages: page === 1 ? totalMessages : 0, // 只有第一页返回准确的消息数 + hitMessages: 0, // 会话模式没有命中概念 + totalChars: page === 1 ? totalChars : 0, + }, + pagination: { + page, + pageSize, + totalBlocks, + totalHits: 0, + hasMore, + }, + } + } catch (error) { + console.error('getMultipleSessionsMessages error:', error) + return { + blocks: [], + stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 }, + pagination: { page, pageSize, totalBlocks: 0, totalHits: 0, hasMore: false }, + } + } finally { + db.close() + } +} diff --git a/electron/main/worker/query/session/index.ts b/electron/main/worker/query/session/index.ts new file mode 100644 index 0000000..e0c6225 --- /dev/null +++ b/electron/main/worker/query/session/index.ts @@ -0,0 +1,41 @@ +/** + * 会话模块统一导出 + * 提供会话索引管理、AI 工具查询、自定义筛选和导出等功能 + */ + +// 类型定义 +export type { + ChatSessionItem, + SessionSearchResultItem, + SessionMessagesResult, + FilterMessage, + ContextBlock, + FilterResult, + PaginationInfo, + FilterResultWithPagination, + ExportFilterParams, + ExportProgress, +} from './types' + +export { DEFAULT_SESSION_GAP_THRESHOLD } from './types' + +// 会话索引管理 +export { + generateSessions, + clearSessions, + hasSessionIndex, + getSessionStats, + updateSessionGapThreshold, + getSessions, + saveSessionSummary, + getSessionSummary, +} from './sessionIndex' + +// AI 工具专用查询 +export { searchSessions, getSessionMessages } from './aiTools' + +// 自定义筛选 +export { filterMessagesWithContext, getMultipleSessionsMessages } from './filter' + +// 导出功能 +export { exportFilterResultToFile } from './export' diff --git a/electron/main/worker/query/session/sessionIndex.ts b/electron/main/worker/query/session/sessionIndex.ts new file mode 100644 index 0000000..4b1c9d0 --- /dev/null +++ b/electron/main/worker/query/session/sessionIndex.ts @@ -0,0 +1,349 @@ +/** + * 会话索引管理模块 + * 提供会话生成、查询、管理等功能 + */ + +import type Database from 'better-sqlite3' +import { openWritableDatabase, openReadonlyDatabase, closeDatabase } from './core' +import { DEFAULT_SESSION_GAP_THRESHOLD, type ChatSessionItem } from './types' + +/** + * 内部清空会话数据函数 + */ +function clearSessionsInternal(db: Database.Database): void { + db.exec('DELETE FROM message_context') + db.exec('DELETE FROM chat_session') +} + +/** + * 生成会话索引 + * 使用 Gap-based 算法,根据消息时间间隔自动切分会话 + * + * @param sessionId 数据库会话ID + * @param gapThreshold 时间间隔阈值(秒),默认 1800(30分钟) + * @param onProgress 进度回调 + * @returns 生成的会话数量 + */ +export function generateSessions( + sessionId: string, + gapThreshold: number = DEFAULT_SESSION_GAP_THRESHOLD, + onProgress?: (current: number, total: number) => void +): number { + // 先关闭缓存的只读连接 + closeDatabase(sessionId) + + const db = openWritableDatabase(sessionId) + if (!db) { + throw new Error(`无法打开数据库: ${sessionId}`) + } + + try { + // 获取消息总数 + const countResult = db.prepare('SELECT COUNT(*) as count FROM message').get() as { count: number } + const totalMessages = countResult.count + + if (totalMessages === 0) { + return 0 + } + + // 清空已有的会话数据 + clearSessionsInternal(db) + + // 使用窗口函数计算会话边界 + // 步骤1:为每条消息计算与前一条的时间差,标记新会话起点 + const sessionMarkSQL = ` + WITH message_ordered AS ( + SELECT + id, + ts, + LAG(ts) OVER (ORDER BY ts, id) AS prev_ts + FROM message + ), + session_marks AS ( + SELECT + id, + ts, + CASE + WHEN prev_ts IS NULL OR (ts - prev_ts) > ? THEN 1 + ELSE 0 + END AS is_new_session + FROM message_ordered + ), + session_ids AS ( + SELECT + id, + ts, + SUM(is_new_session) OVER (ORDER BY ts, id) AS session_num + FROM session_marks + ) + SELECT id, ts, session_num FROM session_ids + ` + + const messages = db.prepare(sessionMarkSQL).all(gapThreshold) as Array<{ + id: number + ts: number + session_num: number + }> + + if (messages.length === 0) { + return 0 + } + + // 步骤2:计算每个会话的统计信息 + const sessionMap = new Map() + + for (const msg of messages) { + const session = sessionMap.get(msg.session_num) + if (!session) { + sessionMap.set(msg.session_num, { + startTs: msg.ts, + endTs: msg.ts, + messageIds: [msg.id], + }) + } else { + session.endTs = msg.ts + session.messageIds.push(msg.id) + } + } + + // 步骤3:批量写入 chat_session 和 message_context 表 + const insertSession = db.prepare(` + INSERT INTO chat_session (start_ts, end_ts, message_count, is_manual, summary) + VALUES (?, ?, ?, 0, NULL) + `) + + const insertContext = db.prepare(` + INSERT INTO message_context (message_id, session_id, topic_id) + VALUES (?, ?, NULL) + `) + + // 开始事务 + const transaction = db.transaction(() => { + let processedCount = 0 + const totalSessions = sessionMap.size + + for (const [, sessionData] of sessionMap) { + // 插入会话记录 + const result = insertSession.run(sessionData.startTs, sessionData.endTs, sessionData.messageIds.length) + const newSessionId = result.lastInsertRowid as number + + // 批量插入消息上下文 + for (const messageId of sessionData.messageIds) { + insertContext.run(messageId, newSessionId) + } + + processedCount++ + if (onProgress && processedCount % 100 === 0) { + onProgress(processedCount, totalSessions) + } + } + + return totalSessions + }) + + const sessionCount = transaction() + + // 最终进度回调 + if (onProgress) { + onProgress(sessionCount, sessionCount) + } + + return sessionCount + } finally { + db.close() + } +} + +/** + * 清空会话索引数据 + * @param sessionId 数据库会话ID + */ +export function clearSessions(sessionId: string): void { + // 先关闭缓存的只读连接 + closeDatabase(sessionId) + + const db = openWritableDatabase(sessionId) + if (!db) { + throw new Error(`无法打开数据库: ${sessionId}`) + } + + try { + clearSessionsInternal(db) + } finally { + db.close() + } +} + +/** + * 检查是否已生成会话索引 + * @param sessionId 数据库会话ID + * @returns 是否有会话索引 + */ +export function hasSessionIndex(sessionId: string): boolean { + const db = openReadonlyDatabase(sessionId) + if (!db) { + return false + } + + try { + // 检查 chat_session 表是否存在且有数据 + const result = db.prepare('SELECT COUNT(*) as count FROM chat_session').get() as { count: number } + return result.count > 0 + } catch { + // 表可能不存在 + return false + } finally { + db.close() + } +} + +/** + * 获取会话索引统计信息 + * @param sessionId 数据库会话ID + */ +export function getSessionStats(sessionId: string): { + sessionCount: number + hasIndex: boolean + gapThreshold: number +} { + const db = openReadonlyDatabase(sessionId) + if (!db) { + return { sessionCount: 0, hasIndex: false, gapThreshold: DEFAULT_SESSION_GAP_THRESHOLD } + } + + try { + // 获取会话数量 + let sessionCount = 0 + try { + const countResult = db.prepare('SELECT COUNT(*) as count FROM chat_session').get() as { count: number } + sessionCount = countResult.count + } catch { + // 表可能不存在 + } + + // 获取配置的阈值 + let gapThreshold = DEFAULT_SESSION_GAP_THRESHOLD + try { + const metaResult = db.prepare('SELECT session_gap_threshold FROM meta LIMIT 1').get() as + | { + session_gap_threshold: number | null + } + | undefined + if (metaResult?.session_gap_threshold) { + gapThreshold = metaResult.session_gap_threshold + } + } catch { + // 字段可能不存在 + } + + return { + sessionCount, + hasIndex: sessionCount > 0, + gapThreshold, + } + } finally { + db.close() + } +} + +/** + * 更新单个聊天的会话切分阈值 + * @param sessionId 数据库会话ID + * @param gapThreshold 阈值(秒),null 表示使用全局配置 + */ +export function updateSessionGapThreshold(sessionId: string, gapThreshold: number | null): void { + // 先关闭缓存的只读连接 + closeDatabase(sessionId) + + const db = openWritableDatabase(sessionId) + if (!db) { + throw new Error(`无法打开数据库: ${sessionId}`) + } + + try { + db.prepare('UPDATE meta SET session_gap_threshold = ?').run(gapThreshold) + } finally { + db.close() + } +} + +/** + * 获取会话列表(用于时间线导航) + * @param sessionId 数据库会话ID + * @returns 会话列表,按时间排序 + */ +export function getSessions(sessionId: string): ChatSessionItem[] { + const db = openReadonlyDatabase(sessionId) + if (!db) { + return [] + } + + try { + // 查询会话列表,同时获取每个会话的首条消息 ID 和摘要 + const sql = ` + SELECT + cs.id, + cs.start_ts as startTs, + cs.end_ts as endTs, + cs.message_count as messageCount, + cs.summary, + (SELECT mc.message_id FROM message_context mc WHERE mc.session_id = cs.id ORDER BY mc.message_id LIMIT 1) as firstMessageId + FROM chat_session cs + ORDER BY cs.start_ts ASC + ` + const sessions = db.prepare(sql).all() as ChatSessionItem[] + return sessions + } catch { + return [] + } finally { + db.close() + } +} + +// ==================== 会话摘要相关函数 ==================== + +/** + * 保存会话摘要 + * @param sessionId 数据库会话ID + * @param chatSessionId 会话索引中的会话ID + * @param summary 摘要内容 + */ +export function saveSessionSummary(sessionId: string, chatSessionId: number, summary: string): void { + // 先关闭缓存的只读连接 + closeDatabase(sessionId) + + const db = openWritableDatabase(sessionId) + if (!db) { + throw new Error(`无法打开数据库: ${sessionId}`) + } + + try { + db.prepare('UPDATE chat_session SET summary = ? WHERE id = ?').run(summary, chatSessionId) + } finally { + db.close() + } +} + +/** + * 获取会话摘要 + * @param sessionId 数据库会话ID + * @param chatSessionId 会话索引中的会话ID + * @returns 摘要内容 + */ +export function getSessionSummary(sessionId: string, chatSessionId: number): string | null { + const db = openReadonlyDatabase(sessionId) + if (!db) { + return null + } + + try { + const result = db.prepare('SELECT summary FROM chat_session WHERE id = ?').get(chatSessionId) as + | { summary: string | null } + | undefined + return result?.summary || null + } catch { + return null + } finally { + db.close() + } +} diff --git a/electron/main/worker/query/session/types.ts b/electron/main/worker/query/session/types.ts new file mode 100644 index 0000000..6a905d2 --- /dev/null +++ b/electron/main/worker/query/session/types.ts @@ -0,0 +1,163 @@ +/** + * 会话模块类型定义 + */ + +/** 默认会话切分阈值:30分钟(秒) */ +export const DEFAULT_SESSION_GAP_THRESHOLD = 1800 + +/** + * 会话列表项类型 + */ +export interface ChatSessionItem { + id: number + startTs: number + endTs: number + messageCount: number + firstMessageId: number + /** 会话摘要(如果有) */ + summary?: string | null +} + +/** + * 会话搜索结果项类型(用于 AI 工具) + */ +export interface SessionSearchResultItem { + id: number + startTs: number + endTs: number + messageCount: number + /** 是否为完整会话(消息数 <= 预览条数) */ + isComplete: boolean + /** 预览消息列表 */ + previewMessages: Array<{ + id: number + senderName: string + content: string | null + timestamp: number + }> +} + +/** + * 会话消息结果类型(用于 AI 工具) + */ +export interface SessionMessagesResult { + sessionId: number + startTs: number + endTs: number + messageCount: number + returnedCount: number + /** 参与者列表 */ + participants: string[] + /** 消息列表 */ + messages: Array<{ + id: number + senderName: string + content: string | null + timestamp: number + }> +} + +/** + * 自定义筛选消息类型(完整信息,兼容 MessageList 组件) + */ +export interface FilterMessage { + id: number + senderName: string + senderPlatformId: string + senderAliases: string[] + senderAvatar: string | null + content: string + timestamp: number + type: number + replyToMessageId: string | null + replyToContent: string | null + replyToSenderName: string | null + /** 是否为命中的消息(关键词匹配) */ + isHit: boolean +} + +/** + * 上下文块类型(用于自定义筛选) + */ +export interface ContextBlock { + /** 块的时间范围 */ + startTs: number + endTs: number + /** 消息列表 */ + messages: FilterMessage[] + /** 命中的消息数量 */ + hitCount: number +} + +/** + * 筛选结果类型 + */ +export interface FilterResult { + /** 上下文块列表 */ + blocks: ContextBlock[] + /** 统计信息 */ + stats: { + /** 总消息数 */ + totalMessages: number + /** 命中的消息数 */ + hitMessages: number + /** 总字符数 */ + totalChars: number + } +} + +/** + * 分页信息类型 + */ +export interface PaginationInfo { + /** 当前页码(从 1 开始) */ + page: number + /** 每页块数 */ + pageSize: number + /** 总块数 */ + totalBlocks: number + /** 总命中数 */ + totalHits: number + /** 是否还有更多 */ + hasMore: boolean +} + +/** + * 带分页的筛选结果类型 + */ +export interface FilterResultWithPagination extends FilterResult { + pagination: PaginationInfo +} + +/** + * 导出筛选结果参数 + */ +export interface ExportFilterParams { + sessionId: string + sessionName: string + outputDir: string + filterMode: 'condition' | 'session' + // 条件筛选参数 + keywords?: string[] + timeFilter?: { startTs: number; endTs: number } + senderIds?: number[] + contextSize?: number + // 会话筛选参数 + chatSessionIds?: number[] +} + +/** + * 导出进度类型 + */ +export interface ExportProgress { + /** 阶段 */ + stage: 'preparing' | 'exporting' | 'done' | 'error' + /** 当前处理的块索引(从 1 开始) */ + currentBlock: number + /** 总块数 */ + totalBlocks: number + /** 百分比(0-100) */ + percentage: number + /** 状态消息 */ + message: string +}