mirror of
https://github.com/hellodigua/ChatLab.git
synced 2026-05-19 04:49:36 +08:00
feat: 新增获取聊天概览工具
This commit is contained in:
@@ -0,0 +1,61 @@
|
||||
import { Type } from '@mariozechner/pi-ai'
|
||||
import type { AgentTool } from '@mariozechner/pi-agent-core'
|
||||
import type { ToolContext } from '../types'
|
||||
import * as workerManager from '../../../worker/workerManager'
|
||||
import { isChineseLocale } from '../utils/format'
|
||||
|
||||
const schema = Type.Object({
|
||||
top_n: Type.Optional(Type.Number({ description: 'ai.tools.get_chat_overview.params.top_n' })),
|
||||
})
|
||||
|
||||
export function createTool(context: ToolContext): AgentTool<typeof schema> {
|
||||
return {
|
||||
name: 'get_chat_overview',
|
||||
label: 'get_chat_overview',
|
||||
description: 'ai.tools.get_chat_overview.desc',
|
||||
parameters: schema,
|
||||
execute: async (_toolCallId, params) => {
|
||||
const { sessionId, locale } = context
|
||||
const topN = params.top_n || 10
|
||||
|
||||
const result = await workerManager.getChatOverview(sessionId, topN)
|
||||
if (!result) {
|
||||
const msg = isChineseLocale(locale) ? '无法获取聊天概览' : 'Unable to get chat overview'
|
||||
return {
|
||||
content: [{ type: 'text', text: msg }],
|
||||
details: { error: msg },
|
||||
}
|
||||
}
|
||||
|
||||
const msgSuffix = isChineseLocale(locale) ? '条' : ''
|
||||
const lines: string[] = [
|
||||
`name: ${result.name}`,
|
||||
`platform: ${result.platform}`,
|
||||
`type: ${result.type}`,
|
||||
`totalMessages: ${result.totalMessages}`,
|
||||
`totalMembers: ${result.totalMembers}`,
|
||||
]
|
||||
|
||||
if (result.firstMessageTs != null && result.lastMessageTs != null) {
|
||||
const start = new Date(result.firstMessageTs * 1000).toLocaleDateString()
|
||||
const end = new Date(result.lastMessageTs * 1000).toLocaleDateString()
|
||||
lines.push(`timeRange: ${start} ~ ${end}`)
|
||||
}
|
||||
|
||||
if (result.topMembers.length > 0) {
|
||||
lines.push(`topMembers:`)
|
||||
for (let i = 0; i < result.topMembers.length; i++) {
|
||||
const m = result.topMembers[i]
|
||||
const pct = result.totalMessages > 0 ? ((m.count / result.totalMessages) * 100).toFixed(1) : '0'
|
||||
lines.push(`${i + 1}. ${m.name} ${m.count}${msgSuffix}(${pct}%)`)
|
||||
}
|
||||
}
|
||||
|
||||
const text = lines.join('\n')
|
||||
return {
|
||||
content: [{ type: 'text', text }],
|
||||
details: result,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -14,9 +14,11 @@ export { createTool as createGetMessageContext } from './get-message-context'
|
||||
export { createTool as createSearchSessions } from './search-sessions'
|
||||
export { createTool as createGetSessionMessages } from './get-session-messages'
|
||||
export { createTool as createGetSessionSummaries } from './get-session-summaries'
|
||||
export { createTool as createGetChatOverview } from './get-chat-overview'
|
||||
export { sqlToolFactories, getSqlToolCatalog, SQL_TOOL_NAMES } from './sql-analysis'
|
||||
|
||||
export const TS_TOOL_NAMES = [
|
||||
'get_chat_overview',
|
||||
'search_messages',
|
||||
'get_recent_messages',
|
||||
'get_member_stats',
|
||||
|
||||
@@ -19,6 +19,7 @@ import {
|
||||
createSearchSessions,
|
||||
createGetSessionMessages,
|
||||
createGetSessionSummaries,
|
||||
createGetChatOverview,
|
||||
sqlToolFactories,
|
||||
} from './definitions'
|
||||
import { t as i18nT } from '../../i18n'
|
||||
@@ -33,6 +34,7 @@ export * from './types'
|
||||
type ToolFactory = (context: ToolContext) => AgentTool<any>
|
||||
|
||||
const coreFactories: ToolFactory[] = [
|
||||
createGetChatOverview,
|
||||
createSearchMessages,
|
||||
createGetRecentMessages,
|
||||
createGetMemberStats,
|
||||
|
||||
@@ -97,7 +97,7 @@ export function deleteSessionCache(sessionId: string, cacheDir: string): void {
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Overview 缓存(业务层) ====================
|
||||
// ==================== Overview 缓存(聚合统计) ====================
|
||||
|
||||
export const CACHE_KEY_OVERVIEW = 'overview'
|
||||
|
||||
@@ -106,8 +106,6 @@ export interface OverviewCache {
|
||||
totalMembers: number
|
||||
firstMessageTs: number | null
|
||||
lastMessageTs: number | null
|
||||
/** member.id -> message count */
|
||||
memberMessageCounts: Record<number, number>
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -139,23 +137,59 @@ export function computeAndSetOverviewCache(
|
||||
}
|
||||
).count
|
||||
|
||||
const memberCounts = db
|
||||
.prepare('SELECT sender_id, COUNT(*) as count FROM message GROUP BY sender_id')
|
||||
.all() as Array<{ sender_id: number; count: number }>
|
||||
|
||||
const memberMessageCounts: Record<number, number> = {}
|
||||
for (const row of memberCounts) {
|
||||
memberMessageCounts[row.sender_id] = row.count
|
||||
}
|
||||
|
||||
const data: OverviewCache = {
|
||||
totalMessages,
|
||||
totalMembers,
|
||||
firstMessageTs: msgStats.first_ts,
|
||||
lastMessageTs: msgStats.last_ts,
|
||||
memberMessageCounts,
|
||||
}
|
||||
|
||||
setCache(sessionId, CACHE_KEY_OVERVIEW, data, cacheDir)
|
||||
|
||||
// 同步生成成员缓存
|
||||
computeAndSetMembersCache(db, sessionId, cacheDir)
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// ==================== Members 缓存(成员维度) ====================
|
||||
|
||||
export const CACHE_KEY_MEMBERS = 'members'
|
||||
|
||||
export interface MemberStat {
|
||||
name: string
|
||||
count: number
|
||||
}
|
||||
|
||||
export interface MembersCache {
|
||||
/** member.id -> { name, count } */
|
||||
members: Record<number, MemberStat>
|
||||
}
|
||||
|
||||
/**
|
||||
* 从数据库计算成员统计并写入缓存
|
||||
*/
|
||||
export function computeAndSetMembersCache(
|
||||
db: Database.Database,
|
||||
sessionId: string,
|
||||
cacheDir: string
|
||||
): MembersCache {
|
||||
const rows = db
|
||||
.prepare(
|
||||
`SELECT msg.sender_id, COUNT(*) as count,
|
||||
COALESCE(m.group_nickname, m.account_name, m.platform_id) as name
|
||||
FROM message msg
|
||||
JOIN member m ON msg.sender_id = m.id
|
||||
GROUP BY msg.sender_id`
|
||||
)
|
||||
.all() as Array<{ sender_id: number; count: number; name: string }>
|
||||
|
||||
const members: Record<number, MemberStat> = {}
|
||||
for (const row of rows) {
|
||||
members[row.sender_id] = { name: row.name, count: row.count }
|
||||
}
|
||||
|
||||
const data: MembersCache = { members }
|
||||
setCache(sessionId, CACHE_KEY_MEMBERS, data, cacheDir)
|
||||
return data
|
||||
}
|
||||
|
||||
@@ -86,6 +86,12 @@ export default {
|
||||
'End time, format "YYYY-MM-DD HH:mm", e.g. "2024-03-15 18:30". Overrides year/month/day/hour when specified',
|
||||
},
|
||||
},
|
||||
get_chat_overview: {
|
||||
desc: 'Get basic overview of the chat: name, platform, type, total messages, total members, time range, and top active members. Use this first to understand the data before deeper analysis.',
|
||||
params: {
|
||||
top_n: 'Return top N active members, default 10',
|
||||
},
|
||||
},
|
||||
get_member_stats: {
|
||||
desc: 'Get member activity statistics. Suitable for questions like "who is the most active" or "who sends the most messages".',
|
||||
params: {
|
||||
|
||||
@@ -87,6 +87,12 @@ export default {
|
||||
'終了時刻。形式 "YYYY-MM-DD HH:mm"(例:"2024-03-15 18:30")。指定すると year/month/day/hour パラメータを上書きする',
|
||||
},
|
||||
},
|
||||
get_chat_overview: {
|
||||
desc: 'チャット記録の基本概要を取得する:グループ名、プラットフォーム、タイプ、総メッセージ数、総メンバー数、期間、最もアクティブなメンバーランキング。分析前にデータの全体像を把握するのに適している。',
|
||||
params: {
|
||||
top_n: '上位 N 名のアクティブメンバーを返却。デフォルト 10',
|
||||
},
|
||||
},
|
||||
get_member_stats: {
|
||||
desc: 'グループメンバーのアクティビティ統計データを取得する。「最もアクティブなのは誰?」「発言数が一番多いのは?」などの質問に適している。',
|
||||
params: {
|
||||
|
||||
@@ -79,6 +79,12 @@ export default {
|
||||
end_time: '结束时间,格式 "YYYY-MM-DD HH:mm",如 "2024-03-15 18:30"。指定后会覆盖 year/month/day/hour 参数',
|
||||
},
|
||||
},
|
||||
get_chat_overview: {
|
||||
desc: '获取聊天记录的基本概览信息,包括群名/平台/类型/总消息数/总成员数/时间跨度/最活跃成员排名。适合在分析前先了解数据全貌。',
|
||||
params: {
|
||||
top_n: '返回前 N 名活跃成员,默认 10',
|
||||
},
|
||||
},
|
||||
get_member_stats: {
|
||||
desc: '获取群成员的活跃度统计数据。适用于回答"谁最活跃"、"发言最多的是谁"等问题。',
|
||||
params: {
|
||||
|
||||
@@ -79,6 +79,12 @@ export default {
|
||||
end_time: '結束時間,格式 "YYYY-MM-DD HH:mm",如 "2024-03-15 18:30"。指定後會覆蓋 year/month/day/hour 參數',
|
||||
},
|
||||
},
|
||||
get_chat_overview: {
|
||||
desc: '取得聊天記錄的基本概覽資訊,包括群名/平台/類型/總訊息數/總成員數/時間跨度/最活躍成員排名。適合在分析前先了解資料全貌。',
|
||||
params: {
|
||||
top_n: '回傳前 N 名活躍成員,預設 10',
|
||||
},
|
||||
},
|
||||
get_member_stats: {
|
||||
desc: '取得群成員的活躍度統計資料。適用於回答「誰最活躍」、「發言最多的是誰」等問題。',
|
||||
params: {
|
||||
|
||||
@@ -25,6 +25,7 @@ import {
|
||||
getMemberNameHistory,
|
||||
getAllSessions,
|
||||
getSession,
|
||||
getChatOverview,
|
||||
getCatchphraseAnalysis,
|
||||
getMentionAnalysis,
|
||||
getMentionGraph,
|
||||
@@ -96,6 +97,7 @@ const syncHandlers: Record<string, (payload: any) => any> = {
|
||||
// 会话管理
|
||||
getAllSessions: () => getAllSessions(),
|
||||
getSession: (p) => getSession(p.sessionId),
|
||||
getChatOverview: (p) => getChatOverview(p.sessionId, p.topN),
|
||||
closeDatabase: (p) => {
|
||||
closeDatabase(p.sessionId)
|
||||
return true
|
||||
|
||||
@@ -24,7 +24,7 @@ export {
|
||||
} from './basic'
|
||||
|
||||
// 会话管理(会话列表与基础信息)
|
||||
export { getAllSessions, getSession } from './sessions'
|
||||
export { getAllSessions, getSession, getChatOverview } from './sessions'
|
||||
|
||||
// 成员分页类型
|
||||
export type { MembersPaginationParams, MembersPaginatedResult } from './basic'
|
||||
|
||||
@@ -7,7 +7,15 @@ import Database from 'better-sqlite3'
|
||||
import * as fs from 'fs'
|
||||
import * as path from 'path'
|
||||
import { openDatabase, getDbDir, getDbPath, getCacheDir } from '../core'
|
||||
import { getCache, computeAndSetOverviewCache, CACHE_KEY_OVERVIEW, type OverviewCache } from '../../database/sessionCache'
|
||||
import {
|
||||
getCache,
|
||||
computeAndSetOverviewCache,
|
||||
computeAndSetMembersCache,
|
||||
CACHE_KEY_OVERVIEW,
|
||||
CACHE_KEY_MEMBERS,
|
||||
type OverviewCache,
|
||||
type MembersCache,
|
||||
} from '../../database/sessionCache'
|
||||
|
||||
interface DbMeta {
|
||||
name: string
|
||||
@@ -231,3 +239,78 @@ export function getSession(sessionId: string): any | null {
|
||||
ownerId: meta.owner_id || null,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取聊天概览(AI 工具使用)
|
||||
* 优先从缓存读取,miss 则计算回填
|
||||
*/
|
||||
export function getChatOverview(sessionId: string, topN: number = 10) {
|
||||
const db = openDatabase(sessionId)
|
||||
if (!db) return null
|
||||
|
||||
const meta = db.prepare('SELECT * FROM meta LIMIT 1').get() as DbMeta | undefined
|
||||
if (!meta) return null
|
||||
|
||||
const cacheDir = getCacheDir()
|
||||
|
||||
// 读取 overview 缓存
|
||||
let overview = getCache<OverviewCache>(sessionId, CACHE_KEY_OVERVIEW, cacheDir)
|
||||
if (!overview) {
|
||||
try {
|
||||
overview = computeAndSetOverviewCache(db, sessionId, cacheDir)
|
||||
} catch {
|
||||
// fallback: 实时查询
|
||||
}
|
||||
}
|
||||
|
||||
// 读取 members 缓存
|
||||
let membersCache = getCache<MembersCache>(sessionId, CACHE_KEY_MEMBERS, cacheDir)
|
||||
if (!membersCache) {
|
||||
try {
|
||||
membersCache = computeAndSetMembersCache(db, sessionId, cacheDir)
|
||||
} catch {
|
||||
// fallback: 无成员数据
|
||||
}
|
||||
}
|
||||
|
||||
// 从缓存计算 Top N 活跃成员
|
||||
let topMembers: Array<{ id: number; name: string; count: number }> = []
|
||||
if (membersCache?.members) {
|
||||
topMembers = Object.entries(membersCache.members)
|
||||
.map(([id, stat]) => ({ id: Number(id), name: stat.name, count: stat.count }))
|
||||
.sort((a, b) => b.count - a.count)
|
||||
.slice(0, topN)
|
||||
}
|
||||
|
||||
// fallback 统计值
|
||||
const totalMessages =
|
||||
overview?.totalMessages ??
|
||||
(
|
||||
db
|
||||
.prepare(
|
||||
`SELECT COUNT(*) as count FROM message msg
|
||||
JOIN member m ON msg.sender_id = m.id
|
||||
WHERE COALESCE(m.account_name, '') != '系统消息'`
|
||||
)
|
||||
.get() as { count: number }
|
||||
).count
|
||||
|
||||
const totalMembers =
|
||||
overview?.totalMembers ??
|
||||
(
|
||||
db.prepare(`SELECT COUNT(*) as count FROM member WHERE COALESCE(account_name, '') != '系统消息'`).get() as {
|
||||
count: number
|
||||
}
|
||||
).count
|
||||
|
||||
return {
|
||||
name: meta.name,
|
||||
platform: meta.platform,
|
||||
type: meta.type,
|
||||
totalMessages,
|
||||
totalMembers,
|
||||
firstMessageTs: overview?.firstMessageTs ?? null,
|
||||
lastMessageTs: overview?.lastMessageTs ?? null,
|
||||
topMembers,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -327,6 +327,22 @@ export async function getSession(sessionId: string): Promise<any | null> {
|
||||
return sendToWorker('getSession', { sessionId })
|
||||
}
|
||||
|
||||
export async function getChatOverview(
|
||||
sessionId: string,
|
||||
topN?: number
|
||||
): Promise<{
|
||||
name: string
|
||||
platform: string
|
||||
type: string
|
||||
totalMessages: number
|
||||
totalMembers: number
|
||||
firstMessageTs: number | null
|
||||
lastMessageTs: number | null
|
||||
topMembers: Array<{ id: number; name: string; count: number }>
|
||||
} | null> {
|
||||
return sendToWorker('getChatOverview', { sessionId, topN })
|
||||
}
|
||||
|
||||
export async function closeDatabase(sessionId: string): Promise<void> {
|
||||
return sendToWorker('closeDatabase', { sessionId })
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user