diff --git a/electron/main/ai/tools/registry.ts b/electron/main/ai/tools/registry.ts index e76497e..1742525 100644 --- a/electron/main/ai/tools/registry.ts +++ b/electron/main/ai/tools/registry.ts @@ -751,6 +751,171 @@ async function getMessageContextExecutor( } } +// ==================== 会话相关工具 ==================== + +/** + * 搜索会话工具 + * 根据关键词和时间范围搜索会话 + */ +const searchSessionsTool: ToolDefinition = { + type: 'function', + function: { + name: 'search_sessions', + description: + '搜索聊天会话(对话段落)。会话是根据消息时间间隔自动切分的对话单元。适用于查找特定话题的讨论、了解某个时间段内发生了几次对话等场景。返回匹配的会话列表及每个会话的前5条消息预览。', + parameters: { + type: 'object', + properties: { + keywords: { + type: 'array', + description: '可选的搜索关键词列表,只返回包含这些关键词的会话(OR 逻辑匹配)', + items: { type: 'string' }, + }, + limit: { + type: 'number', + description: '返回会话数量限制,默认 20', + }, + year: { + type: 'number', + description: '筛选指定年份的会话,如 2024', + }, + month: { + type: 'number', + description: '筛选指定月份的会话(1-12),需要配合 year 使用', + }, + day: { + type: 'number', + description: '筛选指定日期的会话(1-31),需要配合 year 和 month 使用', + }, + start_time: { + type: 'string', + description: '开始时间,格式 "YYYY-MM-DD HH:mm",如 "2024-03-15 14:00"', + }, + end_time: { + type: 'string', + description: '结束时间,格式 "YYYY-MM-DD HH:mm",如 "2024-03-15 18:30"', + }, + }, + }, + }, +} + +async function searchSessionsExecutor( + params: { + keywords?: string[] + limit?: number + year?: number + month?: number + day?: number + start_time?: string + end_time?: string + }, + context: ToolContext +): Promise { + const { sessionId, timeFilter: contextTimeFilter, locale } = context + const limit = params.limit || 20 + + // 使用扩展的时间参数解析 + const effectiveTimeFilter = parseExtendedTimeParams(params, contextTimeFilter) + + const sessions = await workerManager.searchSessions( + sessionId, + params.keywords, + effectiveTimeFilter, + limit, + 5 // 预览5条消息 + ) + + if (sessions.length === 0) { + return { + total: 0, + message: isChineseLocale(locale) ? '未找到匹配的会话' : 'No matching sessions found', + } + } + + const localeStr = isChineseLocale(locale) ? 'zh-CN' : 'en-US' + const msgSuffix = isChineseLocale(locale) ? '条消息' : ' messages' + const completeLabel = isChineseLocale(locale) ? '完整会话' : 'complete' + + return { + total: sessions.length, + timeRange: formatTimeRange(effectiveTimeFilter, locale), + sessions: sessions.map((s) => { + const startTime = new Date(s.startTs * 1000).toLocaleString(localeStr) + const endTime = new Date(s.endTs * 1000).toLocaleString(localeStr) + const completeTag = s.isComplete ? ` [${completeLabel}]` : '' + + return { + sessionId: s.id, + time: `${startTime} ~ ${endTime}`, + messageCount: `${s.messageCount}${msgSuffix}${completeTag}`, + preview: s.previewMessages.map((m) => formatMessageCompact(m, locale)), + } + }), + } +} + +/** + * 获取会话消息工具 + * 获取指定会话的完整消息列表 + */ +const getSessionMessagesTool: ToolDefinition = { + type: 'function', + function: { + name: 'get_session_messages', + description: + '获取指定会话的完整消息列表。用于在 search_sessions 找到相关会话后,获取该会话的完整上下文。返回会话的所有消息及参与者信息。', + parameters: { + type: 'object', + properties: { + session_id: { + type: 'number', + description: '会话 ID,可以从 search_sessions 的返回结果中获取', + }, + limit: { + type: 'number', + description: '返回消息数量限制,默认 500。对于超长会话可以限制返回数量以节省 token', + }, + }, + required: ['session_id'], + }, + }, +} + +async function getSessionMessagesExecutor( + params: { + session_id: number + limit?: number + }, + context: ToolContext +): Promise { + const { sessionId, maxMessagesLimit, locale } = context + // 用户配置优先 + const limit = maxMessagesLimit || params.limit || 500 + + const result = await workerManager.getSessionMessages(sessionId, params.session_id, limit) + + if (!result) { + return { + error: isChineseLocale(locale) ? '未找到指定的会话' : 'Session not found', + sessionId: params.session_id, + } + } + + const localeStr = isChineseLocale(locale) ? 'zh-CN' : 'en-US' + const startTime = new Date(result.startTs * 1000).toLocaleString(localeStr) + const endTime = new Date(result.endTs * 1000).toLocaleString(localeStr) + + return { + sessionId: result.sessionId, + time: `${startTime} ~ ${endTime}`, + messageCount: result.messageCount, + returnedCount: result.returnedCount, + participants: result.participants, + messages: result.messages.map((m) => formatMessageCompact(m, locale)), + } +} + // ==================== 注册工具 ==================== registerTool(searchMessagesTool, searchMessagesExecutor) @@ -761,3 +926,5 @@ registerTool(getGroupMembersTool, getGroupMembersExecutor) registerTool(getMemberNameHistoryTool, getMemberNameHistoryExecutor) registerTool(getConversationBetweenTool, getConversationBetweenExecutor) registerTool(getMessageContextTool, getMessageContextExecutor) +registerTool(searchSessionsTool, searchSessionsExecutor) +registerTool(getSessionMessagesTool, getSessionMessagesExecutor) diff --git a/electron/main/worker/dbWorker.ts b/electron/main/worker/dbWorker.ts index 1858709..0815239 100644 --- a/electron/main/worker/dbWorker.ts +++ b/electron/main/worker/dbWorker.ts @@ -54,6 +54,8 @@ import { getSessionStats, updateSessionGapThreshold, getSessions, + searchSessions, + getSessionMessages, } from './query' import { streamImport, streamParseFileInfo } from './import' @@ -130,6 +132,8 @@ const syncHandlers: Record any> = { getSessionStats: (p) => getSessionStats(p.sessionId), updateSessionGapThreshold: (p) => updateSessionGapThreshold(p.sessionId, p.gapThreshold), getSessions: (p) => getSessions(p.sessionId), + searchSessions: (p) => searchSessions(p.sessionId, p.keywords, p.timeFilter, p.limit, p.previewCount), + getSessionMessages: (p) => getSessionMessages(p.sessionId, p.chatSessionId, p.limit), } // 异步消息处理器(流式操作) diff --git a/electron/main/worker/query/index.ts b/electron/main/worker/query/index.ts index 6c6582b..a86c826 100644 --- a/electron/main/worker/query/index.ts +++ b/electron/main/worker/query/index.ts @@ -62,6 +62,8 @@ export { getSessionStats, updateSessionGapThreshold, getSessions, + searchSessions, + getSessionMessages, DEFAULT_SESSION_GAP_THRESHOLD, } from './session' -export type { ChatSessionItem } from './session' +export type { ChatSessionItem, SessionSearchResultItem, SessionMessagesResult } from './session' diff --git a/electron/main/worker/query/session.ts b/electron/main/worker/query/session.ts index 7a0b6f0..50d7ea0 100644 --- a/electron/main/worker/query/session.ts +++ b/electron/main/worker/query/session.ts @@ -340,3 +340,241 @@ export function getSessions(sessionId: string): ChatSessionItem[] { db.close() } } + +// ==================== AI 工具专用查询函数 ==================== + +/** + * 会话搜索结果项类型(用于 AI 工具) + */ +export interface SessionSearchResultItem { + id: number + startTs: number + endTs: number + messageCount: number + /** 是否为完整会话(消息数 <= 预览条数) */ + isComplete: boolean + /** 预览消息列表 */ + previewMessages: Array<{ + id: number + senderName: string + content: string | null + timestamp: number + }> +} + +/** + * 搜索会话(用于 AI 工具) + * 支持按关键词和时间范围筛选会话 + * + * @param sessionId 数据库会话ID + * @param keywords 关键词列表(可选,OR 逻辑匹配) + * @param timeFilter 时间过滤器(可选) + * @param limit 返回数量限制,默认 20 + * @param previewCount 预览消息数量,默认 5 + * @returns 匹配的会话列表 + */ +export function searchSessions( + sessionId: string, + keywords?: string[], + timeFilter?: { startTs: number; endTs: number }, + limit: number = 20, + previewCount: number = 5 +): SessionSearchResultItem[] { + const db = openReadonlyDatabase(sessionId) + if (!db) { + return [] + } + + try { + // 1. 构建会话查询 SQL + let sessionSql = ` + SELECT + cs.id, + cs.start_ts as startTs, + cs.end_ts as endTs, + cs.message_count as messageCount + FROM chat_session cs + WHERE 1=1 + ` + const params: unknown[] = [] + + // 时间范围过滤 + if (timeFilter) { + sessionSql += ` AND cs.start_ts >= ? AND cs.end_ts <= ?` + params.push(timeFilter.startTs, timeFilter.endTs) + } + + // 关键词过滤:只返回包含关键词的会话 + if (keywords && keywords.length > 0) { + const keywordConditions = keywords.map(() => `m.content LIKE ?`).join(' OR ') + sessionSql += ` + AND cs.id IN ( + SELECT DISTINCT mc.session_id + FROM message_context mc + JOIN message m ON m.id = mc.message_id + WHERE (${keywordConditions}) + ) + ` + for (const kw of keywords) { + params.push(`%${kw}%`) + } + } + + sessionSql += ` ORDER BY cs.start_ts DESC LIMIT ?` + params.push(limit) + + const sessions = db.prepare(sessionSql).all(...params) as Array<{ + id: number + startTs: number + endTs: number + messageCount: number + }> + + // 2. 为每个会话获取预览消息 + const previewSql = ` + SELECT + m.id, + COALESCE(mb.group_nickname, mb.account_name, mb.platform_id) as senderName, + m.content, + m.ts as timestamp + FROM message_context mc + JOIN message m ON m.id = mc.message_id + JOIN member mb ON mb.id = m.sender_id + WHERE mc.session_id = ? + ORDER BY m.ts ASC + LIMIT ? + ` + + const results: SessionSearchResultItem[] = [] + for (const session of sessions) { + const previewMessages = db.prepare(previewSql).all(session.id, previewCount) as Array<{ + id: number + senderName: string + content: string | null + timestamp: number + }> + + results.push({ + id: session.id, + startTs: session.startTs, + endTs: session.endTs, + messageCount: session.messageCount, + isComplete: session.messageCount <= previewCount, + previewMessages, + }) + } + + return results + } catch (error) { + console.error('searchSessions error:', error) + return [] + } finally { + db.close() + } +} + +/** + * 会话消息结果类型(用于 AI 工具) + */ +export interface SessionMessagesResult { + sessionId: number + startTs: number + endTs: number + messageCount: number + returnedCount: number + /** 参与者列表 */ + participants: string[] + /** 消息列表 */ + messages: Array<{ + id: number + senderName: string + content: string | null + timestamp: number + }> +} + +/** + * 获取会话的完整消息(用于 AI 工具) + * + * @param sessionId 数据库会话ID + * @param chatSessionId 会话索引中的会话ID + * @param limit 返回数量限制,默认 500 + * @returns 会话的完整消息 + */ +export function getSessionMessages( + sessionId: string, + chatSessionId: number, + limit: number = 500 +): SessionMessagesResult | null { + const db = openReadonlyDatabase(sessionId) + if (!db) { + return null + } + + try { + // 1. 获取会话基本信息 + const sessionSql = ` + SELECT + id, + start_ts as startTs, + end_ts as endTs, + message_count as messageCount + FROM chat_session + WHERE id = ? + ` + const session = db.prepare(sessionSql).get(chatSessionId) as + | { + id: number + startTs: number + endTs: number + messageCount: number + } + | undefined + + if (!session) { + return null + } + + // 2. 获取会话消息 + const messagesSql = ` + SELECT + m.id, + COALESCE(mb.group_nickname, mb.account_name, mb.platform_id) as senderName, + m.content, + m.ts as timestamp + FROM message_context mc + JOIN message m ON m.id = mc.message_id + JOIN member mb ON mb.id = m.sender_id + WHERE mc.session_id = ? + ORDER BY m.ts ASC + LIMIT ? + ` + const messages = db.prepare(messagesSql).all(chatSessionId, limit) as Array<{ + id: number + senderName: string + content: string | null + timestamp: number + }> + + // 3. 统计参与者 + const participantsSet = new Set() + for (const msg of messages) { + participantsSet.add(msg.senderName) + } + + return { + sessionId: session.id, + startTs: session.startTs, + endTs: session.endTs, + messageCount: session.messageCount, + returnedCount: messages.length, + participants: Array.from(participantsSet), + messages, + } + } catch (error) { + console.error('getSessionMessages error:', error) + return null + } finally { + db.close() + } +} diff --git a/electron/main/worker/workerManager.ts b/electron/main/worker/workerManager.ts index b9c26d8..0e97ad6 100644 --- a/electron/main/worker/workerManager.ts +++ b/electron/main/worker/workerManager.ts @@ -557,3 +557,64 @@ export interface ChatSessionItem { export async function getSessions(sessionId: string): Promise { return sendToWorker('getSessions', { sessionId }) } + +// ==================== AI 工具专用查询函数 ==================== + +/** + * 会话搜索结果项类型(用于 AI 工具) + */ +export interface SessionSearchResultItem { + id: number + startTs: number + endTs: number + messageCount: number + isComplete: boolean + previewMessages: Array<{ + id: number + senderName: string + content: string | null + timestamp: number + }> +} + +/** + * 搜索会话(用于 AI 工具) + */ +export async function searchSessions( + sessionId: string, + keywords?: string[], + timeFilter?: { startTs: number; endTs: number }, + limit?: number, + previewCount?: number +): Promise { + return sendToWorker('searchSessions', { sessionId, keywords, timeFilter, limit, previewCount }) +} + +/** + * 会话消息结果类型(用于 AI 工具) + */ +export interface SessionMessagesResult { + sessionId: number + startTs: number + endTs: number + messageCount: number + returnedCount: number + participants: string[] + messages: Array<{ + id: number + senderName: string + content: string | null + timestamp: number + }> +} + +/** + * 获取会话的完整消息(用于 AI 工具) + */ +export async function getSessionMessages( + sessionId: string, + chatSessionId: number, + limit?: number +): Promise { + return sendToWorker('getSessionMessages', { sessionId, chatSessionId, limit }) +}