mirror of
https://github.com/hellodigua/ChatLab.git
synced 2026-05-18 12:28:56 +08:00
feat: 搜索工具自动携带上下文消息
This commit is contained in:
@@ -33,11 +33,27 @@ export function createTool(context: ToolContext): AgentTool<typeof schema> {
|
||||
params.sender_id
|
||||
)
|
||||
|
||||
const contextBefore = context.searchContextBefore ?? 2
|
||||
const contextAfter = context.searchContextAfter ?? 2
|
||||
let finalMessages = result.messages
|
||||
|
||||
if ((contextBefore > 0 || contextAfter > 0) && result.messages.length > 0) {
|
||||
const hitIds = result.messages.map((m) => m.id).filter((id): id is number => id != null)
|
||||
if (hitIds.length > 0) {
|
||||
finalMessages = await workerManager.getSearchMessageContext(
|
||||
sessionId,
|
||||
hitIds,
|
||||
contextBefore,
|
||||
contextAfter
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const data = {
|
||||
total: result.total,
|
||||
returned: result.messages.length,
|
||||
returned: finalMessages.length,
|
||||
timeRange: formatTimeRange(effectiveTimeFilter, locale),
|
||||
rawMessages: result.messages,
|
||||
rawMessages: finalMessages,
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -34,11 +34,27 @@ export function createTool(context: ToolContext): AgentTool<typeof schema> {
|
||||
params.sender_id
|
||||
)
|
||||
|
||||
const contextBefore = context.searchContextBefore ?? 2
|
||||
const contextAfter = context.searchContextAfter ?? 2
|
||||
let finalMessages = result.messages
|
||||
|
||||
if ((contextBefore > 0 || contextAfter > 0) && result.messages.length > 0) {
|
||||
const hitIds = result.messages.map((m) => m.id).filter((id): id is number => id != null)
|
||||
if (hitIds.length > 0) {
|
||||
finalMessages = await workerManager.getSearchMessageContext(
|
||||
sessionId,
|
||||
hitIds,
|
||||
contextBefore,
|
||||
contextAfter
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const data = {
|
||||
total: result.total,
|
||||
returned: result.messages.length,
|
||||
returned: finalMessages.length,
|
||||
timeRange: formatTimeRange(effectiveTimeFilter, locale),
|
||||
rawMessages: result.messages,
|
||||
rawMessages: finalMessages,
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -53,4 +53,8 @@ export interface ToolContext {
|
||||
locale?: string
|
||||
/** 聊天记录预处理配置(全局) */
|
||||
preprocessConfig?: PreprocessConfig
|
||||
/** 搜索结果上下文:向前取多少条(默认 3) */
|
||||
searchContextBefore?: number
|
||||
/** 搜索结果上下文:向后取多少条(默认 3) */
|
||||
searchContextAfter?: number
|
||||
}
|
||||
|
||||
@@ -34,6 +34,7 @@ import {
|
||||
searchMessages,
|
||||
deepSearchMessages,
|
||||
getMessageContext,
|
||||
getSearchMessageContext,
|
||||
getRecentMessages,
|
||||
getAllRecentMessages,
|
||||
getConversationBetween,
|
||||
@@ -124,6 +125,8 @@ const syncHandlers: Record<string, (payload: any) => any> = {
|
||||
// AI 查询
|
||||
searchMessages: (p) => searchMessages(p.sessionId, p.keywords, p.filter, p.limit, p.offset, p.senderId),
|
||||
getMessageContext: (p) => getMessageContext(p.sessionId, p.messageIds, p.contextSize),
|
||||
getSearchMessageContext: (p) =>
|
||||
getSearchMessageContext(p.sessionId, p.messageIds, p.contextBefore, p.contextAfter),
|
||||
getRecentMessages: (p) => getRecentMessages(p.sessionId, p.filter, p.limit),
|
||||
getAllRecentMessages: (p) => getAllRecentMessages(p.sessionId, p.filter, p.limit),
|
||||
getConversationBetween: (p) => getConversationBetween(p.sessionId, p.memberId1, p.memberId2, p.filter, p.limit),
|
||||
|
||||
@@ -46,6 +46,7 @@ export {
|
||||
searchMessages,
|
||||
deepSearchMessages,
|
||||
getMessageContext,
|
||||
getSearchMessageContext,
|
||||
getRecentMessages,
|
||||
getAllRecentMessages,
|
||||
getConversationBetween,
|
||||
|
||||
@@ -526,6 +526,115 @@ export function getMessageContext(
|
||||
return rows.map(sanitizeMessageRow)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取搜索结果的上下文消息(会话感知 + 区间合并去重)
|
||||
* 用于 search_messages / deep_search_messages 自动扩展上下文。
|
||||
* 当存在会话索引时,上下文不跨会话边界;否则按 message.id 顺序取前后 N 条。
|
||||
*
|
||||
* @param sessionId 数据库会话 ID
|
||||
* @param messageIds 搜索命中的消息 ID 列表
|
||||
* @param contextBefore 每条命中消息向前取多少条上下文
|
||||
* @param contextAfter 每条命中消息向后取多少条上下文
|
||||
*/
|
||||
export function getSearchMessageContext(
|
||||
sessionId: string,
|
||||
messageIds: number[],
|
||||
contextBefore: number = 2,
|
||||
contextAfter: number = 2
|
||||
): MessageResult[] {
|
||||
ensureAvatarColumn(sessionId)
|
||||
const db = openDatabase(sessionId)
|
||||
if (!db) return []
|
||||
if (messageIds.length === 0) return []
|
||||
|
||||
const contextIds = new Set<number>()
|
||||
|
||||
const hasSessionData =
|
||||
(db.prepare('SELECT 1 FROM message_context LIMIT 1').get() as { 1: number } | undefined) !== undefined
|
||||
|
||||
for (const messageId of messageIds) {
|
||||
contextIds.add(messageId)
|
||||
|
||||
if (hasSessionData) {
|
||||
const sessionRow = db
|
||||
.prepare('SELECT session_id FROM message_context WHERE message_id = ?')
|
||||
.get(messageId) as { session_id: number } | undefined
|
||||
|
||||
if (sessionRow) {
|
||||
if (contextBefore > 0) {
|
||||
const rows = db
|
||||
.prepare(
|
||||
`SELECT mc.message_id as id
|
||||
FROM message_context mc
|
||||
WHERE mc.session_id = ? AND mc.message_id < ?
|
||||
ORDER BY mc.message_id DESC
|
||||
LIMIT ?`
|
||||
)
|
||||
.all(sessionRow.session_id, messageId, contextBefore) as { id: number }[]
|
||||
rows.forEach((r) => contextIds.add(r.id))
|
||||
}
|
||||
if (contextAfter > 0) {
|
||||
const rows = db
|
||||
.prepare(
|
||||
`SELECT mc.message_id as id
|
||||
FROM message_context mc
|
||||
WHERE mc.session_id = ? AND mc.message_id > ?
|
||||
ORDER BY mc.message_id ASC
|
||||
LIMIT ?`
|
||||
)
|
||||
.all(sessionRow.session_id, messageId, contextAfter) as { id: number }[]
|
||||
rows.forEach((r) => contextIds.add(r.id))
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: no session data or message not indexed — use simple id-based context
|
||||
if (contextBefore > 0) {
|
||||
const rows = db
|
||||
.prepare('SELECT id FROM message WHERE id < ? ORDER BY id DESC LIMIT ?')
|
||||
.all(messageId, contextBefore) as { id: number }[]
|
||||
rows.forEach((r) => contextIds.add(r.id))
|
||||
}
|
||||
if (contextAfter > 0) {
|
||||
const rows = db
|
||||
.prepare('SELECT id FROM message WHERE id > ? ORDER BY id ASC LIMIT ?')
|
||||
.all(messageId, contextAfter) as { id: number }[]
|
||||
rows.forEach((r) => contextIds.add(r.id))
|
||||
}
|
||||
}
|
||||
|
||||
if (contextIds.size === 0) return []
|
||||
|
||||
const idList = Array.from(contextIds)
|
||||
const placeholders = idList.map(() => '?').join(', ')
|
||||
|
||||
const sql = `
|
||||
SELECT
|
||||
msg.id,
|
||||
m.id as senderId,
|
||||
COALESCE(m.group_nickname, m.account_name, m.platform_id) as senderName,
|
||||
m.platform_id as senderPlatformId,
|
||||
m.aliases,
|
||||
m.avatar,
|
||||
msg.content,
|
||||
msg.ts as timestamp,
|
||||
msg.type,
|
||||
msg.reply_to_message_id,
|
||||
reply_msg.content as replyToContent,
|
||||
COALESCE(reply_m.group_nickname, reply_m.account_name, reply_m.platform_id) as replyToSenderName
|
||||
FROM message msg
|
||||
JOIN member m ON msg.sender_id = m.id
|
||||
LEFT JOIN message reply_msg ON msg.reply_to_message_id = reply_msg.platform_message_id
|
||||
LEFT JOIN member reply_m ON reply_msg.sender_id = reply_m.id
|
||||
WHERE msg.id IN (${placeholders})
|
||||
ORDER BY msg.ts ASC, msg.id ASC
|
||||
`
|
||||
|
||||
const rows = db.prepare(sql).all(...idList) as DbMessageRow[]
|
||||
return rows.map(sanitizeMessageRow)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取指定消息之前的 N 条消息(用于向上无限滚动)
|
||||
* @param sessionId 会话 ID
|
||||
|
||||
@@ -494,6 +494,18 @@ export async function getMessageContext(
|
||||
return sendToWorker('getMessageContext', { sessionId, messageIds, contextSize })
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取搜索结果的上下文消息(会话感知 + 区间合并去重)
|
||||
*/
|
||||
export async function getSearchMessageContext(
|
||||
sessionId: string,
|
||||
messageIds: number[],
|
||||
contextBefore?: number,
|
||||
contextAfter?: number
|
||||
): Promise<SearchMessageResult[]> {
|
||||
return sendToWorker('getSearchMessageContext', { sessionId, messageIds, contextBefore, contextAfter })
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取最近消息(用于概览性问题)
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user