feat: 新增自定义筛选

This commit is contained in:
digua
2026-01-11 23:40:29 +08:00
committed by digua
parent 6a01177ca5
commit 9b48f465f8
20 changed files with 2719 additions and 17 deletions

View File

@@ -65,5 +65,15 @@ export {
searchSessions,
getSessionMessages,
DEFAULT_SESSION_GAP_THRESHOLD,
// 自定义筛选
filterMessagesWithContext,
getMultipleSessionsMessages,
} from './session'
export type {
ChatSessionItem,
SessionSearchResultItem,
SessionMessagesResult,
ContextBlock,
FilterResult,
FilterMessage,
} from './session'
export type { ChatSessionItem, SessionSearchResultItem, SessionMessagesResult } from './session'

View File

@@ -578,3 +578,372 @@ export function getSessionMessages(
db.close()
}
}
// ==================== 自定义筛选专用函数 ====================
/**
* 自定义筛选消息类型(完整信息,兼容 MessageList 组件)
*/
export interface FilterMessage {
id: number
senderName: string
senderPlatformId: string
senderAliases: string[]
senderAvatar: string | null
content: string
timestamp: number
type: number
replyToMessageId: string | null
replyToContent: string | null
replyToSenderName: string | null
/** 是否为命中的消息(关键词匹配) */
isHit: boolean
}
/**
* 上下文块类型(用于自定义筛选)
*/
export interface ContextBlock {
/** 块的时间范围 */
startTs: number
endTs: number
/** 消息列表 */
messages: FilterMessage[]
/** 命中的消息数量 */
hitCount: number
}
/**
* 筛选结果类型
*/
export interface FilterResult {
/** 上下文块列表 */
blocks: ContextBlock[]
/** 统计信息 */
stats: {
/** 总消息数 */
totalMessages: number
/** 命中的消息数 */
hitMessages: number
/** 总字符数 */
totalChars: number
}
}
/**
* 按条件筛选消息并扩充上下文
*
* 核心算法:
* 1. 先搜索匹配条件的消息获取消息ID列表
* 2. 为每个命中消息向前后各扩展 contextSize 条消息
* 3. 合并重叠/相邻的消息范围
* 4. 按合并后的范围分块返回消息
*
* @param sessionId 数据库会话ID
* @param keywords 关键词列表可选OR 逻辑)
* @param timeFilter 时间过滤器(可选)
* @param senderIds 发送者ID列表可选
* @param contextSize 上下文扩展数量(前后各多少条)
* @returns 筛选结果
*/
export function filterMessagesWithContext(
sessionId: string,
keywords?: string[],
timeFilter?: { startTs: number; endTs: number },
senderIds?: number[],
contextSize: number = 10
): FilterResult {
const db = openReadonlyDatabase(sessionId)
if (!db) {
return { blocks: [], stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 } }
}
try {
// 1. 构建基础消息查询(完整信息),按时间排序
// 使用 LEFT JOIN 获取回复消息的信息
const allMessagesSql = `
SELECT
msg.id,
msg.ts,
COALESCE(m.group_nickname, m.account_name, m.platform_id) as senderName,
m.platform_id as senderPlatformId,
COALESCE(m.aliases, '[]') as senderAliasesJson,
m.avatar as senderAvatar,
msg.content,
msg.type,
msg.reply_to_message_id as replyToMessageId,
reply_msg.content as replyToContent,
COALESCE(reply_m.group_nickname, reply_m.account_name, reply_m.platform_id) as replyToSenderName,
msg.sender_id as senderId
FROM message msg
JOIN member m ON msg.sender_id = m.id
LEFT JOIN message reply_msg ON msg.reply_to_message_id = reply_msg.platform_message_id
LEFT JOIN member reply_m ON reply_msg.sender_id = reply_m.id
${timeFilter ? 'WHERE msg.ts >= ? AND msg.ts <= ?' : ''}
ORDER BY msg.ts ASC, msg.id ASC
`
const params: unknown[] = []
if (timeFilter) {
params.push(timeFilter.startTs, timeFilter.endTs)
}
const allMessages = db.prepare(allMessagesSql).all(...params) as Array<{
id: number
ts: number
senderName: string
senderPlatformId: string
senderAliasesJson: string
senderAvatar: string | null
content: string | null
type: number
replyToMessageId: string | null
replyToContent: string | null
replyToSenderName: string | null
senderId: number
}>
if (allMessages.length === 0) {
return { blocks: [], stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 } }
}
// 2. 标记命中的消息
const hitIndexes: number[] = []
for (let i = 0; i < allMessages.length; i++) {
const msg = allMessages[i]
let isHit = true
// 关键词匹配OR 逻辑)
if (keywords && keywords.length > 0) {
const content = (msg.content || '').toLowerCase()
isHit = keywords.some((kw) => content.includes(kw.toLowerCase()))
}
// 发送者匹配
if (isHit && senderIds && senderIds.length > 0) {
isHit = senderIds.includes(msg.senderId)
}
if (isHit) {
hitIndexes.push(i)
}
}
if (hitIndexes.length === 0) {
return { blocks: [], stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 } }
}
// 3. 扩展上下文并合并重叠范围
const ranges: Array<{ start: number; end: number; hitIndexes: number[] }> = []
for (const hitIndex of hitIndexes) {
const start = Math.max(0, hitIndex - contextSize)
const end = Math.min(allMessages.length - 1, hitIndex + contextSize)
// 检查是否能与前一个范围合并
if (ranges.length > 0) {
const lastRange = ranges[ranges.length - 1]
// 如果当前范围的 start <= 上一个范围的 end + 1则合并
if (start <= lastRange.end + 1) {
lastRange.end = Math.max(lastRange.end, end)
lastRange.hitIndexes.push(hitIndex)
continue
}
}
ranges.push({ start, end, hitIndexes: [hitIndex] })
}
// 4. 按范围构建上下文块
const blocks: ContextBlock[] = []
let totalMessages = 0
let totalChars = 0
for (const range of ranges) {
const hitIndexSet = new Set(range.hitIndexes)
const blockMessages: FilterMessage[] = []
for (let i = range.start; i <= range.end; i++) {
const msg = allMessages[i]
const isHit = hitIndexSet.has(i)
// 解析别名 JSON
let senderAliases: string[] = []
try {
senderAliases = JSON.parse(msg.senderAliasesJson || '[]')
} catch {
senderAliases = []
}
blockMessages.push({
id: msg.id,
senderName: msg.senderName,
senderPlatformId: msg.senderPlatformId,
senderAliases,
senderAvatar: msg.senderAvatar,
content: msg.content || '',
timestamp: msg.ts,
type: msg.type,
replyToMessageId: msg.replyToMessageId,
replyToContent: msg.replyToContent,
replyToSenderName: msg.replyToSenderName,
isHit,
})
totalChars += (msg.content || '').length
}
blocks.push({
startTs: allMessages[range.start].ts,
endTs: allMessages[range.end].ts,
messages: blockMessages,
hitCount: range.hitIndexes.length,
})
totalMessages += blockMessages.length
}
return {
blocks,
stats: {
totalMessages,
hitMessages: hitIndexes.length,
totalChars,
},
}
} catch (error) {
console.error('filterMessagesWithContext error:', error)
return { blocks: [], stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 } }
} finally {
db.close()
}
}
/**
* 获取多个会话的完整消息(用于会话筛选模式)
*
* @param sessionId 数据库会话ID
* @param chatSessionIds 要获取的会话ID列表
* @returns 合并后的上下文块和统计
*/
export function getMultipleSessionsMessages(sessionId: string, chatSessionIds: number[]): FilterResult {
const db = openReadonlyDatabase(sessionId)
if (!db) {
return { blocks: [], stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 } }
}
try {
if (chatSessionIds.length === 0) {
return { blocks: [], stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 } }
}
const blocks: ContextBlock[] = []
let totalMessages = 0
let totalChars = 0
// 先获取会话信息,按时间排序
const sessionsSql = `
SELECT id, start_ts as startTs, end_ts as endTs, message_count as messageCount
FROM chat_session
WHERE id IN (${chatSessionIds.map(() => '?').join(',')})
ORDER BY start_ts ASC
`
const sessions = db.prepare(sessionsSql).all(...chatSessionIds) as Array<{
id: number
startTs: number
endTs: number
messageCount: number
}>
// 为每个会话获取消息(完整信息)
// 使用 LEFT JOIN 获取回复消息的信息
const messagesSql = `
SELECT
msg.id,
COALESCE(m.group_nickname, m.account_name, m.platform_id) as senderName,
m.platform_id as senderPlatformId,
COALESCE(m.aliases, '[]') as senderAliasesJson,
m.avatar as senderAvatar,
msg.content,
msg.type,
msg.reply_to_message_id as replyToMessageId,
reply_msg.content as replyToContent,
COALESCE(reply_m.group_nickname, reply_m.account_name, reply_m.platform_id) as replyToSenderName,
msg.ts as timestamp
FROM message_context mc
JOIN message msg ON msg.id = mc.message_id
JOIN member m ON msg.sender_id = m.id
LEFT JOIN message reply_msg ON msg.reply_to_message_id = reply_msg.platform_message_id
LEFT JOIN member reply_m ON reply_msg.sender_id = reply_m.id
WHERE mc.session_id = ?
ORDER BY msg.ts ASC
`
for (const session of sessions) {
const messages = db.prepare(messagesSql).all(session.id) as Array<{
id: number
senderName: string
senderPlatformId: string
senderAliasesJson: string
senderAvatar: string | null
content: string | null
type: number
replyToMessageId: string | null
replyToContent: string | null
replyToSenderName: string | null
timestamp: number
}>
const blockMessages: FilterMessage[] = messages.map((msg) => {
// 解析别名 JSON
let senderAliases: string[] = []
try {
senderAliases = JSON.parse(msg.senderAliasesJson || '[]')
} catch {
senderAliases = []
}
return {
id: msg.id,
senderName: msg.senderName,
senderPlatformId: msg.senderPlatformId,
senderAliases,
senderAvatar: msg.senderAvatar,
content: msg.content || '',
timestamp: msg.timestamp,
type: msg.type,
replyToMessageId: msg.replyToMessageId,
replyToContent: msg.replyToContent,
replyToSenderName: msg.replyToSenderName,
isHit: false, // 会话模式下没有命中高亮
}
})
for (const msg of messages) {
totalChars += (msg.content || '').length
}
blocks.push({
startTs: session.startTs,
endTs: session.endTs,
messages: blockMessages,
hitCount: 0,
})
totalMessages += messages.length
}
return {
blocks,
stats: {
totalMessages,
hitMessages: 0, // 会话模式没有命中概念
totalChars,
},
}
} catch (error) {
console.error('getMultipleSessionsMessages error:', error)
return { blocks: [], stats: { totalMessages: 0, hitMessages: 0, totalChars: 0 } }
} finally {
db.close()
}
}