refactor(llm): unify LLM access layer via pi-ai

- Refactor llm/index.ts: remove chat()/chatStream() wrappers, inline
  provider info from deleted service files, add buildPiModel() for
  direct pi-ai model construction
- Migrate summary/index.ts to use pi-ai's completeSimple directly
- Migrate rag/pipeline/semantic.ts to use pi-ai's completeSimple directly
This commit is contained in:
digua
2026-02-26 21:05:28 +08:00
parent da3d2531f5
commit 90afe5f087
3 changed files with 173 additions and 268 deletions
+106 -224
View File
@@ -8,22 +8,19 @@ import * as path from 'path'
import { randomUUID } from 'crypto'
import { getAiDataDir } from '../../paths'
import type {
LLMConfig,
LLMProvider,
ILLMService,
ProviderInfo,
ChatMessage,
ChatOptions,
ChatStreamChunk,
AIServiceConfig,
AIConfigStore,
} from './types'
import { MAX_CONFIG_COUNT } from './types'
import { GeminiService, GEMINI_INFO } from './gemini'
import { OpenAICompatibleService, OPENAI_COMPATIBLE_INFO } from './openai-compatible'
import { aiLogger, extractErrorInfo, extractErrorStack } from '../logger'
import { aiLogger } from '../logger'
import { encryptApiKey, decryptApiKey, isEncrypted } from './crypto'
import { t } from '../../i18n'
import {
completeSimple,
type Model as PiModel,
} from '@mariozechner/pi-ai'
// 导出类型
export * from './types'
@@ -110,6 +107,31 @@ const DOUBAO_INFO: ProviderInfo = {
],
}
/** Gemini 提供商信息 */
const GEMINI_INFO: ProviderInfo = {
id: 'gemini',
name: 'Gemini',
description: 'Google Gemini 大语言模型',
defaultBaseUrl: 'https://generativelanguage.googleapis.com',
models: [
{ id: 'gemini-3-flash-preview', name: 'Gemini 3 Flash Preview', description: '高速预览版' },
{ id: 'gemini-3-pro-preview', name: 'Gemini 3 Pro Preview', description: '专业预览版' },
],
}
/** OpenAI 兼容提供商信息 */
const OPENAI_COMPATIBLE_INFO: ProviderInfo = {
id: 'openai-compatible',
name: 'OpenAI 兼容',
description: '支持任何兼容 OpenAI API 的服务(如 Ollama、LocalAI、vLLM 等)',
defaultBaseUrl: 'http://localhost:11434/v1',
models: [
{ id: 'llama3.2', name: 'Llama 3.2', description: 'Meta Llama 3.2 模型' },
{ id: 'qwen2.5', name: 'Qwen 2.5', description: '通义千问 2.5 模型' },
{ id: 'deepseek-r1', name: 'DeepSeek R1', description: 'DeepSeek R1 推理模型' },
],
}
// 所有支持的提供商信息
export const PROVIDERS: ProviderInfo[] = [
DEEPSEEK_INFO,
@@ -392,15 +414,6 @@ export function hasActiveConfig(): boolean {
return config !== null
}
/**
* 扩展的 LLM 配置(包含本地服务特有选项)
*/
interface ExtendedLLMConfig extends LLMConfig {
disableThinking?: boolean
/** 标记为推理模型(如 DeepSeek-R1、QwQ 等) */
isReasoningModel?: boolean
}
/**
* 不再自动补齐 Base URL,对 DeepSeek/Qwen 的格式做显式校验
*/
@@ -431,72 +444,6 @@ function validateProviderBaseUrl(provider: LLMProvider, baseUrl?: string): void
}
}
/**
* 创建 LLM 服务实例
*/
export function createLLMService(config: ExtendedLLMConfig): ILLMService {
// 获取提供商的默认 baseUrl
const providerInfo = getProviderInfo(config.provider)
const baseUrl = config.baseUrl || providerInfo?.defaultBaseUrl
// 未显式指定时使用提供商的首个模型作为默认模型
const resolvedModel = config.model || providerInfo?.models?.[0]?.id
// 不自动补齐,发现不合法直接抛错给用户
validateProviderBaseUrl(config.provider, baseUrl)
switch (config.provider) {
case 'gemini':
return new GeminiService(config.apiKey, resolvedModel, config.baseUrl)
// 新增的官方API都使用 OpenAI 兼容格式
case 'deepseek':
case 'qwen':
case 'minimax':
case 'glm':
case 'kimi':
case 'doubao':
// DeepSeek/Qwen 走 OpenAI 兼容实现时,禁用本地思考注入
return new OpenAICompatibleService(
config.apiKey,
resolvedModel,
baseUrl,
config.provider === 'deepseek' || config.provider === 'qwen' ? false : undefined,
config.provider,
providerInfo?.models
)
case 'openai-compatible':
return new OpenAICompatibleService(
config.apiKey,
resolvedModel,
config.baseUrl,
config.disableThinking,
config.provider,
providerInfo?.models,
config.isReasoningModel
)
default:
throw new Error(`Unknown LLM provider: ${config.provider}`)
}
}
/**
* 获取当前配置的 LLM 服务实例
*/
export function getCurrentLLMService(): ILLMService | null {
const activeConfig = getActiveConfig()
if (!activeConfig) {
return null
}
return createLLMService({
provider: activeConfig.provider,
apiKey: activeConfig.apiKey,
model: activeConfig.model,
baseUrl: activeConfig.baseUrl,
maxTokens: activeConfig.maxTokens,
disableThinking: activeConfig.disableThinking,
isReasoningModel: activeConfig.isReasoningModel,
})
}
/**
* 获取提供商信息
*/
@@ -504,159 +451,94 @@ export function getProviderInfo(provider: LLMProvider): ProviderInfo | null {
return PROVIDERS.find((p) => p.id === provider) || null
}
// ==================== pi-ai Model 构建 ====================
/**
* 验证 API Key
* 将 AIServiceConfig 转换为 pi-ai Model 对象
*/
export function buildPiModel(
config: AIServiceConfig
): PiModel<'openai-completions'> | PiModel<'google-generative-ai'> {
const providerInfo = getProviderInfo(config.provider)
const baseUrl = config.baseUrl || providerInfo?.defaultBaseUrl || ''
const modelId = config.model || providerInfo?.models?.[0]?.id || ''
validateProviderBaseUrl(config.provider, baseUrl)
if (config.provider === 'gemini') {
return {
id: modelId,
name: modelId,
api: 'google-generative-ai',
provider: 'google',
baseUrl,
reasoning: false,
input: ['text'],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 1048576,
maxTokens: config.maxTokens ?? 8192,
}
}
return {
id: modelId,
name: modelId,
api: 'openai-completions',
provider: config.provider,
baseUrl,
reasoning: config.isReasoningModel ?? false,
input: ['text'],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 128000,
maxTokens: config.maxTokens ?? 4096,
compat: config.disableThinking ? { thinkingFormat: 'qwen' } : undefined,
}
}
/**
* 验证 API Key(基于 pi-ai completeSimple
* 发送一个最小请求来验证 API Key 是否有效
*/
export async function validateApiKey(
provider: LLMProvider,
apiKey: string
apiKey: string,
baseUrl?: string,
model?: string
): Promise<{ success: boolean; error?: string }> {
const service = createLLMService({ provider, apiKey })
return service.validateApiKey()
}
/**
* 发送聊天请求(使用当前配置)
* 返回完整的 ChatResponse 对象,包含 finishReason 和 tool_calls
*/
export async function chat(
messages: ChatMessage[],
options?: ChatOptions
): Promise<{ content: string; finishReason: string; tool_calls?: import('./types').ToolCall[] }> {
const activeConfig = getActiveConfig()
aiLogger.info('LLM', 'Starting non-streaming chat request', {
messagesCount: messages.length,
firstMessageRole: messages[0]?.role,
firstMessageLength: messages[0]?.content?.length,
config: activeConfig
? {
name: activeConfig.name,
provider: activeConfig.provider,
model: activeConfig.model,
baseUrl: activeConfig.baseUrl,
}
: null,
options,
})
const service = getCurrentLLMService()
if (!service) {
aiLogger.error('LLM', 'Service not configured')
throw new Error(t('llm.notConfigured'))
}
try {
const response = await service.chat(messages, options)
aiLogger.info('LLM', 'Non-streaming request succeeded', {
contentLength: response.content?.length,
finishReason: response.finishReason,
usage: response.usage,
})
return response
} catch (error) {
// 配置信息
const configStr = activeConfig
? `${activeConfig.name} (${activeConfig.provider}/${activeConfig.model}) baseUrl=${activeConfig.baseUrl || '默认'}`
: '未配置'
// 错误信息
const errorInfo = extractErrorInfo(error)
const errorStr = `${errorInfo.name || 'Error'}: ${errorInfo.message}`
aiLogger.error('LLM', `Non-streaming request failed | config: ${configStr}`)
aiLogger.error('LLM', `Error: ${errorStr}`)
// 堆栈单独一行
const stack = extractErrorStack(error)
if (stack) {
aiLogger.error('LLM', `Stack:\n${stack}`)
const providerInfo = getProviderInfo(provider)
const config: AIServiceConfig = {
id: 'validate-temp',
name: 'validate-temp',
provider,
apiKey,
baseUrl,
model: model || providerInfo?.models?.[0]?.id,
createdAt: 0,
updatedAt: 0,
}
throw error
}
}
const piModel = buildPiModel(config)
/**
* 发送聊天请求(流式,使用当前配置)
*/
export async function* chatStream(messages: ChatMessage[], options?: ChatOptions): AsyncGenerator<ChatStreamChunk> {
const activeConfig = getActiveConfig()
const abortController = new AbortController()
const timeout = setTimeout(() => abortController.abort(), 15000)
aiLogger.info('LLM', 'Starting streaming chat request', {
messagesCount: messages.length,
firstMessageRole: messages[0]?.role,
firstMessageLength: messages[0]?.content?.length,
config: activeConfig
? {
name: activeConfig.name,
provider: activeConfig.provider,
model: activeConfig.model,
baseUrl: activeConfig.baseUrl,
}
: null,
})
const service = getCurrentLLMService()
if (!service) {
aiLogger.error('LLM', 'Service not configured (streaming)')
throw new Error(t('llm.notConfigured'))
}
let chunkCount = 0
let totalContent = ''
let receivedFinish = false
let contentChunkCount = 0
try {
for await (const chunk of service.chatStream(messages, options)) {
chunkCount++
totalContent += chunk.content
// 追踪内容 chunk
if (chunk.content) {
contentChunkCount++
}
yield chunk
if (chunk.isFinished) {
receivedFinish = true
aiLogger.info('LLM', 'Streaming request completed', {
chunkCount,
contentChunkCount,
totalContentLength: totalContent.length,
finishReason: chunk.finishReason,
})
}
}
// 如果循环正常结束但没有收到 isFinished 的 chunk,记录警告
if (chunkCount > 0 && !receivedFinish) {
aiLogger.warn('LLM', 'Stream loop ended without completion signal', {
chunkCount,
totalContentLength: totalContent.length,
try {
await completeSimple(piModel, {
messages: [{ role: 'user', content: 'Hi', timestamp: Date.now() }],
}, {
apiKey,
maxTokens: 1,
signal: abortController.signal,
})
return { success: true }
} finally {
clearTimeout(timeout)
}
} catch (error) {
// 配置信息
const configStr = activeConfig
? `${activeConfig.name} (${activeConfig.provider}/${activeConfig.model}) baseUrl=${activeConfig.baseUrl || '默认'}`
: '未配置'
// 错误信息
const errorInfo = extractErrorInfo(error)
const errorStr = `${errorInfo.name || 'Error'}: ${errorInfo.message}`
aiLogger.error(
'LLM',
`Stream request failed | config: ${configStr} | received: ${chunkCount} chunks/${totalContent.length} chars`
)
aiLogger.error('LLM', `Error: ${errorStr}`)
// 堆栈单独一行
const stack = extractErrorStack(error)
if (stack) {
aiLogger.error('LLM', `Stack:\n${stack}`)
const message = error instanceof Error ? error.message : String(error)
if (message.includes('aborted') || message.includes('AbortError')) {
return { success: false, error: 'Request timed out (15s)' }
}
throw error
return { success: false, error: message }
}
}
+21 -13
View File
@@ -10,7 +10,8 @@ import { getEmbeddingService } from '../embedding'
import { getVectorStore } from '../store'
import { getSessionChunks } from '../chunking'
import { loadRAGConfig } from '../config'
import { chat } from '../../llm'
import { completeSimple, type TextContent as PiTextContent } from '@mariozechner/pi-ai'
import { getActiveConfig, buildPiModel } from '../../llm'
import { aiLogger as logger } from '../../logger'
/**
@@ -33,21 +34,28 @@ const QUERY_REWRITE_PROMPT = `你是一个查询优化专家。请将用户的
*/
async function rewriteQuery(query: string, abortSignal?: AbortSignal): Promise<string> {
try {
const activeConfig = getActiveConfig()
if (!activeConfig) return query
const piModel = buildPiModel(activeConfig)
const prompt = QUERY_REWRITE_PROMPT.replace('{query}', query)
const response = await chat(
[
{ role: 'system', content: '你是一个查询优化专家,专门将用户问题改写为更适合语义检索的形式。' },
{ role: 'user', content: prompt },
],
{
temperature: 0.3,
maxTokens: 200,
abortSignal,
}
)
const result = await completeSimple(piModel, {
systemPrompt: '你是一个查询优化专家,专门将用户问题改写为更适合语义检索的形式。',
messages: [{ role: 'user', content: prompt, timestamp: Date.now() }],
}, {
apiKey: activeConfig.apiKey,
temperature: 0.3,
maxTokens: 200,
signal: abortSignal,
})
const rewritten = result.content
.filter((item): item is PiTextContent => item.type === 'text')
.map((item) => item.text)
.join('')
.trim()
const rewritten = response.content.trim()
return rewritten || query
} catch (error) {
logger.warn('[Semantic Pipeline] Query rewrite failed, using original query:', error)
+46 -31
View File
@@ -8,11 +8,41 @@
*/
import Database from 'better-sqlite3'
import { chat } from '../llm'
import { completeSimple, type TextContent as PiTextContent } from '@mariozechner/pi-ai'
import { getActiveConfig, buildPiModel } from '../llm'
import { getDbPath, openDatabase } from '../../database/core'
import { aiLogger } from '../logger'
import { t } from '../../i18n'
/** 调用 LLM 生成文本(直接使用 pi-ai completeSimple */
async function llmComplete(
systemPrompt: string,
userPrompt: string,
options?: { temperature?: number; maxTokens?: number }
): Promise<string> {
const activeConfig = getActiveConfig()
if (!activeConfig) {
throw new Error(t('llm.notConfigured'))
}
const piModel = buildPiModel(activeConfig)
const now = Date.now()
const result = await completeSimple(piModel, {
systemPrompt,
messages: [{ role: 'user', content: userPrompt, timestamp: now }],
}, {
apiKey: activeConfig.apiKey,
temperature: options?.temperature,
maxTokens: options?.maxTokens,
})
return result.content
.filter((item): item is PiTextContent => item.type === 'text')
.map((item) => item.text)
.join('')
}
/** 最小消息数阈值(少于此数量不生成摘要) */
const MIN_MESSAGE_COUNT = 3
@@ -408,17 +438,12 @@ export async function generateSessionSummary(
* 直接生成摘要(适用于短会话)
*/
async function generateDirectSummary(content: string, lengthLimit: number, locale: string): Promise<string> {
const response = await chat(
[
{ role: 'system', content: t('summary.systemPromptDirect') },
{ role: 'user', content: buildSummaryPrompt(content, lengthLimit, locale) },
],
{
temperature: 0.3,
maxTokens: 300,
}
const result = await llmComplete(
t('summary.systemPromptDirect'),
buildSummaryPrompt(content, lengthLimit, locale),
{ temperature: 0.3, maxTokens: 300 }
)
return response.content.trim()
return result.trim()
}
/**
@@ -437,17 +462,12 @@ async function generateMapReduceSummary(
for (let i = 0; i < segments.length; i++) {
const segmentContent = formatMessages(segments[i])
const response = await chat(
[
{ role: 'system', content: t('summary.systemPromptDirect') },
{ role: 'user', content: buildSubSummaryPrompt(segmentContent, locale) },
],
{
temperature: 0.3,
maxTokens: 100,
}
const result = await llmComplete(
t('summary.systemPromptDirect'),
buildSubSummaryPrompt(segmentContent, locale),
{ temperature: 0.3, maxTokens: 100 }
)
subSummaries.push(response.content.trim())
subSummaries.push(result.trim())
}
// 2. Reduce:合并子摘要
@@ -455,18 +475,13 @@ async function generateMapReduceSummary(
return subSummaries[0]
}
const mergeResponse = await chat(
[
{ role: 'system', content: t('summary.systemPromptMerge') },
{ role: 'user', content: buildMergePrompt(subSummaries, lengthLimit, locale) },
],
{
temperature: 0.3,
maxTokens: 300,
}
const mergeResult = await llmComplete(
t('summary.systemPromptMerge'),
buildMergePrompt(subSummaries, lengthLimit, locale),
{ temperature: 0.3, maxTokens: 300 }
)
return mergeResponse.content.trim()
return mergeResult.trim()
}
/**