mirror of
https://github.com/hellodigua/ChatLab.git
synced 2026-04-23 01:39:37 +08:00
feat: 导入聊天记录支持生成会话索引
This commit is contained in:
@@ -53,3 +53,15 @@ export type { MessageResult, PaginatedMessages, MessagesWithTotal } from './mess
|
||||
// SQL 实验室
|
||||
export { executeRawSQL, getSchema } from './sql'
|
||||
export type { SQLResult, TableSchema } from './sql'
|
||||
|
||||
// 会话索引
|
||||
export {
|
||||
generateSessions,
|
||||
clearSessions,
|
||||
hasSessionIndex,
|
||||
getSessionStats,
|
||||
updateSessionGapThreshold,
|
||||
getSessions,
|
||||
DEFAULT_SESSION_GAP_THRESHOLD,
|
||||
} from './session'
|
||||
export type { ChatSessionItem } from './session'
|
||||
|
||||
342
electron/main/worker/query/session.ts
Normal file
342
electron/main/worker/query/session.ts
Normal file
@@ -0,0 +1,342 @@
|
||||
/**
|
||||
* 会话索引模块
|
||||
* 提供基于时间间隔的会话切分算法
|
||||
*/
|
||||
|
||||
import Database from 'better-sqlite3'
|
||||
import { getDbPath, closeDatabase } from '../core'
|
||||
|
||||
/** 默认会话切分阈值:30分钟(秒) */
|
||||
export const DEFAULT_SESSION_GAP_THRESHOLD = 1800
|
||||
|
||||
/**
|
||||
* 打开数据库(可写模式,不使用缓存)
|
||||
* 会话索引需要写入数据
|
||||
*/
|
||||
function openWritableDatabase(sessionId: string): Database.Database | null {
|
||||
const dbPath = getDbPath(sessionId)
|
||||
try {
|
||||
const db = new Database(dbPath)
|
||||
db.pragma('journal_mode = WAL')
|
||||
return db
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 打开数据库(只读模式,不使用缓存)
|
||||
*/
|
||||
function openReadonlyDatabase(sessionId: string): Database.Database | null {
|
||||
const dbPath = getDbPath(sessionId)
|
||||
try {
|
||||
const db = new Database(dbPath, { readonly: true })
|
||||
db.pragma('journal_mode = WAL')
|
||||
return db
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成会话索引
|
||||
* 使用 Gap-based 算法,根据消息时间间隔自动切分会话
|
||||
*
|
||||
* @param sessionId 数据库会话ID
|
||||
* @param gapThreshold 时间间隔阈值(秒),默认 1800(30分钟)
|
||||
* @param onProgress 进度回调
|
||||
* @returns 生成的会话数量
|
||||
*/
|
||||
export function generateSessions(
|
||||
sessionId: string,
|
||||
gapThreshold: number = DEFAULT_SESSION_GAP_THRESHOLD,
|
||||
onProgress?: (current: number, total: number) => void
|
||||
): number {
|
||||
// 先关闭缓存的只读连接
|
||||
closeDatabase(sessionId)
|
||||
|
||||
const db = openWritableDatabase(sessionId)
|
||||
if (!db) {
|
||||
throw new Error(`无法打开数据库: ${sessionId}`)
|
||||
}
|
||||
|
||||
try {
|
||||
// 获取消息总数
|
||||
const countResult = db.prepare('SELECT COUNT(*) as count FROM message').get() as { count: number }
|
||||
const totalMessages = countResult.count
|
||||
|
||||
if (totalMessages === 0) {
|
||||
return 0
|
||||
}
|
||||
|
||||
// 清空已有的会话数据
|
||||
clearSessionsInternal(db)
|
||||
|
||||
// 使用窗口函数计算会话边界
|
||||
// 步骤1:为每条消息计算与前一条的时间差,标记新会话起点
|
||||
const sessionMarkSQL = `
|
||||
WITH message_ordered AS (
|
||||
SELECT
|
||||
id,
|
||||
ts,
|
||||
LAG(ts) OVER (ORDER BY ts, id) AS prev_ts
|
||||
FROM message
|
||||
),
|
||||
session_marks AS (
|
||||
SELECT
|
||||
id,
|
||||
ts,
|
||||
CASE
|
||||
WHEN prev_ts IS NULL OR (ts - prev_ts) > ? THEN 1
|
||||
ELSE 0
|
||||
END AS is_new_session
|
||||
FROM message_ordered
|
||||
),
|
||||
session_ids AS (
|
||||
SELECT
|
||||
id,
|
||||
ts,
|
||||
SUM(is_new_session) OVER (ORDER BY ts, id) AS session_num
|
||||
FROM session_marks
|
||||
)
|
||||
SELECT id, ts, session_num FROM session_ids
|
||||
`
|
||||
|
||||
const messages = db.prepare(sessionMarkSQL).all(gapThreshold) as Array<{
|
||||
id: number
|
||||
ts: number
|
||||
session_num: number
|
||||
}>
|
||||
|
||||
if (messages.length === 0) {
|
||||
return 0
|
||||
}
|
||||
|
||||
// 步骤2:计算每个会话的统计信息
|
||||
const sessionMap = new Map<number, { startTs: number; endTs: number; messageIds: number[] }>()
|
||||
|
||||
for (const msg of messages) {
|
||||
const session = sessionMap.get(msg.session_num)
|
||||
if (!session) {
|
||||
sessionMap.set(msg.session_num, {
|
||||
startTs: msg.ts,
|
||||
endTs: msg.ts,
|
||||
messageIds: [msg.id],
|
||||
})
|
||||
} else {
|
||||
session.endTs = msg.ts
|
||||
session.messageIds.push(msg.id)
|
||||
}
|
||||
}
|
||||
|
||||
// 步骤3:批量写入 chat_session 和 message_context 表
|
||||
const insertSession = db.prepare(`
|
||||
INSERT INTO chat_session (start_ts, end_ts, message_count, is_manual, summary)
|
||||
VALUES (?, ?, ?, 0, NULL)
|
||||
`)
|
||||
|
||||
const insertContext = db.prepare(`
|
||||
INSERT INTO message_context (message_id, session_id, topic_id)
|
||||
VALUES (?, ?, NULL)
|
||||
`)
|
||||
|
||||
// 开始事务
|
||||
const transaction = db.transaction(() => {
|
||||
let processedCount = 0
|
||||
const totalSessions = sessionMap.size
|
||||
|
||||
for (const [, sessionData] of sessionMap) {
|
||||
// 插入会话记录
|
||||
const result = insertSession.run(sessionData.startTs, sessionData.endTs, sessionData.messageIds.length)
|
||||
const newSessionId = result.lastInsertRowid as number
|
||||
|
||||
// 批量插入消息上下文
|
||||
for (const messageId of sessionData.messageIds) {
|
||||
insertContext.run(messageId, newSessionId)
|
||||
}
|
||||
|
||||
processedCount++
|
||||
if (onProgress && processedCount % 100 === 0) {
|
||||
onProgress(processedCount, totalSessions)
|
||||
}
|
||||
}
|
||||
|
||||
return totalSessions
|
||||
})
|
||||
|
||||
const sessionCount = transaction()
|
||||
|
||||
// 最终进度回调
|
||||
if (onProgress) {
|
||||
onProgress(sessionCount, sessionCount)
|
||||
}
|
||||
|
||||
return sessionCount
|
||||
} finally {
|
||||
db.close()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 清空会话索引数据
|
||||
* @param sessionId 数据库会话ID
|
||||
*/
|
||||
export function clearSessions(sessionId: string): void {
|
||||
// 先关闭缓存的只读连接
|
||||
closeDatabase(sessionId)
|
||||
|
||||
const db = openWritableDatabase(sessionId)
|
||||
if (!db) {
|
||||
throw new Error(`无法打开数据库: ${sessionId}`)
|
||||
}
|
||||
|
||||
try {
|
||||
clearSessionsInternal(db)
|
||||
} finally {
|
||||
db.close()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 内部清空会话数据函数
|
||||
*/
|
||||
function clearSessionsInternal(db: Database.Database): void {
|
||||
db.exec('DELETE FROM message_context')
|
||||
db.exec('DELETE FROM chat_session')
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否已生成会话索引
|
||||
* @param sessionId 数据库会话ID
|
||||
* @returns 是否有会话索引
|
||||
*/
|
||||
export function hasSessionIndex(sessionId: string): boolean {
|
||||
const db = openReadonlyDatabase(sessionId)
|
||||
if (!db) {
|
||||
return false
|
||||
}
|
||||
|
||||
try {
|
||||
// 检查 chat_session 表是否存在且有数据
|
||||
const result = db.prepare('SELECT COUNT(*) as count FROM chat_session').get() as { count: number }
|
||||
return result.count > 0
|
||||
} catch {
|
||||
// 表可能不存在
|
||||
return false
|
||||
} finally {
|
||||
db.close()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取会话索引统计信息
|
||||
* @param sessionId 数据库会话ID
|
||||
*/
|
||||
export function getSessionStats(sessionId: string): {
|
||||
sessionCount: number
|
||||
hasIndex: boolean
|
||||
gapThreshold: number
|
||||
} {
|
||||
const db = openReadonlyDatabase(sessionId)
|
||||
if (!db) {
|
||||
return { sessionCount: 0, hasIndex: false, gapThreshold: DEFAULT_SESSION_GAP_THRESHOLD }
|
||||
}
|
||||
|
||||
try {
|
||||
// 获取会话数量
|
||||
let sessionCount = 0
|
||||
try {
|
||||
const countResult = db.prepare('SELECT COUNT(*) as count FROM chat_session').get() as { count: number }
|
||||
sessionCount = countResult.count
|
||||
} catch {
|
||||
// 表可能不存在
|
||||
}
|
||||
|
||||
// 获取配置的阈值
|
||||
let gapThreshold = DEFAULT_SESSION_GAP_THRESHOLD
|
||||
try {
|
||||
const metaResult = db.prepare('SELECT session_gap_threshold FROM meta LIMIT 1').get() as
|
||||
| {
|
||||
session_gap_threshold: number | null
|
||||
}
|
||||
| undefined
|
||||
if (metaResult?.session_gap_threshold) {
|
||||
gapThreshold = metaResult.session_gap_threshold
|
||||
}
|
||||
} catch {
|
||||
// 字段可能不存在
|
||||
}
|
||||
|
||||
return {
|
||||
sessionCount,
|
||||
hasIndex: sessionCount > 0,
|
||||
gapThreshold,
|
||||
}
|
||||
} finally {
|
||||
db.close()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 更新单个聊天的会话切分阈值
|
||||
* @param sessionId 数据库会话ID
|
||||
* @param gapThreshold 阈值(秒),null 表示使用全局配置
|
||||
*/
|
||||
export function updateSessionGapThreshold(sessionId: string, gapThreshold: number | null): void {
|
||||
// 先关闭缓存的只读连接
|
||||
closeDatabase(sessionId)
|
||||
|
||||
const db = openWritableDatabase(sessionId)
|
||||
if (!db) {
|
||||
throw new Error(`无法打开数据库: ${sessionId}`)
|
||||
}
|
||||
|
||||
try {
|
||||
db.prepare('UPDATE meta SET session_gap_threshold = ?').run(gapThreshold)
|
||||
} finally {
|
||||
db.close()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 会话列表项类型
|
||||
*/
|
||||
export interface ChatSessionItem {
|
||||
id: number
|
||||
startTs: number
|
||||
endTs: number
|
||||
messageCount: number
|
||||
firstMessageId: number
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取会话列表(用于时间线导航)
|
||||
* @param sessionId 数据库会话ID
|
||||
* @returns 会话列表,按时间排序
|
||||
*/
|
||||
export function getSessions(sessionId: string): ChatSessionItem[] {
|
||||
const db = openReadonlyDatabase(sessionId)
|
||||
if (!db) {
|
||||
return []
|
||||
}
|
||||
|
||||
try {
|
||||
// 查询会话列表,同时获取每个会话的首条消息 ID
|
||||
const sql = `
|
||||
SELECT
|
||||
cs.id,
|
||||
cs.start_ts as startTs,
|
||||
cs.end_ts as endTs,
|
||||
cs.message_count as messageCount,
|
||||
(SELECT mc.message_id FROM message_context mc WHERE mc.session_id = cs.id ORDER BY mc.message_id LIMIT 1) as firstMessageId
|
||||
FROM chat_session cs
|
||||
ORDER BY cs.start_ts ASC
|
||||
`
|
||||
const sessions = db.prepare(sql).all() as ChatSessionItem[]
|
||||
return sessions
|
||||
} catch {
|
||||
return []
|
||||
} finally {
|
||||
db.close()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user