feat: 搜索工具自动携带上下文消息

This commit is contained in:
digua
2026-04-07 22:06:51 +08:00
committed by digua
parent e86df09371
commit d49a094164
14 changed files with 239 additions and 4 deletions

View File

@@ -526,6 +526,115 @@ export function getMessageContext(
return rows.map(sanitizeMessageRow)
}
/**
* 获取搜索结果的上下文消息(会话感知 + 区间合并去重)
* 用于 search_messages / deep_search_messages 自动扩展上下文。
* 当存在会话索引时,上下文不跨会话边界;否则按 message.id 顺序取前后 N 条。
*
* @param sessionId 数据库会话 ID
* @param messageIds 搜索命中的消息 ID 列表
* @param contextBefore 每条命中消息向前取多少条上下文
* @param contextAfter 每条命中消息向后取多少条上下文
*/
export function getSearchMessageContext(
sessionId: string,
messageIds: number[],
contextBefore: number = 2,
contextAfter: number = 2
): MessageResult[] {
ensureAvatarColumn(sessionId)
const db = openDatabase(sessionId)
if (!db) return []
if (messageIds.length === 0) return []
const contextIds = new Set<number>()
const hasSessionData =
(db.prepare('SELECT 1 FROM message_context LIMIT 1').get() as { 1: number } | undefined) !== undefined
for (const messageId of messageIds) {
contextIds.add(messageId)
if (hasSessionData) {
const sessionRow = db
.prepare('SELECT session_id FROM message_context WHERE message_id = ?')
.get(messageId) as { session_id: number } | undefined
if (sessionRow) {
if (contextBefore > 0) {
const rows = db
.prepare(
`SELECT mc.message_id as id
FROM message_context mc
WHERE mc.session_id = ? AND mc.message_id < ?
ORDER BY mc.message_id DESC
LIMIT ?`
)
.all(sessionRow.session_id, messageId, contextBefore) as { id: number }[]
rows.forEach((r) => contextIds.add(r.id))
}
if (contextAfter > 0) {
const rows = db
.prepare(
`SELECT mc.message_id as id
FROM message_context mc
WHERE mc.session_id = ? AND mc.message_id > ?
ORDER BY mc.message_id ASC
LIMIT ?`
)
.all(sessionRow.session_id, messageId, contextAfter) as { id: number }[]
rows.forEach((r) => contextIds.add(r.id))
}
continue
}
}
// Fallback: no session data or message not indexed — use simple id-based context
if (contextBefore > 0) {
const rows = db
.prepare('SELECT id FROM message WHERE id < ? ORDER BY id DESC LIMIT ?')
.all(messageId, contextBefore) as { id: number }[]
rows.forEach((r) => contextIds.add(r.id))
}
if (contextAfter > 0) {
const rows = db
.prepare('SELECT id FROM message WHERE id > ? ORDER BY id ASC LIMIT ?')
.all(messageId, contextAfter) as { id: number }[]
rows.forEach((r) => contextIds.add(r.id))
}
}
if (contextIds.size === 0) return []
const idList = Array.from(contextIds)
const placeholders = idList.map(() => '?').join(', ')
const sql = `
SELECT
msg.id,
m.id as senderId,
COALESCE(m.group_nickname, m.account_name, m.platform_id) as senderName,
m.platform_id as senderPlatformId,
m.aliases,
m.avatar,
msg.content,
msg.ts as timestamp,
msg.type,
msg.reply_to_message_id,
reply_msg.content as replyToContent,
COALESCE(reply_m.group_nickname, reply_m.account_name, reply_m.platform_id) as replyToSenderName
FROM message msg
JOIN member m ON msg.sender_id = m.id
LEFT JOIN message reply_msg ON msg.reply_to_message_id = reply_msg.platform_message_id
LEFT JOIN member reply_m ON reply_msg.sender_id = reply_m.id
WHERE msg.id IN (${placeholders})
ORDER BY msg.ts ASC, msg.id ASC
`
const rows = db.prepare(sql).all(...idList) as DbMessageRow[]
return rows.map(sanitizeMessageRow)
}
/**
* 获取指定消息之前的 N 条消息(用于向上无限滚动)
* @param sessionId 会话 ID