From 90afe5f087f4cbbaabce6f2f91804b1fb82acb83 Mon Sep 17 00:00:00 2001 From: digua Date: Thu, 26 Feb 2026 21:05:28 +0800 Subject: [PATCH] refactor(llm): unify LLM access layer via pi-ai - Refactor llm/index.ts: remove chat()/chatStream() wrappers, inline provider info from deleted service files, add buildPiModel() for direct pi-ai model construction - Migrate summary/index.ts to use pi-ai's completeSimple directly - Migrate rag/pipeline/semantic.ts to use pi-ai's completeSimple directly --- electron/main/ai/llm/index.ts | 330 +++++++--------------- electron/main/ai/rag/pipeline/semantic.ts | 34 ++- electron/main/ai/summary/index.ts | 77 +++-- 3 files changed, 173 insertions(+), 268 deletions(-) diff --git a/electron/main/ai/llm/index.ts b/electron/main/ai/llm/index.ts index 965c51a..3b62184 100644 --- a/electron/main/ai/llm/index.ts +++ b/electron/main/ai/llm/index.ts @@ -8,22 +8,19 @@ import * as path from 'path' import { randomUUID } from 'crypto' import { getAiDataDir } from '../../paths' import type { - LLMConfig, LLMProvider, - ILLMService, ProviderInfo, - ChatMessage, - ChatOptions, - ChatStreamChunk, AIServiceConfig, AIConfigStore, } from './types' import { MAX_CONFIG_COUNT } from './types' -import { GeminiService, GEMINI_INFO } from './gemini' -import { OpenAICompatibleService, OPENAI_COMPATIBLE_INFO } from './openai-compatible' -import { aiLogger, extractErrorInfo, extractErrorStack } from '../logger' +import { aiLogger } from '../logger' import { encryptApiKey, decryptApiKey, isEncrypted } from './crypto' import { t } from '../../i18n' +import { + completeSimple, + type Model as PiModel, +} from '@mariozechner/pi-ai' // 导出类型 export * from './types' @@ -110,6 +107,31 @@ const DOUBAO_INFO: ProviderInfo = { ], } +/** Gemini 提供商信息 */ +const GEMINI_INFO: ProviderInfo = { + id: 'gemini', + name: 'Gemini', + description: 'Google Gemini 大语言模型', + defaultBaseUrl: 'https://generativelanguage.googleapis.com', + models: [ + { id: 'gemini-3-flash-preview', name: 'Gemini 3 Flash Preview', description: '高速预览版' }, + { id: 'gemini-3-pro-preview', name: 'Gemini 3 Pro Preview', description: '专业预览版' }, + ], +} + +/** OpenAI 兼容提供商信息 */ +const OPENAI_COMPATIBLE_INFO: ProviderInfo = { + id: 'openai-compatible', + name: 'OpenAI 兼容', + description: '支持任何兼容 OpenAI API 的服务(如 Ollama、LocalAI、vLLM 等)', + defaultBaseUrl: 'http://localhost:11434/v1', + models: [ + { id: 'llama3.2', name: 'Llama 3.2', description: 'Meta Llama 3.2 模型' }, + { id: 'qwen2.5', name: 'Qwen 2.5', description: '通义千问 2.5 模型' }, + { id: 'deepseek-r1', name: 'DeepSeek R1', description: 'DeepSeek R1 推理模型' }, + ], +} + // 所有支持的提供商信息 export const PROVIDERS: ProviderInfo[] = [ DEEPSEEK_INFO, @@ -392,15 +414,6 @@ export function hasActiveConfig(): boolean { return config !== null } -/** - * 扩展的 LLM 配置(包含本地服务特有选项) - */ -interface ExtendedLLMConfig extends LLMConfig { - disableThinking?: boolean - /** 标记为推理模型(如 DeepSeek-R1、QwQ 等) */ - isReasoningModel?: boolean -} - /** * 不再自动补齐 Base URL,对 DeepSeek/Qwen 的格式做显式校验 */ @@ -431,72 +444,6 @@ function validateProviderBaseUrl(provider: LLMProvider, baseUrl?: string): void } } -/** - * 创建 LLM 服务实例 - */ -export function createLLMService(config: ExtendedLLMConfig): ILLMService { - // 获取提供商的默认 baseUrl - const providerInfo = getProviderInfo(config.provider) - const baseUrl = config.baseUrl || providerInfo?.defaultBaseUrl - // 未显式指定时使用提供商的首个模型作为默认模型 - const resolvedModel = config.model || providerInfo?.models?.[0]?.id - // 不自动补齐,发现不合法直接抛错给用户 - validateProviderBaseUrl(config.provider, baseUrl) - - switch (config.provider) { - case 'gemini': - return new GeminiService(config.apiKey, resolvedModel, config.baseUrl) - // 新增的官方API都使用 OpenAI 兼容格式 - case 'deepseek': - case 'qwen': - case 'minimax': - case 'glm': - case 'kimi': - case 'doubao': - // DeepSeek/Qwen 走 OpenAI 兼容实现时,禁用本地思考注入 - return new OpenAICompatibleService( - config.apiKey, - resolvedModel, - baseUrl, - config.provider === 'deepseek' || config.provider === 'qwen' ? false : undefined, - config.provider, - providerInfo?.models - ) - case 'openai-compatible': - return new OpenAICompatibleService( - config.apiKey, - resolvedModel, - config.baseUrl, - config.disableThinking, - config.provider, - providerInfo?.models, - config.isReasoningModel - ) - default: - throw new Error(`Unknown LLM provider: ${config.provider}`) - } -} - -/** - * 获取当前配置的 LLM 服务实例 - */ -export function getCurrentLLMService(): ILLMService | null { - const activeConfig = getActiveConfig() - if (!activeConfig) { - return null - } - - return createLLMService({ - provider: activeConfig.provider, - apiKey: activeConfig.apiKey, - model: activeConfig.model, - baseUrl: activeConfig.baseUrl, - maxTokens: activeConfig.maxTokens, - disableThinking: activeConfig.disableThinking, - isReasoningModel: activeConfig.isReasoningModel, - }) -} - /** * 获取提供商信息 */ @@ -504,159 +451,94 @@ export function getProviderInfo(provider: LLMProvider): ProviderInfo | null { return PROVIDERS.find((p) => p.id === provider) || null } +// ==================== pi-ai Model 构建 ==================== + /** - * 验证 API Key + * 将 AIServiceConfig 转换为 pi-ai Model 对象 + */ +export function buildPiModel( + config: AIServiceConfig +): PiModel<'openai-completions'> | PiModel<'google-generative-ai'> { + const providerInfo = getProviderInfo(config.provider) + const baseUrl = config.baseUrl || providerInfo?.defaultBaseUrl || '' + const modelId = config.model || providerInfo?.models?.[0]?.id || '' + + validateProviderBaseUrl(config.provider, baseUrl) + + if (config.provider === 'gemini') { + return { + id: modelId, + name: modelId, + api: 'google-generative-ai', + provider: 'google', + baseUrl, + reasoning: false, + input: ['text'], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 1048576, + maxTokens: config.maxTokens ?? 8192, + } + } + + return { + id: modelId, + name: modelId, + api: 'openai-completions', + provider: config.provider, + baseUrl, + reasoning: config.isReasoningModel ?? false, + input: ['text'], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 128000, + maxTokens: config.maxTokens ?? 4096, + compat: config.disableThinking ? { thinkingFormat: 'qwen' } : undefined, + } +} + +/** + * 验证 API Key(基于 pi-ai completeSimple) + * 发送一个最小请求来验证 API Key 是否有效 */ export async function validateApiKey( provider: LLMProvider, - apiKey: string + apiKey: string, + baseUrl?: string, + model?: string ): Promise<{ success: boolean; error?: string }> { - const service = createLLMService({ provider, apiKey }) - return service.validateApiKey() -} - -/** - * 发送聊天请求(使用当前配置) - * 返回完整的 ChatResponse 对象,包含 finishReason 和 tool_calls - */ -export async function chat( - messages: ChatMessage[], - options?: ChatOptions -): Promise<{ content: string; finishReason: string; tool_calls?: import('./types').ToolCall[] }> { - const activeConfig = getActiveConfig() - - aiLogger.info('LLM', 'Starting non-streaming chat request', { - messagesCount: messages.length, - firstMessageRole: messages[0]?.role, - firstMessageLength: messages[0]?.content?.length, - config: activeConfig - ? { - name: activeConfig.name, - provider: activeConfig.provider, - model: activeConfig.model, - baseUrl: activeConfig.baseUrl, - } - : null, - options, - }) - - const service = getCurrentLLMService() - if (!service) { - aiLogger.error('LLM', 'Service not configured') - throw new Error(t('llm.notConfigured')) - } - try { - const response = await service.chat(messages, options) - aiLogger.info('LLM', 'Non-streaming request succeeded', { - contentLength: response.content?.length, - finishReason: response.finishReason, - usage: response.usage, - }) - return response - } catch (error) { - // 配置信息 - const configStr = activeConfig - ? `${activeConfig.name} (${activeConfig.provider}/${activeConfig.model}) baseUrl=${activeConfig.baseUrl || '默认'}` - : '未配置' - // 错误信息 - const errorInfo = extractErrorInfo(error) - const errorStr = `${errorInfo.name || 'Error'}: ${errorInfo.message}` - - aiLogger.error('LLM', `Non-streaming request failed | config: ${configStr}`) - aiLogger.error('LLM', `Error: ${errorStr}`) - - // 堆栈单独一行 - const stack = extractErrorStack(error) - if (stack) { - aiLogger.error('LLM', `Stack:\n${stack}`) + const providerInfo = getProviderInfo(provider) + const config: AIServiceConfig = { + id: 'validate-temp', + name: 'validate-temp', + provider, + apiKey, + baseUrl, + model: model || providerInfo?.models?.[0]?.id, + createdAt: 0, + updatedAt: 0, } - throw error - } -} + const piModel = buildPiModel(config) -/** - * 发送聊天请求(流式,使用当前配置) - */ -export async function* chatStream(messages: ChatMessage[], options?: ChatOptions): AsyncGenerator { - const activeConfig = getActiveConfig() + const abortController = new AbortController() + const timeout = setTimeout(() => abortController.abort(), 15000) - aiLogger.info('LLM', 'Starting streaming chat request', { - messagesCount: messages.length, - firstMessageRole: messages[0]?.role, - firstMessageLength: messages[0]?.content?.length, - config: activeConfig - ? { - name: activeConfig.name, - provider: activeConfig.provider, - model: activeConfig.model, - baseUrl: activeConfig.baseUrl, - } - : null, - }) - - const service = getCurrentLLMService() - if (!service) { - aiLogger.error('LLM', 'Service not configured (streaming)') - throw new Error(t('llm.notConfigured')) - } - - let chunkCount = 0 - let totalContent = '' - - let receivedFinish = false - let contentChunkCount = 0 - - try { - for await (const chunk of service.chatStream(messages, options)) { - chunkCount++ - totalContent += chunk.content - - // 追踪内容 chunk - if (chunk.content) { - contentChunkCount++ - } - - yield chunk - - if (chunk.isFinished) { - receivedFinish = true - aiLogger.info('LLM', 'Streaming request completed', { - chunkCount, - contentChunkCount, - totalContentLength: totalContent.length, - finishReason: chunk.finishReason, - }) - } - } - - // 如果循环正常结束但没有收到 isFinished 的 chunk,记录警告 - if (chunkCount > 0 && !receivedFinish) { - aiLogger.warn('LLM', 'Stream loop ended without completion signal', { - chunkCount, - totalContentLength: totalContent.length, + try { + await completeSimple(piModel, { + messages: [{ role: 'user', content: 'Hi', timestamp: Date.now() }], + }, { + apiKey, + maxTokens: 1, + signal: abortController.signal, }) + return { success: true } + } finally { + clearTimeout(timeout) } } catch (error) { - // 配置信息 - const configStr = activeConfig - ? `${activeConfig.name} (${activeConfig.provider}/${activeConfig.model}) baseUrl=${activeConfig.baseUrl || '默认'}` - : '未配置' - // 错误信息 - const errorInfo = extractErrorInfo(error) - const errorStr = `${errorInfo.name || 'Error'}: ${errorInfo.message}` - - aiLogger.error( - 'LLM', - `Stream request failed | config: ${configStr} | received: ${chunkCount} chunks/${totalContent.length} chars` - ) - aiLogger.error('LLM', `Error: ${errorStr}`) - - // 堆栈单独一行 - const stack = extractErrorStack(error) - if (stack) { - aiLogger.error('LLM', `Stack:\n${stack}`) + const message = error instanceof Error ? error.message : String(error) + if (message.includes('aborted') || message.includes('AbortError')) { + return { success: false, error: 'Request timed out (15s)' } } - throw error + return { success: false, error: message } } } diff --git a/electron/main/ai/rag/pipeline/semantic.ts b/electron/main/ai/rag/pipeline/semantic.ts index a825043..14c656b 100644 --- a/electron/main/ai/rag/pipeline/semantic.ts +++ b/electron/main/ai/rag/pipeline/semantic.ts @@ -10,7 +10,8 @@ import { getEmbeddingService } from '../embedding' import { getVectorStore } from '../store' import { getSessionChunks } from '../chunking' import { loadRAGConfig } from '../config' -import { chat } from '../../llm' +import { completeSimple, type TextContent as PiTextContent } from '@mariozechner/pi-ai' +import { getActiveConfig, buildPiModel } from '../../llm' import { aiLogger as logger } from '../../logger' /** @@ -33,21 +34,28 @@ const QUERY_REWRITE_PROMPT = `你是一个查询优化专家。请将用户的 */ async function rewriteQuery(query: string, abortSignal?: AbortSignal): Promise { try { + const activeConfig = getActiveConfig() + if (!activeConfig) return query + + const piModel = buildPiModel(activeConfig) const prompt = QUERY_REWRITE_PROMPT.replace('{query}', query) - const response = await chat( - [ - { role: 'system', content: '你是一个查询优化专家,专门将用户问题改写为更适合语义检索的形式。' }, - { role: 'user', content: prompt }, - ], - { - temperature: 0.3, - maxTokens: 200, - abortSignal, - } - ) + const result = await completeSimple(piModel, { + systemPrompt: '你是一个查询优化专家,专门将用户问题改写为更适合语义检索的形式。', + messages: [{ role: 'user', content: prompt, timestamp: Date.now() }], + }, { + apiKey: activeConfig.apiKey, + temperature: 0.3, + maxTokens: 200, + signal: abortSignal, + }) + + const rewritten = result.content + .filter((item): item is PiTextContent => item.type === 'text') + .map((item) => item.text) + .join('') + .trim() - const rewritten = response.content.trim() return rewritten || query } catch (error) { logger.warn('[Semantic Pipeline] Query rewrite failed, using original query:', error) diff --git a/electron/main/ai/summary/index.ts b/electron/main/ai/summary/index.ts index 9f4df3b..51d191a 100644 --- a/electron/main/ai/summary/index.ts +++ b/electron/main/ai/summary/index.ts @@ -8,11 +8,41 @@ */ import Database from 'better-sqlite3' -import { chat } from '../llm' +import { completeSimple, type TextContent as PiTextContent } from '@mariozechner/pi-ai' +import { getActiveConfig, buildPiModel } from '../llm' import { getDbPath, openDatabase } from '../../database/core' import { aiLogger } from '../logger' import { t } from '../../i18n' +/** 调用 LLM 生成文本(直接使用 pi-ai completeSimple) */ +async function llmComplete( + systemPrompt: string, + userPrompt: string, + options?: { temperature?: number; maxTokens?: number } +): Promise { + const activeConfig = getActiveConfig() + if (!activeConfig) { + throw new Error(t('llm.notConfigured')) + } + + const piModel = buildPiModel(activeConfig) + const now = Date.now() + + const result = await completeSimple(piModel, { + systemPrompt, + messages: [{ role: 'user', content: userPrompt, timestamp: now }], + }, { + apiKey: activeConfig.apiKey, + temperature: options?.temperature, + maxTokens: options?.maxTokens, + }) + + return result.content + .filter((item): item is PiTextContent => item.type === 'text') + .map((item) => item.text) + .join('') +} + /** 最小消息数阈值(少于此数量不生成摘要) */ const MIN_MESSAGE_COUNT = 3 @@ -408,17 +438,12 @@ export async function generateSessionSummary( * 直接生成摘要(适用于短会话) */ async function generateDirectSummary(content: string, lengthLimit: number, locale: string): Promise { - const response = await chat( - [ - { role: 'system', content: t('summary.systemPromptDirect') }, - { role: 'user', content: buildSummaryPrompt(content, lengthLimit, locale) }, - ], - { - temperature: 0.3, - maxTokens: 300, - } + const result = await llmComplete( + t('summary.systemPromptDirect'), + buildSummaryPrompt(content, lengthLimit, locale), + { temperature: 0.3, maxTokens: 300 } ) - return response.content.trim() + return result.trim() } /** @@ -437,17 +462,12 @@ async function generateMapReduceSummary( for (let i = 0; i < segments.length; i++) { const segmentContent = formatMessages(segments[i]) - const response = await chat( - [ - { role: 'system', content: t('summary.systemPromptDirect') }, - { role: 'user', content: buildSubSummaryPrompt(segmentContent, locale) }, - ], - { - temperature: 0.3, - maxTokens: 100, - } + const result = await llmComplete( + t('summary.systemPromptDirect'), + buildSubSummaryPrompt(segmentContent, locale), + { temperature: 0.3, maxTokens: 100 } ) - subSummaries.push(response.content.trim()) + subSummaries.push(result.trim()) } // 2. Reduce:合并子摘要 @@ -455,18 +475,13 @@ async function generateMapReduceSummary( return subSummaries[0] } - const mergeResponse = await chat( - [ - { role: 'system', content: t('summary.systemPromptMerge') }, - { role: 'user', content: buildMergePrompt(subSummaries, lengthLimit, locale) }, - ], - { - temperature: 0.3, - maxTokens: 300, - } + const mergeResult = await llmComplete( + t('summary.systemPromptMerge'), + buildMergePrompt(subSummaries, lengthLimit, locale), + { temperature: 0.3, maxTokens: 300 } ) - return mergeResponse.content.trim() + return mergeResult.trim() } /**