mirror of
https://github.com/hellodigua/ChatLab.git
synced 2026-04-28 15:52:44 +08:00
refactor(tools): modularize tool system with AgentTool + TypeBox + i18n
- Delete monolithic registry.ts (−1185 lines) - Add tools/definitions/ with 12 individual tool files + index.ts, each using AgentTool interface and TypeBox schemas - Add tools/utils/ with shared helpers (format.ts, schemas.ts, time-params.ts) - Rewrite tools/index.ts to provide getAllTools() factory - Clean up tools/types.ts, keep only ToolContext and OwnerInfo - Use i18n keys for tool descriptions, preserve Chinese as comments
This commit is contained in:
@@ -0,0 +1,63 @@
|
||||
import { Type } from '@mariozechner/pi-ai'
|
||||
import type { AgentTool } from '@mariozechner/pi-agent-core'
|
||||
import type { ToolContext } from '../types'
|
||||
import { timeParamProperties } from '../utils/schemas'
|
||||
import * as workerManager from '../../../worker/workerManager'
|
||||
import { parseExtendedTimeParams } from '../utils/time-params'
|
||||
import { formatTimeRange, formatMessageCompact, t } from '../utils/format'
|
||||
|
||||
const schema = Type.Object({
|
||||
member_id_1: Type.Number({ description: 'ai.tools.get_conversation_between.params.member_id_1' }),
|
||||
member_id_2: Type.Number({ description: 'ai.tools.get_conversation_between.params.member_id_2' }),
|
||||
limit: Type.Optional(Type.Number({ description: 'ai.tools.get_conversation_between.params.limit' })),
|
||||
...timeParamProperties,
|
||||
})
|
||||
|
||||
/** 获取两个群成员之间的对话记录。适用于回答"A和B之间聊了什么"、"查看两人的对话"等问题。需要先通过 get_group_members 获取成员 ID。支持精确到分钟级别的时间查询。 */
|
||||
export function createTool(context: ToolContext): AgentTool<typeof schema> {
|
||||
return {
|
||||
name: 'get_conversation_between',
|
||||
label: 'get_conversation_between',
|
||||
description: 'ai.tools.get_conversation_between.desc',
|
||||
parameters: schema,
|
||||
execute: async (_toolCallId, params) => {
|
||||
const { sessionId, timeFilter: contextTimeFilter, maxMessagesLimit, locale } = context
|
||||
const limit = maxMessagesLimit || params.limit || 100
|
||||
const effectiveTimeFilter = parseExtendedTimeParams(params, contextTimeFilter)
|
||||
|
||||
const result = await workerManager.getConversationBetween(
|
||||
sessionId,
|
||||
params.member_id_1,
|
||||
params.member_id_2,
|
||||
effectiveTimeFilter,
|
||||
limit
|
||||
)
|
||||
|
||||
if (result.messages.length === 0) {
|
||||
const data = {
|
||||
error: t('noConversation', locale) as string,
|
||||
member1Id: params.member_id_1,
|
||||
member2Id: params.member_id_2,
|
||||
}
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify(data) }],
|
||||
details: data,
|
||||
}
|
||||
}
|
||||
|
||||
const data = {
|
||||
total: result.total,
|
||||
returned: result.messages.length,
|
||||
member1: result.member1Name,
|
||||
member2: result.member2Name,
|
||||
timeRange: formatTimeRange(effectiveTimeFilter, locale),
|
||||
conversation: result.messages.map((m) => formatMessageCompact(m, locale)),
|
||||
}
|
||||
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify(data) }],
|
||||
details: data,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
57
electron/main/ai/tools/definitions/get-group-members.ts
Normal file
57
electron/main/ai/tools/definitions/get-group-members.ts
Normal file
@@ -0,0 +1,57 @@
|
||||
import { Type } from '@mariozechner/pi-ai'
|
||||
import type { AgentTool } from '@mariozechner/pi-agent-core'
|
||||
import type { ToolContext } from '../types'
|
||||
import * as workerManager from '../../../worker/workerManager'
|
||||
import { isChineseLocale, t } from '../utils/format'
|
||||
|
||||
const schema = Type.Object({
|
||||
search: Type.Optional(Type.String({ description: 'ai.tools.get_group_members.params.search' })),
|
||||
limit: Type.Optional(Type.Number({ description: 'ai.tools.get_group_members.params.limit' })),
|
||||
})
|
||||
|
||||
/** 获取群成员列表,包括成员的基本信息、别名和消息统计。适用于查询"群里有哪些人"、"某人的别名是什么"、"谁的QQ号是xxx"等问题。 */
|
||||
export function createTool(context: ToolContext): AgentTool<typeof schema> {
|
||||
return {
|
||||
name: 'get_group_members',
|
||||
label: 'get_group_members',
|
||||
description: 'ai.tools.get_group_members.desc',
|
||||
parameters: schema,
|
||||
execute: async (_toolCallId, params) => {
|
||||
const { sessionId, locale } = context
|
||||
const members = await workerManager.getMembers(sessionId)
|
||||
|
||||
let filteredMembers = members
|
||||
if (params.search) {
|
||||
const keyword = params.search.toLowerCase()
|
||||
filteredMembers = members.filter((m) => {
|
||||
if (m.groupNickname && m.groupNickname.toLowerCase().includes(keyword)) return true
|
||||
if (m.accountName && m.accountName.toLowerCase().includes(keyword)) return true
|
||||
if (m.platformId.includes(keyword)) return true
|
||||
if (m.aliases.some((alias) => alias.toLowerCase().includes(keyword))) return true
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
if (params.limit && params.limit > 0) {
|
||||
filteredMembers = filteredMembers.slice(0, params.limit)
|
||||
}
|
||||
|
||||
const msgSuffix = isChineseLocale(locale) ? '条' : ''
|
||||
const aliasLabel = t('alias', locale) as string
|
||||
const data = {
|
||||
totalMembers: members.length,
|
||||
returnedMembers: filteredMembers.length,
|
||||
members: filteredMembers.map((m) => {
|
||||
const displayName = m.groupNickname || m.accountName || m.platformId
|
||||
const aliasStr = m.aliases.length > 0 ? `|${aliasLabel}:${m.aliases.join(',')}` : ''
|
||||
return `${m.id}|${m.platformId}|${displayName}|${m.messageCount}${msgSuffix}${aliasStr}`
|
||||
}),
|
||||
}
|
||||
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify(data) }],
|
||||
details: data,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
import { Type } from '@mariozechner/pi-ai'
|
||||
import type { AgentTool } from '@mariozechner/pi-agent-core'
|
||||
import type { ToolContext } from '../types'
|
||||
import * as workerManager from '../../../worker/workerManager'
|
||||
import { isChineseLocale, t } from '../utils/format'
|
||||
|
||||
const schema = Type.Object({
|
||||
member_id: Type.Number({ description: 'ai.tools.get_member_name_history.params.member_id' }),
|
||||
})
|
||||
|
||||
/** 获取成员的昵称变更历史记录。适用于回答"某人以前叫什么名字"、"某人的昵称变化"、"某人曾用名"等问题。需要先通过 get_group_members 工具获取成员 ID。 */
|
||||
export function createTool(context: ToolContext): AgentTool<typeof schema> {
|
||||
return {
|
||||
name: 'get_member_name_history',
|
||||
label: 'get_member_name_history',
|
||||
description: 'ai.tools.get_member_name_history.desc',
|
||||
parameters: schema,
|
||||
execute: async (_toolCallId, params) => {
|
||||
const { sessionId, locale } = context
|
||||
|
||||
const members = await workerManager.getMembers(sessionId)
|
||||
const member = members.find((m) => m.id === params.member_id)
|
||||
|
||||
if (!member) {
|
||||
const data = {
|
||||
error: t('memberNotFound', locale) as string,
|
||||
member_id: params.member_id,
|
||||
}
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify(data) }],
|
||||
details: data,
|
||||
}
|
||||
}
|
||||
|
||||
const history = await workerManager.getMemberNameHistory(sessionId, params.member_id)
|
||||
|
||||
const localeStr = isChineseLocale(locale) ? 'zh-CN' : 'en-US'
|
||||
const untilNow = t('untilNow', locale) as string
|
||||
const formatHistory = (h: { name: string; startTs: number; endTs: number | null }) => {
|
||||
const start = new Date(h.startTs * 1000).toLocaleDateString(localeStr)
|
||||
const end = h.endTs ? new Date(h.endTs * 1000).toLocaleDateString(localeStr) : untilNow
|
||||
return `${h.name} (${start} ~ ${end})`
|
||||
}
|
||||
|
||||
const accountNames = history.filter((h: { nameType: string }) => h.nameType === 'account_name').map(formatHistory)
|
||||
const groupNicknames = history
|
||||
.filter((h: { nameType: string }) => h.nameType === 'group_nickname')
|
||||
.map(formatHistory)
|
||||
|
||||
const displayName = member.groupNickname || member.accountName || member.platformId
|
||||
const aliasLabel = t('alias', locale) as string
|
||||
const aliasStr = member.aliases.length > 0 ? `|${aliasLabel}:${member.aliases.join(',')}` : ''
|
||||
const noChangeRecord = t('noChangeRecord', locale) as string
|
||||
|
||||
const data = {
|
||||
member: `${member.id}|${member.platformId}|${displayName}${aliasStr}`,
|
||||
accountNameHistory: accountNames.length > 0 ? accountNames : noChangeRecord,
|
||||
groupNicknameHistory: groupNicknames.length > 0 ? groupNicknames : noChangeRecord,
|
||||
}
|
||||
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify(data) }],
|
||||
details: data,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
39
electron/main/ai/tools/definitions/get-member-stats.ts
Normal file
39
electron/main/ai/tools/definitions/get-member-stats.ts
Normal file
@@ -0,0 +1,39 @@
|
||||
import { Type } from '@mariozechner/pi-ai'
|
||||
import type { AgentTool } from '@mariozechner/pi-agent-core'
|
||||
import type { ToolContext } from '../types'
|
||||
import * as workerManager from '../../../worker/workerManager'
|
||||
import { isChineseLocale } from '../utils/format'
|
||||
|
||||
const schema = Type.Object({
|
||||
top_n: Type.Optional(Type.Number({ description: 'ai.tools.get_member_stats.params.top_n' })),
|
||||
})
|
||||
|
||||
/** 获取群成员的活跃度统计数据。适用于回答"谁最活跃"、"发言最多的是谁"等问题。 */
|
||||
export function createTool(context: ToolContext): AgentTool<typeof schema> {
|
||||
return {
|
||||
name: 'get_member_stats',
|
||||
label: 'get_member_stats',
|
||||
description: 'ai.tools.get_member_stats.desc',
|
||||
parameters: schema,
|
||||
execute: async (_toolCallId, params) => {
|
||||
const { sessionId, timeFilter, locale } = context
|
||||
const topN = params.top_n || 10
|
||||
|
||||
const result = await workerManager.getMemberActivity(sessionId, timeFilter)
|
||||
const topMembers = result.slice(0, topN)
|
||||
|
||||
const msgSuffix = isChineseLocale(locale) ? '条' : ''
|
||||
const data = {
|
||||
totalMembers: result.length,
|
||||
topMembers: topMembers.map(
|
||||
(m, index) => `${index + 1}. ${m.name} ${m.messageCount}${msgSuffix}(${m.percentage}%)`
|
||||
),
|
||||
}
|
||||
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify(data) }],
|
||||
details: data,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
49
electron/main/ai/tools/definitions/get-message-context.ts
Normal file
49
electron/main/ai/tools/definitions/get-message-context.ts
Normal file
@@ -0,0 +1,49 @@
|
||||
import { Type } from '@mariozechner/pi-ai'
|
||||
import type { AgentTool } from '@mariozechner/pi-agent-core'
|
||||
import type { ToolContext } from '../types'
|
||||
import * as workerManager from '../../../worker/workerManager'
|
||||
import { formatMessageCompact, t } from '../utils/format'
|
||||
|
||||
const schema = Type.Object({
|
||||
message_ids: Type.Array(Type.Number(), { description: 'ai.tools.get_message_context.params.message_ids' }),
|
||||
context_size: Type.Optional(Type.Number({ description: 'ai.tools.get_message_context.params.context_size' })),
|
||||
})
|
||||
|
||||
/** 根据消息 ID 获取前后的上下文消息。适用于需要查看某条消息前后聊天内容的场景,比如"这条消息的前后在聊什么"、"查看某条消息的上下文"等。支持单个或批量消息 ID。 */
|
||||
export function createTool(context: ToolContext): AgentTool<typeof schema> {
|
||||
return {
|
||||
name: 'get_message_context',
|
||||
label: 'get_message_context',
|
||||
description: 'ai.tools.get_message_context.desc',
|
||||
parameters: schema,
|
||||
execute: async (_toolCallId, params) => {
|
||||
const { sessionId, locale } = context
|
||||
const contextSize = params.context_size || 20
|
||||
|
||||
const messages = await workerManager.getMessageContext(sessionId, params.message_ids, contextSize)
|
||||
|
||||
if (messages.length === 0) {
|
||||
const data = {
|
||||
error: t('noMessageContext', locale) as string,
|
||||
messageIds: params.message_ids,
|
||||
}
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify(data) }],
|
||||
details: data,
|
||||
}
|
||||
}
|
||||
|
||||
const data = {
|
||||
totalMessages: messages.length,
|
||||
contextSize: contextSize,
|
||||
requestedMessageIds: params.message_ids,
|
||||
messages: messages.map((m) => formatMessageCompact(m, locale)),
|
||||
}
|
||||
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify(data) }],
|
||||
details: data,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
41
electron/main/ai/tools/definitions/get-recent-messages.ts
Normal file
41
electron/main/ai/tools/definitions/get-recent-messages.ts
Normal file
@@ -0,0 +1,41 @@
|
||||
import { Type } from '@mariozechner/pi-ai'
|
||||
import type { AgentTool } from '@mariozechner/pi-agent-core'
|
||||
import type { ToolContext } from '../types'
|
||||
import * as workerManager from '../../../worker/workerManager'
|
||||
import { parseExtendedTimeParams } from '../utils/time-params'
|
||||
import { formatTimeRange, formatMessageCompact } from '../utils/format'
|
||||
import { timeParamProperties } from '../utils/schemas'
|
||||
|
||||
const schema = Type.Object({
|
||||
limit: Type.Optional(Type.Number({ description: 'ai.tools.get_recent_messages.params.limit' })),
|
||||
...timeParamProperties,
|
||||
})
|
||||
|
||||
/** 获取指定时间段内的群聊消息。适用于回答"最近大家聊了什么"、"X月群里聊了什么"等概览性问题。支持精确到分钟级别的时间查询。 */
|
||||
export function createTool(context: ToolContext): AgentTool<typeof schema> {
|
||||
return {
|
||||
name: 'get_recent_messages',
|
||||
label: 'get_recent_messages',
|
||||
description: 'ai.tools.get_recent_messages.desc',
|
||||
parameters: schema,
|
||||
execute: async (_toolCallId, params) => {
|
||||
const { sessionId, timeFilter: contextTimeFilter, maxMessagesLimit, locale } = context
|
||||
const limit = maxMessagesLimit || params.limit || 100
|
||||
const effectiveTimeFilter = parseExtendedTimeParams(params, contextTimeFilter)
|
||||
|
||||
const result = await workerManager.getRecentMessages(sessionId, effectiveTimeFilter, limit)
|
||||
|
||||
const data = {
|
||||
total: result.total,
|
||||
returned: result.messages.length,
|
||||
timeRange: formatTimeRange(effectiveTimeFilter, locale),
|
||||
messages: result.messages.map((m) => formatMessageCompact(m, locale)),
|
||||
}
|
||||
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify(data) }],
|
||||
details: data,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
51
electron/main/ai/tools/definitions/get-session-messages.ts
Normal file
51
electron/main/ai/tools/definitions/get-session-messages.ts
Normal file
@@ -0,0 +1,51 @@
|
||||
import { Type } from '@mariozechner/pi-ai'
|
||||
import type { AgentTool } from '@mariozechner/pi-agent-core'
|
||||
import type { ToolContext } from '../types'
|
||||
import * as workerManager from '../../../worker/workerManager'
|
||||
import { isChineseLocale, formatMessageCompact } from '../utils/format'
|
||||
|
||||
const schema = Type.Object({
|
||||
session_id: Type.Number({ description: 'ai.tools.get_session_messages.params.session_id' }),
|
||||
limit: Type.Optional(Type.Number({ description: 'ai.tools.get_session_messages.params.limit' })),
|
||||
})
|
||||
|
||||
/** 获取指定会话的完整消息列表。用于在 search_sessions 找到相关会话后,获取该会话的完整上下文。返回会话的所有消息及参与者信息。 */
|
||||
export function createTool(context: ToolContext): AgentTool<typeof schema> {
|
||||
return {
|
||||
name: 'get_session_messages',
|
||||
label: 'get_session_messages',
|
||||
description: 'ai.tools.get_session_messages.desc',
|
||||
parameters: schema,
|
||||
execute: async (_toolCallId, params) => {
|
||||
const { sessionId, maxMessagesLimit, locale } = context
|
||||
const limit = maxMessagesLimit || params.limit || 1000
|
||||
|
||||
const result = await workerManager.getSessionMessages(sessionId, params.session_id, limit)
|
||||
|
||||
let data: Record<string, unknown>
|
||||
if (!result) {
|
||||
data = {
|
||||
error: isChineseLocale(locale) ? '未找到指定的会话' : 'Session not found',
|
||||
sessionId: params.session_id,
|
||||
}
|
||||
} else {
|
||||
const localeStr = isChineseLocale(locale) ? 'zh-CN' : 'en-US'
|
||||
const startTime = new Date(result.startTs * 1000).toLocaleString(localeStr)
|
||||
const endTime = new Date(result.endTs * 1000).toLocaleString(localeStr)
|
||||
data = {
|
||||
sessionId: result.sessionId,
|
||||
time: `${startTime} ~ ${endTime}`,
|
||||
messageCount: result.messageCount,
|
||||
returnedCount: result.returnedCount,
|
||||
participants: result.participants,
|
||||
messages: result.messages.map((m) => formatMessageCompact(m, locale)),
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify(data) }],
|
||||
details: data,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
77
electron/main/ai/tools/definitions/get-session-summaries.ts
Normal file
77
electron/main/ai/tools/definitions/get-session-summaries.ts
Normal file
@@ -0,0 +1,77 @@
|
||||
import { Type } from '@mariozechner/pi-ai'
|
||||
import type { AgentTool } from '@mariozechner/pi-agent-core'
|
||||
import type { ToolContext } from '../types'
|
||||
import * as workerManager from '../../../worker/workerManager'
|
||||
import { parseExtendedTimeParams } from '../utils/time-params'
|
||||
import { formatTimeRange, isChineseLocale } from '../utils/format'
|
||||
import { timeParamPropertiesNoHour } from '../utils/schemas'
|
||||
|
||||
const schema = Type.Object({
|
||||
keywords: Type.Optional(Type.Array(Type.String(), { description: 'ai.tools.get_session_summaries.params.keywords' })),
|
||||
limit: Type.Optional(Type.Number({ description: 'ai.tools.get_session_summaries.params.limit' })),
|
||||
...timeParamPropertiesNoHour,
|
||||
})
|
||||
|
||||
/** 获取会话摘要列表,快速了解群聊历史讨论的主题。适用场景:1. 了解群里最近在聊什么话题 2. 按关键词搜索讨论过的话题 3. 概览性问题如"群里有没有讨论过旅游"。返回的摘要是对每个会话的简短总结,可以帮助快速定位感兴趣的会话,然后用 get_session_messages 获取详情。 */
|
||||
export function createTool(context: ToolContext): AgentTool<typeof schema> {
|
||||
return {
|
||||
name: 'get_session_summaries',
|
||||
label: 'get_session_summaries',
|
||||
description: 'ai.tools.get_session_summaries.desc',
|
||||
parameters: schema,
|
||||
execute: async (_toolCallId, params) => {
|
||||
const { sessionId, timeFilter: contextTimeFilter, locale } = context
|
||||
const limit = params.limit || 20
|
||||
const effectiveTimeFilter = parseExtendedTimeParams(params, contextTimeFilter)
|
||||
|
||||
const sessions = await workerManager.getSessionSummaries(sessionId, {
|
||||
limit: limit * 2,
|
||||
timeFilter: effectiveTimeFilter,
|
||||
})
|
||||
|
||||
let data: Record<string, unknown>
|
||||
if (!sessions || sessions.length === 0) {
|
||||
data = {
|
||||
message: isChineseLocale(locale)
|
||||
? '未找到带摘要的会话。可能还没有生成摘要,请在会话时间线中点击"批量生成"按钮。'
|
||||
: 'No sessions with summaries found. Summaries may not have been generated yet.',
|
||||
}
|
||||
} else {
|
||||
let filteredSessions = sessions
|
||||
if (params.keywords && params.keywords.length > 0) {
|
||||
const keywords = params.keywords.map((k) => k.toLowerCase())
|
||||
filteredSessions = sessions.filter((s) =>
|
||||
keywords.some((keyword) => s.summary?.toLowerCase().includes(keyword))
|
||||
)
|
||||
}
|
||||
|
||||
filteredSessions = filteredSessions.filter((s) => s.summary)
|
||||
const limitedSessions = filteredSessions.slice(0, limit)
|
||||
|
||||
const localeStr = isChineseLocale(locale) ? 'zh-CN' : 'en-US'
|
||||
|
||||
data = {
|
||||
total: filteredSessions.length,
|
||||
returned: limitedSessions.length,
|
||||
timeRange: formatTimeRange(effectiveTimeFilter, locale),
|
||||
sessions: limitedSessions.map((s) => {
|
||||
const startTime = new Date(s.startTs * 1000).toLocaleString(localeStr)
|
||||
const endTime = new Date(s.endTs * 1000).toLocaleString(localeStr)
|
||||
return {
|
||||
sessionId: s.id,
|
||||
time: `${startTime} ~ ${endTime}`,
|
||||
messageCount: s.messageCount,
|
||||
participants: s.participants,
|
||||
summary: s.summary,
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify(data) }],
|
||||
details: data,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
65
electron/main/ai/tools/definitions/get-time-stats.ts
Normal file
65
electron/main/ai/tools/definitions/get-time-stats.ts
Normal file
@@ -0,0 +1,65 @@
|
||||
import { Type } from '@mariozechner/pi-ai'
|
||||
import type { AgentTool } from '@mariozechner/pi-agent-core'
|
||||
import type { ToolContext } from '../types'
|
||||
import * as workerManager from '../../../worker/workerManager'
|
||||
import { isChineseLocale, i18nTexts, t } from '../utils/format'
|
||||
|
||||
const schema = Type.Object({
|
||||
type: Type.Union([Type.Literal('hourly'), Type.Literal('weekday'), Type.Literal('daily')], {
|
||||
description: 'ai.tools.get_time_stats.params.type',
|
||||
}),
|
||||
})
|
||||
|
||||
/** 获取群聊的时间分布统计。适用于回答"什么时候最活跃"、"大家一般几点聊天"等问题。 */
|
||||
export function createTool(context: ToolContext): AgentTool<typeof schema> {
|
||||
return {
|
||||
name: 'get_time_stats',
|
||||
label: 'get_time_stats',
|
||||
description: 'ai.tools.get_time_stats.desc',
|
||||
parameters: schema,
|
||||
execute: async (_toolCallId, params) => {
|
||||
const { sessionId, timeFilter, locale } = context
|
||||
const msgSuffix = isChineseLocale(locale) ? '条' : ''
|
||||
|
||||
let data: Record<string, unknown>
|
||||
switch (params.type) {
|
||||
case 'hourly': {
|
||||
const result = await workerManager.getHourlyActivity(sessionId, timeFilter)
|
||||
const peak = result.reduce((max, curr) => (curr.messageCount > max.messageCount ? curr : max))
|
||||
data = {
|
||||
peakHour: `${peak.hour}:00 (${peak.messageCount}${msgSuffix})`,
|
||||
distribution: result.map((h) => `${h.hour}:00 ${h.messageCount}${msgSuffix}`),
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'weekday': {
|
||||
const weekdayNames = t('weekdays', locale) as string[]
|
||||
const result = await workerManager.getWeekdayActivity(sessionId, timeFilter)
|
||||
const peak = result.reduce((max, curr) => (curr.messageCount > max.messageCount ? curr : max))
|
||||
data = {
|
||||
peakDay: `${weekdayNames[peak.weekday]} (${peak.messageCount}${msgSuffix})`,
|
||||
distribution: result.map((w) => `${weekdayNames[w.weekday]} ${w.messageCount}${msgSuffix}`),
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'daily': {
|
||||
const result = await workerManager.getDailyActivity(sessionId, timeFilter)
|
||||
const recent = result.slice(-30)
|
||||
const total = recent.reduce((sum, d) => sum + d.messageCount, 0)
|
||||
const avg = Math.round(total / recent.length)
|
||||
const summaryFn = i18nTexts.dailySummary[isChineseLocale(locale) ? 'zh' : 'en']
|
||||
data = {
|
||||
summary: summaryFn(recent.length, total, avg),
|
||||
trend: recent.map((d) => `${d.date} ${d.messageCount}${msgSuffix}`),
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify(data) }],
|
||||
details: data,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
17
electron/main/ai/tools/definitions/index.ts
Normal file
17
electron/main/ai/tools/definitions/index.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
/**
|
||||
* 工具定义聚合
|
||||
* 收集 definitions/ 下所有工具的 createTool 工厂函数
|
||||
*/
|
||||
|
||||
export { createTool as createSearchMessages } from './search-messages'
|
||||
export { createTool as createGetRecentMessages } from './get-recent-messages'
|
||||
export { createTool as createGetMemberStats } from './get-member-stats'
|
||||
export { createTool as createGetTimeStats } from './get-time-stats'
|
||||
export { createTool as createGetGroupMembers } from './get-group-members'
|
||||
export { createTool as createGetMemberNameHistory } from './get-member-name-history'
|
||||
export { createTool as createGetConversationBetween } from './get-conversation-between'
|
||||
export { createTool as createGetMessageContext } from './get-message-context'
|
||||
export { createTool as createSearchSessions } from './search-sessions'
|
||||
export { createTool as createGetSessionMessages } from './get-session-messages'
|
||||
export { createTool as createGetSessionSummaries } from './get-session-summaries'
|
||||
export { createTool as createSemanticSearchMessages } from './semantic-search-messages'
|
||||
50
electron/main/ai/tools/definitions/search-messages.ts
Normal file
50
electron/main/ai/tools/definitions/search-messages.ts
Normal file
@@ -0,0 +1,50 @@
|
||||
import { Type } from '@mariozechner/pi-ai'
|
||||
import type { AgentTool } from '@mariozechner/pi-agent-core'
|
||||
import type { ToolContext } from '../types'
|
||||
import * as workerManager from '../../../worker/workerManager'
|
||||
import { parseExtendedTimeParams } from '../utils/time-params'
|
||||
import { formatTimeRange, formatMessageCompact } from '../utils/format'
|
||||
import { timeParamProperties } from '../utils/schemas'
|
||||
|
||||
const schema = Type.Object({
|
||||
keywords: Type.Array(Type.String(), { description: 'ai.tools.search_messages.params.keywords' }),
|
||||
sender_id: Type.Optional(Type.Number({ description: 'ai.tools.search_messages.params.sender_id' })),
|
||||
limit: Type.Optional(Type.Number({ description: 'ai.tools.search_messages.params.limit' })),
|
||||
...timeParamProperties,
|
||||
})
|
||||
|
||||
/** 根据关键词搜索群聊记录。适用于用户想要查找特定话题、关键词相关的聊天内容。可以指定时间范围和发送者来筛选消息。支持精确到分钟级别的时间查询。 */
|
||||
export function createTool(context: ToolContext): AgentTool<typeof schema> {
|
||||
return {
|
||||
name: 'search_messages',
|
||||
label: 'search_messages',
|
||||
description: 'ai.tools.search_messages.desc',
|
||||
parameters: schema,
|
||||
execute: async (_toolCallId, params) => {
|
||||
const { sessionId, timeFilter: contextTimeFilter, maxMessagesLimit, locale } = context
|
||||
const limit = Math.min(maxMessagesLimit || params.limit || 1000, 50000)
|
||||
const effectiveTimeFilter = parseExtendedTimeParams(params, contextTimeFilter)
|
||||
|
||||
const result = await workerManager.searchMessages(
|
||||
sessionId,
|
||||
params.keywords,
|
||||
effectiveTimeFilter,
|
||||
limit,
|
||||
0,
|
||||
params.sender_id
|
||||
)
|
||||
|
||||
const data = {
|
||||
total: result.total,
|
||||
returned: result.messages.length,
|
||||
timeRange: formatTimeRange(effectiveTimeFilter, locale),
|
||||
messages: result.messages.map((m) => formatMessageCompact(m, locale)),
|
||||
}
|
||||
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify(data) }],
|
||||
details: data,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
67
electron/main/ai/tools/definitions/search-sessions.ts
Normal file
67
electron/main/ai/tools/definitions/search-sessions.ts
Normal file
@@ -0,0 +1,67 @@
|
||||
import { Type } from '@mariozechner/pi-ai'
|
||||
import type { AgentTool } from '@mariozechner/pi-agent-core'
|
||||
import type { ToolContext } from '../types'
|
||||
import { timeParamPropertiesNoHour } from '../utils/schemas'
|
||||
import * as workerManager from '../../../worker/workerManager'
|
||||
import { parseExtendedTimeParams } from '../utils/time-params'
|
||||
import { formatTimeRange, formatMessageCompact, isChineseLocale } from '../utils/format'
|
||||
|
||||
const schema = Type.Object({
|
||||
keywords: Type.Optional(Type.Array(Type.String(), { description: 'ai.tools.search_sessions.params.keywords' })),
|
||||
limit: Type.Optional(Type.Number({ description: 'ai.tools.search_sessions.params.limit' })),
|
||||
...timeParamPropertiesNoHour,
|
||||
})
|
||||
|
||||
/** 搜索聊天会话(对话段落)。会话是根据消息时间间隔自动切分的对话单元。适用于查找特定话题的讨论、了解某个时间段内发生了几次对话等场景。返回匹配的会话列表及每个会话的前5条消息预览。 */
|
||||
export function createTool(context: ToolContext): AgentTool<typeof schema> {
|
||||
return {
|
||||
name: 'search_sessions',
|
||||
label: 'search_sessions',
|
||||
description: 'ai.tools.search_sessions.desc',
|
||||
parameters: schema,
|
||||
execute: async (_toolCallId, params) => {
|
||||
const { sessionId, timeFilter: contextTimeFilter, locale } = context
|
||||
const limit = params.limit || 20
|
||||
const effectiveTimeFilter = parseExtendedTimeParams(params, contextTimeFilter)
|
||||
|
||||
const sessions = await workerManager.searchSessions(sessionId, params.keywords, effectiveTimeFilter, limit, 5)
|
||||
|
||||
if (sessions.length === 0) {
|
||||
const data = {
|
||||
total: 0,
|
||||
message: isChineseLocale(locale) ? '未找到匹配的会话' : 'No matching sessions found',
|
||||
}
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify(data) }],
|
||||
details: data,
|
||||
}
|
||||
}
|
||||
|
||||
const localeStr = isChineseLocale(locale) ? 'zh-CN' : 'en-US'
|
||||
const msgSuffix = isChineseLocale(locale) ? '条消息' : ' messages'
|
||||
const completeLabel = isChineseLocale(locale) ? '完整会话' : 'complete'
|
||||
|
||||
const data = {
|
||||
total: sessions.length,
|
||||
timeRange: formatTimeRange(effectiveTimeFilter, locale),
|
||||
sessions: sessions.map((s) => {
|
||||
const startTime = new Date(s.startTs * 1000).toLocaleString(localeStr)
|
||||
const endTime = new Date(s.endTs * 1000).toLocaleString(localeStr)
|
||||
const completeTag = s.isComplete ? ` [${completeLabel}]` : ''
|
||||
|
||||
return {
|
||||
sessionId: s.id,
|
||||
time: `${startTime} ~ ${endTime}`,
|
||||
messageCount: `${s.messageCount}${msgSuffix}${completeTag}`,
|
||||
preview: s.previewMessages.map((m) => formatMessageCompact(m, locale)),
|
||||
}
|
||||
}),
|
||||
}
|
||||
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify(data) }],
|
||||
details: data,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
import { Type } from '@mariozechner/pi-ai'
|
||||
import type { AgentTool } from '@mariozechner/pi-agent-core'
|
||||
import type { ToolContext } from '../types'
|
||||
import { executeSemanticPipeline, isEmbeddingEnabled } from '../../rag'
|
||||
import { getDbPath } from '../../../database/core'
|
||||
import { parseExtendedTimeParams } from '../utils/time-params'
|
||||
import { formatTimeRange, isChineseLocale } from '../utils/format'
|
||||
import { timeParamPropertiesNoHour } from '../utils/schemas'
|
||||
|
||||
const schema = Type.Object({
|
||||
query: Type.String({ description: 'ai.tools.semantic_search_messages.params.query' }),
|
||||
top_k: Type.Optional(Type.Number({ description: 'ai.tools.semantic_search_messages.params.top_k' })),
|
||||
candidate_limit: Type.Optional(
|
||||
Type.Number({ description: 'ai.tools.semantic_search_messages.params.candidate_limit' })
|
||||
),
|
||||
...timeParamPropertiesNoHour,
|
||||
})
|
||||
|
||||
/** 使用 Embedding 向量相似度搜索历史对话,理解语义而非关键词匹配。⚠️ 使用场景(优先使用 search_messages 关键词搜索,以下场景再考虑本工具):1. 找"类似的话"或"类似的表达":如"有没有说过类似'我想你了'这样的话" 2. 关键词搜索结果不足:当 search_messages 返回结果太少或不相关时,可用本工具补充 3. 模糊的情感/关系分析:如"对方对我的态度是怎样的"、"我们之间的氛围"。❌ 不适合的场景(请用 search_messages):有明确关键词的搜索(如"旅游"、"生日"、"加班")、查找特定人物的发言、查找特定时间段的消息 */
|
||||
export function createTool(context: ToolContext): AgentTool<typeof schema> {
|
||||
return {
|
||||
name: 'semantic_search_messages',
|
||||
label: 'semantic_search_messages',
|
||||
description: 'ai.tools.semantic_search_messages.desc',
|
||||
parameters: schema,
|
||||
execute: async (_toolCallId, params) => {
|
||||
const { sessionId, timeFilter: contextTimeFilter, locale } = context
|
||||
|
||||
let data: Record<string, unknown>
|
||||
if (!isEmbeddingEnabled()) {
|
||||
data = {
|
||||
error: isChineseLocale(locale)
|
||||
? '语义搜索未启用。请在设置中添加并启用 Embedding 配置。'
|
||||
: 'Semantic search is not enabled. Please add and enable an Embedding config in settings.',
|
||||
}
|
||||
} else {
|
||||
const effectiveTimeFilter = parseExtendedTimeParams(params, contextTimeFilter)
|
||||
const dbPath = getDbPath(sessionId)
|
||||
|
||||
const result = await executeSemanticPipeline({
|
||||
userMessage: params.query,
|
||||
dbPath,
|
||||
timeFilter: effectiveTimeFilter,
|
||||
candidateLimit: params.candidate_limit,
|
||||
topK: params.top_k,
|
||||
})
|
||||
|
||||
if (!result.success) {
|
||||
data = {
|
||||
error: result.error || (isChineseLocale(locale) ? '语义搜索失败' : 'Semantic search failed'),
|
||||
}
|
||||
} else if (result.results.length === 0) {
|
||||
data = {
|
||||
message: isChineseLocale(locale) ? '未找到相关的历史对话' : 'No relevant conversations found',
|
||||
rewrittenQuery: result.rewrittenQuery,
|
||||
}
|
||||
} else {
|
||||
data = {
|
||||
total: result.results.length,
|
||||
rewrittenQuery: result.rewrittenQuery,
|
||||
timeRange: formatTimeRange(effectiveTimeFilter, locale),
|
||||
results: result.results.map((r, i) => ({
|
||||
rank: i + 1,
|
||||
score: `${(r.score * 100).toFixed(1)}%`,
|
||||
sessionId: r.metadata?.sessionId,
|
||||
timeRange: r.metadata
|
||||
? formatTimeRange({ startTs: r.metadata.startTs, endTs: r.metadata.endTs }, locale)
|
||||
: undefined,
|
||||
participants: r.metadata?.participants,
|
||||
content: r.content.length > 500 ? r.content.slice(0, 500) + '...' : r.content,
|
||||
})),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify(data) }],
|
||||
details: data,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,182 +1,88 @@
|
||||
/**
|
||||
* AI Tools 模块入口
|
||||
* 工具注册与管理
|
||||
* 工具创建与管理
|
||||
*/
|
||||
|
||||
import type { ToolDefinition, ToolCall } from '../llm/types'
|
||||
import type { ToolRegistry, RegisteredTool, ToolContext, ToolExecutionResult, ToolExecutor } from './types'
|
||||
import type { AgentTool } from '@mariozechner/pi-agent-core'
|
||||
import type { ToolContext } from './types'
|
||||
import {
|
||||
createSearchMessages,
|
||||
createGetRecentMessages,
|
||||
createGetMemberStats,
|
||||
createGetTimeStats,
|
||||
createGetGroupMembers,
|
||||
createGetMemberNameHistory,
|
||||
createGetConversationBetween,
|
||||
createGetMessageContext,
|
||||
createSearchSessions,
|
||||
createGetSessionMessages,
|
||||
createGetSessionSummaries,
|
||||
createSemanticSearchMessages,
|
||||
} from './definitions'
|
||||
import { isEmbeddingEnabled } from '../rag'
|
||||
import { t as i18nT } from '../../i18n'
|
||||
|
||||
// 导出类型
|
||||
export * from './types'
|
||||
|
||||
// 全局工具注册表
|
||||
const toolRegistry: ToolRegistry = new Map()
|
||||
type ToolFactory = (context: ToolContext) => AgentTool<any>
|
||||
|
||||
// 工具是否已初始化
|
||||
let toolsInitialized = false
|
||||
let initPromise: Promise<void> | null = null
|
||||
const coreFactories: ToolFactory[] = [
|
||||
createSearchMessages,
|
||||
createGetRecentMessages,
|
||||
createGetMemberStats,
|
||||
createGetTimeStats,
|
||||
createGetGroupMembers,
|
||||
createGetMemberNameHistory,
|
||||
createGetConversationBetween,
|
||||
createGetMessageContext,
|
||||
createSearchSessions,
|
||||
createGetSessionMessages,
|
||||
createGetSessionSummaries,
|
||||
]
|
||||
|
||||
/**
|
||||
* 注册一个工具
|
||||
* @param definition 工具定义
|
||||
* @param executor 执行函数
|
||||
*/
|
||||
export function registerTool(definition: ToolDefinition, executor: ToolExecutor): void {
|
||||
const name = definition.function.name
|
||||
toolRegistry.set(name, { definition, executor })
|
||||
}
|
||||
|
||||
/**
|
||||
* 初始化所有工具(确保工具已注册)
|
||||
* 使用动态 import 避免循环依赖
|
||||
*/
|
||||
export async function ensureToolsInitialized(): Promise<void> {
|
||||
if (toolsInitialized) return
|
||||
if (initPromise) return initPromise
|
||||
|
||||
initPromise = (async () => {
|
||||
// 动态导入 registry 模块
|
||||
await import('./registry')
|
||||
toolsInitialized = true
|
||||
})()
|
||||
|
||||
return initPromise
|
||||
}
|
||||
|
||||
/**
|
||||
* 翻译工具定义的 description 和参数 description
|
||||
* 使用 i18next 查找翻译,如果未找到则保留原始文本(中文)
|
||||
* 翻译 AgentTool 的描述(工具级 + 参数级)
|
||||
*
|
||||
* i18n 键命名规则:
|
||||
* - 工具描述:ai.tools.{toolName}.desc
|
||||
* - 参数描述:ai.tools.{toolName}.params.{paramName}
|
||||
*/
|
||||
function translateToolDefinition(tool: ToolDefinition): ToolDefinition {
|
||||
const name = tool.function.name
|
||||
function translateTool(tool: AgentTool<any>): AgentTool<any> {
|
||||
const name = tool.name
|
||||
|
||||
const descKey = `ai.tools.${name}.desc`
|
||||
const translatedDesc = i18nT(descKey)
|
||||
|
||||
// 深拷贝并翻译参数描述
|
||||
const translatedProperties: typeof tool.function.parameters.properties = {}
|
||||
for (const [paramName, param] of Object.entries(tool.function.parameters.properties)) {
|
||||
const paramKey = `ai.tools.${name}.params.${paramName}`
|
||||
const translatedParamDesc = i18nT(paramKey)
|
||||
translatedProperties[paramName] = {
|
||||
...param,
|
||||
// 如果 i18next 返回的是 key 本身,说明没有找到翻译,保留原始文本
|
||||
description: translatedParamDesc !== paramKey ? translatedParamDesc : param.description,
|
||||
const params = tool.parameters as Record<string, unknown>
|
||||
if (params?.properties && typeof params.properties === 'object') {
|
||||
for (const [paramName, param] of Object.entries(params.properties as Record<string, Record<string, unknown>>)) {
|
||||
const paramKey = `ai.tools.${name}.params.${paramName}`
|
||||
const translated = i18nT(paramKey)
|
||||
if (translated !== paramKey) {
|
||||
param.description = translated
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
type: tool.type,
|
||||
function: {
|
||||
name: tool.function.name,
|
||||
// 如果 i18next 返回的是 key 本身,说明没有找到翻译,保留原始文本
|
||||
description: translatedDesc !== descKey ? translatedDesc : tool.function.description,
|
||||
parameters: {
|
||||
type: tool.function.parameters.type,
|
||||
properties: translatedProperties,
|
||||
required: tool.function.parameters.required,
|
||||
},
|
||||
},
|
||||
...tool,
|
||||
description: translatedDesc !== descKey ? translatedDesc : tool.description,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有已注册的工具定义
|
||||
* 获取所有可用的 AgentTool
|
||||
*
|
||||
* 根据配置动态过滤工具(如:语义搜索工具仅在启用 Embedding 时可用)
|
||||
* 根据当前 locale 动态翻译工具描述(解决"响应式"陷阱:每次调用时实时翻译)
|
||||
* @returns 工具定义数组(用于传给 LLM)
|
||||
* 根据当前 locale 动态翻译工具描述
|
||||
*/
|
||||
export async function getAllToolDefinitions(): Promise<ToolDefinition[]> {
|
||||
await ensureToolsInitialized()
|
||||
export function getAllTools(context: ToolContext): AgentTool<any>[] {
|
||||
const tools: AgentTool<any>[] = coreFactories.map((f) => f(context))
|
||||
|
||||
const allTools = Array.from(toolRegistry.values()).map((reg) => reg.definition)
|
||||
|
||||
// 根据 Embedding 配置决定是否包含语义搜索工具
|
||||
const embeddingEnabled = isEmbeddingEnabled()
|
||||
const filteredTools = embeddingEnabled
|
||||
? allTools
|
||||
: allTools.filter((tool) => tool.function.name !== 'semantic_search_messages')
|
||||
|
||||
// 所有 locale 统一走翻译层,确保 locale 文件同构
|
||||
return filteredTools.map(translateToolDefinition)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取指定工具
|
||||
* @param name 工具名称
|
||||
*/
|
||||
export async function getTool(name: string): Promise<RegisteredTool | undefined> {
|
||||
await ensureToolsInitialized()
|
||||
return toolRegistry.get(name)
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行单个工具调用
|
||||
* @param toolCall LLM 返回的 tool_call
|
||||
* @param context 执行上下文
|
||||
*/
|
||||
export async function executeToolCall(toolCall: ToolCall, context: ToolContext): Promise<ToolExecutionResult> {
|
||||
await ensureToolsInitialized()
|
||||
const toolName = toolCall.function.name
|
||||
|
||||
// 查找工具
|
||||
const tool = toolRegistry.get(toolName)
|
||||
if (!tool) {
|
||||
return {
|
||||
toolName,
|
||||
success: false,
|
||||
error: i18nT('tools.notRegistered', { toolName }),
|
||||
}
|
||||
if (isEmbeddingEnabled()) {
|
||||
tools.push(createSemanticSearchMessages(context))
|
||||
}
|
||||
|
||||
try {
|
||||
// 解析参数
|
||||
const params = JSON.parse(toolCall.function.arguments || '{}')
|
||||
|
||||
// 执行工具
|
||||
const result = await tool.executor(params, context)
|
||||
|
||||
return {
|
||||
toolName,
|
||||
success: true,
|
||||
result,
|
||||
}
|
||||
} catch (error) {
|
||||
return {
|
||||
toolName,
|
||||
success: false,
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量执行工具调用
|
||||
* @param toolCalls LLM 返回的 tool_calls 数组
|
||||
* @param context 执行上下文
|
||||
*/
|
||||
export async function executeToolCalls(toolCalls: ToolCall[], context: ToolContext): Promise<ToolExecutionResult[]> {
|
||||
// 并行执行所有工具调用
|
||||
return Promise.all(toolCalls.map((tc) => executeToolCall(tc, context)))
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查工具是否已注册
|
||||
*/
|
||||
export async function hasToolsRegistered(): Promise<boolean> {
|
||||
await ensureToolsInitialized()
|
||||
return toolRegistry.size > 0
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取已注册工具数量
|
||||
*/
|
||||
export async function getRegisteredToolCount(): Promise<number> {
|
||||
await ensureToolsInitialized()
|
||||
return toolRegistry.size
|
||||
return tools.map(translateTool)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,14 +1,7 @@
|
||||
/**
|
||||
* AI Tools 类型定义
|
||||
* 定义工具的接口和执行上下文
|
||||
*/
|
||||
|
||||
import type { ToolDefinition } from '../llm/types'
|
||||
|
||||
/**
|
||||
* 工具执行上下文
|
||||
* 包含执行工具时需要的所有上下文信息
|
||||
*/
|
||||
/** Owner 信息(当前用户在对话中的身份) */
|
||||
export interface OwnerInfo {
|
||||
/** Owner 的 platformId */
|
||||
@@ -17,6 +10,10 @@ export interface OwnerInfo {
|
||||
displayName: string
|
||||
}
|
||||
|
||||
/**
|
||||
* 工具执行上下文
|
||||
* 包含执行工具时需要的所有上下文信息
|
||||
*/
|
||||
export interface ToolContext {
|
||||
/** 当前会话 ID(数据库文件名) */
|
||||
sessionId: string
|
||||
@@ -34,41 +31,3 @@ export interface ToolContext {
|
||||
/** 语言环境(用于工具返回结果的国际化) */
|
||||
locale?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* 工具执行函数类型
|
||||
* @param params 从 LLM 解析出的参数对象
|
||||
* @param context 执行上下文
|
||||
* @returns 执行结果(将被序列化为字符串传回 LLM)
|
||||
*/
|
||||
export type ToolExecutor<T = Record<string, unknown>> = (params: T, context: ToolContext) => Promise<unknown>
|
||||
|
||||
/**
|
||||
* 注册的工具
|
||||
* 包含工具定义和执行函数
|
||||
*/
|
||||
export interface RegisteredTool {
|
||||
/** 工具定义(OpenAI 格式) */
|
||||
definition: ToolDefinition
|
||||
/** 执行函数 */
|
||||
executor: ToolExecutor
|
||||
}
|
||||
|
||||
/**
|
||||
* 工具注册表
|
||||
*/
|
||||
export type ToolRegistry = Map<string, RegisteredTool>
|
||||
|
||||
/**
|
||||
* 工具执行结果
|
||||
*/
|
||||
export interface ToolExecutionResult {
|
||||
/** 工具名称 */
|
||||
toolName: string
|
||||
/** 执行是否成功 */
|
||||
success: boolean
|
||||
/** 执行结果(成功时) */
|
||||
result?: unknown
|
||||
/** 错误信息(失败时) */
|
||||
error?: string
|
||||
}
|
||||
|
||||
76
electron/main/ai/tools/utils/format.ts
Normal file
76
electron/main/ai/tools/utils/format.ts
Normal file
@@ -0,0 +1,76 @@
|
||||
/**
|
||||
* 工具结果格式化 & i18n 辅助
|
||||
*/
|
||||
|
||||
export function isChineseLocale(locale?: string): boolean {
|
||||
return locale === 'zh-CN'
|
||||
}
|
||||
|
||||
export const i18nTexts = {
|
||||
allTime: { zh: '全部时间', en: 'All time' },
|
||||
noContent: { zh: '[无内容]', en: '[No content]' },
|
||||
memberNotFound: { zh: '未找到该成员', en: 'Member not found' },
|
||||
untilNow: { zh: '至今', en: 'Present' },
|
||||
noChangeRecord: { zh: '无变更记录', en: 'No change record' },
|
||||
noConversation: { zh: '未找到这两人之间的对话', en: 'No conversation found between these two members' },
|
||||
noMessageContext: { zh: '未找到指定的消息或上下文', en: 'Message or context not found' },
|
||||
messages: { zh: '条', en: '' },
|
||||
alias: { zh: '别名', en: 'Alias' },
|
||||
weekdays: {
|
||||
zh: ['', '周一', '周二', '周三', '周四', '周五', '周六', '周日'],
|
||||
en: ['', 'Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'],
|
||||
},
|
||||
dailySummary: {
|
||||
zh: (days: number, total: number, avg: number) => `最近${days}天共${total}条,日均${avg}条`,
|
||||
en: (days: number, total: number, avg: number) => `Last ${days} days: ${total} messages, avg ${avg}/day`,
|
||||
},
|
||||
}
|
||||
|
||||
export function t(key: keyof typeof i18nTexts, locale?: string): string | string[] {
|
||||
const text = i18nTexts[key]
|
||||
if (typeof text === 'object' && 'zh' in text && 'en' in text) {
|
||||
return isChineseLocale(locale) ? text.zh : text.en
|
||||
}
|
||||
return ''
|
||||
}
|
||||
|
||||
const MAX_MESSAGE_CONTENT_LENGTH = 200
|
||||
|
||||
/**
|
||||
* 格式化消息为简洁文本格式
|
||||
* 输出格式: "2025/3/3 07:25:04 张三: 消息内容"
|
||||
*/
|
||||
export function formatMessageCompact(
|
||||
msg: {
|
||||
id?: number
|
||||
senderName: string
|
||||
content: string | null
|
||||
timestamp: number
|
||||
},
|
||||
locale?: string
|
||||
): string {
|
||||
const localeStr = isChineseLocale(locale) ? 'zh-CN' : 'en-US'
|
||||
const time = new Date(msg.timestamp * 1000).toLocaleString(localeStr)
|
||||
let content = msg.content || (t('noContent', locale) as string)
|
||||
|
||||
if (content.length > MAX_MESSAGE_CONTENT_LENGTH) {
|
||||
content = content.slice(0, MAX_MESSAGE_CONTENT_LENGTH) + '...'
|
||||
}
|
||||
|
||||
return `${time} ${msg.senderName}: ${content}`
|
||||
}
|
||||
|
||||
/**
|
||||
* 格式化时间范围用于返回结果
|
||||
*/
|
||||
export function formatTimeRange(
|
||||
timeFilter?: { startTs: number; endTs: number },
|
||||
locale?: string
|
||||
): string | { start: string; end: string } {
|
||||
if (!timeFilter) return t('allTime', locale) as string
|
||||
const localeStr = isChineseLocale(locale) ? 'zh-CN' : 'en-US'
|
||||
return {
|
||||
start: new Date(timeFilter.startTs * 1000).toLocaleString(localeStr),
|
||||
end: new Date(timeFilter.endTs * 1000).toLocaleString(localeStr),
|
||||
}
|
||||
}
|
||||
23
electron/main/ai/tools/utils/schemas.ts
Normal file
23
electron/main/ai/tools/utils/schemas.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
/**
|
||||
* 共享 TypeBox Schema 片段
|
||||
* 多个工具复用的时间参数 schema
|
||||
*/
|
||||
|
||||
import { Type } from '@mariozechner/pi-ai'
|
||||
|
||||
export const timeParamProperties = {
|
||||
year: Type.Optional(Type.Number({ description: 'ai.tools._shared.params.year' })),
|
||||
month: Type.Optional(Type.Number({ description: 'ai.tools._shared.params.month' })),
|
||||
day: Type.Optional(Type.Number({ description: 'ai.tools._shared.params.day' })),
|
||||
hour: Type.Optional(Type.Number({ description: 'ai.tools._shared.params.hour' })),
|
||||
start_time: Type.Optional(Type.String({ description: 'ai.tools._shared.params.start_time' })),
|
||||
end_time: Type.Optional(Type.String({ description: 'ai.tools._shared.params.end_time' })),
|
||||
}
|
||||
|
||||
export const timeParamPropertiesNoHour = {
|
||||
year: Type.Optional(Type.Number({ description: 'ai.tools._shared.params.year' })),
|
||||
month: Type.Optional(Type.Number({ description: 'ai.tools._shared.params.month' })),
|
||||
day: Type.Optional(Type.Number({ description: 'ai.tools._shared.params.day' })),
|
||||
start_time: Type.Optional(Type.String({ description: 'ai.tools._shared.params.start_time' })),
|
||||
end_time: Type.Optional(Type.String({ description: 'ai.tools._shared.params.end_time' })),
|
||||
}
|
||||
78
electron/main/ai/tools/utils/time-params.ts
Normal file
78
electron/main/ai/tools/utils/time-params.ts
Normal file
@@ -0,0 +1,78 @@
|
||||
/**
|
||||
* 时间参数解析工具
|
||||
*/
|
||||
|
||||
export interface ExtendedTimeParams {
|
||||
year?: number
|
||||
month?: number
|
||||
day?: number
|
||||
hour?: number
|
||||
start_time?: string // 格式: "YYYY-MM-DD HH:mm"
|
||||
end_time?: string // 格式: "YYYY-MM-DD HH:mm"
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析扩展的时间参数,返回时间过滤器
|
||||
* 优先级: start_time/end_time > year/month/day/hour 组合 > context.timeFilter
|
||||
*/
|
||||
export function parseExtendedTimeParams(
|
||||
params: ExtendedTimeParams,
|
||||
contextTimeFilter?: { startTs: number; endTs: number }
|
||||
): { startTs: number; endTs: number } | undefined {
|
||||
if (params.start_time || params.end_time) {
|
||||
let startTs: number | undefined
|
||||
let endTs: number | undefined
|
||||
|
||||
if (params.start_time) {
|
||||
const startDate = new Date(params.start_time.replace(' ', 'T'))
|
||||
if (!isNaN(startDate.getTime())) {
|
||||
startTs = Math.floor(startDate.getTime() / 1000)
|
||||
}
|
||||
}
|
||||
|
||||
if (params.end_time) {
|
||||
const endDate = new Date(params.end_time.replace(' ', 'T'))
|
||||
if (!isNaN(endDate.getTime())) {
|
||||
endTs = Math.floor(endDate.getTime() / 1000)
|
||||
}
|
||||
}
|
||||
|
||||
if (startTs !== undefined || endTs !== undefined) {
|
||||
return {
|
||||
startTs: startTs ?? 0,
|
||||
endTs: endTs ?? Math.floor(Date.now() / 1000),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (params.year) {
|
||||
const year = params.year
|
||||
const month = params.month
|
||||
const day = params.day
|
||||
const hour = params.hour
|
||||
|
||||
let startDate: Date
|
||||
let endDate: Date
|
||||
|
||||
if (month && day && hour !== undefined) {
|
||||
startDate = new Date(year, month - 1, day, hour, 0, 0)
|
||||
endDate = new Date(year, month - 1, day, hour, 59, 59)
|
||||
} else if (month && day) {
|
||||
startDate = new Date(year, month - 1, day, 0, 0, 0)
|
||||
endDate = new Date(year, month - 1, day, 23, 59, 59)
|
||||
} else if (month) {
|
||||
startDate = new Date(year, month - 1, 1)
|
||||
endDate = new Date(year, month, 0, 23, 59, 59)
|
||||
} else {
|
||||
startDate = new Date(year, 0, 1)
|
||||
endDate = new Date(year, 11, 31, 23, 59, 59)
|
||||
}
|
||||
|
||||
return {
|
||||
startTs: Math.floor(startDate.getTime() / 1000),
|
||||
endTs: Math.floor(endDate.getTime() / 1000),
|
||||
}
|
||||
}
|
||||
|
||||
return contextTimeFilter
|
||||
}
|
||||
Reference in New Issue
Block a user