feat: 搜索工具自动携带上下文消息

This commit is contained in:
digua
2026-04-07 22:06:51 +08:00
committed by digua
parent e86df09371
commit d49a094164
14 changed files with 239 additions and 4 deletions
@@ -33,11 +33,27 @@ export function createTool(context: ToolContext): AgentTool<typeof schema> {
params.sender_id
)
const contextBefore = context.searchContextBefore ?? 2
const contextAfter = context.searchContextAfter ?? 2
let finalMessages = result.messages
if ((contextBefore > 0 || contextAfter > 0) && result.messages.length > 0) {
const hitIds = result.messages.map((m) => m.id).filter((id): id is number => id != null)
if (hitIds.length > 0) {
finalMessages = await workerManager.getSearchMessageContext(
sessionId,
hitIds,
contextBefore,
contextAfter
)
}
}
const data = {
total: result.total,
returned: result.messages.length,
returned: finalMessages.length,
timeRange: formatTimeRange(effectiveTimeFilter, locale),
rawMessages: result.messages,
rawMessages: finalMessages,
}
return {
@@ -34,11 +34,27 @@ export function createTool(context: ToolContext): AgentTool<typeof schema> {
params.sender_id
)
const contextBefore = context.searchContextBefore ?? 2
const contextAfter = context.searchContextAfter ?? 2
let finalMessages = result.messages
if ((contextBefore > 0 || contextAfter > 0) && result.messages.length > 0) {
const hitIds = result.messages.map((m) => m.id).filter((id): id is number => id != null)
if (hitIds.length > 0) {
finalMessages = await workerManager.getSearchMessageContext(
sessionId,
hitIds,
contextBefore,
contextAfter
)
}
}
const data = {
total: result.total,
returned: result.messages.length,
returned: finalMessages.length,
timeRange: formatTimeRange(effectiveTimeFilter, locale),
rawMessages: result.messages,
rawMessages: finalMessages,
}
return {
+4
View File
@@ -53,4 +53,8 @@ export interface ToolContext {
locale?: string
/** 聊天记录预处理配置(全局) */
preprocessConfig?: PreprocessConfig
/** 搜索结果上下文:向前取多少条(默认 3) */
searchContextBefore?: number
/** 搜索结果上下文:向后取多少条(默认 3) */
searchContextAfter?: number
}
+3
View File
@@ -34,6 +34,7 @@ import {
searchMessages,
deepSearchMessages,
getMessageContext,
getSearchMessageContext,
getRecentMessages,
getAllRecentMessages,
getConversationBetween,
@@ -124,6 +125,8 @@ const syncHandlers: Record<string, (payload: any) => any> = {
// AI 查询
searchMessages: (p) => searchMessages(p.sessionId, p.keywords, p.filter, p.limit, p.offset, p.senderId),
getMessageContext: (p) => getMessageContext(p.sessionId, p.messageIds, p.contextSize),
getSearchMessageContext: (p) =>
getSearchMessageContext(p.sessionId, p.messageIds, p.contextBefore, p.contextAfter),
getRecentMessages: (p) => getRecentMessages(p.sessionId, p.filter, p.limit),
getAllRecentMessages: (p) => getAllRecentMessages(p.sessionId, p.filter, p.limit),
getConversationBetween: (p) => getConversationBetween(p.sessionId, p.memberId1, p.memberId2, p.filter, p.limit),
+1
View File
@@ -46,6 +46,7 @@ export {
searchMessages,
deepSearchMessages,
getMessageContext,
getSearchMessageContext,
getRecentMessages,
getAllRecentMessages,
getConversationBetween,
+109
View File
@@ -526,6 +526,115 @@ export function getMessageContext(
return rows.map(sanitizeMessageRow)
}
/**
* 获取搜索结果的上下文消息(会话感知 + 区间合并去重)
* 用于 search_messages / deep_search_messages 自动扩展上下文。
* 当存在会话索引时,上下文不跨会话边界;否则按 message.id 顺序取前后 N 条。
*
* @param sessionId 数据库会话 ID
* @param messageIds 搜索命中的消息 ID 列表
* @param contextBefore 每条命中消息向前取多少条上下文
* @param contextAfter 每条命中消息向后取多少条上下文
*/
export function getSearchMessageContext(
sessionId: string,
messageIds: number[],
contextBefore: number = 2,
contextAfter: number = 2
): MessageResult[] {
ensureAvatarColumn(sessionId)
const db = openDatabase(sessionId)
if (!db) return []
if (messageIds.length === 0) return []
const contextIds = new Set<number>()
const hasSessionData =
(db.prepare('SELECT 1 FROM message_context LIMIT 1').get() as { 1: number } | undefined) !== undefined
for (const messageId of messageIds) {
contextIds.add(messageId)
if (hasSessionData) {
const sessionRow = db
.prepare('SELECT session_id FROM message_context WHERE message_id = ?')
.get(messageId) as { session_id: number } | undefined
if (sessionRow) {
if (contextBefore > 0) {
const rows = db
.prepare(
`SELECT mc.message_id as id
FROM message_context mc
WHERE mc.session_id = ? AND mc.message_id < ?
ORDER BY mc.message_id DESC
LIMIT ?`
)
.all(sessionRow.session_id, messageId, contextBefore) as { id: number }[]
rows.forEach((r) => contextIds.add(r.id))
}
if (contextAfter > 0) {
const rows = db
.prepare(
`SELECT mc.message_id as id
FROM message_context mc
WHERE mc.session_id = ? AND mc.message_id > ?
ORDER BY mc.message_id ASC
LIMIT ?`
)
.all(sessionRow.session_id, messageId, contextAfter) as { id: number }[]
rows.forEach((r) => contextIds.add(r.id))
}
continue
}
}
// Fallback: no session data or message not indexed — use simple id-based context
if (contextBefore > 0) {
const rows = db
.prepare('SELECT id FROM message WHERE id < ? ORDER BY id DESC LIMIT ?')
.all(messageId, contextBefore) as { id: number }[]
rows.forEach((r) => contextIds.add(r.id))
}
if (contextAfter > 0) {
const rows = db
.prepare('SELECT id FROM message WHERE id > ? ORDER BY id ASC LIMIT ?')
.all(messageId, contextAfter) as { id: number }[]
rows.forEach((r) => contextIds.add(r.id))
}
}
if (contextIds.size === 0) return []
const idList = Array.from(contextIds)
const placeholders = idList.map(() => '?').join(', ')
const sql = `
SELECT
msg.id,
m.id as senderId,
COALESCE(m.group_nickname, m.account_name, m.platform_id) as senderName,
m.platform_id as senderPlatformId,
m.aliases,
m.avatar,
msg.content,
msg.ts as timestamp,
msg.type,
msg.reply_to_message_id,
reply_msg.content as replyToContent,
COALESCE(reply_m.group_nickname, reply_m.account_name, reply_m.platform_id) as replyToSenderName
FROM message msg
JOIN member m ON msg.sender_id = m.id
LEFT JOIN message reply_msg ON msg.reply_to_message_id = reply_msg.platform_message_id
LEFT JOIN member reply_m ON reply_msg.sender_id = reply_m.id
WHERE msg.id IN (${placeholders})
ORDER BY msg.ts ASC, msg.id ASC
`
const rows = db.prepare(sql).all(...idList) as DbMessageRow[]
return rows.map(sanitizeMessageRow)
}
/**
* 获取指定消息之前的 N 条消息(用于向上无限滚动)
* @param sessionId 会话 ID
+12
View File
@@ -494,6 +494,18 @@ export async function getMessageContext(
return sendToWorker('getMessageContext', { sessionId, messageIds, contextSize })
}
/**
* 获取搜索结果的上下文消息(会话感知 + 区间合并去重)
*/
export async function getSearchMessageContext(
sessionId: string,
messageIds: number[],
contextBefore?: number,
contextAfter?: number
): Promise<SearchMessageResult[]> {
return sendToWorker('getSearchMessageContext', { sessionId, messageIds, contextBefore, contextAfter })
}
/**
* 获取最近消息(用于概览性问题)
*/