mirror of
https://github.com/hellodigua/ChatLab.git
synced 2026-05-06 05:01:19 +08:00
feat: 接入AI sdk
This commit is contained in:
@@ -96,6 +96,8 @@ type StreamMode = 'text' | 'think' | 'tool_call'
|
||||
function createStreamParser(handlers: {
|
||||
onText: (text: string) => void
|
||||
onThink: (text: string, tag: string) => void
|
||||
onThinkStart?: (tag: string) => void
|
||||
onThinkEnd?: (tag: string) => void
|
||||
}): { push: (text: string) => void; flush: () => void } {
|
||||
let buffer = ''
|
||||
let mode: StreamMode = 'text'
|
||||
@@ -165,6 +167,7 @@ function createStreamParser(handlers: {
|
||||
// 进入思考模式
|
||||
currentThinkTag = hit.tag.slice(1, -1)
|
||||
mode = 'think'
|
||||
handlers.onThinkStart?.(currentThinkTag)
|
||||
buffer = buffer.slice(startTags[startTagsLower.indexOf(hit.tag)].length)
|
||||
continue
|
||||
}
|
||||
@@ -194,6 +197,7 @@ function createStreamParser(handlers: {
|
||||
|
||||
buffer = buffer.slice(endIndex + endTag.length)
|
||||
mode = 'text'
|
||||
handlers.onThinkEnd?.(currentThinkTag)
|
||||
currentThinkTag = ''
|
||||
continue
|
||||
}
|
||||
@@ -265,6 +269,8 @@ export interface AgentStreamChunk {
|
||||
content?: string
|
||||
/** 思考标签名称(type=think 时) */
|
||||
thinkTag?: string
|
||||
/** 思考耗时(毫秒,type=think 时可选) */
|
||||
thinkDurationMs?: number
|
||||
/** 工具名称(type=tool_start/tool_result 时) */
|
||||
toolName?: string
|
||||
/** 工具调用参数(type=tool_start 时) */
|
||||
@@ -680,14 +686,24 @@ export class Agent {
|
||||
let accumulatedContent = ''
|
||||
let roundContent = ''
|
||||
let toolCalls: ToolCall[] | undefined
|
||||
let thinkStartAt: number | null = null // 记录思考开始时间
|
||||
const parser = createStreamParser({
|
||||
onText: (text) => {
|
||||
roundContent += text
|
||||
onChunk({ type: 'content', content: text })
|
||||
},
|
||||
onThinkStart: () => {
|
||||
thinkStartAt = Date.now()
|
||||
},
|
||||
onThink: (text, tag) => {
|
||||
onChunk({ type: 'think', content: text, thinkTag: tag })
|
||||
},
|
||||
onThinkEnd: (tag) => {
|
||||
if (thinkStartAt === null) return
|
||||
const durationMs = Date.now() - thinkStartAt
|
||||
thinkStartAt = null
|
||||
onChunk({ type: 'think', content: '', thinkTag: tag, thinkDurationMs: durationMs })
|
||||
},
|
||||
})
|
||||
|
||||
// 流式调用 LLM(传入 abortSignal)
|
||||
@@ -851,14 +867,24 @@ export class Agent {
|
||||
|
||||
// 最后一轮不带 tools(传入 abortSignal)
|
||||
let finalRawContent = ''
|
||||
let finalThinkStartAt: number | null = null // 记录最终思考开始时间
|
||||
const finalParser = createStreamParser({
|
||||
onText: (text) => {
|
||||
finalContent += text
|
||||
onChunk({ type: 'content', content: text })
|
||||
},
|
||||
onThinkStart: () => {
|
||||
finalThinkStartAt = Date.now()
|
||||
},
|
||||
onThink: (text, tag) => {
|
||||
onChunk({ type: 'think', content: text, thinkTag: tag })
|
||||
},
|
||||
onThinkEnd: (tag) => {
|
||||
if (finalThinkStartAt === null) return
|
||||
const durationMs = Date.now() - finalThinkStartAt
|
||||
finalThinkStartAt = null
|
||||
onChunk({ type: 'think', content: '', thinkTag: tag, thinkDurationMs: durationMs })
|
||||
},
|
||||
})
|
||||
for await (const chunk of chatStream(this.messages, {
|
||||
...this.config.llmOptions,
|
||||
|
||||
@@ -117,7 +117,7 @@ export interface AIConversation {
|
||||
*/
|
||||
export type ContentBlock =
|
||||
| { type: 'text'; text: string }
|
||||
| { type: 'think'; tag: string; text: string } // 思考内容块
|
||||
| { type: 'think'; tag: string; text: string; durationMs?: number } // 思考内容块
|
||||
| {
|
||||
type: 'tool'
|
||||
tool: {
|
||||
|
||||
@@ -1,341 +0,0 @@
|
||||
/**
|
||||
* DeepSeek LLM Provider
|
||||
* 使用 OpenAI 兼容的 API 格式,支持 Function Calling
|
||||
*/
|
||||
|
||||
import type {
|
||||
ILLMService,
|
||||
LLMProvider,
|
||||
ChatMessage,
|
||||
ChatOptions,
|
||||
ChatResponse,
|
||||
ChatStreamChunk,
|
||||
ProviderInfo,
|
||||
ToolCall,
|
||||
} from './types'
|
||||
|
||||
const DEFAULT_BASE_URL = 'https://api.deepseek.com'
|
||||
|
||||
const MODELS = [
|
||||
{ id: 'deepseek-chat', name: 'DeepSeek Chat', description: '通用对话模型' },
|
||||
{ id: 'deepseek-coder', name: 'DeepSeek Coder', description: '代码生成模型' },
|
||||
]
|
||||
|
||||
export const DEEPSEEK_INFO: ProviderInfo = {
|
||||
id: 'deepseek',
|
||||
name: 'DeepSeek',
|
||||
description: 'DeepSeek AI 大语言模型',
|
||||
defaultBaseUrl: DEFAULT_BASE_URL,
|
||||
models: MODELS,
|
||||
}
|
||||
|
||||
export class DeepSeekService implements ILLMService {
|
||||
private apiKey: string
|
||||
private baseUrl: string
|
||||
private model: string
|
||||
|
||||
constructor(apiKey: string, model?: string, baseUrl?: string) {
|
||||
this.apiKey = apiKey
|
||||
this.baseUrl = baseUrl || DEFAULT_BASE_URL
|
||||
this.model = model || 'deepseek-chat'
|
||||
}
|
||||
|
||||
getProvider(): LLMProvider {
|
||||
return 'deepseek'
|
||||
}
|
||||
|
||||
getModels(): string[] {
|
||||
return MODELS.map((m) => m.id)
|
||||
}
|
||||
|
||||
getDefaultModel(): string {
|
||||
return 'deepseek-chat'
|
||||
}
|
||||
|
||||
async chat(messages: ChatMessage[], options?: ChatOptions): Promise<ChatResponse> {
|
||||
// 构建请求体
|
||||
const requestBody: Record<string, unknown> = {
|
||||
model: this.model,
|
||||
messages: messages.map((m) => {
|
||||
const msg: Record<string, unknown> = { role: m.role, content: m.content }
|
||||
// 处理 tool 消息
|
||||
if (m.role === 'tool' && m.tool_call_id) {
|
||||
msg.tool_call_id = m.tool_call_id
|
||||
}
|
||||
// 处理 assistant 消息中的 tool_calls
|
||||
if (m.role === 'assistant' && m.tool_calls) {
|
||||
msg.tool_calls = m.tool_calls
|
||||
}
|
||||
return msg
|
||||
}),
|
||||
temperature: options?.temperature ?? 0.7,
|
||||
max_tokens: options?.maxTokens ?? 2048,
|
||||
stream: false,
|
||||
}
|
||||
|
||||
// 如果有 tools,添加到请求体
|
||||
if (options?.tools && options.tools.length > 0) {
|
||||
requestBody.tools = options.tools
|
||||
}
|
||||
|
||||
const response = await fetch(`${this.baseUrl}/v1/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${this.apiKey}`,
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
signal: options?.abortSignal,
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text()
|
||||
throw new Error(`DeepSeek API error: ${response.status} - ${error}`)
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
const choice = data.choices?.[0]
|
||||
const message = choice?.message
|
||||
|
||||
// 解析 finish_reason
|
||||
let finishReason: ChatResponse['finishReason'] = 'error'
|
||||
if (choice?.finish_reason === 'stop') {
|
||||
finishReason = 'stop'
|
||||
} else if (choice?.finish_reason === 'length') {
|
||||
finishReason = 'length'
|
||||
} else if (choice?.finish_reason === 'tool_calls') {
|
||||
finishReason = 'tool_calls'
|
||||
}
|
||||
|
||||
// 解析 tool_calls
|
||||
let toolCalls: ToolCall[] | undefined
|
||||
if (message?.tool_calls && Array.isArray(message.tool_calls)) {
|
||||
toolCalls = message.tool_calls.map((tc: Record<string, unknown>) => ({
|
||||
id: tc.id as string,
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: (tc.function as Record<string, unknown>)?.name as string,
|
||||
arguments: (tc.function as Record<string, unknown>)?.arguments as string,
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
return {
|
||||
content: message?.content || '',
|
||||
finishReason,
|
||||
tool_calls: toolCalls,
|
||||
usage: data.usage
|
||||
? {
|
||||
promptTokens: data.usage.prompt_tokens,
|
||||
completionTokens: data.usage.completion_tokens,
|
||||
totalTokens: data.usage.total_tokens,
|
||||
}
|
||||
: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
async *chatStream(messages: ChatMessage[], options?: ChatOptions): AsyncGenerator<ChatStreamChunk> {
|
||||
// 构建请求体
|
||||
const requestBody: Record<string, unknown> = {
|
||||
model: this.model,
|
||||
messages: messages.map((m) => {
|
||||
const msg: Record<string, unknown> = { role: m.role, content: m.content }
|
||||
if (m.role === 'tool' && m.tool_call_id) {
|
||||
msg.tool_call_id = m.tool_call_id
|
||||
}
|
||||
if (m.role === 'assistant' && m.tool_calls) {
|
||||
msg.tool_calls = m.tool_calls
|
||||
}
|
||||
return msg
|
||||
}),
|
||||
temperature: options?.temperature ?? 0.7,
|
||||
max_tokens: options?.maxTokens ?? 2048,
|
||||
stream: true,
|
||||
// 启用流式响应中的 usage 统计
|
||||
stream_options: { include_usage: true },
|
||||
}
|
||||
|
||||
if (options?.tools && options.tools.length > 0) {
|
||||
requestBody.tools = options.tools
|
||||
}
|
||||
|
||||
const response = await fetch(`${this.baseUrl}/v1/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${this.apiKey}`,
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
signal: options?.abortSignal,
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text()
|
||||
throw new Error(`DeepSeek API error: ${response.status} - ${error}`)
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader()
|
||||
if (!reader) {
|
||||
throw new Error('Failed to get response reader')
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder()
|
||||
let buffer = ''
|
||||
|
||||
// 用于收集流式 tool_calls
|
||||
const toolCallsAccumulator: Map<number, { id: string; name: string; arguments: string }> = new Map()
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
// 检查是否已中止
|
||||
if (options?.abortSignal?.aborted) {
|
||||
yield { content: '', isFinished: true, finishReason: 'stop' }
|
||||
return
|
||||
}
|
||||
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
|
||||
buffer += decoder.decode(value, { stream: true })
|
||||
const lines = buffer.split('\n')
|
||||
buffer = lines.pop() || ''
|
||||
|
||||
for (const line of lines) {
|
||||
const trimmed = line.trim()
|
||||
if (!trimmed || !trimmed.startsWith('data: ')) continue
|
||||
|
||||
const data = trimmed.slice(6)
|
||||
if (data === '[DONE]') {
|
||||
// 如果有累积的 tool_calls,返回它们
|
||||
if (toolCallsAccumulator.size > 0) {
|
||||
const toolCalls: ToolCall[] = Array.from(toolCallsAccumulator.values()).map((tc) => ({
|
||||
id: tc.id,
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: tc.name,
|
||||
arguments: tc.arguments,
|
||||
},
|
||||
}))
|
||||
yield { content: '', isFinished: true, finishReason: 'tool_calls', tool_calls: toolCalls }
|
||||
} else {
|
||||
yield { content: '', isFinished: true, finishReason: 'stop' }
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(data)
|
||||
const delta = parsed.choices?.[0]?.delta
|
||||
const finishReason = parsed.choices?.[0]?.finish_reason
|
||||
|
||||
// 处理文本内容
|
||||
if (delta?.content) {
|
||||
yield {
|
||||
content: delta.content,
|
||||
isFinished: false,
|
||||
}
|
||||
}
|
||||
|
||||
// 处理流式 tool_calls(增量累积)
|
||||
if (delta?.tool_calls && Array.isArray(delta.tool_calls)) {
|
||||
for (const tc of delta.tool_calls) {
|
||||
const index = tc.index ?? 0
|
||||
const existing = toolCallsAccumulator.get(index)
|
||||
if (existing) {
|
||||
// 累积 arguments
|
||||
if (tc.function?.arguments) {
|
||||
existing.arguments += tc.function.arguments
|
||||
}
|
||||
} else {
|
||||
// 新的 tool_call
|
||||
toolCallsAccumulator.set(index, {
|
||||
id: tc.id || '',
|
||||
name: tc.function?.name || '',
|
||||
arguments: tc.function?.arguments || '',
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (finishReason) {
|
||||
let reason: ChatStreamChunk['finishReason'] = 'error'
|
||||
if (finishReason === 'stop') {
|
||||
reason = 'stop'
|
||||
} else if (finishReason === 'length') {
|
||||
reason = 'length'
|
||||
} else if (finishReason === 'tool_calls') {
|
||||
reason = 'tool_calls'
|
||||
}
|
||||
|
||||
// 解析 usage 信息
|
||||
const usage = parsed.usage
|
||||
? {
|
||||
promptTokens: parsed.usage.prompt_tokens,
|
||||
completionTokens: parsed.usage.completion_tokens,
|
||||
totalTokens: parsed.usage.total_tokens,
|
||||
}
|
||||
: undefined
|
||||
|
||||
// 如果有 tool_calls,返回它们
|
||||
if (toolCallsAccumulator.size > 0) {
|
||||
const toolCalls: ToolCall[] = Array.from(toolCallsAccumulator.values()).map((tc) => ({
|
||||
id: tc.id,
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: tc.name,
|
||||
arguments: tc.arguments,
|
||||
},
|
||||
}))
|
||||
yield { content: '', isFinished: true, finishReason: reason, tool_calls: toolCalls, usage }
|
||||
} else {
|
||||
yield { content: '', isFinished: true, finishReason: reason, usage }
|
||||
}
|
||||
return
|
||||
}
|
||||
} catch {
|
||||
// 忽略解析错误,继续处理下一行
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
// 如果是中止错误,正常返回
|
||||
if (error instanceof Error && error.name === 'AbortError') {
|
||||
yield { content: '', isFinished: true, finishReason: 'stop' }
|
||||
return
|
||||
}
|
||||
throw error
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
}
|
||||
|
||||
async validateApiKey(): Promise<{ success: boolean; error?: string }> {
|
||||
try {
|
||||
// 发送一个简单请求验证 API Key
|
||||
const response = await fetch(`${this.baseUrl}/v1/models`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
Authorization: `Bearer ${this.apiKey}`,
|
||||
},
|
||||
})
|
||||
if (response.ok) {
|
||||
return { success: true }
|
||||
}
|
||||
// 尝试获取错误详情
|
||||
const errorText = await response.text()
|
||||
let errorMessage = `HTTP ${response.status}`
|
||||
try {
|
||||
const errorJson = JSON.parse(errorText)
|
||||
errorMessage = errorJson.error?.message || errorJson.message || errorMessage
|
||||
} catch {
|
||||
if (errorText) {
|
||||
errorMessage = errorText.slice(0, 200)
|
||||
}
|
||||
}
|
||||
return { success: false, error: errorMessage }
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error)
|
||||
return { success: false, error: errorMessage }
|
||||
}
|
||||
}
|
||||
}
|
||||
+159
-423
@@ -1,8 +1,11 @@
|
||||
/**
|
||||
* Google Gemini LLM Provider
|
||||
* 使用 Gemini REST API 格式,支持 Function Calling
|
||||
* 使用 AI SDK 的 Google Generative AI Provider,支持 Function Calling
|
||||
*/
|
||||
|
||||
import { createGoogleGenerativeAI } from '@ai-sdk/google'
|
||||
import { generateText, streamText } from 'ai'
|
||||
import type { ContentPart, ModelMessage, ToolSet, TypedToolCall } from 'ai'
|
||||
import type {
|
||||
ILLMService,
|
||||
LLMProvider,
|
||||
@@ -10,13 +13,15 @@ import type {
|
||||
ChatOptions,
|
||||
ChatResponse,
|
||||
ChatStreamChunk,
|
||||
ProviderInfo,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ProviderInfo,
|
||||
} from './types'
|
||||
import { aiLogger } from '../logger'
|
||||
import { buildModelMessages, buildToolSet, mapFinishReason, mapUsage } from './sdkUtils'
|
||||
|
||||
const DEFAULT_BASE_URL = 'https://generativelanguage.googleapis.com'
|
||||
const DEFAULT_API_VERSION = '/v1beta'
|
||||
const GEMINI_MAX_RETRIES = 1
|
||||
|
||||
const MODELS = [
|
||||
{ id: 'gemini-3-flash-preview', name: 'Gemini 3 Flash Preview', description: '高速预览版' },
|
||||
@@ -31,85 +36,133 @@ export const GEMINI_INFO: ProviderInfo = {
|
||||
models: MODELS,
|
||||
}
|
||||
|
||||
// ==================== Gemini API 类型定义 ====================
|
||||
/**
|
||||
* 统一处理 Gemini 的 baseUrl,确保包含 /v1beta
|
||||
*/
|
||||
function normalizeBaseUrl(baseUrl?: string): string {
|
||||
let normalized = (baseUrl || DEFAULT_BASE_URL).replace(/\/+$/, '')
|
||||
|
||||
/** Gemini 消息 part(支持多种类型) */
|
||||
interface GeminiPart {
|
||||
text?: string
|
||||
functionCall?: {
|
||||
name: string
|
||||
args: Record<string, unknown>
|
||||
if (normalized.endsWith(DEFAULT_API_VERSION)) {
|
||||
return normalized
|
||||
}
|
||||
functionResponse?: {
|
||||
name: string
|
||||
response: unknown
|
||||
}
|
||||
/** Gemini 3+ 模型的思考签名 */
|
||||
thoughtSignature?: string
|
||||
|
||||
return `${normalized}${DEFAULT_API_VERSION}`
|
||||
}
|
||||
|
||||
/** Gemini 消息内容 */
|
||||
interface GeminiContent {
|
||||
role: 'user' | 'model'
|
||||
parts: GeminiPart[]
|
||||
const GEMINI_PROVIDER_KEYS = ['google', 'vertex', 'gemini']
|
||||
|
||||
/**
|
||||
* 从 AI SDK 的 providerMetadata 中提取 thoughtSignature
|
||||
*/
|
||||
function getThoughtSignature(metadata?: unknown): string | undefined {
|
||||
if (!metadata || typeof metadata !== 'object') {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const record = metadata as Record<string, unknown>
|
||||
for (const key of GEMINI_PROVIDER_KEYS) {
|
||||
const providerMeta = record[key]
|
||||
if (providerMeta && typeof providerMeta === 'object') {
|
||||
const signature = (providerMeta as Record<string, unknown>).thoughtSignature
|
||||
if (typeof signature === 'string' && signature) {
|
||||
return signature
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const directSignature = record.thoughtSignature
|
||||
if (typeof directSignature === 'string' && directSignature) {
|
||||
return directSignature
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
||||
/** Gemini 函数声明(对应 OpenAI 的 ToolDefinition) */
|
||||
interface GeminiFunctionDeclaration {
|
||||
name: string
|
||||
description: string
|
||||
parameters: {
|
||||
type: string
|
||||
properties: Record<string, unknown>
|
||||
required?: string[]
|
||||
}
|
||||
/**
|
||||
* Gemini 工具调用需要 thoughtSignature
|
||||
*/
|
||||
function mapGeminiToolCalls(toolCalls: TypedToolCall<ToolSet>[]): ToolCall[] {
|
||||
return toolCalls.map((tc) => ({
|
||||
id: tc.toolCallId,
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: tc.toolName,
|
||||
arguments: JSON.stringify(tc.input ?? {}),
|
||||
},
|
||||
thoughtSignature: getThoughtSignature(tc.providerMetadata),
|
||||
}))
|
||||
}
|
||||
|
||||
/** Gemini 请求体 */
|
||||
interface GeminiRequest {
|
||||
contents: GeminiContent[]
|
||||
generationConfig?: {
|
||||
temperature?: number
|
||||
maxOutputTokens?: number
|
||||
}
|
||||
systemInstruction?: {
|
||||
parts: Array<{ text: string }>
|
||||
}
|
||||
tools?: Array<{
|
||||
functionDeclarations: GeminiFunctionDeclaration[]
|
||||
}>
|
||||
}
|
||||
/**
|
||||
* 构建 Gemini 的模型消息,补充 thoughtSignature
|
||||
*/
|
||||
function buildGeminiModelMessages(messages: ChatMessage[]): ModelMessage[] {
|
||||
const modelMessages = buildModelMessages(messages)
|
||||
const signatureMap = new Map<string, string>()
|
||||
|
||||
/** Gemini 响应候选项 */
|
||||
interface GeminiCandidate {
|
||||
content?: {
|
||||
parts?: GeminiPart[]
|
||||
role?: string
|
||||
for (const message of messages) {
|
||||
if (message.role !== 'assistant' || !message.tool_calls) continue
|
||||
for (const toolCall of message.tool_calls) {
|
||||
if (toolCall.thoughtSignature) {
|
||||
signatureMap.set(toolCall.id, toolCall.thoughtSignature)
|
||||
}
|
||||
}
|
||||
}
|
||||
finishReason?: string
|
||||
}
|
||||
|
||||
/** Gemini API 响应 */
|
||||
interface GeminiResponse {
|
||||
candidates?: GeminiCandidate[]
|
||||
usageMetadata?: {
|
||||
promptTokenCount?: number
|
||||
candidatesTokenCount?: number
|
||||
totalTokenCount?: number
|
||||
if (signatureMap.size === 0) {
|
||||
return modelMessages
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== GeminiService 类 ====================
|
||||
return modelMessages.map((message) => {
|
||||
if (message.role !== 'assistant' || !Array.isArray(message.content)) {
|
||||
return message
|
||||
}
|
||||
|
||||
const contentParts = message.content.map((part) => {
|
||||
if (part.type !== 'tool-call') {
|
||||
return part
|
||||
}
|
||||
|
||||
const signature = signatureMap.get(part.toolCallId)
|
||||
if (!signature) {
|
||||
return part
|
||||
}
|
||||
|
||||
const nextPart: ContentPart<ToolSet> & {
|
||||
providerOptions?: Record<string, { thoughtSignature?: string }>
|
||||
} = {
|
||||
...part,
|
||||
providerOptions: {
|
||||
google: { thoughtSignature: signature },
|
||||
},
|
||||
}
|
||||
|
||||
return nextPart
|
||||
})
|
||||
|
||||
return {
|
||||
...message,
|
||||
content: contentParts,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
export class GeminiService implements ILLMService {
|
||||
private apiKey: string
|
||||
private baseUrl: string
|
||||
private model: string
|
||||
|
||||
private provider = createGoogleGenerativeAI()
|
||||
|
||||
constructor(apiKey: string, model?: string, baseUrl?: string) {
|
||||
this.apiKey = apiKey
|
||||
this.baseUrl = baseUrl || DEFAULT_BASE_URL
|
||||
this.baseUrl = normalizeBaseUrl(baseUrl)
|
||||
this.model = model || 'gemini-3-flash-preview'
|
||||
this.provider = createGoogleGenerativeAI({
|
||||
apiKey: this.apiKey,
|
||||
baseURL: this.baseUrl,
|
||||
name: 'gemini',
|
||||
})
|
||||
}
|
||||
|
||||
getProvider(): LLMProvider {
|
||||
@@ -124,404 +177,87 @@ export class GeminiService implements ILLMService {
|
||||
return 'gemini-3-flash-preview'
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 OpenAI 格式的 tools 转换为 Gemini 格式
|
||||
*/
|
||||
private convertTools(tools?: ToolDefinition[]): Array<{ functionDeclarations: GeminiFunctionDeclaration[] }> | undefined {
|
||||
if (!tools || tools.length === 0) return undefined
|
||||
|
||||
const functionDeclarations: GeminiFunctionDeclaration[] = tools.map((tool) => ({
|
||||
name: tool.function.name,
|
||||
description: tool.function.description,
|
||||
parameters: tool.function.parameters,
|
||||
}))
|
||||
|
||||
return [{ functionDeclarations }]
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 OpenAI 格式消息转换为 Gemini 格式
|
||||
*/
|
||||
private convertMessages(messages: ChatMessage[]): {
|
||||
contents: GeminiContent[]
|
||||
systemInstruction?: { parts: Array<{ text: string }> }
|
||||
} {
|
||||
const contents: GeminiContent[] = []
|
||||
let systemInstruction: { parts: Array<{ text: string }> } | undefined
|
||||
|
||||
for (const msg of messages) {
|
||||
if (msg.role === 'system') {
|
||||
// Gemini 使用 systemInstruction 处理系统提示
|
||||
systemInstruction = {
|
||||
parts: [{ text: msg.content }],
|
||||
}
|
||||
} else if (msg.role === 'user') {
|
||||
contents.push({
|
||||
role: 'user',
|
||||
parts: [{ text: msg.content }],
|
||||
})
|
||||
} else if (msg.role === 'assistant') {
|
||||
// 处理 assistant 消息(可能包含 tool_calls)
|
||||
if (msg.tool_calls && msg.tool_calls.length > 0) {
|
||||
// 有工具调用的 assistant 消息
|
||||
const parts: GeminiPart[] = []
|
||||
if (msg.content) {
|
||||
parts.push({ text: msg.content })
|
||||
}
|
||||
for (const tc of msg.tool_calls) {
|
||||
const part: GeminiPart = {
|
||||
functionCall: {
|
||||
name: tc.function.name,
|
||||
args: JSON.parse(tc.function.arguments),
|
||||
},
|
||||
}
|
||||
// Gemini 3+ 需要包含 thoughtSignature
|
||||
if (tc.thoughtSignature) {
|
||||
part.thoughtSignature = tc.thoughtSignature
|
||||
}
|
||||
parts.push(part)
|
||||
}
|
||||
contents.push({ role: 'model', parts })
|
||||
} else {
|
||||
// 普通文本消息
|
||||
contents.push({
|
||||
role: 'model',
|
||||
parts: [{ text: msg.content }],
|
||||
})
|
||||
}
|
||||
} else if (msg.role === 'tool') {
|
||||
// 工具结果消息 - 在 Gemini 中作为 user 角色的 functionResponse
|
||||
// 注意:需要从消息内容解析工具名称和结果
|
||||
// tool_call_id 格式通常是 "call_xxx",我们需要从上下文获取工具名
|
||||
// 这里简化处理:假设内容是 JSON 格式的结果
|
||||
try {
|
||||
const result = JSON.parse(msg.content)
|
||||
// 尝试从上一条 assistant 消息中找到对应的 tool_call
|
||||
// 由于 Gemini 需要 name,我们从 tool_call_id 推断或使用默认值
|
||||
contents.push({
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: msg.tool_call_id?.replace('call_', '') || 'unknown',
|
||||
response: result,
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
} catch {
|
||||
// 如果不是 JSON,直接作为文本结果
|
||||
contents.push({
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: msg.tool_call_id?.replace('call_', '') || 'unknown',
|
||||
response: { result: msg.content },
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { contents, systemInstruction }
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 API URL
|
||||
*/
|
||||
private buildUrl(stream: boolean): string {
|
||||
const action = stream ? 'streamGenerateContent' : 'generateContent'
|
||||
const base = this.baseUrl.replace(/\/$/, '')
|
||||
return `${base}/v1beta/models/${this.model}:${action}?key=${this.apiKey}`
|
||||
}
|
||||
|
||||
/**
|
||||
* 从 Gemini parts 中提取工具调用
|
||||
*/
|
||||
private extractToolCalls(parts?: GeminiPart[]): ToolCall[] | undefined {
|
||||
if (!parts) return undefined
|
||||
|
||||
const toolCalls: ToolCall[] = []
|
||||
for (const part of parts) {
|
||||
if (part.functionCall) {
|
||||
toolCalls.push({
|
||||
id: `call_${part.functionCall.name}_${Date.now()}`,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: part.functionCall.name,
|
||||
arguments: JSON.stringify(part.functionCall.args),
|
||||
},
|
||||
// 保存 Gemini 3+ 的思考签名
|
||||
thoughtSignature: part.thoughtSignature,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return toolCalls.length > 0 ? toolCalls : undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* 从 Gemini parts 中提取文本内容
|
||||
*/
|
||||
private extractText(parts?: GeminiPart[]): string {
|
||||
if (!parts) return ''
|
||||
return parts
|
||||
.filter((p) => p.text)
|
||||
.map((p) => p.text)
|
||||
.join('')
|
||||
}
|
||||
|
||||
async chat(messages: ChatMessage[], options?: ChatOptions): Promise<ChatResponse> {
|
||||
const { contents, systemInstruction } = this.convertMessages(messages)
|
||||
const model = this.provider.chat(this.model)
|
||||
const toolSet = buildToolSet(options?.tools)
|
||||
|
||||
const requestBody: GeminiRequest = {
|
||||
contents,
|
||||
generationConfig: {
|
||||
temperature: options?.temperature ?? 0.7,
|
||||
maxOutputTokens: options?.maxTokens ?? 2048,
|
||||
},
|
||||
}
|
||||
|
||||
if (systemInstruction) {
|
||||
requestBody.systemInstruction = systemInstruction
|
||||
}
|
||||
|
||||
// 添加工具定义
|
||||
const geminiTools = this.convertTools(options?.tools)
|
||||
if (geminiTools) {
|
||||
requestBody.tools = geminiTools
|
||||
}
|
||||
|
||||
const response = await fetch(this.buildUrl(false), {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
signal: options?.abortSignal,
|
||||
const result = await generateText({
|
||||
model,
|
||||
messages: buildGeminiModelMessages(messages),
|
||||
tools: toolSet,
|
||||
temperature: options?.temperature ?? 0.7,
|
||||
maxTokens: options?.maxTokens ?? 2048,
|
||||
// 降低 Gemini 自动重试次数,避免触发免费配额限制
|
||||
maxRetries: GEMINI_MAX_RETRIES,
|
||||
abortSignal: options?.abortSignal,
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text()
|
||||
throw new Error(`Gemini API error: ${response.status} - ${error}`)
|
||||
}
|
||||
|
||||
const data: GeminiResponse = await response.json()
|
||||
const candidate = data.candidates?.[0]
|
||||
const parts = candidate?.content?.parts
|
||||
const content = this.extractText(parts)
|
||||
const toolCalls = this.extractToolCalls(parts)
|
||||
|
||||
// 解析 finish_reason
|
||||
let finishReason: ChatResponse['finishReason'] = 'error'
|
||||
const reason = candidate?.finishReason
|
||||
if (reason === 'STOP') {
|
||||
finishReason = toolCalls ? 'tool_calls' : 'stop'
|
||||
} else if (reason === 'MAX_TOKENS') {
|
||||
finishReason = 'length'
|
||||
}
|
||||
const toolCalls = result.toolCalls.length > 0 ? mapGeminiToolCalls(result.toolCalls) : undefined
|
||||
|
||||
return {
|
||||
content,
|
||||
finishReason,
|
||||
content: result.text,
|
||||
finishReason: mapFinishReason(result.finishReason),
|
||||
tool_calls: toolCalls,
|
||||
usage: data.usageMetadata
|
||||
? {
|
||||
promptTokens: data.usageMetadata.promptTokenCount || 0,
|
||||
completionTokens: data.usageMetadata.candidatesTokenCount || 0,
|
||||
totalTokens: data.usageMetadata.totalTokenCount || 0,
|
||||
}
|
||||
: undefined,
|
||||
usage: mapUsage(result.usage),
|
||||
}
|
||||
}
|
||||
|
||||
async *chatStream(messages: ChatMessage[], options?: ChatOptions): AsyncGenerator<ChatStreamChunk> {
|
||||
const { contents, systemInstruction } = this.convertMessages(messages)
|
||||
const model = this.provider.chat(this.model)
|
||||
const toolSet = buildToolSet(options?.tools)
|
||||
|
||||
const requestBody: GeminiRequest = {
|
||||
contents,
|
||||
generationConfig: {
|
||||
temperature: options?.temperature ?? 0.7,
|
||||
maxOutputTokens: options?.maxTokens ?? 2048,
|
||||
},
|
||||
}
|
||||
|
||||
if (systemInstruction) {
|
||||
requestBody.systemInstruction = systemInstruction
|
||||
}
|
||||
|
||||
// 添加工具定义
|
||||
const geminiTools = this.convertTools(options?.tools)
|
||||
if (geminiTools) {
|
||||
requestBody.tools = geminiTools
|
||||
}
|
||||
|
||||
// Gemini 流式需要添加 alt=sse 参数
|
||||
const url = this.buildUrl(true) + '&alt=sse'
|
||||
|
||||
aiLogger.info('Gemini', '开始流式请求', {
|
||||
url: url.replace(/key=[^&]+/, 'key=***'),
|
||||
model: this.model,
|
||||
messagesCount: contents.length,
|
||||
hasSystemInstruction: !!systemInstruction,
|
||||
hasTools: !!geminiTools,
|
||||
const result = streamText({
|
||||
model,
|
||||
messages: buildGeminiModelMessages(messages),
|
||||
tools: toolSet,
|
||||
temperature: options?.temperature ?? 0.7,
|
||||
maxTokens: options?.maxTokens ?? 2048,
|
||||
// 降低 Gemini 自动重试次数,避免触发免费配额限制
|
||||
maxRetries: GEMINI_MAX_RETRIES,
|
||||
abortSignal: options?.abortSignal,
|
||||
})
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
signal: options?.abortSignal,
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text()
|
||||
aiLogger.error('Gemini', 'API 请求失败', { status: response.status, error: error.slice(0, 500) })
|
||||
throw new Error(`Gemini API error: ${response.status} - ${error}`)
|
||||
}
|
||||
|
||||
aiLogger.info('Gemini', 'API 响应成功,开始读取流')
|
||||
|
||||
const reader = response.body?.getReader()
|
||||
if (!reader) {
|
||||
throw new Error('Failed to get response reader')
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder()
|
||||
let buffer = ''
|
||||
// 用于累积工具调用(可能跨多个 chunk)
|
||||
const toolCallsAccumulator: ToolCall[] = []
|
||||
// 用于追踪最新的 usage 信息
|
||||
let latestUsage: { promptTokens: number; completionTokens: number; totalTokens: number } | undefined
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
// 检查是否已中止
|
||||
for await (const chunk of result.textStream) {
|
||||
if (options?.abortSignal?.aborted) {
|
||||
yield { content: '', isFinished: true, finishReason: 'stop' }
|
||||
return
|
||||
}
|
||||
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
|
||||
buffer += decoder.decode(value, { stream: true })
|
||||
const lines = buffer.split('\n')
|
||||
buffer = lines.pop() || ''
|
||||
|
||||
for (const line of lines) {
|
||||
const trimmed = line.trim()
|
||||
|
||||
if (!trimmed || !trimmed.startsWith('data: ')) continue
|
||||
|
||||
const data = trimmed.slice(6)
|
||||
if (data === '[DONE]') {
|
||||
if (toolCallsAccumulator.length > 0) {
|
||||
yield { content: '', isFinished: true, finishReason: 'tool_calls', tool_calls: toolCallsAccumulator }
|
||||
} else {
|
||||
yield { content: '', isFinished: true, finishReason: 'stop' }
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed: GeminiResponse = JSON.parse(data)
|
||||
const candidate = parsed.candidates?.[0]
|
||||
const parts = candidate?.content?.parts
|
||||
|
||||
// 处理文本内容
|
||||
const text = this.extractText(parts)
|
||||
if (text) {
|
||||
yield { content: text, isFinished: false }
|
||||
}
|
||||
|
||||
// 处理工具调用
|
||||
const toolCalls = this.extractToolCalls(parts)
|
||||
if (toolCalls) {
|
||||
aiLogger.info('Gemini', '检测到工具调用', { toolCalls: toolCalls.map((tc) => tc.function.name) })
|
||||
toolCallsAccumulator.push(...toolCalls)
|
||||
}
|
||||
|
||||
// 更新 usage 信息
|
||||
if (parsed.usageMetadata) {
|
||||
latestUsage = {
|
||||
promptTokens: parsed.usageMetadata.promptTokenCount || 0,
|
||||
completionTokens: parsed.usageMetadata.candidatesTokenCount || 0,
|
||||
totalTokens: parsed.usageMetadata.totalTokenCount || 0,
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否完成
|
||||
const finishReason = candidate?.finishReason
|
||||
if (finishReason) {
|
||||
aiLogger.info('Gemini', '流式响应完成', { finishReason, toolCallsCount: toolCallsAccumulator.length, usage: latestUsage })
|
||||
|
||||
if (toolCallsAccumulator.length > 0) {
|
||||
yield { content: '', isFinished: true, finishReason: 'tool_calls', tool_calls: toolCallsAccumulator, usage: latestUsage }
|
||||
} else {
|
||||
let reason: ChatStreamChunk['finishReason'] = 'stop'
|
||||
if (finishReason === 'MAX_TOKENS') {
|
||||
reason = 'length'
|
||||
}
|
||||
yield { content: '', isFinished: true, finishReason: reason, usage: latestUsage }
|
||||
}
|
||||
return
|
||||
}
|
||||
} catch (e) {
|
||||
// 记录解析错误
|
||||
aiLogger.warn('Gemini', 'SSE 数据解析失败', { data: data.slice(0, 200), error: String(e) })
|
||||
}
|
||||
if (chunk) {
|
||||
yield { content: chunk, isFinished: false }
|
||||
}
|
||||
}
|
||||
|
||||
// 如果循环正常结束,发送完成信号
|
||||
if (toolCallsAccumulator.length > 0) {
|
||||
yield { content: '', isFinished: true, finishReason: 'tool_calls', tool_calls: toolCallsAccumulator, usage: latestUsage }
|
||||
} else {
|
||||
yield { content: '', isFinished: true, finishReason: 'stop', usage: latestUsage }
|
||||
const finishReason = mapFinishReason(await result.finishReason)
|
||||
const toolCalls = await result.toolCalls
|
||||
const usage = mapUsage(await result.totalUsage)
|
||||
|
||||
yield {
|
||||
content: '',
|
||||
isFinished: true,
|
||||
finishReason,
|
||||
tool_calls: toolCalls.length > 0 ? mapGeminiToolCalls(toolCalls) : undefined,
|
||||
usage,
|
||||
}
|
||||
} catch (error) {
|
||||
// 如果是中止错误,正常返回
|
||||
if (error instanceof Error && error.name === 'AbortError') {
|
||||
yield { content: '', isFinished: true, finishReason: 'stop' }
|
||||
return
|
||||
}
|
||||
aiLogger.error('Gemini', '流式请求失败', { error: String(error) })
|
||||
throw error
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
}
|
||||
|
||||
async validateApiKey(): Promise<{ success: boolean; error?: string }> {
|
||||
try {
|
||||
// 使用 models.list API 验证 API Key
|
||||
const url = `${this.baseUrl.replace(/\/$/, '')}/v1beta/models?key=${this.apiKey}`
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: 'GET',
|
||||
const model = this.provider.chat(this.model)
|
||||
await generateText({
|
||||
model,
|
||||
prompt: 'Hi',
|
||||
maxTokens: 1,
|
||||
})
|
||||
|
||||
if (response.ok) {
|
||||
return { success: true }
|
||||
}
|
||||
|
||||
// 尝试获取错误详情
|
||||
const errorText = await response.text()
|
||||
let errorMessage = `HTTP ${response.status}`
|
||||
try {
|
||||
const errorJson = JSON.parse(errorText)
|
||||
errorMessage = errorJson.error?.message || errorJson.message || errorMessage
|
||||
} catch {
|
||||
if (errorText) {
|
||||
errorMessage = errorText.slice(0, 200)
|
||||
}
|
||||
}
|
||||
return { success: false, error: errorMessage }
|
||||
return { success: true }
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error)
|
||||
return { success: false, error: errorMessage }
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
import * as fs from 'fs'
|
||||
import * as path from 'path'
|
||||
import { randomUUID } from 'crypto'
|
||||
import { getAiDataDir, ensureDir } from '../../paths'
|
||||
import { getAiDataDir } from '../../paths'
|
||||
import type {
|
||||
LLMConfig,
|
||||
LLMProvider,
|
||||
@@ -19,8 +19,6 @@ import type {
|
||||
AIConfigStore,
|
||||
} from './types'
|
||||
import { MAX_CONFIG_COUNT } from './types'
|
||||
import { DeepSeekService, DEEPSEEK_INFO } from './deepseek'
|
||||
import { QwenService, QWEN_INFO } from './qwen'
|
||||
import { GeminiService, GEMINI_INFO } from './gemini'
|
||||
import { OpenAICompatibleService, OPENAI_COMPATIBLE_INFO } from './openai-compatible'
|
||||
import { aiLogger } from '../logger'
|
||||
@@ -30,6 +28,31 @@ export * from './types'
|
||||
|
||||
// ==================== 新增提供商信息 ====================
|
||||
|
||||
/** DeepSeek 提供商信息 */
|
||||
const DEEPSEEK_INFO: ProviderInfo = {
|
||||
id: 'deepseek',
|
||||
name: 'DeepSeek',
|
||||
description: 'DeepSeek AI 大语言模型',
|
||||
defaultBaseUrl: 'https://api.deepseek.com/v1',
|
||||
models: [
|
||||
{ id: 'deepseek-chat', name: 'DeepSeek Chat', description: '通用对话模型' },
|
||||
{ id: 'deepseek-coder', name: 'DeepSeek Coder', description: '代码生成模型' },
|
||||
],
|
||||
}
|
||||
|
||||
/** 通义千问 (Qwen) 提供商信息 */
|
||||
const QWEN_INFO: ProviderInfo = {
|
||||
id: 'qwen',
|
||||
name: '通义千问',
|
||||
description: '阿里云通义千问大语言模型',
|
||||
defaultBaseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||
models: [
|
||||
{ id: 'qwen-turbo', name: 'Qwen Turbo', description: '通义千问超大规模语言模型,速度快' },
|
||||
{ id: 'qwen-plus', name: 'Qwen Plus', description: '通义千问超大规模语言模型,效果好' },
|
||||
{ id: 'qwen-max', name: 'Qwen Max', description: '通义千问千亿级别超大规模语言模型' },
|
||||
],
|
||||
}
|
||||
|
||||
/** MiniMax 提供商信息 */
|
||||
const MINIMAX_INFO: ProviderInfo = {
|
||||
id: 'minimax',
|
||||
@@ -324,6 +347,36 @@ interface ExtendedLLMConfig extends LLMConfig {
|
||||
disableThinking?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* 不再自动补齐 Base URL,对 DeepSeek/Qwen 的格式做显式校验
|
||||
*/
|
||||
function validateProviderBaseUrl(provider: LLMProvider, baseUrl?: string): void {
|
||||
if (!baseUrl) return
|
||||
|
||||
const normalized = baseUrl.replace(/\/+$/, '')
|
||||
|
||||
if (provider === 'deepseek') {
|
||||
if (normalized.endsWith('/chat/completions')) {
|
||||
throw new Error('DeepSeek Base URL 请填写到 /v1 层级,不要包含 /chat/completions')
|
||||
}
|
||||
if (!normalized.endsWith('/v1')) {
|
||||
throw new Error('DeepSeek Base URL 需要以 /v1 结尾')
|
||||
}
|
||||
}
|
||||
|
||||
if (provider === 'qwen') {
|
||||
if (normalized.endsWith('/chat/completions')) {
|
||||
throw new Error('通义千问 Base URL 请填写到 /v1 层级,不要包含 /chat/completions')
|
||||
}
|
||||
if (!normalized.endsWith('/v1')) {
|
||||
throw new Error('通义千问 Base URL 需要以 /v1 结尾')
|
||||
}
|
||||
if (normalized.includes('dashscope.aliyuncs.com') && !normalized.includes('/compatible-mode/')) {
|
||||
throw new Error('通义千问 Base URL 需要包含 /compatible-mode/v1')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建 LLM 服务实例
|
||||
*/
|
||||
@@ -331,22 +384,39 @@ export function createLLMService(config: ExtendedLLMConfig): ILLMService {
|
||||
// 获取提供商的默认 baseUrl
|
||||
const providerInfo = getProviderInfo(config.provider)
|
||||
const baseUrl = config.baseUrl || providerInfo?.defaultBaseUrl
|
||||
// 未显式指定时使用提供商的首个模型作为默认模型
|
||||
const resolvedModel = config.model || providerInfo?.models?.[0]?.id
|
||||
// 不自动补齐,发现不合法直接抛错给用户
|
||||
validateProviderBaseUrl(config.provider, baseUrl)
|
||||
|
||||
switch (config.provider) {
|
||||
case 'deepseek':
|
||||
return new DeepSeekService(config.apiKey, config.model, config.baseUrl)
|
||||
case 'qwen':
|
||||
return new QwenService(config.apiKey, config.model, config.baseUrl)
|
||||
case 'gemini':
|
||||
return new GeminiService(config.apiKey, config.model, config.baseUrl)
|
||||
return new GeminiService(config.apiKey, resolvedModel, config.baseUrl)
|
||||
// 新增的官方API都使用 OpenAI 兼容格式
|
||||
case 'deepseek':
|
||||
case 'qwen':
|
||||
case 'minimax':
|
||||
case 'glm':
|
||||
case 'kimi':
|
||||
case 'doubao':
|
||||
return new OpenAICompatibleService(config.apiKey, config.model, baseUrl)
|
||||
// DeepSeek/Qwen 走 OpenAI 兼容实现时,禁用本地思考注入
|
||||
return new OpenAICompatibleService(
|
||||
config.apiKey,
|
||||
resolvedModel,
|
||||
baseUrl,
|
||||
config.provider === 'deepseek' || config.provider === 'qwen' ? false : undefined,
|
||||
config.provider,
|
||||
providerInfo?.models
|
||||
)
|
||||
case 'openai-compatible':
|
||||
return new OpenAICompatibleService(config.apiKey, config.model, config.baseUrl, config.disableThinking)
|
||||
return new OpenAICompatibleService(
|
||||
config.apiKey,
|
||||
resolvedModel,
|
||||
config.baseUrl,
|
||||
config.disableThinking,
|
||||
config.provider,
|
||||
providerInfo?.models
|
||||
)
|
||||
default:
|
||||
throw new Error(`Unknown LLM provider: ${config.provider}`)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
* 支持任何兼容 OpenAI API 格式的服务(如 Ollama、LocalAI、vLLM 等)
|
||||
*/
|
||||
|
||||
import { createOpenAI } from '@ai-sdk/openai'
|
||||
import { generateText, streamText } from 'ai'
|
||||
import type { ModelMessage, ToolSet, TypedToolCall } from 'ai'
|
||||
import type {
|
||||
ILLMService,
|
||||
LLMProvider,
|
||||
@@ -10,12 +13,14 @@ import type {
|
||||
ChatOptions,
|
||||
ChatResponse,
|
||||
ChatStreamChunk,
|
||||
ProviderInfo,
|
||||
ToolCall,
|
||||
ProviderInfo,
|
||||
} from './types'
|
||||
import { aiLogger } from '../logger'
|
||||
import { buildModelMessages, buildToolSet, mapFinishReason, mapToolCalls, mapUsage } from './sdkUtils'
|
||||
|
||||
const DEFAULT_BASE_URL = 'http://localhost:11434/v1'
|
||||
const DEFAULT_THOUGHT_SIGNATURE = 'context_engineering_is_the_way_to_go'
|
||||
|
||||
export const OPENAI_COMPATIBLE_INFO: ProviderInfo = {
|
||||
id: 'openai-compatible',
|
||||
@@ -29,364 +34,267 @@ export const OPENAI_COMPATIBLE_INFO: ProviderInfo = {
|
||||
],
|
||||
}
|
||||
|
||||
/**
|
||||
* 统一处理 baseUrl:去掉尾部斜杠和多余路径
|
||||
*/
|
||||
function normalizeBaseUrl(baseUrl?: string): string {
|
||||
let processed = baseUrl || DEFAULT_BASE_URL
|
||||
processed = processed.replace(/\/+$/, '')
|
||||
if (processed.endsWith('/chat/completions')) {
|
||||
processed = processed.slice(0, -'/chat/completions'.length)
|
||||
}
|
||||
return processed
|
||||
}
|
||||
|
||||
/**
|
||||
* MiniMax 流式返回可能是累计文本,这里按前缀增量去重
|
||||
*/
|
||||
function dedupeCumulativeStreamChunk(
|
||||
chunk: string,
|
||||
previousText: string
|
||||
): { delta: string; nextText: string } {
|
||||
if (!previousText) {
|
||||
return { delta: chunk, nextText: chunk }
|
||||
}
|
||||
|
||||
if (chunk.startsWith(previousText)) {
|
||||
return { delta: chunk.slice(previousText.length), nextText: chunk }
|
||||
}
|
||||
|
||||
if (previousText.startsWith(chunk)) {
|
||||
// 偶发回退或重复帧,保持已输出内容
|
||||
return { delta: '', nextText: previousText }
|
||||
}
|
||||
|
||||
// 无法判定为累计时,退化为增量追加
|
||||
return { delta: chunk, nextText: previousText + chunk }
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 包装 fetch:注入思考开关和 thought_signature
|
||||
*/
|
||||
function createCompatFetch(disableThinking: boolean): typeof fetch {
|
||||
return async (input, init) => {
|
||||
if (!init?.body || typeof init.body !== 'string') {
|
||||
return fetch(input, init)
|
||||
}
|
||||
|
||||
let parsedBody: Record<string, unknown> | null = null
|
||||
try {
|
||||
parsedBody = JSON.parse(init.body) as Record<string, unknown>
|
||||
} catch {
|
||||
return fetch(input, init)
|
||||
}
|
||||
|
||||
if (!parsedBody) {
|
||||
return fetch(input, init)
|
||||
}
|
||||
|
||||
let changed = false
|
||||
|
||||
if (Array.isArray(parsedBody.messages)) {
|
||||
const messages = parsedBody.messages as Array<Record<string, unknown>>
|
||||
// 为 Gemini 兼容后端补充 thought_signature
|
||||
for (const message of messages) {
|
||||
if (
|
||||
message &&
|
||||
typeof message === 'object' &&
|
||||
(message as { role?: string }).role === 'assistant' &&
|
||||
Array.isArray((message as { tool_calls?: unknown[] }).tool_calls)
|
||||
) {
|
||||
const toolCalls = (message as { tool_calls: Array<Record<string, unknown>> }).tool_calls
|
||||
for (const toolCall of toolCalls) {
|
||||
const typedCall = toolCall as Record<string, unknown> & {
|
||||
thought_signature?: string
|
||||
thoughtSignature?: string
|
||||
}
|
||||
if (!typedCall.thought_signature && !typedCall.thoughtSignature) {
|
||||
typedCall.thought_signature = DEFAULT_THOUGHT_SIGNATURE
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 禁用思考模式(用于本地模型)
|
||||
if (disableThinking) {
|
||||
const chatTemplate = parsedBody.chat_template_kwargs
|
||||
if (!chatTemplate || typeof chatTemplate !== 'object') {
|
||||
parsedBody.chat_template_kwargs = { enable_thinking: false }
|
||||
changed = true
|
||||
} else if (!(chatTemplate as { enable_thinking?: boolean }).enable_thinking) {
|
||||
parsedBody.chat_template_kwargs = {
|
||||
...(chatTemplate as Record<string, unknown>),
|
||||
enable_thinking: false,
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!changed) {
|
||||
return fetch(input, init)
|
||||
}
|
||||
|
||||
const nextInit: RequestInit = {
|
||||
...init,
|
||||
body: JSON.stringify(parsedBody),
|
||||
}
|
||||
|
||||
return fetch(input, nextInit)
|
||||
}
|
||||
}
|
||||
|
||||
export class OpenAICompatibleService implements ILLMService {
|
||||
private apiKey: string
|
||||
private baseUrl: string
|
||||
private model: string
|
||||
private disableThinking: boolean
|
||||
private providerId: LLMProvider
|
||||
private models: ProviderInfo['models']
|
||||
private defaultModel: string
|
||||
private provider = createOpenAI()
|
||||
|
||||
constructor(apiKey: string, model?: string, baseUrl?: string, disableThinking?: boolean) {
|
||||
this.apiKey = apiKey || 'sk-no-key-required' // 本地服务可能不需要 API Key
|
||||
// 智能处理 baseUrl:如果用户已经包含 /chat/completions,则去掉它
|
||||
let processedBaseUrl = baseUrl || DEFAULT_BASE_URL
|
||||
processedBaseUrl = processedBaseUrl.replace(/\/+$/, '') // 去掉尾部斜杠
|
||||
if (processedBaseUrl.endsWith('/chat/completions')) {
|
||||
processedBaseUrl = processedBaseUrl.slice(0, -'/chat/completions'.length)
|
||||
}
|
||||
this.baseUrl = processedBaseUrl
|
||||
this.model = model || 'llama3.2'
|
||||
this.disableThinking = disableThinking ?? true // 默认禁用思考模式
|
||||
}
|
||||
constructor(
|
||||
apiKey: string,
|
||||
model?: string,
|
||||
baseUrl?: string,
|
||||
disableThinking?: boolean,
|
||||
providerId?: LLMProvider,
|
||||
models?: ProviderInfo['models']
|
||||
) {
|
||||
const normalizedBaseUrl = normalizeBaseUrl(baseUrl)
|
||||
const resolvedApiKey = apiKey || 'sk-no-key-required'
|
||||
const resolvedDisableThinking = disableThinking ?? true
|
||||
const resolvedProviderId = providerId ?? 'openai-compatible'
|
||||
const resolvedModels = models && models.length > 0 ? models : OPENAI_COMPATIBLE_INFO.models
|
||||
const defaultModel = resolvedModels[0]?.id || 'llama3.2'
|
||||
const resolvedModel = model || defaultModel
|
||||
|
||||
/**
|
||||
* 设置 Bearer Token 认证头
|
||||
*/
|
||||
private setAuthHeaders(headers: Record<string, string>): void {
|
||||
if (this.apiKey && this.apiKey !== 'sk-no-key-required') {
|
||||
headers['Authorization'] = `Bearer ${this.apiKey}`
|
||||
}
|
||||
this.apiKey = resolvedApiKey
|
||||
this.baseUrl = normalizedBaseUrl
|
||||
this.providerId = resolvedProviderId
|
||||
this.models = resolvedModels
|
||||
this.defaultModel = defaultModel
|
||||
this.model = resolvedModel
|
||||
|
||||
this.provider = createOpenAI({
|
||||
apiKey: resolvedApiKey,
|
||||
baseURL: normalizedBaseUrl,
|
||||
name: 'openai-compatible',
|
||||
fetch: createCompatFetch(resolvedDisableThinking),
|
||||
})
|
||||
}
|
||||
|
||||
getProvider(): LLMProvider {
|
||||
return 'openai-compatible'
|
||||
return this.providerId
|
||||
}
|
||||
|
||||
getModels(): string[] {
|
||||
return OPENAI_COMPATIBLE_INFO.models.map((m) => m.id)
|
||||
return this.models.map((m) => m.id)
|
||||
}
|
||||
|
||||
getDefaultModel(): string {
|
||||
return 'llama3.2'
|
||||
return this.defaultModel
|
||||
}
|
||||
|
||||
// 统一处理消息映射,保持与旧实现一致
|
||||
private buildMessages(messages: ChatMessage[]): ModelMessage[] {
|
||||
return buildModelMessages(messages)
|
||||
}
|
||||
|
||||
// 统一映射工具调用结果
|
||||
private mapToolCalls(toolCalls: TypedToolCall<ToolSet>[]): ToolCall[] {
|
||||
return mapToolCalls(toolCalls)
|
||||
}
|
||||
|
||||
// 仅 MiniMax 需要累计去重,避免其他模型缓存全文
|
||||
private shouldTrackStreamText(): boolean {
|
||||
return this.providerId === 'minimax'
|
||||
}
|
||||
|
||||
// MiniMax 流式返回可能是累计文本,按前缀增量去重
|
||||
private getStreamChunkDelta(chunk: string, previousText: string): { delta: string; nextText: string } {
|
||||
if (this.providerId !== 'minimax') {
|
||||
return { delta: chunk, nextText: previousText + chunk }
|
||||
}
|
||||
return dedupeCumulativeStreamChunk(chunk, previousText)
|
||||
}
|
||||
|
||||
async chat(messages: ChatMessage[], options?: ChatOptions): Promise<ChatResponse> {
|
||||
const requestBody: Record<string, unknown> = {
|
||||
model: this.model,
|
||||
messages: messages.map((m) => {
|
||||
const msg: Record<string, unknown> = { role: m.role, content: m.content }
|
||||
if (m.role === 'tool' && m.tool_call_id) {
|
||||
msg.tool_call_id = m.tool_call_id
|
||||
}
|
||||
if (m.role === 'assistant' && m.tool_calls) {
|
||||
// 确保 thoughtSignature 被传递(Gemini 3+ 通过 OpenAI 兼容 API 需要)
|
||||
msg.tool_calls = m.tool_calls.map((tc) => ({
|
||||
...tc,
|
||||
// 如果没有签名,使用虚拟签名(用于 Vertex AI/Gemini 后端)
|
||||
thought_signature: tc.thoughtSignature || 'context_engineering_is_the_way_to_go',
|
||||
}))
|
||||
}
|
||||
return msg
|
||||
}),
|
||||
const model = this.provider.chat(this.model)
|
||||
const toolSet = buildToolSet(options?.tools)
|
||||
|
||||
const result = await generateText({
|
||||
model,
|
||||
messages: this.buildMessages(messages),
|
||||
tools: toolSet,
|
||||
temperature: options?.temperature ?? 0.7,
|
||||
max_tokens: options?.maxTokens ?? 2048,
|
||||
stream: false,
|
||||
}
|
||||
|
||||
if (options?.tools && options.tools.length > 0) {
|
||||
requestBody.tools = options.tools
|
||||
}
|
||||
|
||||
// 禁用思考模式(用于 Qwen3、DeepSeek-R1 等本地模型)
|
||||
if (this.disableThinking) {
|
||||
requestBody.chat_template_kwargs = { enable_thinking: false }
|
||||
}
|
||||
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
this.setAuthHeaders(headers)
|
||||
|
||||
const response = await fetch(`${this.baseUrl}/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify(requestBody),
|
||||
signal: options?.abortSignal,
|
||||
maxTokens: options?.maxTokens ?? 2048,
|
||||
abortSignal: options?.abortSignal,
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text()
|
||||
throw new Error(`OpenAI Compatible API error: ${response.status} - ${error}`)
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
const choice = data.choices?.[0]
|
||||
const message = choice?.message
|
||||
|
||||
let finishReason: ChatResponse['finishReason'] = 'error'
|
||||
if (choice?.finish_reason === 'stop') {
|
||||
finishReason = 'stop'
|
||||
} else if (choice?.finish_reason === 'length') {
|
||||
finishReason = 'length'
|
||||
} else if (choice?.finish_reason === 'tool_calls') {
|
||||
finishReason = 'tool_calls'
|
||||
}
|
||||
|
||||
let toolCalls: ToolCall[] | undefined
|
||||
if (message?.tool_calls && Array.isArray(message.tool_calls)) {
|
||||
toolCalls = message.tool_calls.map((tc: Record<string, unknown>) => ({
|
||||
id: tc.id as string,
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: (tc.function as Record<string, unknown>)?.name as string,
|
||||
arguments: (tc.function as Record<string, unknown>)?.arguments as string,
|
||||
},
|
||||
// 提取 thoughtSignature(Gemini 3+ 通过 OpenAI 兼容 API 可能返回此字段)
|
||||
thoughtSignature: (tc.thought_signature || tc.thoughtSignature) as string | undefined,
|
||||
}))
|
||||
}
|
||||
const toolCalls = result.toolCalls.length > 0 ? this.mapToolCalls(result.toolCalls) : undefined
|
||||
|
||||
return {
|
||||
content: message?.content || '',
|
||||
finishReason,
|
||||
content: result.text,
|
||||
finishReason: mapFinishReason(result.finishReason),
|
||||
tool_calls: toolCalls,
|
||||
usage: data.usage
|
||||
? {
|
||||
promptTokens: data.usage.prompt_tokens,
|
||||
completionTokens: data.usage.completion_tokens,
|
||||
totalTokens: data.usage.total_tokens,
|
||||
}
|
||||
: undefined,
|
||||
usage: mapUsage(result.usage),
|
||||
}
|
||||
}
|
||||
|
||||
async *chatStream(messages: ChatMessage[], options?: ChatOptions): AsyncGenerator<ChatStreamChunk> {
|
||||
const requestBody: Record<string, unknown> = {
|
||||
model: this.model,
|
||||
messages: messages.map((m) => {
|
||||
const msg: Record<string, unknown> = { role: m.role, content: m.content }
|
||||
if (m.role === 'tool' && m.tool_call_id) {
|
||||
msg.tool_call_id = m.tool_call_id
|
||||
}
|
||||
if (m.role === 'assistant' && m.tool_calls) {
|
||||
// 确保 thoughtSignature 被传递(Gemini 3+ 通过 OpenAI 兼容 API 需要)
|
||||
msg.tool_calls = m.tool_calls.map((tc) => ({
|
||||
...tc,
|
||||
// 如果没有签名,使用虚拟签名(用于 Vertex AI/Gemini 后端)
|
||||
thought_signature: tc.thoughtSignature || 'context_engineering_is_the_way_to_go',
|
||||
}))
|
||||
}
|
||||
return msg
|
||||
}),
|
||||
const model = this.provider.chat(this.model)
|
||||
const toolSet = buildToolSet(options?.tools)
|
||||
const shouldTrack = this.shouldTrackStreamText()
|
||||
let streamedText = ''
|
||||
|
||||
const result = streamText({
|
||||
model,
|
||||
messages: this.buildMessages(messages),
|
||||
tools: toolSet,
|
||||
temperature: options?.temperature ?? 0.7,
|
||||
max_tokens: options?.maxTokens ?? 2048,
|
||||
stream: true,
|
||||
// 启用流式响应中的 usage 统计(OpenAI API 兼容)
|
||||
stream_options: { include_usage: true },
|
||||
}
|
||||
|
||||
if (options?.tools && options.tools.length > 0) {
|
||||
requestBody.tools = options.tools
|
||||
}
|
||||
|
||||
// 禁用思考模式(用于 Qwen3、DeepSeek-R1 等本地模型)
|
||||
if (this.disableThinking) {
|
||||
requestBody.chat_template_kwargs = { enable_thinking: false }
|
||||
}
|
||||
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
this.setAuthHeaders(headers)
|
||||
|
||||
const response = await fetch(`${this.baseUrl}/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify(requestBody),
|
||||
signal: options?.abortSignal,
|
||||
maxTokens: options?.maxTokens ?? 2048,
|
||||
abortSignal: options?.abortSignal,
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text()
|
||||
throw new Error(`OpenAI Compatible API error: ${response.status} - ${error}`)
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader()
|
||||
if (!reader) {
|
||||
throw new Error('Failed to get response reader')
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder()
|
||||
let buffer = ''
|
||||
const toolCallsAccumulator: Map<number, { id: string; name: string; arguments: string }> = new Map()
|
||||
|
||||
let totalChunks = 0
|
||||
let totalContent = ''
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
// 检查是否已中止
|
||||
for await (const chunk of result.textStream) {
|
||||
if (options?.abortSignal?.aborted) {
|
||||
yield { content: '', isFinished: true, finishReason: 'stop' }
|
||||
return
|
||||
}
|
||||
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
|
||||
buffer += decoder.decode(value, { stream: true })
|
||||
const lines = buffer.split('\n')
|
||||
buffer = lines.pop() || ''
|
||||
|
||||
for (const line of lines) {
|
||||
const trimmed = line.trim()
|
||||
if (!trimmed || !trimmed.startsWith('data: ')) continue
|
||||
|
||||
const data = trimmed.slice(6)
|
||||
|
||||
if (data === '[DONE]') {
|
||||
if (toolCallsAccumulator.size > 0) {
|
||||
const toolCalls: ToolCall[] = Array.from(toolCallsAccumulator.values()).map((tc) => ({
|
||||
id: tc.id,
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: tc.name,
|
||||
arguments: tc.arguments,
|
||||
},
|
||||
// 传递 thoughtSignature(如果存在)
|
||||
thoughtSignature: tc.thoughtSignature,
|
||||
}))
|
||||
yield { content: '', isFinished: true, finishReason: 'tool_calls', tool_calls: toolCalls }
|
||||
} else {
|
||||
yield { content: '', isFinished: true, finishReason: 'stop' }
|
||||
}
|
||||
return
|
||||
if (chunk) {
|
||||
const { delta, nextText } = this.getStreamChunkDelta(chunk, streamedText)
|
||||
if (shouldTrack) {
|
||||
streamedText = nextText
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(data)
|
||||
const delta = parsed.choices?.[0]?.delta
|
||||
const finishReason = parsed.choices?.[0]?.finish_reason
|
||||
|
||||
// 调试:如果有 delta 但没有 content,记录其他可能的内容字段(只写日志文件,不输出控制台)
|
||||
if (delta && !delta.content && !delta.tool_calls && !finishReason) {
|
||||
const deltaKeys = Object.keys(delta)
|
||||
if (deltaKeys.length > 0 && !deltaKeys.every((k) => ['role', 'name', 'audio_content'].includes(k))) {
|
||||
aiLogger.debug('OpenAI-Compatible', '检测到未处理的 delta 字段', { deltaKeys, delta })
|
||||
}
|
||||
}
|
||||
|
||||
if (delta?.content) {
|
||||
totalChunks++
|
||||
totalContent += delta.content
|
||||
yield {
|
||||
content: delta.content,
|
||||
isFinished: false,
|
||||
}
|
||||
}
|
||||
|
||||
if (delta?.tool_calls && Array.isArray(delta.tool_calls)) {
|
||||
for (const tc of delta.tool_calls) {
|
||||
const index = tc.index ?? 0
|
||||
const existing = toolCallsAccumulator.get(index)
|
||||
if (existing) {
|
||||
if (tc.function?.arguments) {
|
||||
existing.arguments += tc.function.arguments
|
||||
}
|
||||
// 更新 thoughtSignature(如果存在)
|
||||
if (tc.thought_signature || tc.thoughtSignature) {
|
||||
existing.thoughtSignature = tc.thought_signature || tc.thoughtSignature
|
||||
}
|
||||
} else {
|
||||
toolCallsAccumulator.set(index, {
|
||||
id: tc.id || '',
|
||||
name: tc.function?.name || '',
|
||||
arguments: tc.function?.arguments || '',
|
||||
// 提取 thoughtSignature(Gemini 3+ 通过 OpenAI 兼容 API 可能返回此字段)
|
||||
// 如果 API 不返回,使用 Gemini 文档提供的虚拟签名绕过验证
|
||||
thoughtSignature: tc.thought_signature || tc.thoughtSignature || 'context_engineering_is_the_way_to_go',
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (finishReason) {
|
||||
let reason: ChatStreamChunk['finishReason'] = 'error'
|
||||
if (finishReason === 'stop') {
|
||||
reason = 'stop'
|
||||
} else if (finishReason === 'length') {
|
||||
reason = 'length'
|
||||
} else if (finishReason === 'tool_calls') {
|
||||
reason = 'tool_calls'
|
||||
}
|
||||
|
||||
// 解析 usage 信息
|
||||
const usage = parsed.usage
|
||||
? {
|
||||
promptTokens: parsed.usage.prompt_tokens,
|
||||
completionTokens: parsed.usage.completion_tokens,
|
||||
totalTokens: parsed.usage.total_tokens,
|
||||
}
|
||||
: undefined
|
||||
|
||||
if (toolCallsAccumulator.size > 0) {
|
||||
const toolCalls: ToolCall[] = Array.from(toolCallsAccumulator.values()).map((tc) => ({
|
||||
id: tc.id,
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: tc.name,
|
||||
arguments: tc.arguments,
|
||||
},
|
||||
// 传递 thoughtSignature(如果存在)
|
||||
thoughtSignature: tc.thoughtSignature,
|
||||
}))
|
||||
yield { content: '', isFinished: true, finishReason: reason, tool_calls: toolCalls, usage }
|
||||
} else {
|
||||
yield { content: '', isFinished: true, finishReason: reason, usage }
|
||||
}
|
||||
return
|
||||
}
|
||||
} catch {
|
||||
// 忽略解析错误,继续处理下一行
|
||||
if (delta) {
|
||||
yield { content: delta, isFinished: false }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 流读取完成后的处理(如果没有收到 [DONE] 或 finish_reason)
|
||||
// 这种情况可能发生在某些 API 不发送标准结束标记时
|
||||
aiLogger.info('OpenAI-Compatible', '流循环结束,执行兜底处理', {
|
||||
totalChunks,
|
||||
totalContentLength: totalContent.length,
|
||||
toolCallsCount: toolCallsAccumulator.size,
|
||||
bufferRemaining: buffer.length,
|
||||
})
|
||||
const finishReason = mapFinishReason(await result.finishReason)
|
||||
const toolCalls = await result.toolCalls
|
||||
const usage = mapUsage(await result.totalUsage)
|
||||
|
||||
// 如果有累积的 tool_calls,发送它们
|
||||
if (toolCallsAccumulator.size > 0) {
|
||||
const toolCalls: ToolCall[] = Array.from(toolCallsAccumulator.values()).map((tc) => ({
|
||||
id: tc.id,
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: tc.name,
|
||||
arguments: tc.arguments,
|
||||
},
|
||||
// 传递 thoughtSignature(如果存在)
|
||||
thoughtSignature: tc.thoughtSignature,
|
||||
}))
|
||||
yield { content: '', isFinished: true, finishReason: 'tool_calls', tool_calls: toolCalls }
|
||||
} else {
|
||||
// 没有 tool_calls,发送普通完成信号
|
||||
yield { content: '', isFinished: true, finishReason: 'stop' }
|
||||
yield {
|
||||
content: '',
|
||||
isFinished: true,
|
||||
finishReason,
|
||||
tool_calls: toolCalls.length > 0 ? this.mapToolCalls(toolCalls) : undefined,
|
||||
usage,
|
||||
}
|
||||
} catch (error) {
|
||||
// 如果是中止错误,正常返回
|
||||
if (error instanceof Error && error.name === 'AbortError') {
|
||||
yield { content: '', isFinished: true, finishReason: 'stop' }
|
||||
return
|
||||
}
|
||||
aiLogger.error('OpenAI-Compatible', '流处理异常', { error: String(error) })
|
||||
// 使用 providerId 作为日志前缀,便于区分兼容服务来源
|
||||
aiLogger.error(this.providerId, '流式请求失败', { error: String(error) })
|
||||
throw error
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -395,12 +303,11 @@ export class OpenAICompatibleService implements ILLMService {
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
this.setAuthHeaders(headers)
|
||||
if (this.apiKey && this.apiKey !== 'sk-no-key-required') {
|
||||
headers['Authorization'] = `Bearer ${this.apiKey}`
|
||||
}
|
||||
|
||||
const url = `${this.baseUrl}/chat/completions`
|
||||
|
||||
// 发送一个简单的测试请求来验证连接和认证
|
||||
const response = await fetch(url, {
|
||||
const response = await fetch(`${this.baseUrl}/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify({
|
||||
@@ -410,16 +317,14 @@ export class OpenAICompatibleService implements ILLMService {
|
||||
}),
|
||||
})
|
||||
|
||||
// 200 表示成功,401/403 表示认证失败,其他状态可能是参数问题但服务可达
|
||||
if (response.ok) {
|
||||
return { success: true }
|
||||
}
|
||||
|
||||
// 尝试获取错误详情
|
||||
const errorText = await response.text()
|
||||
let errorMessage = `HTTP ${response.status}`
|
||||
try {
|
||||
const errorJson = JSON.parse(errorText)
|
||||
const errorJson = JSON.parse(errorText) as { error?: { message?: string }; message?: string }
|
||||
errorMessage = errorJson.error?.message || errorJson.message || errorMessage
|
||||
} catch {
|
||||
if (errorText) {
|
||||
@@ -427,13 +332,10 @@ export class OpenAICompatibleService implements ILLMService {
|
||||
}
|
||||
}
|
||||
|
||||
// 认证失败
|
||||
if (response.status === 401 || response.status === 403) {
|
||||
return { success: false, error: errorMessage }
|
||||
}
|
||||
|
||||
// 其他错误(如 400 参数错误)但服务可达,认为验证通过
|
||||
// 因为这说明认证成功了,只是请求参数有问题
|
||||
return { success: true }
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error)
|
||||
|
||||
@@ -1,319 +0,0 @@
|
||||
/**
|
||||
* 通义千问 (Qwen) LLM Provider
|
||||
* 使用阿里云 DashScope 兼容 OpenAI 格式的 API,支持 Function Calling
|
||||
*/
|
||||
|
||||
import type {
|
||||
ILLMService,
|
||||
LLMProvider,
|
||||
ChatMessage,
|
||||
ChatOptions,
|
||||
ChatResponse,
|
||||
ChatStreamChunk,
|
||||
ProviderInfo,
|
||||
ToolCall,
|
||||
} from './types'
|
||||
|
||||
const DEFAULT_BASE_URL = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
|
||||
|
||||
const MODELS = [
|
||||
{ id: 'qwen-turbo', name: 'Qwen Turbo', description: '通义千问超大规模语言模型,速度快' },
|
||||
{ id: 'qwen-plus', name: 'Qwen Plus', description: '通义千问超大规模语言模型,效果好' },
|
||||
{ id: 'qwen-max', name: 'Qwen Max', description: '通义千问千亿级别超大规模语言模型' },
|
||||
]
|
||||
|
||||
export const QWEN_INFO: ProviderInfo = {
|
||||
id: 'qwen',
|
||||
name: '通义千问',
|
||||
description: '阿里云通义千问大语言模型',
|
||||
defaultBaseUrl: DEFAULT_BASE_URL,
|
||||
models: MODELS,
|
||||
}
|
||||
|
||||
export class QwenService implements ILLMService {
|
||||
private apiKey: string
|
||||
private baseUrl: string
|
||||
private model: string
|
||||
|
||||
constructor(apiKey: string, model?: string, baseUrl?: string) {
|
||||
this.apiKey = apiKey
|
||||
this.baseUrl = baseUrl || DEFAULT_BASE_URL
|
||||
this.model = model || 'qwen-turbo'
|
||||
}
|
||||
|
||||
getProvider(): LLMProvider {
|
||||
return 'qwen'
|
||||
}
|
||||
|
||||
getModels(): string[] {
|
||||
return MODELS.map((m) => m.id)
|
||||
}
|
||||
|
||||
getDefaultModel(): string {
|
||||
return 'qwen-turbo'
|
||||
}
|
||||
|
||||
async chat(messages: ChatMessage[], options?: ChatOptions): Promise<ChatResponse> {
|
||||
// 构建请求体
|
||||
const requestBody: Record<string, unknown> = {
|
||||
model: this.model,
|
||||
messages: messages.map((m) => {
|
||||
const msg: Record<string, unknown> = { role: m.role, content: m.content }
|
||||
if (m.role === 'tool' && m.tool_call_id) {
|
||||
msg.tool_call_id = m.tool_call_id
|
||||
}
|
||||
if (m.role === 'assistant' && m.tool_calls) {
|
||||
msg.tool_calls = m.tool_calls
|
||||
}
|
||||
return msg
|
||||
}),
|
||||
temperature: options?.temperature ?? 0.7,
|
||||
max_tokens: options?.maxTokens ?? 2048,
|
||||
stream: false,
|
||||
}
|
||||
|
||||
if (options?.tools && options.tools.length > 0) {
|
||||
requestBody.tools = options.tools
|
||||
}
|
||||
|
||||
const response = await fetch(`${this.baseUrl}/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${this.apiKey}`,
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
signal: options?.abortSignal,
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text()
|
||||
throw new Error(`Qwen API error: ${response.status} - ${error}`)
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
const choice = data.choices?.[0]
|
||||
const message = choice?.message
|
||||
|
||||
// 解析 finish_reason
|
||||
let finishReason: ChatResponse['finishReason'] = 'error'
|
||||
if (choice?.finish_reason === 'stop') {
|
||||
finishReason = 'stop'
|
||||
} else if (choice?.finish_reason === 'length') {
|
||||
finishReason = 'length'
|
||||
} else if (choice?.finish_reason === 'tool_calls') {
|
||||
finishReason = 'tool_calls'
|
||||
}
|
||||
|
||||
// 解析 tool_calls
|
||||
let toolCalls: ToolCall[] | undefined
|
||||
if (message?.tool_calls && Array.isArray(message.tool_calls)) {
|
||||
toolCalls = message.tool_calls.map((tc: Record<string, unknown>) => ({
|
||||
id: tc.id as string,
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: (tc.function as Record<string, unknown>)?.name as string,
|
||||
arguments: (tc.function as Record<string, unknown>)?.arguments as string,
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
return {
|
||||
content: message?.content || '',
|
||||
finishReason,
|
||||
tool_calls: toolCalls,
|
||||
usage: data.usage
|
||||
? {
|
||||
promptTokens: data.usage.prompt_tokens,
|
||||
completionTokens: data.usage.completion_tokens,
|
||||
totalTokens: data.usage.total_tokens,
|
||||
}
|
||||
: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
async *chatStream(messages: ChatMessage[], options?: ChatOptions): AsyncGenerator<ChatStreamChunk> {
|
||||
// 构建请求体
|
||||
const requestBody: Record<string, unknown> = {
|
||||
model: this.model,
|
||||
messages: messages.map((m) => {
|
||||
const msg: Record<string, unknown> = { role: m.role, content: m.content }
|
||||
if (m.role === 'tool' && m.tool_call_id) {
|
||||
msg.tool_call_id = m.tool_call_id
|
||||
}
|
||||
if (m.role === 'assistant' && m.tool_calls) {
|
||||
msg.tool_calls = m.tool_calls
|
||||
}
|
||||
return msg
|
||||
}),
|
||||
temperature: options?.temperature ?? 0.7,
|
||||
max_tokens: options?.maxTokens ?? 2048,
|
||||
stream: true,
|
||||
// 启用流式响应中的 usage 统计
|
||||
stream_options: { include_usage: true },
|
||||
}
|
||||
|
||||
if (options?.tools && options.tools.length > 0) {
|
||||
requestBody.tools = options.tools
|
||||
}
|
||||
|
||||
const response = await fetch(`${this.baseUrl}/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${this.apiKey}`,
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
signal: options?.abortSignal,
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text()
|
||||
throw new Error(`Qwen API error: ${response.status} - ${error}`)
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader()
|
||||
if (!reader) {
|
||||
throw new Error('Failed to get response reader')
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder()
|
||||
let buffer = ''
|
||||
const toolCallsAccumulator: Map<number, { id: string; name: string; arguments: string }> = new Map()
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
// 检查是否已中止
|
||||
if (options?.abortSignal?.aborted) {
|
||||
yield { content: '', isFinished: true, finishReason: 'stop' }
|
||||
return
|
||||
}
|
||||
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
|
||||
buffer += decoder.decode(value, { stream: true })
|
||||
const lines = buffer.split('\n')
|
||||
buffer = lines.pop() || ''
|
||||
|
||||
for (const line of lines) {
|
||||
const trimmed = line.trim()
|
||||
if (!trimmed || !trimmed.startsWith('data: ')) continue
|
||||
|
||||
const data = trimmed.slice(6)
|
||||
if (data === '[DONE]') {
|
||||
if (toolCallsAccumulator.size > 0) {
|
||||
const toolCalls: ToolCall[] = Array.from(toolCallsAccumulator.values()).map((tc) => ({
|
||||
id: tc.id,
|
||||
type: 'function' as const,
|
||||
function: { name: tc.name, arguments: tc.arguments },
|
||||
}))
|
||||
yield { content: '', isFinished: true, finishReason: 'tool_calls', tool_calls: toolCalls }
|
||||
} else {
|
||||
yield { content: '', isFinished: true, finishReason: 'stop' }
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(data)
|
||||
const delta = parsed.choices?.[0]?.delta
|
||||
const finishReason = parsed.choices?.[0]?.finish_reason
|
||||
|
||||
if (delta?.content) {
|
||||
yield { content: delta.content, isFinished: false }
|
||||
}
|
||||
|
||||
// 处理流式 tool_calls
|
||||
if (delta?.tool_calls && Array.isArray(delta.tool_calls)) {
|
||||
for (const tc of delta.tool_calls) {
|
||||
const index = tc.index ?? 0
|
||||
const existing = toolCallsAccumulator.get(index)
|
||||
if (existing) {
|
||||
if (tc.function?.arguments) {
|
||||
existing.arguments += tc.function.arguments
|
||||
}
|
||||
} else {
|
||||
toolCallsAccumulator.set(index, {
|
||||
id: tc.id || '',
|
||||
name: tc.function?.name || '',
|
||||
arguments: tc.function?.arguments || '',
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (finishReason) {
|
||||
let reason: ChatStreamChunk['finishReason'] = 'error'
|
||||
if (finishReason === 'stop') reason = 'stop'
|
||||
else if (finishReason === 'length') reason = 'length'
|
||||
else if (finishReason === 'tool_calls') reason = 'tool_calls'
|
||||
|
||||
// 解析 usage 信息
|
||||
const usage = parsed.usage
|
||||
? {
|
||||
promptTokens: parsed.usage.prompt_tokens,
|
||||
completionTokens: parsed.usage.completion_tokens,
|
||||
totalTokens: parsed.usage.total_tokens,
|
||||
}
|
||||
: undefined
|
||||
|
||||
if (toolCallsAccumulator.size > 0) {
|
||||
const toolCalls: ToolCall[] = Array.from(toolCallsAccumulator.values()).map((tc) => ({
|
||||
id: tc.id,
|
||||
type: 'function' as const,
|
||||
function: { name: tc.name, arguments: tc.arguments },
|
||||
}))
|
||||
yield { content: '', isFinished: true, finishReason: reason, tool_calls: toolCalls, usage }
|
||||
} else {
|
||||
yield { content: '', isFinished: true, finishReason: reason, usage }
|
||||
}
|
||||
return
|
||||
}
|
||||
} catch {
|
||||
// 忽略解析错误,继续处理下一行
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
// 如果是中止错误,正常返回
|
||||
if (error instanceof Error && error.name === 'AbortError') {
|
||||
yield { content: '', isFinished: true, finishReason: 'stop' }
|
||||
return
|
||||
}
|
||||
throw error
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
}
|
||||
|
||||
async validateApiKey(): Promise<{ success: boolean; error?: string }> {
|
||||
try {
|
||||
// 发送一个简单请求验证 API Key
|
||||
const response = await fetch(`${this.baseUrl}/models`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
Authorization: `Bearer ${this.apiKey}`,
|
||||
},
|
||||
})
|
||||
if (response.ok) {
|
||||
return { success: true }
|
||||
}
|
||||
// 尝试获取错误详情
|
||||
const errorText = await response.text()
|
||||
let errorMessage = `HTTP ${response.status}`
|
||||
try {
|
||||
const errorJson = JSON.parse(errorText)
|
||||
errorMessage = errorJson.error?.message || errorJson.message || errorMessage
|
||||
} catch {
|
||||
if (errorText) {
|
||||
errorMessage = errorText.slice(0, 200)
|
||||
}
|
||||
}
|
||||
return { success: false, error: errorMessage }
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error)
|
||||
return { success: false, error: errorMessage }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
/**
|
||||
* AI SDK 适配通用工具
|
||||
*/
|
||||
|
||||
import { jsonSchema } from 'ai'
|
||||
import type {
|
||||
ContentPart,
|
||||
FinishReason,
|
||||
LanguageModelUsage,
|
||||
ModelMessage,
|
||||
ToolSet,
|
||||
TypedToolCall,
|
||||
JSONValue,
|
||||
} from 'ai'
|
||||
import type { ChatMessage, ChatResponse, ToolCall, ToolDefinition } from './types'
|
||||
|
||||
const UNKNOWN_TOOL_NAME = 'unknown_tool'
|
||||
|
||||
/**
|
||||
* 将 OpenAI 风格的工具定义转换为 AI SDK ToolSet
|
||||
*/
|
||||
export function buildToolSet(tools?: ToolDefinition[]): ToolSet | undefined {
|
||||
if (!tools || tools.length === 0) return undefined
|
||||
|
||||
const toolSet: ToolSet = {}
|
||||
for (const tool of tools) {
|
||||
toolSet[tool.function.name] = {
|
||||
description: tool.function.description,
|
||||
inputSchema: jsonSchema(tool.function.parameters),
|
||||
}
|
||||
}
|
||||
|
||||
return toolSet
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析 JSON 字符串,失败返回 null
|
||||
*/
|
||||
export function safeParseJson(value: string): JSONValue | null {
|
||||
try {
|
||||
return JSON.parse(value) as JSONValue
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 AI SDK finishReason 映射为现有的 LLM finishReason
|
||||
*/
|
||||
export function mapFinishReason(reason: FinishReason): ChatResponse['finishReason'] {
|
||||
if (reason === 'stop') return 'stop'
|
||||
if (reason === 'length') return 'length'
|
||||
if (reason === 'tool-calls') return 'tool_calls'
|
||||
return 'error'
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 AI SDK usage 映射为现有的 usage 结构
|
||||
*/
|
||||
export function mapUsage(usage?: LanguageModelUsage): ChatResponse['usage'] | undefined {
|
||||
if (!usage) return undefined
|
||||
|
||||
const promptTokens = usage.inputTokens
|
||||
const completionTokens = usage.outputTokens
|
||||
const totalTokens = usage.totalTokens
|
||||
|
||||
if (promptTokens == null && completionTokens == null && totalTokens == null) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
return {
|
||||
promptTokens: promptTokens ?? 0,
|
||||
completionTokens: completionTokens ?? 0,
|
||||
totalTokens: totalTokens ?? 0,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 AI SDK toolCalls 映射为现有 ToolCall 结构
|
||||
*/
|
||||
export function mapToolCalls(toolCalls: TypedToolCall<ToolSet>[]): ToolCall[] {
|
||||
return toolCalls.map((tc) => ({
|
||||
id: tc.toolCallId,
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: tc.toolName,
|
||||
arguments: JSON.stringify(tc.input ?? {}),
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
/**
|
||||
* 将现有 ChatMessage 转为 AI SDK 的 ModelMessage
|
||||
*/
|
||||
export function buildModelMessages(messages: ChatMessage[]): ModelMessage[] {
|
||||
const toolCallNameMap = new Map<string, string>()
|
||||
|
||||
// 先建立 tool_call_id 到 toolName 的映射,供 tool 消息使用
|
||||
for (const message of messages) {
|
||||
if (message.role !== 'assistant' || !message.tool_calls) continue
|
||||
for (const toolCall of message.tool_calls) {
|
||||
toolCallNameMap.set(toolCall.id, toolCall.function.name)
|
||||
}
|
||||
}
|
||||
|
||||
return messages.map((message) => {
|
||||
if (message.role === 'assistant') {
|
||||
if (message.tool_calls && message.tool_calls.length > 0) {
|
||||
const contentParts: Array<ContentPart<ToolSet>> = []
|
||||
|
||||
if (message.content) {
|
||||
contentParts.push({ type: 'text', text: message.content })
|
||||
}
|
||||
|
||||
for (const toolCall of message.tool_calls) {
|
||||
const input = safeParseJson(toolCall.function.arguments || '{}') ?? {}
|
||||
contentParts.push({
|
||||
type: 'tool-call',
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.function.name,
|
||||
input,
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
role: 'assistant',
|
||||
content: contentParts,
|
||||
}
|
||||
}
|
||||
|
||||
return { role: 'assistant', content: message.content }
|
||||
}
|
||||
|
||||
if (message.role === 'tool') {
|
||||
const toolCallId = message.tool_call_id || ''
|
||||
const toolName = toolCallNameMap.get(toolCallId) || UNKNOWN_TOOL_NAME
|
||||
const parsed = safeParseJson(message.content)
|
||||
|
||||
return {
|
||||
role: 'tool',
|
||||
content: [
|
||||
{
|
||||
type: 'tool-result',
|
||||
toolCallId,
|
||||
toolName,
|
||||
// 工具结果只允许 JSON/文本,保持与旧实现一致
|
||||
output: parsed ? { type: 'json', value: parsed } : { type: 'text', value: message.content },
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role,
|
||||
content: message.content,
|
||||
}
|
||||
})
|
||||
}
|
||||
+99
-3
@@ -11,6 +11,98 @@ import type { IpcContext } from './types'
|
||||
// 用于跟踪活跃的 Agent 请求,支持中止操作
|
||||
const activeAgentRequests = new Map<string, AbortController>()
|
||||
|
||||
/**
|
||||
* 格式化 AI 报错信息,输出更友好的提示
|
||||
*/
|
||||
function formatAIError(error: unknown): string {
|
||||
const candidates: unknown[] = []
|
||||
if (error) {
|
||||
candidates.push(error)
|
||||
}
|
||||
|
||||
const errorObj = error as {
|
||||
lastError?: unknown
|
||||
errors?: unknown[]
|
||||
}
|
||||
|
||||
if (errorObj?.lastError) {
|
||||
candidates.push(errorObj.lastError)
|
||||
}
|
||||
|
||||
if (Array.isArray(errorObj?.errors)) {
|
||||
candidates.push(...errorObj.errors)
|
||||
}
|
||||
|
||||
let rawMessage = ''
|
||||
let statusCode: number | undefined
|
||||
let retrySeconds: number | undefined
|
||||
|
||||
for (const candidate of candidates) {
|
||||
if (!candidate || typeof candidate !== 'object') {
|
||||
if (!rawMessage && typeof candidate === 'string') {
|
||||
rawMessage = candidate
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
const record = candidate as Record<string, unknown>
|
||||
if (typeof record.statusCode === 'number') {
|
||||
statusCode = record.statusCode
|
||||
}
|
||||
|
||||
if (!rawMessage && typeof record.message === 'string') {
|
||||
rawMessage = record.message
|
||||
}
|
||||
|
||||
if (!rawMessage && record.data && typeof record.data === 'object') {
|
||||
const data = record.data as { error?: { message?: string } }
|
||||
if (data.error?.message) {
|
||||
rawMessage = data.error.message
|
||||
}
|
||||
}
|
||||
|
||||
if (record.responseBody && typeof record.responseBody === 'string') {
|
||||
const responseBody = record.responseBody
|
||||
try {
|
||||
const parsed = JSON.parse(responseBody) as { error?: { message?: string } }
|
||||
if (!rawMessage && parsed.error?.message) {
|
||||
rawMessage = parsed.error.message
|
||||
}
|
||||
} catch {
|
||||
if (!rawMessage) {
|
||||
rawMessage = responseBody
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (rawMessage) {
|
||||
const retryMatch = rawMessage.match(/retry in ([0-9.]+)s/i)
|
||||
if (retryMatch) {
|
||||
retrySeconds = Math.ceil(Number(retryMatch[1]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const fallbackMessage = rawMessage || String(error)
|
||||
const lowerMessage = fallbackMessage.toLowerCase()
|
||||
|
||||
if (statusCode === 429 || lowerMessage.includes('quota') || lowerMessage.includes('resource_exhausted')) {
|
||||
return retrySeconds
|
||||
? `Gemini 配额已用尽,请等待 ${retrySeconds} 秒后重试,或更换/升级配额。`
|
||||
: 'Gemini 配额已用尽,请稍后重试或更换/升级配额。'
|
||||
}
|
||||
|
||||
if (statusCode === 503 || lowerMessage.includes('overloaded') || lowerMessage.includes('unavailable')) {
|
||||
return 'Gemini 模型繁忙,请稍后重试。'
|
||||
}
|
||||
|
||||
if (fallbackMessage.length > 300) {
|
||||
return `${fallbackMessage.slice(0, 300)}...`
|
||||
}
|
||||
|
||||
return fallbackMessage
|
||||
}
|
||||
|
||||
export function registerAIHandlers({ win }: IpcContext): void {
|
||||
console.log('[IPC] Registering AI handlers...')
|
||||
|
||||
@@ -464,11 +556,15 @@ export function registerAIHandlers({ win }: IpcContext): void {
|
||||
aiLogger.info('IPC', `Agent 请求已中止: ${requestId}`)
|
||||
return
|
||||
}
|
||||
aiLogger.error('IPC', `Agent 执行出错: ${requestId}`, { error: String(error) })
|
||||
const friendlyError = formatAIError(error)
|
||||
aiLogger.error('IPC', `Agent 执行出错: ${requestId}`, {
|
||||
error: String(error),
|
||||
friendlyError,
|
||||
})
|
||||
// 发送错误 chunk
|
||||
win.webContents.send('agent:streamChunk', {
|
||||
requestId,
|
||||
chunk: { type: 'error', error: String(error), isFinished: true },
|
||||
chunk: { type: 'error', error: friendlyError, isFinished: true },
|
||||
})
|
||||
// 发送完成事件(带错误信息),确保前端 promise 能 resolve
|
||||
win.webContents.send('agent:complete', {
|
||||
@@ -478,7 +574,7 @@ export function registerAIHandlers({ win }: IpcContext): void {
|
||||
toolsUsed: [],
|
||||
toolRounds: 0,
|
||||
totalUsage: { promptTokens: 0, completionTokens: 0, totalTokens: 0 },
|
||||
error: String(error),
|
||||
error: friendlyError,
|
||||
},
|
||||
})
|
||||
} finally {
|
||||
|
||||
Reference in New Issue
Block a user