mirror of
https://github.com/hellodigua/ChatLab.git
synced 2026-05-07 13:30:57 +08:00
refactor(llm): unify LLM access layer via pi-ai
- Refactor llm/index.ts: remove chat()/chatStream() wrappers, inline provider info from deleted service files, add buildPiModel() for direct pi-ai model construction - Migrate summary/index.ts to use pi-ai's completeSimple directly - Migrate rag/pipeline/semantic.ts to use pi-ai's completeSimple directly
This commit is contained in:
+106
-224
@@ -8,22 +8,19 @@ import * as path from 'path'
|
||||
import { randomUUID } from 'crypto'
|
||||
import { getAiDataDir } from '../../paths'
|
||||
import type {
|
||||
LLMConfig,
|
||||
LLMProvider,
|
||||
ILLMService,
|
||||
ProviderInfo,
|
||||
ChatMessage,
|
||||
ChatOptions,
|
||||
ChatStreamChunk,
|
||||
AIServiceConfig,
|
||||
AIConfigStore,
|
||||
} from './types'
|
||||
import { MAX_CONFIG_COUNT } from './types'
|
||||
import { GeminiService, GEMINI_INFO } from './gemini'
|
||||
import { OpenAICompatibleService, OPENAI_COMPATIBLE_INFO } from './openai-compatible'
|
||||
import { aiLogger, extractErrorInfo, extractErrorStack } from '../logger'
|
||||
import { aiLogger } from '../logger'
|
||||
import { encryptApiKey, decryptApiKey, isEncrypted } from './crypto'
|
||||
import { t } from '../../i18n'
|
||||
import {
|
||||
completeSimple,
|
||||
type Model as PiModel,
|
||||
} from '@mariozechner/pi-ai'
|
||||
|
||||
// 导出类型
|
||||
export * from './types'
|
||||
@@ -110,6 +107,31 @@ const DOUBAO_INFO: ProviderInfo = {
|
||||
],
|
||||
}
|
||||
|
||||
/** Gemini 提供商信息 */
|
||||
const GEMINI_INFO: ProviderInfo = {
|
||||
id: 'gemini',
|
||||
name: 'Gemini',
|
||||
description: 'Google Gemini 大语言模型',
|
||||
defaultBaseUrl: 'https://generativelanguage.googleapis.com',
|
||||
models: [
|
||||
{ id: 'gemini-3-flash-preview', name: 'Gemini 3 Flash Preview', description: '高速预览版' },
|
||||
{ id: 'gemini-3-pro-preview', name: 'Gemini 3 Pro Preview', description: '专业预览版' },
|
||||
],
|
||||
}
|
||||
|
||||
/** OpenAI 兼容提供商信息 */
|
||||
const OPENAI_COMPATIBLE_INFO: ProviderInfo = {
|
||||
id: 'openai-compatible',
|
||||
name: 'OpenAI 兼容',
|
||||
description: '支持任何兼容 OpenAI API 的服务(如 Ollama、LocalAI、vLLM 等)',
|
||||
defaultBaseUrl: 'http://localhost:11434/v1',
|
||||
models: [
|
||||
{ id: 'llama3.2', name: 'Llama 3.2', description: 'Meta Llama 3.2 模型' },
|
||||
{ id: 'qwen2.5', name: 'Qwen 2.5', description: '通义千问 2.5 模型' },
|
||||
{ id: 'deepseek-r1', name: 'DeepSeek R1', description: 'DeepSeek R1 推理模型' },
|
||||
],
|
||||
}
|
||||
|
||||
// 所有支持的提供商信息
|
||||
export const PROVIDERS: ProviderInfo[] = [
|
||||
DEEPSEEK_INFO,
|
||||
@@ -392,15 +414,6 @@ export function hasActiveConfig(): boolean {
|
||||
return config !== null
|
||||
}
|
||||
|
||||
/**
|
||||
* 扩展的 LLM 配置(包含本地服务特有选项)
|
||||
*/
|
||||
interface ExtendedLLMConfig extends LLMConfig {
|
||||
disableThinking?: boolean
|
||||
/** 标记为推理模型(如 DeepSeek-R1、QwQ 等) */
|
||||
isReasoningModel?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* 不再自动补齐 Base URL,对 DeepSeek/Qwen 的格式做显式校验
|
||||
*/
|
||||
@@ -431,72 +444,6 @@ function validateProviderBaseUrl(provider: LLMProvider, baseUrl?: string): void
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建 LLM 服务实例
|
||||
*/
|
||||
export function createLLMService(config: ExtendedLLMConfig): ILLMService {
|
||||
// 获取提供商的默认 baseUrl
|
||||
const providerInfo = getProviderInfo(config.provider)
|
||||
const baseUrl = config.baseUrl || providerInfo?.defaultBaseUrl
|
||||
// 未显式指定时使用提供商的首个模型作为默认模型
|
||||
const resolvedModel = config.model || providerInfo?.models?.[0]?.id
|
||||
// 不自动补齐,发现不合法直接抛错给用户
|
||||
validateProviderBaseUrl(config.provider, baseUrl)
|
||||
|
||||
switch (config.provider) {
|
||||
case 'gemini':
|
||||
return new GeminiService(config.apiKey, resolvedModel, config.baseUrl)
|
||||
// 新增的官方API都使用 OpenAI 兼容格式
|
||||
case 'deepseek':
|
||||
case 'qwen':
|
||||
case 'minimax':
|
||||
case 'glm':
|
||||
case 'kimi':
|
||||
case 'doubao':
|
||||
// DeepSeek/Qwen 走 OpenAI 兼容实现时,禁用本地思考注入
|
||||
return new OpenAICompatibleService(
|
||||
config.apiKey,
|
||||
resolvedModel,
|
||||
baseUrl,
|
||||
config.provider === 'deepseek' || config.provider === 'qwen' ? false : undefined,
|
||||
config.provider,
|
||||
providerInfo?.models
|
||||
)
|
||||
case 'openai-compatible':
|
||||
return new OpenAICompatibleService(
|
||||
config.apiKey,
|
||||
resolvedModel,
|
||||
config.baseUrl,
|
||||
config.disableThinking,
|
||||
config.provider,
|
||||
providerInfo?.models,
|
||||
config.isReasoningModel
|
||||
)
|
||||
default:
|
||||
throw new Error(`Unknown LLM provider: ${config.provider}`)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前配置的 LLM 服务实例
|
||||
*/
|
||||
export function getCurrentLLMService(): ILLMService | null {
|
||||
const activeConfig = getActiveConfig()
|
||||
if (!activeConfig) {
|
||||
return null
|
||||
}
|
||||
|
||||
return createLLMService({
|
||||
provider: activeConfig.provider,
|
||||
apiKey: activeConfig.apiKey,
|
||||
model: activeConfig.model,
|
||||
baseUrl: activeConfig.baseUrl,
|
||||
maxTokens: activeConfig.maxTokens,
|
||||
disableThinking: activeConfig.disableThinking,
|
||||
isReasoningModel: activeConfig.isReasoningModel,
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取提供商信息
|
||||
*/
|
||||
@@ -504,159 +451,94 @@ export function getProviderInfo(provider: LLMProvider): ProviderInfo | null {
|
||||
return PROVIDERS.find((p) => p.id === provider) || null
|
||||
}
|
||||
|
||||
// ==================== pi-ai Model 构建 ====================
|
||||
|
||||
/**
|
||||
* 验证 API Key
|
||||
* 将 AIServiceConfig 转换为 pi-ai Model 对象
|
||||
*/
|
||||
export function buildPiModel(
|
||||
config: AIServiceConfig
|
||||
): PiModel<'openai-completions'> | PiModel<'google-generative-ai'> {
|
||||
const providerInfo = getProviderInfo(config.provider)
|
||||
const baseUrl = config.baseUrl || providerInfo?.defaultBaseUrl || ''
|
||||
const modelId = config.model || providerInfo?.models?.[0]?.id || ''
|
||||
|
||||
validateProviderBaseUrl(config.provider, baseUrl)
|
||||
|
||||
if (config.provider === 'gemini') {
|
||||
return {
|
||||
id: modelId,
|
||||
name: modelId,
|
||||
api: 'google-generative-ai',
|
||||
provider: 'google',
|
||||
baseUrl,
|
||||
reasoning: false,
|
||||
input: ['text'],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 1048576,
|
||||
maxTokens: config.maxTokens ?? 8192,
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
id: modelId,
|
||||
name: modelId,
|
||||
api: 'openai-completions',
|
||||
provider: config.provider,
|
||||
baseUrl,
|
||||
reasoning: config.isReasoningModel ?? false,
|
||||
input: ['text'],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 128000,
|
||||
maxTokens: config.maxTokens ?? 4096,
|
||||
compat: config.disableThinking ? { thinkingFormat: 'qwen' } : undefined,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证 API Key(基于 pi-ai completeSimple)
|
||||
* 发送一个最小请求来验证 API Key 是否有效
|
||||
*/
|
||||
export async function validateApiKey(
|
||||
provider: LLMProvider,
|
||||
apiKey: string
|
||||
apiKey: string,
|
||||
baseUrl?: string,
|
||||
model?: string
|
||||
): Promise<{ success: boolean; error?: string }> {
|
||||
const service = createLLMService({ provider, apiKey })
|
||||
return service.validateApiKey()
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送聊天请求(使用当前配置)
|
||||
* 返回完整的 ChatResponse 对象,包含 finishReason 和 tool_calls
|
||||
*/
|
||||
export async function chat(
|
||||
messages: ChatMessage[],
|
||||
options?: ChatOptions
|
||||
): Promise<{ content: string; finishReason: string; tool_calls?: import('./types').ToolCall[] }> {
|
||||
const activeConfig = getActiveConfig()
|
||||
|
||||
aiLogger.info('LLM', 'Starting non-streaming chat request', {
|
||||
messagesCount: messages.length,
|
||||
firstMessageRole: messages[0]?.role,
|
||||
firstMessageLength: messages[0]?.content?.length,
|
||||
config: activeConfig
|
||||
? {
|
||||
name: activeConfig.name,
|
||||
provider: activeConfig.provider,
|
||||
model: activeConfig.model,
|
||||
baseUrl: activeConfig.baseUrl,
|
||||
}
|
||||
: null,
|
||||
options,
|
||||
})
|
||||
|
||||
const service = getCurrentLLMService()
|
||||
if (!service) {
|
||||
aiLogger.error('LLM', 'Service not configured')
|
||||
throw new Error(t('llm.notConfigured'))
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await service.chat(messages, options)
|
||||
aiLogger.info('LLM', 'Non-streaming request succeeded', {
|
||||
contentLength: response.content?.length,
|
||||
finishReason: response.finishReason,
|
||||
usage: response.usage,
|
||||
})
|
||||
return response
|
||||
} catch (error) {
|
||||
// 配置信息
|
||||
const configStr = activeConfig
|
||||
? `${activeConfig.name} (${activeConfig.provider}/${activeConfig.model}) baseUrl=${activeConfig.baseUrl || '默认'}`
|
||||
: '未配置'
|
||||
// 错误信息
|
||||
const errorInfo = extractErrorInfo(error)
|
||||
const errorStr = `${errorInfo.name || 'Error'}: ${errorInfo.message}`
|
||||
|
||||
aiLogger.error('LLM', `Non-streaming request failed | config: ${configStr}`)
|
||||
aiLogger.error('LLM', `Error: ${errorStr}`)
|
||||
|
||||
// 堆栈单独一行
|
||||
const stack = extractErrorStack(error)
|
||||
if (stack) {
|
||||
aiLogger.error('LLM', `Stack:\n${stack}`)
|
||||
const providerInfo = getProviderInfo(provider)
|
||||
const config: AIServiceConfig = {
|
||||
id: 'validate-temp',
|
||||
name: 'validate-temp',
|
||||
provider,
|
||||
apiKey,
|
||||
baseUrl,
|
||||
model: model || providerInfo?.models?.[0]?.id,
|
||||
createdAt: 0,
|
||||
updatedAt: 0,
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
const piModel = buildPiModel(config)
|
||||
|
||||
/**
|
||||
* 发送聊天请求(流式,使用当前配置)
|
||||
*/
|
||||
export async function* chatStream(messages: ChatMessage[], options?: ChatOptions): AsyncGenerator<ChatStreamChunk> {
|
||||
const activeConfig = getActiveConfig()
|
||||
const abortController = new AbortController()
|
||||
const timeout = setTimeout(() => abortController.abort(), 15000)
|
||||
|
||||
aiLogger.info('LLM', 'Starting streaming chat request', {
|
||||
messagesCount: messages.length,
|
||||
firstMessageRole: messages[0]?.role,
|
||||
firstMessageLength: messages[0]?.content?.length,
|
||||
config: activeConfig
|
||||
? {
|
||||
name: activeConfig.name,
|
||||
provider: activeConfig.provider,
|
||||
model: activeConfig.model,
|
||||
baseUrl: activeConfig.baseUrl,
|
||||
}
|
||||
: null,
|
||||
})
|
||||
|
||||
const service = getCurrentLLMService()
|
||||
if (!service) {
|
||||
aiLogger.error('LLM', 'Service not configured (streaming)')
|
||||
throw new Error(t('llm.notConfigured'))
|
||||
}
|
||||
|
||||
let chunkCount = 0
|
||||
let totalContent = ''
|
||||
|
||||
let receivedFinish = false
|
||||
let contentChunkCount = 0
|
||||
|
||||
try {
|
||||
for await (const chunk of service.chatStream(messages, options)) {
|
||||
chunkCount++
|
||||
totalContent += chunk.content
|
||||
|
||||
// 追踪内容 chunk
|
||||
if (chunk.content) {
|
||||
contentChunkCount++
|
||||
}
|
||||
|
||||
yield chunk
|
||||
|
||||
if (chunk.isFinished) {
|
||||
receivedFinish = true
|
||||
aiLogger.info('LLM', 'Streaming request completed', {
|
||||
chunkCount,
|
||||
contentChunkCount,
|
||||
totalContentLength: totalContent.length,
|
||||
finishReason: chunk.finishReason,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 如果循环正常结束但没有收到 isFinished 的 chunk,记录警告
|
||||
if (chunkCount > 0 && !receivedFinish) {
|
||||
aiLogger.warn('LLM', 'Stream loop ended without completion signal', {
|
||||
chunkCount,
|
||||
totalContentLength: totalContent.length,
|
||||
try {
|
||||
await completeSimple(piModel, {
|
||||
messages: [{ role: 'user', content: 'Hi', timestamp: Date.now() }],
|
||||
}, {
|
||||
apiKey,
|
||||
maxTokens: 1,
|
||||
signal: abortController.signal,
|
||||
})
|
||||
return { success: true }
|
||||
} finally {
|
||||
clearTimeout(timeout)
|
||||
}
|
||||
} catch (error) {
|
||||
// 配置信息
|
||||
const configStr = activeConfig
|
||||
? `${activeConfig.name} (${activeConfig.provider}/${activeConfig.model}) baseUrl=${activeConfig.baseUrl || '默认'}`
|
||||
: '未配置'
|
||||
// 错误信息
|
||||
const errorInfo = extractErrorInfo(error)
|
||||
const errorStr = `${errorInfo.name || 'Error'}: ${errorInfo.message}`
|
||||
|
||||
aiLogger.error(
|
||||
'LLM',
|
||||
`Stream request failed | config: ${configStr} | received: ${chunkCount} chunks/${totalContent.length} chars`
|
||||
)
|
||||
aiLogger.error('LLM', `Error: ${errorStr}`)
|
||||
|
||||
// 堆栈单独一行
|
||||
const stack = extractErrorStack(error)
|
||||
if (stack) {
|
||||
aiLogger.error('LLM', `Stack:\n${stack}`)
|
||||
const message = error instanceof Error ? error.message : String(error)
|
||||
if (message.includes('aborted') || message.includes('AbortError')) {
|
||||
return { success: false, error: 'Request timed out (15s)' }
|
||||
}
|
||||
throw error
|
||||
return { success: false, error: message }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,8 @@ import { getEmbeddingService } from '../embedding'
|
||||
import { getVectorStore } from '../store'
|
||||
import { getSessionChunks } from '../chunking'
|
||||
import { loadRAGConfig } from '../config'
|
||||
import { chat } from '../../llm'
|
||||
import { completeSimple, type TextContent as PiTextContent } from '@mariozechner/pi-ai'
|
||||
import { getActiveConfig, buildPiModel } from '../../llm'
|
||||
import { aiLogger as logger } from '../../logger'
|
||||
|
||||
/**
|
||||
@@ -33,21 +34,28 @@ const QUERY_REWRITE_PROMPT = `你是一个查询优化专家。请将用户的
|
||||
*/
|
||||
async function rewriteQuery(query: string, abortSignal?: AbortSignal): Promise<string> {
|
||||
try {
|
||||
const activeConfig = getActiveConfig()
|
||||
if (!activeConfig) return query
|
||||
|
||||
const piModel = buildPiModel(activeConfig)
|
||||
const prompt = QUERY_REWRITE_PROMPT.replace('{query}', query)
|
||||
|
||||
const response = await chat(
|
||||
[
|
||||
{ role: 'system', content: '你是一个查询优化专家,专门将用户问题改写为更适合语义检索的形式。' },
|
||||
{ role: 'user', content: prompt },
|
||||
],
|
||||
{
|
||||
temperature: 0.3,
|
||||
maxTokens: 200,
|
||||
abortSignal,
|
||||
}
|
||||
)
|
||||
const result = await completeSimple(piModel, {
|
||||
systemPrompt: '你是一个查询优化专家,专门将用户问题改写为更适合语义检索的形式。',
|
||||
messages: [{ role: 'user', content: prompt, timestamp: Date.now() }],
|
||||
}, {
|
||||
apiKey: activeConfig.apiKey,
|
||||
temperature: 0.3,
|
||||
maxTokens: 200,
|
||||
signal: abortSignal,
|
||||
})
|
||||
|
||||
const rewritten = result.content
|
||||
.filter((item): item is PiTextContent => item.type === 'text')
|
||||
.map((item) => item.text)
|
||||
.join('')
|
||||
.trim()
|
||||
|
||||
const rewritten = response.content.trim()
|
||||
return rewritten || query
|
||||
} catch (error) {
|
||||
logger.warn('[Semantic Pipeline] Query rewrite failed, using original query:', error)
|
||||
|
||||
@@ -8,11 +8,41 @@
|
||||
*/
|
||||
|
||||
import Database from 'better-sqlite3'
|
||||
import { chat } from '../llm'
|
||||
import { completeSimple, type TextContent as PiTextContent } from '@mariozechner/pi-ai'
|
||||
import { getActiveConfig, buildPiModel } from '../llm'
|
||||
import { getDbPath, openDatabase } from '../../database/core'
|
||||
import { aiLogger } from '../logger'
|
||||
import { t } from '../../i18n'
|
||||
|
||||
/** 调用 LLM 生成文本(直接使用 pi-ai completeSimple) */
|
||||
async function llmComplete(
|
||||
systemPrompt: string,
|
||||
userPrompt: string,
|
||||
options?: { temperature?: number; maxTokens?: number }
|
||||
): Promise<string> {
|
||||
const activeConfig = getActiveConfig()
|
||||
if (!activeConfig) {
|
||||
throw new Error(t('llm.notConfigured'))
|
||||
}
|
||||
|
||||
const piModel = buildPiModel(activeConfig)
|
||||
const now = Date.now()
|
||||
|
||||
const result = await completeSimple(piModel, {
|
||||
systemPrompt,
|
||||
messages: [{ role: 'user', content: userPrompt, timestamp: now }],
|
||||
}, {
|
||||
apiKey: activeConfig.apiKey,
|
||||
temperature: options?.temperature,
|
||||
maxTokens: options?.maxTokens,
|
||||
})
|
||||
|
||||
return result.content
|
||||
.filter((item): item is PiTextContent => item.type === 'text')
|
||||
.map((item) => item.text)
|
||||
.join('')
|
||||
}
|
||||
|
||||
/** 最小消息数阈值(少于此数量不生成摘要) */
|
||||
const MIN_MESSAGE_COUNT = 3
|
||||
|
||||
@@ -408,17 +438,12 @@ export async function generateSessionSummary(
|
||||
* 直接生成摘要(适用于短会话)
|
||||
*/
|
||||
async function generateDirectSummary(content: string, lengthLimit: number, locale: string): Promise<string> {
|
||||
const response = await chat(
|
||||
[
|
||||
{ role: 'system', content: t('summary.systemPromptDirect') },
|
||||
{ role: 'user', content: buildSummaryPrompt(content, lengthLimit, locale) },
|
||||
],
|
||||
{
|
||||
temperature: 0.3,
|
||||
maxTokens: 300,
|
||||
}
|
||||
const result = await llmComplete(
|
||||
t('summary.systemPromptDirect'),
|
||||
buildSummaryPrompt(content, lengthLimit, locale),
|
||||
{ temperature: 0.3, maxTokens: 300 }
|
||||
)
|
||||
return response.content.trim()
|
||||
return result.trim()
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -437,17 +462,12 @@ async function generateMapReduceSummary(
|
||||
|
||||
for (let i = 0; i < segments.length; i++) {
|
||||
const segmentContent = formatMessages(segments[i])
|
||||
const response = await chat(
|
||||
[
|
||||
{ role: 'system', content: t('summary.systemPromptDirect') },
|
||||
{ role: 'user', content: buildSubSummaryPrompt(segmentContent, locale) },
|
||||
],
|
||||
{
|
||||
temperature: 0.3,
|
||||
maxTokens: 100,
|
||||
}
|
||||
const result = await llmComplete(
|
||||
t('summary.systemPromptDirect'),
|
||||
buildSubSummaryPrompt(segmentContent, locale),
|
||||
{ temperature: 0.3, maxTokens: 100 }
|
||||
)
|
||||
subSummaries.push(response.content.trim())
|
||||
subSummaries.push(result.trim())
|
||||
}
|
||||
|
||||
// 2. Reduce:合并子摘要
|
||||
@@ -455,18 +475,13 @@ async function generateMapReduceSummary(
|
||||
return subSummaries[0]
|
||||
}
|
||||
|
||||
const mergeResponse = await chat(
|
||||
[
|
||||
{ role: 'system', content: t('summary.systemPromptMerge') },
|
||||
{ role: 'user', content: buildMergePrompt(subSummaries, lengthLimit, locale) },
|
||||
],
|
||||
{
|
||||
temperature: 0.3,
|
||||
maxTokens: 300,
|
||||
}
|
||||
const mergeResult = await llmComplete(
|
||||
t('summary.systemPromptMerge'),
|
||||
buildMergePrompt(subSummaries, lengthLimit, locale),
|
||||
{ temperature: 0.3, maxTokens: 300 }
|
||||
)
|
||||
|
||||
return mergeResponse.content.trim()
|
||||
return mergeResult.trim()
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user