mirror of
https://github.com/hellodigua/ChatLab.git
synced 2026-04-23 09:42:59 +08:00
refactor: 重构session查询文件
This commit is contained in:
@@ -68,12 +68,16 @@ export {
|
||||
getSessionStats,
|
||||
updateSessionGapThreshold,
|
||||
getSessions,
|
||||
saveSessionSummary,
|
||||
getSessionSummary,
|
||||
searchSessions,
|
||||
getSessionMessages,
|
||||
DEFAULT_SESSION_GAP_THRESHOLD,
|
||||
// 自定义筛选
|
||||
filterMessagesWithContext,
|
||||
getMultipleSessionsMessages,
|
||||
// 导出功能
|
||||
exportFilterResultToFile,
|
||||
} from './session'
|
||||
export type {
|
||||
ChatSessionItem,
|
||||
@@ -82,6 +86,10 @@ export type {
|
||||
ContextBlock,
|
||||
FilterResult,
|
||||
FilterMessage,
|
||||
PaginationInfo,
|
||||
FilterResultWithPagination,
|
||||
ExportFilterParams,
|
||||
ExportProgress,
|
||||
} from './session'
|
||||
|
||||
// NLP 查询
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
205
electron/main/worker/query/session/aiTools.ts
Normal file
205
electron/main/worker/query/session/aiTools.ts
Normal file
@@ -0,0 +1,205 @@
|
||||
/**
|
||||
* AI 工具专用查询模块
|
||||
* 提供搜索会话和获取会话消息等功能,供 AI 工具使用
|
||||
*/
|
||||
|
||||
import { openReadonlyDatabase } from './core'
|
||||
import type { SessionSearchResultItem, SessionMessagesResult } from './types'
|
||||
|
||||
/**
|
||||
* 搜索会话(用于 AI 工具)
|
||||
* 支持按关键词和时间范围筛选会话
|
||||
*
|
||||
* @param sessionId 数据库会话ID
|
||||
* @param keywords 关键词列表(可选,OR 逻辑匹配)
|
||||
* @param timeFilter 时间过滤器(可选)
|
||||
* @param limit 返回数量限制,默认 20
|
||||
* @param previewCount 预览消息数量,默认 5
|
||||
* @returns 匹配的会话列表
|
||||
*/
|
||||
export function searchSessions(
|
||||
sessionId: string,
|
||||
keywords?: string[],
|
||||
timeFilter?: { startTs: number; endTs: number },
|
||||
limit: number = 20,
|
||||
previewCount: number = 5
|
||||
): SessionSearchResultItem[] {
|
||||
const db = openReadonlyDatabase(sessionId)
|
||||
if (!db) {
|
||||
return []
|
||||
}
|
||||
|
||||
try {
|
||||
// 1. 构建会话查询 SQL
|
||||
let sessionSql = `
|
||||
SELECT
|
||||
cs.id,
|
||||
cs.start_ts as startTs,
|
||||
cs.end_ts as endTs,
|
||||
cs.message_count as messageCount
|
||||
FROM chat_session cs
|
||||
WHERE 1=1
|
||||
`
|
||||
const params: unknown[] = []
|
||||
|
||||
// 时间范围过滤
|
||||
if (timeFilter) {
|
||||
sessionSql += ` AND cs.start_ts >= ? AND cs.end_ts <= ?`
|
||||
params.push(timeFilter.startTs, timeFilter.endTs)
|
||||
}
|
||||
|
||||
// 关键词过滤:只返回包含关键词的会话
|
||||
if (keywords && keywords.length > 0) {
|
||||
const keywordConditions = keywords.map(() => `m.content LIKE ?`).join(' OR ')
|
||||
sessionSql += `
|
||||
AND cs.id IN (
|
||||
SELECT DISTINCT mc.session_id
|
||||
FROM message_context mc
|
||||
JOIN message m ON m.id = mc.message_id
|
||||
WHERE (${keywordConditions})
|
||||
)
|
||||
`
|
||||
for (const kw of keywords) {
|
||||
params.push(`%${kw}%`)
|
||||
}
|
||||
}
|
||||
|
||||
sessionSql += ` ORDER BY cs.start_ts DESC LIMIT ?`
|
||||
params.push(limit)
|
||||
|
||||
const sessions = db.prepare(sessionSql).all(...params) as Array<{
|
||||
id: number
|
||||
startTs: number
|
||||
endTs: number
|
||||
messageCount: number
|
||||
}>
|
||||
|
||||
// 2. 为每个会话获取预览消息
|
||||
const previewSql = `
|
||||
SELECT
|
||||
m.id,
|
||||
COALESCE(mb.group_nickname, mb.account_name, mb.platform_id) as senderName,
|
||||
m.content,
|
||||
m.ts as timestamp
|
||||
FROM message_context mc
|
||||
JOIN message m ON m.id = mc.message_id
|
||||
JOIN member mb ON mb.id = m.sender_id
|
||||
WHERE mc.session_id = ?
|
||||
ORDER BY m.ts ASC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
const results: SessionSearchResultItem[] = []
|
||||
for (const session of sessions) {
|
||||
const previewMessages = db.prepare(previewSql).all(session.id, previewCount) as Array<{
|
||||
id: number
|
||||
senderName: string
|
||||
content: string | null
|
||||
timestamp: number
|
||||
}>
|
||||
|
||||
results.push({
|
||||
id: session.id,
|
||||
startTs: session.startTs,
|
||||
endTs: session.endTs,
|
||||
messageCount: session.messageCount,
|
||||
isComplete: session.messageCount <= previewCount,
|
||||
previewMessages,
|
||||
})
|
||||
}
|
||||
|
||||
return results
|
||||
} catch (error) {
|
||||
console.error('searchSessions error:', error)
|
||||
return []
|
||||
} finally {
|
||||
db.close()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取会话的完整消息(用于 AI 工具)
|
||||
*
|
||||
* @param sessionId 数据库会话ID
|
||||
* @param chatSessionId 会话索引中的会话ID
|
||||
* @param limit 返回数量限制,默认 500
|
||||
* @returns 会话的完整消息
|
||||
*/
|
||||
export function getSessionMessages(
|
||||
sessionId: string,
|
||||
chatSessionId: number,
|
||||
limit: number = 500
|
||||
): SessionMessagesResult | null {
|
||||
const db = openReadonlyDatabase(sessionId)
|
||||
if (!db) {
|
||||
return null
|
||||
}
|
||||
|
||||
try {
|
||||
// 1. 获取会话基本信息
|
||||
const sessionSql = `
|
||||
SELECT
|
||||
id,
|
||||
start_ts as startTs,
|
||||
end_ts as endTs,
|
||||
message_count as messageCount
|
||||
FROM chat_session
|
||||
WHERE id = ?
|
||||
`
|
||||
const session = db.prepare(sessionSql).get(chatSessionId) as
|
||||
| {
|
||||
id: number
|
||||
startTs: number
|
||||
endTs: number
|
||||
messageCount: number
|
||||
}
|
||||
| undefined
|
||||
|
||||
if (!session) {
|
||||
db.close()
|
||||
return null
|
||||
}
|
||||
|
||||
// 2. 获取会话消息
|
||||
const messagesSql = `
|
||||
SELECT
|
||||
m.id,
|
||||
COALESCE(mb.group_nickname, mb.account_name, mb.platform_id) as senderName,
|
||||
m.content,
|
||||
m.ts as timestamp
|
||||
FROM message_context mc
|
||||
JOIN message m ON m.id = mc.message_id
|
||||
JOIN member mb ON mb.id = m.sender_id
|
||||
WHERE mc.session_id = ?
|
||||
ORDER BY m.ts ASC
|
||||
LIMIT ?
|
||||
`
|
||||
const messages = db.prepare(messagesSql).all(chatSessionId, limit) as Array<{
|
||||
id: number
|
||||
senderName: string
|
||||
content: string | null
|
||||
timestamp: number
|
||||
}>
|
||||
|
||||
// 3. 统计参与者
|
||||
const participantsSet = new Set<string>()
|
||||
for (const msg of messages) {
|
||||
participantsSet.add(msg.senderName)
|
||||
}
|
||||
|
||||
return {
|
||||
sessionId: session.id,
|
||||
startTs: session.startTs,
|
||||
endTs: session.endTs,
|
||||
messageCount: session.messageCount,
|
||||
returnedCount: messages.length,
|
||||
participants: Array.from(participantsSet),
|
||||
messages,
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('getSessionMessages error:', error)
|
||||
return null
|
||||
} finally {
|
||||
db.close()
|
||||
}
|
||||
}
|
||||
39
electron/main/worker/query/session/core.ts
Normal file
39
electron/main/worker/query/session/core.ts
Normal file
@@ -0,0 +1,39 @@
|
||||
/**
|
||||
* 会话模块核心工具函数
|
||||
* 提供数据库连接等共享功能
|
||||
*/
|
||||
|
||||
import Database from 'better-sqlite3'
|
||||
import { getDbPath, closeDatabase } from '../../core'
|
||||
|
||||
// 重新导出 closeDatabase 供其他模块使用
|
||||
export { closeDatabase }
|
||||
|
||||
/**
|
||||
* 打开数据库(可写模式,不使用缓存)
|
||||
* 会话索引需要写入数据
|
||||
*/
|
||||
export function openWritableDatabase(sessionId: string): Database.Database | null {
|
||||
const dbPath = getDbPath(sessionId)
|
||||
try {
|
||||
const db = new Database(dbPath)
|
||||
db.pragma('journal_mode = WAL')
|
||||
return db
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 打开数据库(只读模式,不使用缓存)
|
||||
*/
|
||||
export function openReadonlyDatabase(sessionId: string): Database.Database | null {
|
||||
const dbPath = getDbPath(sessionId)
|
||||
try {
|
||||
const db = new Database(dbPath, { readonly: true })
|
||||
db.pragma('journal_mode = WAL')
|
||||
return db
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
381
electron/main/worker/query/session/export.ts
Normal file
381
electron/main/worker/query/session/export.ts
Normal file
@@ -0,0 +1,381 @@
|
||||
/**
|
||||
* 导出功能模块
|
||||
* 提供将筛选结果导出为 Markdown 文件的功能
|
||||
*/
|
||||
|
||||
import * as fs from 'fs'
|
||||
import * as path from 'path'
|
||||
import { parentPort } from 'worker_threads'
|
||||
import { openReadonlyDatabase } from './core'
|
||||
import type { ExportFilterParams, ExportProgress } from './types'
|
||||
|
||||
/**
|
||||
* 发送导出进度到主进程
|
||||
*/
|
||||
function sendExportProgress(requestId: string, progress: ExportProgress): void {
|
||||
parentPort?.postMessage({
|
||||
id: requestId,
|
||||
type: 'progress',
|
||||
payload: progress,
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 导出筛选结果到 Markdown 文件(后端生成,支持大数据量)
|
||||
* 使用流式写入,避免内存溢出
|
||||
*
|
||||
* @param params 导出参数
|
||||
* @param requestId 请求 ID(用于发送进度)
|
||||
* @returns 生成的文件路径
|
||||
*/
|
||||
export function exportFilterResultToFile(
|
||||
params: ExportFilterParams,
|
||||
requestId?: string
|
||||
): { success: boolean; filePath?: string; error?: string } {
|
||||
const db = openReadonlyDatabase(params.sessionId)
|
||||
if (!db) {
|
||||
return { success: false, error: '无法打开数据库' }
|
||||
}
|
||||
|
||||
try {
|
||||
const timestamp = Date.now()
|
||||
const fileName = `${params.sessionName}_筛选结果_${timestamp}.md`
|
||||
const filePath = path.join(params.outputDir, fileName)
|
||||
|
||||
// 创建写入流
|
||||
const writeStream = fs.createWriteStream(filePath, { encoding: 'utf8' })
|
||||
|
||||
// 写入头部
|
||||
writeStream.write(`# ${params.sessionName} - 聊天记录筛选结果\n\n`)
|
||||
writeStream.write(`> 导出时间: ${new Date().toLocaleString()}\n\n`)
|
||||
|
||||
// 写入筛选条件摘要
|
||||
writeStream.write(`## 筛选条件\n\n`)
|
||||
if (params.filterMode === 'condition') {
|
||||
if (params.keywords && params.keywords.length > 0) {
|
||||
writeStream.write(`- 关键词: ${params.keywords.join(', ')}\n`)
|
||||
}
|
||||
if (params.timeFilter) {
|
||||
const start = new Date(params.timeFilter.startTs * 1000).toLocaleString()
|
||||
const end = new Date(params.timeFilter.endTs * 1000).toLocaleString()
|
||||
writeStream.write(`- 时间范围: ${start} ~ ${end}\n`)
|
||||
}
|
||||
writeStream.write(`- 上下文扩展: ±${params.contextSize || 10} 条消息\n`)
|
||||
} else {
|
||||
writeStream.write(`- 模式: 会话筛选\n`)
|
||||
writeStream.write(`- 选中会话数: ${params.chatSessionIds?.length || 0}\n`)
|
||||
}
|
||||
writeStream.write('\n')
|
||||
|
||||
let totalMessages = 0
|
||||
let totalHits = 0
|
||||
let totalChars = 0
|
||||
let blockIndex = 0
|
||||
|
||||
if (params.filterMode === 'condition') {
|
||||
// 条件筛选模式:流式处理
|
||||
const contextSize = params.contextSize || 10
|
||||
|
||||
// 第一阶段:获取命中消息的索引
|
||||
const lightweightSql = `
|
||||
SELECT
|
||||
id,
|
||||
ts,
|
||||
sender_id as senderId,
|
||||
content
|
||||
FROM message
|
||||
${params.timeFilter ? 'WHERE ts >= ? AND ts <= ?' : ''}
|
||||
ORDER BY ts ASC, id ASC
|
||||
`
|
||||
const sqlParams: unknown[] = []
|
||||
if (params.timeFilter) {
|
||||
sqlParams.push(params.timeFilter.startTs, params.timeFilter.endTs)
|
||||
}
|
||||
|
||||
const hitIndexes: number[] = []
|
||||
let msgIndex = 0
|
||||
const stmt = db.prepare(lightweightSql)
|
||||
|
||||
for (const row of stmt.iterate(...sqlParams) as Iterable<{
|
||||
id: number
|
||||
ts: number
|
||||
senderId: number
|
||||
content: string | null
|
||||
}>) {
|
||||
let isHit = true
|
||||
|
||||
if (params.keywords && params.keywords.length > 0) {
|
||||
const content = (row.content || '').toLowerCase()
|
||||
isHit = params.keywords.some((kw) => content.includes(kw.toLowerCase()))
|
||||
}
|
||||
|
||||
if (isHit && params.senderIds && params.senderIds.length > 0) {
|
||||
isHit = params.senderIds.includes(row.senderId)
|
||||
}
|
||||
|
||||
if (isHit) {
|
||||
hitIndexes.push(msgIndex)
|
||||
}
|
||||
msgIndex++
|
||||
}
|
||||
|
||||
totalHits = hitIndexes.length
|
||||
|
||||
// 发送准备阶段进度
|
||||
if (requestId) {
|
||||
sendExportProgress(requestId, {
|
||||
stage: 'preparing',
|
||||
currentBlock: 0,
|
||||
totalBlocks: 0,
|
||||
percentage: 10,
|
||||
message: `正在分析数据,找到 ${totalHits} 条匹配消息...`,
|
||||
})
|
||||
}
|
||||
|
||||
if (hitIndexes.length === 0) {
|
||||
writeStream.write(`## 统计信息\n\n`)
|
||||
writeStream.write(`- 无匹配结果\n`)
|
||||
writeStream.end()
|
||||
if (requestId) {
|
||||
sendExportProgress(requestId, {
|
||||
stage: 'done',
|
||||
currentBlock: 0,
|
||||
totalBlocks: 0,
|
||||
percentage: 100,
|
||||
message: '导出完成(无匹配结果)',
|
||||
})
|
||||
}
|
||||
return { success: true, filePath }
|
||||
}
|
||||
|
||||
// 计算上下文范围并合并
|
||||
const ranges: Array<{ start: number; end: number; hitIndexes: number[] }> = []
|
||||
const totalMsgCount = msgIndex
|
||||
|
||||
for (const hitIdx of hitIndexes) {
|
||||
const start = Math.max(0, hitIdx - contextSize)
|
||||
const end = Math.min(totalMsgCount - 1, hitIdx + contextSize)
|
||||
|
||||
if (ranges.length > 0) {
|
||||
const lastRange = ranges[ranges.length - 1]
|
||||
if (start <= lastRange.end + 1) {
|
||||
lastRange.end = Math.max(lastRange.end, end)
|
||||
lastRange.hitIndexes.push(hitIdx)
|
||||
continue
|
||||
}
|
||||
}
|
||||
ranges.push({ start, end, hitIndexes: [hitIdx] })
|
||||
}
|
||||
|
||||
const totalBlocks = ranges.length
|
||||
|
||||
// 发送开始导出进度
|
||||
if (requestId) {
|
||||
sendExportProgress(requestId, {
|
||||
stage: 'exporting',
|
||||
currentBlock: 0,
|
||||
totalBlocks,
|
||||
percentage: 15,
|
||||
message: `开始导出 ${totalBlocks} 个对话块...`,
|
||||
})
|
||||
}
|
||||
|
||||
// 写入统计信息
|
||||
writeStream.write(`## 统计信息\n\n`)
|
||||
writeStream.write(`- 对话块数: ${totalBlocks}\n`)
|
||||
writeStream.write(`- 命中消息: ${totalHits}\n\n`)
|
||||
|
||||
// 第二阶段:流式写入每个块的内容
|
||||
writeStream.write(`## 对话内容\n\n`)
|
||||
|
||||
for (const range of ranges) {
|
||||
blockIndex++
|
||||
|
||||
// 发送导出进度(每个块)
|
||||
if (requestId) {
|
||||
const percentage = Math.round(15 + ((blockIndex - 1) / totalBlocks) * 80)
|
||||
sendExportProgress(requestId, {
|
||||
stage: 'exporting',
|
||||
currentBlock: blockIndex,
|
||||
totalBlocks,
|
||||
percentage,
|
||||
message: `正在导出对话块 ${blockIndex}/${totalBlocks}...`,
|
||||
})
|
||||
}
|
||||
|
||||
const blockSql = `
|
||||
SELECT
|
||||
msg.id,
|
||||
msg.ts,
|
||||
COALESCE(m.group_nickname, m.account_name, m.platform_id) as senderName,
|
||||
msg.content
|
||||
FROM message msg
|
||||
JOIN member m ON msg.sender_id = m.id
|
||||
${params.timeFilter ? 'WHERE msg.ts >= ? AND msg.ts <= ?' : ''}
|
||||
ORDER BY msg.ts ASC, msg.id ASC
|
||||
LIMIT ? OFFSET ?
|
||||
`
|
||||
const blockParams: unknown[] = []
|
||||
if (params.timeFilter) {
|
||||
blockParams.push(params.timeFilter.startTs, params.timeFilter.endTs)
|
||||
}
|
||||
blockParams.push(range.end - range.start + 1, range.start)
|
||||
|
||||
const messages = db.prepare(blockSql).all(...blockParams) as Array<{
|
||||
id: number
|
||||
ts: number
|
||||
senderName: string
|
||||
content: string | null
|
||||
}>
|
||||
|
||||
if (messages.length === 0) continue
|
||||
|
||||
const hitIndexSet = new Set(range.hitIndexes.map((idx) => idx - range.start))
|
||||
|
||||
const startTime = new Date(messages[0].ts * 1000).toLocaleString()
|
||||
const endTime = new Date(messages[messages.length - 1].ts * 1000).toLocaleString()
|
||||
writeStream.write(`### 对话块 ${blockIndex} (${startTime} ~ ${endTime})\n\n`)
|
||||
|
||||
for (let i = 0; i < messages.length; i++) {
|
||||
const msg = messages[i]
|
||||
const time = new Date(msg.ts * 1000).toLocaleTimeString()
|
||||
const hitMark = hitIndexSet.has(i) ? ' ⭐' : ''
|
||||
const content = msg.content || '[非文本消息]'
|
||||
writeStream.write(`${time} ${msg.senderName}${hitMark}: ${content}\n`)
|
||||
totalMessages++
|
||||
totalChars += (msg.content || '').length
|
||||
}
|
||||
writeStream.write('\n')
|
||||
}
|
||||
} else {
|
||||
// 会话筛选模式
|
||||
if (!params.chatSessionIds || params.chatSessionIds.length === 0) {
|
||||
writeStream.write(`## 统计信息\n\n`)
|
||||
writeStream.write(`- 未选择会话\n`)
|
||||
writeStream.end()
|
||||
if (requestId) {
|
||||
sendExportProgress(requestId, {
|
||||
stage: 'done',
|
||||
currentBlock: 0,
|
||||
totalBlocks: 0,
|
||||
percentage: 100,
|
||||
message: '导出完成(未选择会话)',
|
||||
})
|
||||
}
|
||||
return { success: true, filePath }
|
||||
}
|
||||
|
||||
// 发送准备阶段进度
|
||||
if (requestId) {
|
||||
sendExportProgress(requestId, {
|
||||
stage: 'preparing',
|
||||
currentBlock: 0,
|
||||
totalBlocks: params.chatSessionIds.length,
|
||||
percentage: 10,
|
||||
message: `正在准备导出 ${params.chatSessionIds.length} 个会话...`,
|
||||
})
|
||||
}
|
||||
|
||||
// 获取会话信息
|
||||
const sessionsSql = `
|
||||
SELECT id, start_ts as startTs, end_ts as endTs
|
||||
FROM chat_session
|
||||
WHERE id IN (${params.chatSessionIds.map(() => '?').join(',')})
|
||||
ORDER BY start_ts ASC
|
||||
`
|
||||
const sessions = db.prepare(sessionsSql).all(...params.chatSessionIds) as Array<{
|
||||
id: number
|
||||
startTs: number
|
||||
endTs: number
|
||||
}>
|
||||
|
||||
const totalBlocks = sessions.length
|
||||
|
||||
writeStream.write(`## 统计信息\n\n`)
|
||||
writeStream.write(`- 对话块数: ${totalBlocks}\n\n`)
|
||||
|
||||
writeStream.write(`## 对话内容\n\n`)
|
||||
|
||||
const messagesSql = `
|
||||
SELECT
|
||||
msg.id,
|
||||
COALESCE(m.group_nickname, m.account_name, m.platform_id) as senderName,
|
||||
msg.content,
|
||||
msg.ts as timestamp
|
||||
FROM message_context mc
|
||||
JOIN message msg ON msg.id = mc.message_id
|
||||
JOIN member m ON msg.sender_id = m.id
|
||||
WHERE mc.session_id = ?
|
||||
ORDER BY msg.ts ASC
|
||||
`
|
||||
|
||||
for (const session of sessions) {
|
||||
blockIndex++
|
||||
|
||||
// 发送导出进度(每个会话)
|
||||
if (requestId) {
|
||||
const percentage = Math.round(15 + ((blockIndex - 1) / totalBlocks) * 80)
|
||||
sendExportProgress(requestId, {
|
||||
stage: 'exporting',
|
||||
currentBlock: blockIndex,
|
||||
totalBlocks,
|
||||
percentage,
|
||||
message: `正在导出会话 ${blockIndex}/${totalBlocks}...`,
|
||||
})
|
||||
}
|
||||
|
||||
const messages = db.prepare(messagesSql).all(session.id) as Array<{
|
||||
id: number
|
||||
senderName: string
|
||||
content: string | null
|
||||
timestamp: number
|
||||
}>
|
||||
|
||||
if (messages.length === 0) continue
|
||||
|
||||
const startTime = new Date(session.startTs * 1000).toLocaleString()
|
||||
const endTime = new Date(session.endTs * 1000).toLocaleString()
|
||||
writeStream.write(`### 对话块 ${blockIndex} (${startTime} ~ ${endTime})\n\n`)
|
||||
|
||||
for (const msg of messages) {
|
||||
const time = new Date(msg.timestamp * 1000).toLocaleTimeString()
|
||||
const content = msg.content || '[非文本消息]'
|
||||
writeStream.write(`${time} ${msg.senderName}: ${content}\n`)
|
||||
totalMessages++
|
||||
totalChars += (msg.content || '').length
|
||||
}
|
||||
writeStream.write('\n')
|
||||
}
|
||||
}
|
||||
|
||||
writeStream.end()
|
||||
|
||||
// 发送完成进度
|
||||
if (requestId) {
|
||||
sendExportProgress(requestId, {
|
||||
stage: 'done',
|
||||
currentBlock: blockIndex,
|
||||
totalBlocks: blockIndex,
|
||||
percentage: 100,
|
||||
message: `导出完成,共 ${blockIndex} 个对话块`,
|
||||
})
|
||||
}
|
||||
|
||||
return { success: true, filePath }
|
||||
} catch (error) {
|
||||
console.error('exportFilterResultToFile error:', error)
|
||||
// 发送错误进度
|
||||
if (requestId) {
|
||||
sendExportProgress(requestId, {
|
||||
stage: 'error',
|
||||
currentBlock: 0,
|
||||
totalBlocks: 0,
|
||||
percentage: 0,
|
||||
message: `导出失败: ${String(error)}`,
|
||||
})
|
||||
}
|
||||
return { success: false, error: String(error) }
|
||||
} finally {
|
||||
db.close()
|
||||
}
|
||||
}
|
||||
447
electron/main/worker/query/session/filter.ts
Normal file
447
electron/main/worker/query/session/filter.ts
Normal file
@@ -0,0 +1,447 @@
|
||||
/**
|
||||
* 自定义筛选模块
|
||||
* 提供按条件筛选消息和获取多会话消息等功能
|
||||
*/
|
||||
|
||||
import { openReadonlyDatabase } from './core'
|
||||
import type {
|
||||
FilterMessage,
|
||||
ContextBlock,
|
||||
FilterResultWithPagination,
|
||||
} from './types'
|
||||
|
||||
/**
|
||||
* 按条件筛选消息并扩充上下文(支持分页)
|
||||
*
|
||||
* 两阶段查询架构:
|
||||
* 1. 第一阶段:轻量级查询获取消息 ID、序号和匹配信息(不加载完整内容)
|
||||
* 2. 第二阶段:计算上下文范围、合并、分页后只获取当前页的完整消息
|
||||
*
|
||||
* @param sessionId 数据库会话ID
|
||||
* @param keywords 关键词列表(可选,OR 逻辑)
|
||||
* @param timeFilter 时间过滤器(可选)
|
||||
* @param senderIds 发送者ID列表(可选)
|
||||
* @param contextSize 上下文扩展数量(前后各多少条)
|
||||
* @param page 页码(从 1 开始,默认 1)
|
||||
* @param pageSize 每页块数(默认 50)
|
||||
* @returns 筛选结果(带分页信息)
|
||||
*/
|
||||
export function filterMessagesWithContext(
|
||||
sessionId: string,
|
||||
keywords?: string[],
|
||||
timeFilter?: { startTs: number; endTs: number },
|
||||
senderIds?: number[],
|
||||
contextSize: number = 10,
|
||||
page: number = 1,
|
||||
pageSize: number = 50
|
||||
): FilterResultWithPagination {
|
||||
const db = openReadonlyDatabase(sessionId)
|
||||
if (!db) {
|
||||
return {
|
||||
blocks: [],
|
||||
stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 },
|
||||
pagination: { page, pageSize, totalBlocks: 0, totalHits: 0, hasMore: false },
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
// ==================== 第一阶段:轻量级查询 ====================
|
||||
// 只获取消息的 ID、时间戳、发送者ID、内容(用于匹配)
|
||||
// 使用 ROW_NUMBER() 计算全局序号,避免一次性加载所有完整数据
|
||||
|
||||
const lightweightSql = `
|
||||
SELECT
|
||||
id,
|
||||
ts,
|
||||
sender_id as senderId,
|
||||
content
|
||||
FROM message
|
||||
${timeFilter ? 'WHERE ts >= ? AND ts <= ?' : ''}
|
||||
ORDER BY ts ASC, id ASC
|
||||
`
|
||||
|
||||
const params: unknown[] = []
|
||||
if (timeFilter) {
|
||||
params.push(timeFilter.startTs, timeFilter.endTs)
|
||||
}
|
||||
|
||||
// 使用 iterate() 流式处理,避免一次性加载所有数据到内存
|
||||
const stmt = db.prepare(lightweightSql)
|
||||
const hitIndexes: number[] = []
|
||||
let totalMessageCount = 0
|
||||
let estimatedTotalChars = 0
|
||||
|
||||
// 流式遍历消息,标记命中的索引
|
||||
for (const row of stmt.iterate(...params) as Iterable<{
|
||||
id: number
|
||||
ts: number
|
||||
senderId: number
|
||||
content: string | null
|
||||
}>) {
|
||||
let isHit = true
|
||||
|
||||
// 关键词匹配(OR 逻辑)
|
||||
if (keywords && keywords.length > 0) {
|
||||
const content = (row.content || '').toLowerCase()
|
||||
isHit = keywords.some((kw) => content.includes(kw.toLowerCase()))
|
||||
}
|
||||
|
||||
// 发送者匹配
|
||||
if (isHit && senderIds && senderIds.length > 0) {
|
||||
isHit = senderIds.includes(row.senderId)
|
||||
}
|
||||
|
||||
if (isHit) {
|
||||
hitIndexes.push(totalMessageCount)
|
||||
}
|
||||
|
||||
totalMessageCount++
|
||||
}
|
||||
|
||||
if (hitIndexes.length === 0) {
|
||||
return {
|
||||
blocks: [],
|
||||
stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 },
|
||||
pagination: { page, pageSize, totalBlocks: 0, totalHits: 0, hasMore: false },
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 计算上下文范围并合并 ====================
|
||||
const ranges: Array<{ start: number; end: number; hitIndexes: number[] }> = []
|
||||
|
||||
for (const hitIndex of hitIndexes) {
|
||||
const start = Math.max(0, hitIndex - contextSize)
|
||||
const end = Math.min(totalMessageCount - 1, hitIndex + contextSize)
|
||||
|
||||
// 检查是否能与前一个范围合并
|
||||
if (ranges.length > 0) {
|
||||
const lastRange = ranges[ranges.length - 1]
|
||||
if (start <= lastRange.end + 1) {
|
||||
lastRange.end = Math.max(lastRange.end, end)
|
||||
lastRange.hitIndexes.push(hitIndex)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
ranges.push({ start, end, hitIndexes: [hitIndex] })
|
||||
}
|
||||
|
||||
const totalBlocks = ranges.length
|
||||
const totalHits = hitIndexes.length
|
||||
|
||||
// ==================== 分页处理 ====================
|
||||
const startIndex = (page - 1) * pageSize
|
||||
const endIndex = Math.min(startIndex + pageSize, totalBlocks)
|
||||
const pageRanges = ranges.slice(startIndex, endIndex)
|
||||
const hasMore = endIndex < totalBlocks
|
||||
|
||||
if (pageRanges.length === 0) {
|
||||
return {
|
||||
blocks: [],
|
||||
stats: { totalMessages: 0, hitMessages: totalHits, totalChars: 0 },
|
||||
pagination: { page, pageSize, totalBlocks, totalHits, hasMore: false },
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 第二阶段:获取当前页的完整消息 ====================
|
||||
// 只为当前页的范围获取完整消息数据
|
||||
|
||||
const blocks: ContextBlock[] = []
|
||||
let totalMessages = 0
|
||||
let totalChars = 0
|
||||
|
||||
for (const range of pageRanges) {
|
||||
// 使用 LIMIT OFFSET 获取指定范围的消息
|
||||
const blockSql = `
|
||||
SELECT
|
||||
msg.id,
|
||||
msg.ts,
|
||||
COALESCE(m.group_nickname, m.account_name, m.platform_id) as senderName,
|
||||
m.platform_id as senderPlatformId,
|
||||
COALESCE(m.aliases, '[]') as senderAliasesJson,
|
||||
m.avatar as senderAvatar,
|
||||
msg.content,
|
||||
msg.type,
|
||||
msg.reply_to_message_id as replyToMessageId,
|
||||
reply_msg.content as replyToContent,
|
||||
COALESCE(reply_m.group_nickname, reply_m.account_name, reply_m.platform_id) as replyToSenderName
|
||||
FROM message msg
|
||||
JOIN member m ON msg.sender_id = m.id
|
||||
LEFT JOIN message reply_msg ON msg.reply_to_message_id = reply_msg.platform_message_id
|
||||
LEFT JOIN member reply_m ON reply_msg.sender_id = reply_m.id
|
||||
${timeFilter ? 'WHERE msg.ts >= ? AND msg.ts <= ?' : ''}
|
||||
ORDER BY msg.ts ASC, msg.id ASC
|
||||
LIMIT ? OFFSET ?
|
||||
`
|
||||
|
||||
const blockParams: unknown[] = []
|
||||
if (timeFilter) {
|
||||
blockParams.push(timeFilter.startTs, timeFilter.endTs)
|
||||
}
|
||||
blockParams.push(range.end - range.start + 1, range.start)
|
||||
|
||||
const messages = db.prepare(blockSql).all(...blockParams) as Array<{
|
||||
id: number
|
||||
ts: number
|
||||
senderName: string
|
||||
senderPlatformId: string
|
||||
senderAliasesJson: string
|
||||
senderAvatar: string | null
|
||||
content: string | null
|
||||
type: number
|
||||
replyToMessageId: string | null
|
||||
replyToContent: string | null
|
||||
replyToSenderName: string | null
|
||||
}>
|
||||
|
||||
// 构建 hitIndexSet(相对于 range.start 的偏移)
|
||||
const hitIndexSet = new Set(range.hitIndexes.map((idx) => idx - range.start))
|
||||
|
||||
const blockMessages: FilterMessage[] = []
|
||||
for (let i = 0; i < messages.length; i++) {
|
||||
const msg = messages[i]
|
||||
const isHit = hitIndexSet.has(i)
|
||||
|
||||
// 解析别名 JSON
|
||||
let senderAliases: string[] = []
|
||||
try {
|
||||
senderAliases = JSON.parse(msg.senderAliasesJson || '[]')
|
||||
} catch {
|
||||
senderAliases = []
|
||||
}
|
||||
|
||||
blockMessages.push({
|
||||
id: msg.id,
|
||||
senderName: msg.senderName,
|
||||
senderPlatformId: msg.senderPlatformId,
|
||||
senderAliases,
|
||||
senderAvatar: msg.senderAvatar,
|
||||
content: msg.content || '',
|
||||
timestamp: msg.ts,
|
||||
type: msg.type,
|
||||
replyToMessageId: msg.replyToMessageId,
|
||||
replyToContent: msg.replyToContent,
|
||||
replyToSenderName: msg.replyToSenderName,
|
||||
isHit,
|
||||
})
|
||||
totalChars += (msg.content || '').length
|
||||
}
|
||||
|
||||
if (blockMessages.length > 0) {
|
||||
blocks.push({
|
||||
startTs: blockMessages[0].timestamp,
|
||||
endTs: blockMessages[blockMessages.length - 1].timestamp,
|
||||
messages: blockMessages,
|
||||
hitCount: range.hitIndexes.length,
|
||||
})
|
||||
totalMessages += blockMessages.length
|
||||
}
|
||||
}
|
||||
|
||||
// 如果是第一页,需要估算总字符数(用于统计显示)
|
||||
// 由于我们不再一次性加载所有数据,这里使用采样估算
|
||||
if (page === 1 && totalBlocks > pageSize) {
|
||||
// 估算:当前页的平均字符数 × 总块数
|
||||
const avgCharsPerBlock = totalChars / blocks.length
|
||||
estimatedTotalChars = Math.round(avgCharsPerBlock * totalBlocks)
|
||||
} else if (page === 1) {
|
||||
estimatedTotalChars = totalChars
|
||||
}
|
||||
|
||||
return {
|
||||
blocks,
|
||||
stats: {
|
||||
totalMessages: page === 1 ? totalMessages : 0, // 只有第一页返回准确的消息数
|
||||
hitMessages: totalHits,
|
||||
totalChars: page === 1 ? (totalBlocks > pageSize ? estimatedTotalChars : totalChars) : 0,
|
||||
},
|
||||
pagination: {
|
||||
page,
|
||||
pageSize,
|
||||
totalBlocks,
|
||||
totalHits,
|
||||
hasMore,
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('filterMessagesWithContext error:', error)
|
||||
return {
|
||||
blocks: [],
|
||||
stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 },
|
||||
pagination: { page, pageSize, totalBlocks: 0, totalHits: 0, hasMore: false },
|
||||
}
|
||||
} finally {
|
||||
db.close()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取多个会话的完整消息(用于会话筛选模式,支持分页)
|
||||
*
|
||||
* @param sessionId 数据库会话ID
|
||||
* @param chatSessionIds 要获取的会话ID列表
|
||||
* @param page 页码(从 1 开始,默认 1)
|
||||
* @param pageSize 每页块数(默认 50)
|
||||
* @returns 合并后的上下文块和统计(带分页信息)
|
||||
*/
|
||||
export function getMultipleSessionsMessages(
|
||||
sessionId: string,
|
||||
chatSessionIds: number[],
|
||||
page: number = 1,
|
||||
pageSize: number = 50
|
||||
): FilterResultWithPagination {
|
||||
const db = openReadonlyDatabase(sessionId)
|
||||
if (!db) {
|
||||
return {
|
||||
blocks: [],
|
||||
stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 },
|
||||
pagination: { page, pageSize, totalBlocks: 0, totalHits: 0, hasMore: false },
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
if (chatSessionIds.length === 0) {
|
||||
return {
|
||||
blocks: [],
|
||||
stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 },
|
||||
pagination: { page, pageSize, totalBlocks: 0, totalHits: 0, hasMore: false },
|
||||
}
|
||||
}
|
||||
|
||||
// 先获取会话信息,按时间排序
|
||||
const sessionsSql = `
|
||||
SELECT id, start_ts as startTs, end_ts as endTs, message_count as messageCount
|
||||
FROM chat_session
|
||||
WHERE id IN (${chatSessionIds.map(() => '?').join(',')})
|
||||
ORDER BY start_ts ASC
|
||||
`
|
||||
const allSessions = db.prepare(sessionsSql).all(...chatSessionIds) as Array<{
|
||||
id: number
|
||||
startTs: number
|
||||
endTs: number
|
||||
messageCount: number
|
||||
}>
|
||||
|
||||
const totalBlocks = allSessions.length
|
||||
|
||||
// 分页处理
|
||||
const startIndex = (page - 1) * pageSize
|
||||
const endIndex = Math.min(startIndex + pageSize, totalBlocks)
|
||||
const pageSessions = allSessions.slice(startIndex, endIndex)
|
||||
const hasMore = endIndex < totalBlocks
|
||||
|
||||
if (pageSessions.length === 0) {
|
||||
return {
|
||||
blocks: [],
|
||||
stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 },
|
||||
pagination: { page, pageSize, totalBlocks, totalHits: 0, hasMore: false },
|
||||
}
|
||||
}
|
||||
|
||||
const blocks: ContextBlock[] = []
|
||||
let totalMessages = 0
|
||||
let totalChars = 0
|
||||
|
||||
// 为当前页的会话获取消息(完整信息)
|
||||
const messagesSql = `
|
||||
SELECT
|
||||
msg.id,
|
||||
COALESCE(m.group_nickname, m.account_name, m.platform_id) as senderName,
|
||||
m.platform_id as senderPlatformId,
|
||||
COALESCE(m.aliases, '[]') as senderAliasesJson,
|
||||
m.avatar as senderAvatar,
|
||||
msg.content,
|
||||
msg.type,
|
||||
msg.reply_to_message_id as replyToMessageId,
|
||||
reply_msg.content as replyToContent,
|
||||
COALESCE(reply_m.group_nickname, reply_m.account_name, reply_m.platform_id) as replyToSenderName,
|
||||
msg.ts as timestamp
|
||||
FROM message_context mc
|
||||
JOIN message msg ON msg.id = mc.message_id
|
||||
JOIN member m ON msg.sender_id = m.id
|
||||
LEFT JOIN message reply_msg ON msg.reply_to_message_id = reply_msg.platform_message_id
|
||||
LEFT JOIN member reply_m ON reply_msg.sender_id = reply_m.id
|
||||
WHERE mc.session_id = ?
|
||||
ORDER BY msg.ts ASC
|
||||
`
|
||||
|
||||
for (const session of pageSessions) {
|
||||
const messages = db.prepare(messagesSql).all(session.id) as Array<{
|
||||
id: number
|
||||
senderName: string
|
||||
senderPlatformId: string
|
||||
senderAliasesJson: string
|
||||
senderAvatar: string | null
|
||||
content: string | null
|
||||
type: number
|
||||
replyToMessageId: string | null
|
||||
replyToContent: string | null
|
||||
replyToSenderName: string | null
|
||||
timestamp: number
|
||||
}>
|
||||
|
||||
const blockMessages: FilterMessage[] = messages.map((msg) => {
|
||||
// 解析别名 JSON
|
||||
let senderAliases: string[] = []
|
||||
try {
|
||||
senderAliases = JSON.parse(msg.senderAliasesJson || '[]')
|
||||
} catch {
|
||||
senderAliases = []
|
||||
}
|
||||
|
||||
return {
|
||||
id: msg.id,
|
||||
senderName: msg.senderName,
|
||||
senderPlatformId: msg.senderPlatformId,
|
||||
senderAliases,
|
||||
senderAvatar: msg.senderAvatar,
|
||||
content: msg.content || '',
|
||||
timestamp: msg.timestamp,
|
||||
type: msg.type,
|
||||
replyToMessageId: msg.replyToMessageId,
|
||||
replyToContent: msg.replyToContent,
|
||||
replyToSenderName: msg.replyToSenderName,
|
||||
isHit: false, // 会话模式下没有命中高亮
|
||||
}
|
||||
})
|
||||
|
||||
for (const msg of messages) {
|
||||
totalChars += (msg.content || '').length
|
||||
}
|
||||
|
||||
blocks.push({
|
||||
startTs: session.startTs,
|
||||
endTs: session.endTs,
|
||||
messages: blockMessages,
|
||||
hitCount: 0,
|
||||
})
|
||||
|
||||
totalMessages += messages.length
|
||||
}
|
||||
|
||||
return {
|
||||
blocks,
|
||||
stats: {
|
||||
totalMessages: page === 1 ? totalMessages : 0, // 只有第一页返回准确的消息数
|
||||
hitMessages: 0, // 会话模式没有命中概念
|
||||
totalChars: page === 1 ? totalChars : 0,
|
||||
},
|
||||
pagination: {
|
||||
page,
|
||||
pageSize,
|
||||
totalBlocks,
|
||||
totalHits: 0,
|
||||
hasMore,
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('getMultipleSessionsMessages error:', error)
|
||||
return {
|
||||
blocks: [],
|
||||
stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 },
|
||||
pagination: { page, pageSize, totalBlocks: 0, totalHits: 0, hasMore: false },
|
||||
}
|
||||
} finally {
|
||||
db.close()
|
||||
}
|
||||
}
|
||||
41
electron/main/worker/query/session/index.ts
Normal file
41
electron/main/worker/query/session/index.ts
Normal file
@@ -0,0 +1,41 @@
|
||||
/**
|
||||
* 会话模块统一导出
|
||||
* 提供会话索引管理、AI 工具查询、自定义筛选和导出等功能
|
||||
*/
|
||||
|
||||
// 类型定义
|
||||
export type {
|
||||
ChatSessionItem,
|
||||
SessionSearchResultItem,
|
||||
SessionMessagesResult,
|
||||
FilterMessage,
|
||||
ContextBlock,
|
||||
FilterResult,
|
||||
PaginationInfo,
|
||||
FilterResultWithPagination,
|
||||
ExportFilterParams,
|
||||
ExportProgress,
|
||||
} from './types'
|
||||
|
||||
export { DEFAULT_SESSION_GAP_THRESHOLD } from './types'
|
||||
|
||||
// 会话索引管理
|
||||
export {
|
||||
generateSessions,
|
||||
clearSessions,
|
||||
hasSessionIndex,
|
||||
getSessionStats,
|
||||
updateSessionGapThreshold,
|
||||
getSessions,
|
||||
saveSessionSummary,
|
||||
getSessionSummary,
|
||||
} from './sessionIndex'
|
||||
|
||||
// AI 工具专用查询
|
||||
export { searchSessions, getSessionMessages } from './aiTools'
|
||||
|
||||
// 自定义筛选
|
||||
export { filterMessagesWithContext, getMultipleSessionsMessages } from './filter'
|
||||
|
||||
// 导出功能
|
||||
export { exportFilterResultToFile } from './export'
|
||||
349
electron/main/worker/query/session/sessionIndex.ts
Normal file
349
electron/main/worker/query/session/sessionIndex.ts
Normal file
@@ -0,0 +1,349 @@
|
||||
/**
|
||||
* 会话索引管理模块
|
||||
* 提供会话生成、查询、管理等功能
|
||||
*/
|
||||
|
||||
import type Database from 'better-sqlite3'
|
||||
import { openWritableDatabase, openReadonlyDatabase, closeDatabase } from './core'
|
||||
import { DEFAULT_SESSION_GAP_THRESHOLD, type ChatSessionItem } from './types'
|
||||
|
||||
/**
|
||||
* 内部清空会话数据函数
|
||||
*/
|
||||
function clearSessionsInternal(db: Database.Database): void {
|
||||
db.exec('DELETE FROM message_context')
|
||||
db.exec('DELETE FROM chat_session')
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成会话索引
|
||||
* 使用 Gap-based 算法,根据消息时间间隔自动切分会话
|
||||
*
|
||||
* @param sessionId 数据库会话ID
|
||||
* @param gapThreshold 时间间隔阈值(秒),默认 1800(30分钟)
|
||||
* @param onProgress 进度回调
|
||||
* @returns 生成的会话数量
|
||||
*/
|
||||
export function generateSessions(
|
||||
sessionId: string,
|
||||
gapThreshold: number = DEFAULT_SESSION_GAP_THRESHOLD,
|
||||
onProgress?: (current: number, total: number) => void
|
||||
): number {
|
||||
// 先关闭缓存的只读连接
|
||||
closeDatabase(sessionId)
|
||||
|
||||
const db = openWritableDatabase(sessionId)
|
||||
if (!db) {
|
||||
throw new Error(`无法打开数据库: ${sessionId}`)
|
||||
}
|
||||
|
||||
try {
|
||||
// 获取消息总数
|
||||
const countResult = db.prepare('SELECT COUNT(*) as count FROM message').get() as { count: number }
|
||||
const totalMessages = countResult.count
|
||||
|
||||
if (totalMessages === 0) {
|
||||
return 0
|
||||
}
|
||||
|
||||
// 清空已有的会话数据
|
||||
clearSessionsInternal(db)
|
||||
|
||||
// 使用窗口函数计算会话边界
|
||||
// 步骤1:为每条消息计算与前一条的时间差,标记新会话起点
|
||||
const sessionMarkSQL = `
|
||||
WITH message_ordered AS (
|
||||
SELECT
|
||||
id,
|
||||
ts,
|
||||
LAG(ts) OVER (ORDER BY ts, id) AS prev_ts
|
||||
FROM message
|
||||
),
|
||||
session_marks AS (
|
||||
SELECT
|
||||
id,
|
||||
ts,
|
||||
CASE
|
||||
WHEN prev_ts IS NULL OR (ts - prev_ts) > ? THEN 1
|
||||
ELSE 0
|
||||
END AS is_new_session
|
||||
FROM message_ordered
|
||||
),
|
||||
session_ids AS (
|
||||
SELECT
|
||||
id,
|
||||
ts,
|
||||
SUM(is_new_session) OVER (ORDER BY ts, id) AS session_num
|
||||
FROM session_marks
|
||||
)
|
||||
SELECT id, ts, session_num FROM session_ids
|
||||
`
|
||||
|
||||
const messages = db.prepare(sessionMarkSQL).all(gapThreshold) as Array<{
|
||||
id: number
|
||||
ts: number
|
||||
session_num: number
|
||||
}>
|
||||
|
||||
if (messages.length === 0) {
|
||||
return 0
|
||||
}
|
||||
|
||||
// 步骤2:计算每个会话的统计信息
|
||||
const sessionMap = new Map<number, { startTs: number; endTs: number; messageIds: number[] }>()
|
||||
|
||||
for (const msg of messages) {
|
||||
const session = sessionMap.get(msg.session_num)
|
||||
if (!session) {
|
||||
sessionMap.set(msg.session_num, {
|
||||
startTs: msg.ts,
|
||||
endTs: msg.ts,
|
||||
messageIds: [msg.id],
|
||||
})
|
||||
} else {
|
||||
session.endTs = msg.ts
|
||||
session.messageIds.push(msg.id)
|
||||
}
|
||||
}
|
||||
|
||||
// 步骤3:批量写入 chat_session 和 message_context 表
|
||||
const insertSession = db.prepare(`
|
||||
INSERT INTO chat_session (start_ts, end_ts, message_count, is_manual, summary)
|
||||
VALUES (?, ?, ?, 0, NULL)
|
||||
`)
|
||||
|
||||
const insertContext = db.prepare(`
|
||||
INSERT INTO message_context (message_id, session_id, topic_id)
|
||||
VALUES (?, ?, NULL)
|
||||
`)
|
||||
|
||||
// 开始事务
|
||||
const transaction = db.transaction(() => {
|
||||
let processedCount = 0
|
||||
const totalSessions = sessionMap.size
|
||||
|
||||
for (const [, sessionData] of sessionMap) {
|
||||
// 插入会话记录
|
||||
const result = insertSession.run(sessionData.startTs, sessionData.endTs, sessionData.messageIds.length)
|
||||
const newSessionId = result.lastInsertRowid as number
|
||||
|
||||
// 批量插入消息上下文
|
||||
for (const messageId of sessionData.messageIds) {
|
||||
insertContext.run(messageId, newSessionId)
|
||||
}
|
||||
|
||||
processedCount++
|
||||
if (onProgress && processedCount % 100 === 0) {
|
||||
onProgress(processedCount, totalSessions)
|
||||
}
|
||||
}
|
||||
|
||||
return totalSessions
|
||||
})
|
||||
|
||||
const sessionCount = transaction()
|
||||
|
||||
// 最终进度回调
|
||||
if (onProgress) {
|
||||
onProgress(sessionCount, sessionCount)
|
||||
}
|
||||
|
||||
return sessionCount
|
||||
} finally {
|
||||
db.close()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 清空会话索引数据
|
||||
* @param sessionId 数据库会话ID
|
||||
*/
|
||||
export function clearSessions(sessionId: string): void {
|
||||
// 先关闭缓存的只读连接
|
||||
closeDatabase(sessionId)
|
||||
|
||||
const db = openWritableDatabase(sessionId)
|
||||
if (!db) {
|
||||
throw new Error(`无法打开数据库: ${sessionId}`)
|
||||
}
|
||||
|
||||
try {
|
||||
clearSessionsInternal(db)
|
||||
} finally {
|
||||
db.close()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否已生成会话索引
|
||||
* @param sessionId 数据库会话ID
|
||||
* @returns 是否有会话索引
|
||||
*/
|
||||
export function hasSessionIndex(sessionId: string): boolean {
|
||||
const db = openReadonlyDatabase(sessionId)
|
||||
if (!db) {
|
||||
return false
|
||||
}
|
||||
|
||||
try {
|
||||
// 检查 chat_session 表是否存在且有数据
|
||||
const result = db.prepare('SELECT COUNT(*) as count FROM chat_session').get() as { count: number }
|
||||
return result.count > 0
|
||||
} catch {
|
||||
// 表可能不存在
|
||||
return false
|
||||
} finally {
|
||||
db.close()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取会话索引统计信息
|
||||
* @param sessionId 数据库会话ID
|
||||
*/
|
||||
export function getSessionStats(sessionId: string): {
|
||||
sessionCount: number
|
||||
hasIndex: boolean
|
||||
gapThreshold: number
|
||||
} {
|
||||
const db = openReadonlyDatabase(sessionId)
|
||||
if (!db) {
|
||||
return { sessionCount: 0, hasIndex: false, gapThreshold: DEFAULT_SESSION_GAP_THRESHOLD }
|
||||
}
|
||||
|
||||
try {
|
||||
// 获取会话数量
|
||||
let sessionCount = 0
|
||||
try {
|
||||
const countResult = db.prepare('SELECT COUNT(*) as count FROM chat_session').get() as { count: number }
|
||||
sessionCount = countResult.count
|
||||
} catch {
|
||||
// 表可能不存在
|
||||
}
|
||||
|
||||
// 获取配置的阈值
|
||||
let gapThreshold = DEFAULT_SESSION_GAP_THRESHOLD
|
||||
try {
|
||||
const metaResult = db.prepare('SELECT session_gap_threshold FROM meta LIMIT 1').get() as
|
||||
| {
|
||||
session_gap_threshold: number | null
|
||||
}
|
||||
| undefined
|
||||
if (metaResult?.session_gap_threshold) {
|
||||
gapThreshold = metaResult.session_gap_threshold
|
||||
}
|
||||
} catch {
|
||||
// 字段可能不存在
|
||||
}
|
||||
|
||||
return {
|
||||
sessionCount,
|
||||
hasIndex: sessionCount > 0,
|
||||
gapThreshold,
|
||||
}
|
||||
} finally {
|
||||
db.close()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 更新单个聊天的会话切分阈值
|
||||
* @param sessionId 数据库会话ID
|
||||
* @param gapThreshold 阈值(秒),null 表示使用全局配置
|
||||
*/
|
||||
export function updateSessionGapThreshold(sessionId: string, gapThreshold: number | null): void {
|
||||
// 先关闭缓存的只读连接
|
||||
closeDatabase(sessionId)
|
||||
|
||||
const db = openWritableDatabase(sessionId)
|
||||
if (!db) {
|
||||
throw new Error(`无法打开数据库: ${sessionId}`)
|
||||
}
|
||||
|
||||
try {
|
||||
db.prepare('UPDATE meta SET session_gap_threshold = ?').run(gapThreshold)
|
||||
} finally {
|
||||
db.close()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取会话列表(用于时间线导航)
|
||||
* @param sessionId 数据库会话ID
|
||||
* @returns 会话列表,按时间排序
|
||||
*/
|
||||
export function getSessions(sessionId: string): ChatSessionItem[] {
|
||||
const db = openReadonlyDatabase(sessionId)
|
||||
if (!db) {
|
||||
return []
|
||||
}
|
||||
|
||||
try {
|
||||
// 查询会话列表,同时获取每个会话的首条消息 ID 和摘要
|
||||
const sql = `
|
||||
SELECT
|
||||
cs.id,
|
||||
cs.start_ts as startTs,
|
||||
cs.end_ts as endTs,
|
||||
cs.message_count as messageCount,
|
||||
cs.summary,
|
||||
(SELECT mc.message_id FROM message_context mc WHERE mc.session_id = cs.id ORDER BY mc.message_id LIMIT 1) as firstMessageId
|
||||
FROM chat_session cs
|
||||
ORDER BY cs.start_ts ASC
|
||||
`
|
||||
const sessions = db.prepare(sql).all() as ChatSessionItem[]
|
||||
return sessions
|
||||
} catch {
|
||||
return []
|
||||
} finally {
|
||||
db.close()
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 会话摘要相关函数 ====================
|
||||
|
||||
/**
|
||||
* 保存会话摘要
|
||||
* @param sessionId 数据库会话ID
|
||||
* @param chatSessionId 会话索引中的会话ID
|
||||
* @param summary 摘要内容
|
||||
*/
|
||||
export function saveSessionSummary(sessionId: string, chatSessionId: number, summary: string): void {
|
||||
// 先关闭缓存的只读连接
|
||||
closeDatabase(sessionId)
|
||||
|
||||
const db = openWritableDatabase(sessionId)
|
||||
if (!db) {
|
||||
throw new Error(`无法打开数据库: ${sessionId}`)
|
||||
}
|
||||
|
||||
try {
|
||||
db.prepare('UPDATE chat_session SET summary = ? WHERE id = ?').run(summary, chatSessionId)
|
||||
} finally {
|
||||
db.close()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取会话摘要
|
||||
* @param sessionId 数据库会话ID
|
||||
* @param chatSessionId 会话索引中的会话ID
|
||||
* @returns 摘要内容
|
||||
*/
|
||||
export function getSessionSummary(sessionId: string, chatSessionId: number): string | null {
|
||||
const db = openReadonlyDatabase(sessionId)
|
||||
if (!db) {
|
||||
return null
|
||||
}
|
||||
|
||||
try {
|
||||
const result = db.prepare('SELECT summary FROM chat_session WHERE id = ?').get(chatSessionId) as
|
||||
| { summary: string | null }
|
||||
| undefined
|
||||
return result?.summary || null
|
||||
} catch {
|
||||
return null
|
||||
} finally {
|
||||
db.close()
|
||||
}
|
||||
}
|
||||
163
electron/main/worker/query/session/types.ts
Normal file
163
electron/main/worker/query/session/types.ts
Normal file
@@ -0,0 +1,163 @@
|
||||
/**
|
||||
* 会话模块类型定义
|
||||
*/
|
||||
|
||||
/** 默认会话切分阈值:30分钟(秒) */
|
||||
export const DEFAULT_SESSION_GAP_THRESHOLD = 1800
|
||||
|
||||
/**
|
||||
* 会话列表项类型
|
||||
*/
|
||||
export interface ChatSessionItem {
|
||||
id: number
|
||||
startTs: number
|
||||
endTs: number
|
||||
messageCount: number
|
||||
firstMessageId: number
|
||||
/** 会话摘要(如果有) */
|
||||
summary?: string | null
|
||||
}
|
||||
|
||||
/**
|
||||
* 会话搜索结果项类型(用于 AI 工具)
|
||||
*/
|
||||
export interface SessionSearchResultItem {
|
||||
id: number
|
||||
startTs: number
|
||||
endTs: number
|
||||
messageCount: number
|
||||
/** 是否为完整会话(消息数 <= 预览条数) */
|
||||
isComplete: boolean
|
||||
/** 预览消息列表 */
|
||||
previewMessages: Array<{
|
||||
id: number
|
||||
senderName: string
|
||||
content: string | null
|
||||
timestamp: number
|
||||
}>
|
||||
}
|
||||
|
||||
/**
|
||||
* 会话消息结果类型(用于 AI 工具)
|
||||
*/
|
||||
export interface SessionMessagesResult {
|
||||
sessionId: number
|
||||
startTs: number
|
||||
endTs: number
|
||||
messageCount: number
|
||||
returnedCount: number
|
||||
/** 参与者列表 */
|
||||
participants: string[]
|
||||
/** 消息列表 */
|
||||
messages: Array<{
|
||||
id: number
|
||||
senderName: string
|
||||
content: string | null
|
||||
timestamp: number
|
||||
}>
|
||||
}
|
||||
|
||||
/**
|
||||
* 自定义筛选消息类型(完整信息,兼容 MessageList 组件)
|
||||
*/
|
||||
export interface FilterMessage {
|
||||
id: number
|
||||
senderName: string
|
||||
senderPlatformId: string
|
||||
senderAliases: string[]
|
||||
senderAvatar: string | null
|
||||
content: string
|
||||
timestamp: number
|
||||
type: number
|
||||
replyToMessageId: string | null
|
||||
replyToContent: string | null
|
||||
replyToSenderName: string | null
|
||||
/** 是否为命中的消息(关键词匹配) */
|
||||
isHit: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* 上下文块类型(用于自定义筛选)
|
||||
*/
|
||||
export interface ContextBlock {
|
||||
/** 块的时间范围 */
|
||||
startTs: number
|
||||
endTs: number
|
||||
/** 消息列表 */
|
||||
messages: FilterMessage[]
|
||||
/** 命中的消息数量 */
|
||||
hitCount: number
|
||||
}
|
||||
|
||||
/**
|
||||
* 筛选结果类型
|
||||
*/
|
||||
export interface FilterResult {
|
||||
/** 上下文块列表 */
|
||||
blocks: ContextBlock[]
|
||||
/** 统计信息 */
|
||||
stats: {
|
||||
/** 总消息数 */
|
||||
totalMessages: number
|
||||
/** 命中的消息数 */
|
||||
hitMessages: number
|
||||
/** 总字符数 */
|
||||
totalChars: number
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 分页信息类型
|
||||
*/
|
||||
export interface PaginationInfo {
|
||||
/** 当前页码(从 1 开始) */
|
||||
page: number
|
||||
/** 每页块数 */
|
||||
pageSize: number
|
||||
/** 总块数 */
|
||||
totalBlocks: number
|
||||
/** 总命中数 */
|
||||
totalHits: number
|
||||
/** 是否还有更多 */
|
||||
hasMore: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* 带分页的筛选结果类型
|
||||
*/
|
||||
export interface FilterResultWithPagination extends FilterResult {
|
||||
pagination: PaginationInfo
|
||||
}
|
||||
|
||||
/**
|
||||
* 导出筛选结果参数
|
||||
*/
|
||||
export interface ExportFilterParams {
|
||||
sessionId: string
|
||||
sessionName: string
|
||||
outputDir: string
|
||||
filterMode: 'condition' | 'session'
|
||||
// 条件筛选参数
|
||||
keywords?: string[]
|
||||
timeFilter?: { startTs: number; endTs: number }
|
||||
senderIds?: number[]
|
||||
contextSize?: number
|
||||
// 会话筛选参数
|
||||
chatSessionIds?: number[]
|
||||
}
|
||||
|
||||
/**
|
||||
* 导出进度类型
|
||||
*/
|
||||
export interface ExportProgress {
|
||||
/** 阶段 */
|
||||
stage: 'preparing' | 'exporting' | 'done' | 'error'
|
||||
/** 当前处理的块索引(从 1 开始) */
|
||||
currentBlock: number
|
||||
/** 总块数 */
|
||||
totalBlocks: number
|
||||
/** 百分比(0-100) */
|
||||
percentage: number
|
||||
/** 状态消息 */
|
||||
message: string
|
||||
}
|
||||
Reference in New Issue
Block a user