feat: 接入AI sdk

This commit is contained in:
digua
2026-01-21 00:18:59 +08:00
parent fc76602604
commit 8f3f3e62f2
15 changed files with 960 additions and 1438 deletions
+26
View File
@@ -96,6 +96,8 @@ type StreamMode = 'text' | 'think' | 'tool_call'
function createStreamParser(handlers: {
onText: (text: string) => void
onThink: (text: string, tag: string) => void
onThinkStart?: (tag: string) => void
onThinkEnd?: (tag: string) => void
}): { push: (text: string) => void; flush: () => void } {
let buffer = ''
let mode: StreamMode = 'text'
@@ -165,6 +167,7 @@ function createStreamParser(handlers: {
// 进入思考模式
currentThinkTag = hit.tag.slice(1, -1)
mode = 'think'
handlers.onThinkStart?.(currentThinkTag)
buffer = buffer.slice(startTags[startTagsLower.indexOf(hit.tag)].length)
continue
}
@@ -194,6 +197,7 @@ function createStreamParser(handlers: {
buffer = buffer.slice(endIndex + endTag.length)
mode = 'text'
handlers.onThinkEnd?.(currentThinkTag)
currentThinkTag = ''
continue
}
@@ -265,6 +269,8 @@ export interface AgentStreamChunk {
content?: string
/** 思考标签名称(type=think 时) */
thinkTag?: string
/** 思考耗时(毫秒,type=think 时可选) */
thinkDurationMs?: number
/** 工具名称(type=tool_start/tool_result 时) */
toolName?: string
/** 工具调用参数(type=tool_start 时) */
@@ -680,14 +686,24 @@ export class Agent {
let accumulatedContent = ''
let roundContent = ''
let toolCalls: ToolCall[] | undefined
let thinkStartAt: number | null = null // 记录思考开始时间
const parser = createStreamParser({
onText: (text) => {
roundContent += text
onChunk({ type: 'content', content: text })
},
onThinkStart: () => {
thinkStartAt = Date.now()
},
onThink: (text, tag) => {
onChunk({ type: 'think', content: text, thinkTag: tag })
},
onThinkEnd: (tag) => {
if (thinkStartAt === null) return
const durationMs = Date.now() - thinkStartAt
thinkStartAt = null
onChunk({ type: 'think', content: '', thinkTag: tag, thinkDurationMs: durationMs })
},
})
// 流式调用 LLM(传入 abortSignal
@@ -851,14 +867,24 @@ export class Agent {
// 最后一轮不带 tools(传入 abortSignal
let finalRawContent = ''
let finalThinkStartAt: number | null = null // 记录最终思考开始时间
const finalParser = createStreamParser({
onText: (text) => {
finalContent += text
onChunk({ type: 'content', content: text })
},
onThinkStart: () => {
finalThinkStartAt = Date.now()
},
onThink: (text, tag) => {
onChunk({ type: 'think', content: text, thinkTag: tag })
},
onThinkEnd: (tag) => {
if (finalThinkStartAt === null) return
const durationMs = Date.now() - finalThinkStartAt
finalThinkStartAt = null
onChunk({ type: 'think', content: '', thinkTag: tag, thinkDurationMs: durationMs })
},
})
for await (const chunk of chatStream(this.messages, {
...this.config.llmOptions,
+1 -1
View File
@@ -117,7 +117,7 @@ export interface AIConversation {
*/
export type ContentBlock =
| { type: 'text'; text: string }
| { type: 'think'; tag: string; text: string } // 思考内容块
| { type: 'think'; tag: string; text: string; durationMs?: number } // 思考内容块
| {
type: 'tool'
tool: {
-341
View File
@@ -1,341 +0,0 @@
/**
* DeepSeek LLM Provider
* 使用 OpenAI 兼容的 API 格式,支持 Function Calling
*/
import type {
ILLMService,
LLMProvider,
ChatMessage,
ChatOptions,
ChatResponse,
ChatStreamChunk,
ProviderInfo,
ToolCall,
} from './types'
const DEFAULT_BASE_URL = 'https://api.deepseek.com'
const MODELS = [
{ id: 'deepseek-chat', name: 'DeepSeek Chat', description: '通用对话模型' },
{ id: 'deepseek-coder', name: 'DeepSeek Coder', description: '代码生成模型' },
]
export const DEEPSEEK_INFO: ProviderInfo = {
id: 'deepseek',
name: 'DeepSeek',
description: 'DeepSeek AI 大语言模型',
defaultBaseUrl: DEFAULT_BASE_URL,
models: MODELS,
}
export class DeepSeekService implements ILLMService {
private apiKey: string
private baseUrl: string
private model: string
constructor(apiKey: string, model?: string, baseUrl?: string) {
this.apiKey = apiKey
this.baseUrl = baseUrl || DEFAULT_BASE_URL
this.model = model || 'deepseek-chat'
}
getProvider(): LLMProvider {
return 'deepseek'
}
getModels(): string[] {
return MODELS.map((m) => m.id)
}
getDefaultModel(): string {
return 'deepseek-chat'
}
async chat(messages: ChatMessage[], options?: ChatOptions): Promise<ChatResponse> {
// 构建请求体
const requestBody: Record<string, unknown> = {
model: this.model,
messages: messages.map((m) => {
const msg: Record<string, unknown> = { role: m.role, content: m.content }
// 处理 tool 消息
if (m.role === 'tool' && m.tool_call_id) {
msg.tool_call_id = m.tool_call_id
}
// 处理 assistant 消息中的 tool_calls
if (m.role === 'assistant' && m.tool_calls) {
msg.tool_calls = m.tool_calls
}
return msg
}),
temperature: options?.temperature ?? 0.7,
max_tokens: options?.maxTokens ?? 2048,
stream: false,
}
// 如果有 tools,添加到请求体
if (options?.tools && options.tools.length > 0) {
requestBody.tools = options.tools
}
const response = await fetch(`${this.baseUrl}/v1/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${this.apiKey}`,
},
body: JSON.stringify(requestBody),
signal: options?.abortSignal,
})
if (!response.ok) {
const error = await response.text()
throw new Error(`DeepSeek API error: ${response.status} - ${error}`)
}
const data = await response.json()
const choice = data.choices?.[0]
const message = choice?.message
// 解析 finish_reason
let finishReason: ChatResponse['finishReason'] = 'error'
if (choice?.finish_reason === 'stop') {
finishReason = 'stop'
} else if (choice?.finish_reason === 'length') {
finishReason = 'length'
} else if (choice?.finish_reason === 'tool_calls') {
finishReason = 'tool_calls'
}
// 解析 tool_calls
let toolCalls: ToolCall[] | undefined
if (message?.tool_calls && Array.isArray(message.tool_calls)) {
toolCalls = message.tool_calls.map((tc: Record<string, unknown>) => ({
id: tc.id as string,
type: 'function' as const,
function: {
name: (tc.function as Record<string, unknown>)?.name as string,
arguments: (tc.function as Record<string, unknown>)?.arguments as string,
},
}))
}
return {
content: message?.content || '',
finishReason,
tool_calls: toolCalls,
usage: data.usage
? {
promptTokens: data.usage.prompt_tokens,
completionTokens: data.usage.completion_tokens,
totalTokens: data.usage.total_tokens,
}
: undefined,
}
}
async *chatStream(messages: ChatMessage[], options?: ChatOptions): AsyncGenerator<ChatStreamChunk> {
// 构建请求体
const requestBody: Record<string, unknown> = {
model: this.model,
messages: messages.map((m) => {
const msg: Record<string, unknown> = { role: m.role, content: m.content }
if (m.role === 'tool' && m.tool_call_id) {
msg.tool_call_id = m.tool_call_id
}
if (m.role === 'assistant' && m.tool_calls) {
msg.tool_calls = m.tool_calls
}
return msg
}),
temperature: options?.temperature ?? 0.7,
max_tokens: options?.maxTokens ?? 2048,
stream: true,
// 启用流式响应中的 usage 统计
stream_options: { include_usage: true },
}
if (options?.tools && options.tools.length > 0) {
requestBody.tools = options.tools
}
const response = await fetch(`${this.baseUrl}/v1/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${this.apiKey}`,
},
body: JSON.stringify(requestBody),
signal: options?.abortSignal,
})
if (!response.ok) {
const error = await response.text()
throw new Error(`DeepSeek API error: ${response.status} - ${error}`)
}
const reader = response.body?.getReader()
if (!reader) {
throw new Error('Failed to get response reader')
}
const decoder = new TextDecoder()
let buffer = ''
// 用于收集流式 tool_calls
const toolCallsAccumulator: Map<number, { id: string; name: string; arguments: string }> = new Map()
try {
while (true) {
// 检查是否已中止
if (options?.abortSignal?.aborted) {
yield { content: '', isFinished: true, finishReason: 'stop' }
return
}
const { done, value } = await reader.read()
if (done) break
buffer += decoder.decode(value, { stream: true })
const lines = buffer.split('\n')
buffer = lines.pop() || ''
for (const line of lines) {
const trimmed = line.trim()
if (!trimmed || !trimmed.startsWith('data: ')) continue
const data = trimmed.slice(6)
if (data === '[DONE]') {
// 如果有累积的 tool_calls,返回它们
if (toolCallsAccumulator.size > 0) {
const toolCalls: ToolCall[] = Array.from(toolCallsAccumulator.values()).map((tc) => ({
id: tc.id,
type: 'function' as const,
function: {
name: tc.name,
arguments: tc.arguments,
},
}))
yield { content: '', isFinished: true, finishReason: 'tool_calls', tool_calls: toolCalls }
} else {
yield { content: '', isFinished: true, finishReason: 'stop' }
}
return
}
try {
const parsed = JSON.parse(data)
const delta = parsed.choices?.[0]?.delta
const finishReason = parsed.choices?.[0]?.finish_reason
// 处理文本内容
if (delta?.content) {
yield {
content: delta.content,
isFinished: false,
}
}
// 处理流式 tool_calls(增量累积)
if (delta?.tool_calls && Array.isArray(delta.tool_calls)) {
for (const tc of delta.tool_calls) {
const index = tc.index ?? 0
const existing = toolCallsAccumulator.get(index)
if (existing) {
// 累积 arguments
if (tc.function?.arguments) {
existing.arguments += tc.function.arguments
}
} else {
// 新的 tool_call
toolCallsAccumulator.set(index, {
id: tc.id || '',
name: tc.function?.name || '',
arguments: tc.function?.arguments || '',
})
}
}
}
if (finishReason) {
let reason: ChatStreamChunk['finishReason'] = 'error'
if (finishReason === 'stop') {
reason = 'stop'
} else if (finishReason === 'length') {
reason = 'length'
} else if (finishReason === 'tool_calls') {
reason = 'tool_calls'
}
// 解析 usage 信息
const usage = parsed.usage
? {
promptTokens: parsed.usage.prompt_tokens,
completionTokens: parsed.usage.completion_tokens,
totalTokens: parsed.usage.total_tokens,
}
: undefined
// 如果有 tool_calls,返回它们
if (toolCallsAccumulator.size > 0) {
const toolCalls: ToolCall[] = Array.from(toolCallsAccumulator.values()).map((tc) => ({
id: tc.id,
type: 'function' as const,
function: {
name: tc.name,
arguments: tc.arguments,
},
}))
yield { content: '', isFinished: true, finishReason: reason, tool_calls: toolCalls, usage }
} else {
yield { content: '', isFinished: true, finishReason: reason, usage }
}
return
}
} catch {
// 忽略解析错误,继续处理下一行
}
}
}
} catch (error) {
// 如果是中止错误,正常返回
if (error instanceof Error && error.name === 'AbortError') {
yield { content: '', isFinished: true, finishReason: 'stop' }
return
}
throw error
} finally {
reader.releaseLock()
}
}
async validateApiKey(): Promise<{ success: boolean; error?: string }> {
try {
// 发送一个简单请求验证 API Key
const response = await fetch(`${this.baseUrl}/v1/models`, {
method: 'GET',
headers: {
Authorization: `Bearer ${this.apiKey}`,
},
})
if (response.ok) {
return { success: true }
}
// 尝试获取错误详情
const errorText = await response.text()
let errorMessage = `HTTP ${response.status}`
try {
const errorJson = JSON.parse(errorText)
errorMessage = errorJson.error?.message || errorJson.message || errorMessage
} catch {
if (errorText) {
errorMessage = errorText.slice(0, 200)
}
}
return { success: false, error: errorMessage }
} catch (error) {
const errorMessage = error instanceof Error ? error.message : String(error)
return { success: false, error: errorMessage }
}
}
}
+159 -423
View File
@@ -1,8 +1,11 @@
/**
* Google Gemini LLM Provider
* 使用 Gemini REST API 格式,支持 Function Calling
* 使用 AI SDK 的 Google Generative AI Provider,支持 Function Calling
*/
import { createGoogleGenerativeAI } from '@ai-sdk/google'
import { generateText, streamText } from 'ai'
import type { ContentPart, ModelMessage, ToolSet, TypedToolCall } from 'ai'
import type {
ILLMService,
LLMProvider,
@@ -10,13 +13,15 @@ import type {
ChatOptions,
ChatResponse,
ChatStreamChunk,
ProviderInfo,
ToolCall,
ToolDefinition,
ProviderInfo,
} from './types'
import { aiLogger } from '../logger'
import { buildModelMessages, buildToolSet, mapFinishReason, mapUsage } from './sdkUtils'
const DEFAULT_BASE_URL = 'https://generativelanguage.googleapis.com'
const DEFAULT_API_VERSION = '/v1beta'
const GEMINI_MAX_RETRIES = 1
const MODELS = [
{ id: 'gemini-3-flash-preview', name: 'Gemini 3 Flash Preview', description: '高速预览版' },
@@ -31,85 +36,133 @@ export const GEMINI_INFO: ProviderInfo = {
models: MODELS,
}
// ==================== Gemini API 类型定义 ====================
/**
* 统一处理 Gemini 的 baseUrl,确保包含 /v1beta
*/
function normalizeBaseUrl(baseUrl?: string): string {
let normalized = (baseUrl || DEFAULT_BASE_URL).replace(/\/+$/, '')
/** Gemini 消息 part(支持多种类型) */
interface GeminiPart {
text?: string
functionCall?: {
name: string
args: Record<string, unknown>
if (normalized.endsWith(DEFAULT_API_VERSION)) {
return normalized
}
functionResponse?: {
name: string
response: unknown
}
/** Gemini 3+ 模型的思考签名 */
thoughtSignature?: string
return `${normalized}${DEFAULT_API_VERSION}`
}
/** Gemini 消息内容 */
interface GeminiContent {
role: 'user' | 'model'
parts: GeminiPart[]
const GEMINI_PROVIDER_KEYS = ['google', 'vertex', 'gemini']
/**
* 从 AI SDK 的 providerMetadata 中提取 thoughtSignature
*/
function getThoughtSignature(metadata?: unknown): string | undefined {
if (!metadata || typeof metadata !== 'object') {
return undefined
}
const record = metadata as Record<string, unknown>
for (const key of GEMINI_PROVIDER_KEYS) {
const providerMeta = record[key]
if (providerMeta && typeof providerMeta === 'object') {
const signature = (providerMeta as Record<string, unknown>).thoughtSignature
if (typeof signature === 'string' && signature) {
return signature
}
}
}
const directSignature = record.thoughtSignature
if (typeof directSignature === 'string' && directSignature) {
return directSignature
}
return undefined
}
/** Gemini 函数声明(对应 OpenAI 的 ToolDefinition */
interface GeminiFunctionDeclaration {
name: string
description: string
parameters: {
type: string
properties: Record<string, unknown>
required?: string[]
}
/**
* Gemini 工具调用需要 thoughtSignature
*/
function mapGeminiToolCalls(toolCalls: TypedToolCall<ToolSet>[]): ToolCall[] {
return toolCalls.map((tc) => ({
id: tc.toolCallId,
type: 'function' as const,
function: {
name: tc.toolName,
arguments: JSON.stringify(tc.input ?? {}),
},
thoughtSignature: getThoughtSignature(tc.providerMetadata),
}))
}
/** Gemini 请求体 */
interface GeminiRequest {
contents: GeminiContent[]
generationConfig?: {
temperature?: number
maxOutputTokens?: number
}
systemInstruction?: {
parts: Array<{ text: string }>
}
tools?: Array<{
functionDeclarations: GeminiFunctionDeclaration[]
}>
}
/**
* 构建 Gemini 的模型消息,补充 thoughtSignature
*/
function buildGeminiModelMessages(messages: ChatMessage[]): ModelMessage[] {
const modelMessages = buildModelMessages(messages)
const signatureMap = new Map<string, string>()
/** Gemini 响应候选项 */
interface GeminiCandidate {
content?: {
parts?: GeminiPart[]
role?: string
for (const message of messages) {
if (message.role !== 'assistant' || !message.tool_calls) continue
for (const toolCall of message.tool_calls) {
if (toolCall.thoughtSignature) {
signatureMap.set(toolCall.id, toolCall.thoughtSignature)
}
}
}
finishReason?: string
}
/** Gemini API 响应 */
interface GeminiResponse {
candidates?: GeminiCandidate[]
usageMetadata?: {
promptTokenCount?: number
candidatesTokenCount?: number
totalTokenCount?: number
if (signatureMap.size === 0) {
return modelMessages
}
}
// ==================== GeminiService 类 ====================
return modelMessages.map((message) => {
if (message.role !== 'assistant' || !Array.isArray(message.content)) {
return message
}
const contentParts = message.content.map((part) => {
if (part.type !== 'tool-call') {
return part
}
const signature = signatureMap.get(part.toolCallId)
if (!signature) {
return part
}
const nextPart: ContentPart<ToolSet> & {
providerOptions?: Record<string, { thoughtSignature?: string }>
} = {
...part,
providerOptions: {
google: { thoughtSignature: signature },
},
}
return nextPart
})
return {
...message,
content: contentParts,
}
})
}
export class GeminiService implements ILLMService {
private apiKey: string
private baseUrl: string
private model: string
private provider = createGoogleGenerativeAI()
constructor(apiKey: string, model?: string, baseUrl?: string) {
this.apiKey = apiKey
this.baseUrl = baseUrl || DEFAULT_BASE_URL
this.baseUrl = normalizeBaseUrl(baseUrl)
this.model = model || 'gemini-3-flash-preview'
this.provider = createGoogleGenerativeAI({
apiKey: this.apiKey,
baseURL: this.baseUrl,
name: 'gemini',
})
}
getProvider(): LLMProvider {
@@ -124,404 +177,87 @@ export class GeminiService implements ILLMService {
return 'gemini-3-flash-preview'
}
/**
* 将 OpenAI 格式的 tools 转换为 Gemini 格式
*/
private convertTools(tools?: ToolDefinition[]): Array<{ functionDeclarations: GeminiFunctionDeclaration[] }> | undefined {
if (!tools || tools.length === 0) return undefined
const functionDeclarations: GeminiFunctionDeclaration[] = tools.map((tool) => ({
name: tool.function.name,
description: tool.function.description,
parameters: tool.function.parameters,
}))
return [{ functionDeclarations }]
}
/**
* 将 OpenAI 格式消息转换为 Gemini 格式
*/
private convertMessages(messages: ChatMessage[]): {
contents: GeminiContent[]
systemInstruction?: { parts: Array<{ text: string }> }
} {
const contents: GeminiContent[] = []
let systemInstruction: { parts: Array<{ text: string }> } | undefined
for (const msg of messages) {
if (msg.role === 'system') {
// Gemini 使用 systemInstruction 处理系统提示
systemInstruction = {
parts: [{ text: msg.content }],
}
} else if (msg.role === 'user') {
contents.push({
role: 'user',
parts: [{ text: msg.content }],
})
} else if (msg.role === 'assistant') {
// 处理 assistant 消息(可能包含 tool_calls
if (msg.tool_calls && msg.tool_calls.length > 0) {
// 有工具调用的 assistant 消息
const parts: GeminiPart[] = []
if (msg.content) {
parts.push({ text: msg.content })
}
for (const tc of msg.tool_calls) {
const part: GeminiPart = {
functionCall: {
name: tc.function.name,
args: JSON.parse(tc.function.arguments),
},
}
// Gemini 3+ 需要包含 thoughtSignature
if (tc.thoughtSignature) {
part.thoughtSignature = tc.thoughtSignature
}
parts.push(part)
}
contents.push({ role: 'model', parts })
} else {
// 普通文本消息
contents.push({
role: 'model',
parts: [{ text: msg.content }],
})
}
} else if (msg.role === 'tool') {
// 工具结果消息 - 在 Gemini 中作为 user 角色的 functionResponse
// 注意:需要从消息内容解析工具名称和结果
// tool_call_id 格式通常是 "call_xxx",我们需要从上下文获取工具名
// 这里简化处理:假设内容是 JSON 格式的结果
try {
const result = JSON.parse(msg.content)
// 尝试从上一条 assistant 消息中找到对应的 tool_call
// 由于 Gemini 需要 name,我们从 tool_call_id 推断或使用默认值
contents.push({
role: 'user',
parts: [
{
functionResponse: {
name: msg.tool_call_id?.replace('call_', '') || 'unknown',
response: result,
},
},
],
})
} catch {
// 如果不是 JSON,直接作为文本结果
contents.push({
role: 'user',
parts: [
{
functionResponse: {
name: msg.tool_call_id?.replace('call_', '') || 'unknown',
response: { result: msg.content },
},
},
],
})
}
}
}
return { contents, systemInstruction }
}
/**
* 构建 API URL
*/
private buildUrl(stream: boolean): string {
const action = stream ? 'streamGenerateContent' : 'generateContent'
const base = this.baseUrl.replace(/\/$/, '')
return `${base}/v1beta/models/${this.model}:${action}?key=${this.apiKey}`
}
/**
* 从 Gemini parts 中提取工具调用
*/
private extractToolCalls(parts?: GeminiPart[]): ToolCall[] | undefined {
if (!parts) return undefined
const toolCalls: ToolCall[] = []
for (const part of parts) {
if (part.functionCall) {
toolCalls.push({
id: `call_${part.functionCall.name}_${Date.now()}`,
type: 'function',
function: {
name: part.functionCall.name,
arguments: JSON.stringify(part.functionCall.args),
},
// 保存 Gemini 3+ 的思考签名
thoughtSignature: part.thoughtSignature,
})
}
}
return toolCalls.length > 0 ? toolCalls : undefined
}
/**
* 从 Gemini parts 中提取文本内容
*/
private extractText(parts?: GeminiPart[]): string {
if (!parts) return ''
return parts
.filter((p) => p.text)
.map((p) => p.text)
.join('')
}
async chat(messages: ChatMessage[], options?: ChatOptions): Promise<ChatResponse> {
const { contents, systemInstruction } = this.convertMessages(messages)
const model = this.provider.chat(this.model)
const toolSet = buildToolSet(options?.tools)
const requestBody: GeminiRequest = {
contents,
generationConfig: {
temperature: options?.temperature ?? 0.7,
maxOutputTokens: options?.maxTokens ?? 2048,
},
}
if (systemInstruction) {
requestBody.systemInstruction = systemInstruction
}
// 添加工具定义
const geminiTools = this.convertTools(options?.tools)
if (geminiTools) {
requestBody.tools = geminiTools
}
const response = await fetch(this.buildUrl(false), {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(requestBody),
signal: options?.abortSignal,
const result = await generateText({
model,
messages: buildGeminiModelMessages(messages),
tools: toolSet,
temperature: options?.temperature ?? 0.7,
maxTokens: options?.maxTokens ?? 2048,
// 降低 Gemini 自动重试次数,避免触发免费配额限制
maxRetries: GEMINI_MAX_RETRIES,
abortSignal: options?.abortSignal,
})
if (!response.ok) {
const error = await response.text()
throw new Error(`Gemini API error: ${response.status} - ${error}`)
}
const data: GeminiResponse = await response.json()
const candidate = data.candidates?.[0]
const parts = candidate?.content?.parts
const content = this.extractText(parts)
const toolCalls = this.extractToolCalls(parts)
// 解析 finish_reason
let finishReason: ChatResponse['finishReason'] = 'error'
const reason = candidate?.finishReason
if (reason === 'STOP') {
finishReason = toolCalls ? 'tool_calls' : 'stop'
} else if (reason === 'MAX_TOKENS') {
finishReason = 'length'
}
const toolCalls = result.toolCalls.length > 0 ? mapGeminiToolCalls(result.toolCalls) : undefined
return {
content,
finishReason,
content: result.text,
finishReason: mapFinishReason(result.finishReason),
tool_calls: toolCalls,
usage: data.usageMetadata
? {
promptTokens: data.usageMetadata.promptTokenCount || 0,
completionTokens: data.usageMetadata.candidatesTokenCount || 0,
totalTokens: data.usageMetadata.totalTokenCount || 0,
}
: undefined,
usage: mapUsage(result.usage),
}
}
async *chatStream(messages: ChatMessage[], options?: ChatOptions): AsyncGenerator<ChatStreamChunk> {
const { contents, systemInstruction } = this.convertMessages(messages)
const model = this.provider.chat(this.model)
const toolSet = buildToolSet(options?.tools)
const requestBody: GeminiRequest = {
contents,
generationConfig: {
temperature: options?.temperature ?? 0.7,
maxOutputTokens: options?.maxTokens ?? 2048,
},
}
if (systemInstruction) {
requestBody.systemInstruction = systemInstruction
}
// 添加工具定义
const geminiTools = this.convertTools(options?.tools)
if (geminiTools) {
requestBody.tools = geminiTools
}
// Gemini 流式需要添加 alt=sse 参数
const url = this.buildUrl(true) + '&alt=sse'
aiLogger.info('Gemini', '开始流式请求', {
url: url.replace(/key=[^&]+/, 'key=***'),
model: this.model,
messagesCount: contents.length,
hasSystemInstruction: !!systemInstruction,
hasTools: !!geminiTools,
const result = streamText({
model,
messages: buildGeminiModelMessages(messages),
tools: toolSet,
temperature: options?.temperature ?? 0.7,
maxTokens: options?.maxTokens ?? 2048,
// 降低 Gemini 自动重试次数,避免触发免费配额限制
maxRetries: GEMINI_MAX_RETRIES,
abortSignal: options?.abortSignal,
})
const response = await fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(requestBody),
signal: options?.abortSignal,
})
if (!response.ok) {
const error = await response.text()
aiLogger.error('Gemini', 'API 请求失败', { status: response.status, error: error.slice(0, 500) })
throw new Error(`Gemini API error: ${response.status} - ${error}`)
}
aiLogger.info('Gemini', 'API 响应成功,开始读取流')
const reader = response.body?.getReader()
if (!reader) {
throw new Error('Failed to get response reader')
}
const decoder = new TextDecoder()
let buffer = ''
// 用于累积工具调用(可能跨多个 chunk)
const toolCallsAccumulator: ToolCall[] = []
// 用于追踪最新的 usage 信息
let latestUsage: { promptTokens: number; completionTokens: number; totalTokens: number } | undefined
try {
while (true) {
// 检查是否已中止
for await (const chunk of result.textStream) {
if (options?.abortSignal?.aborted) {
yield { content: '', isFinished: true, finishReason: 'stop' }
return
}
const { done, value } = await reader.read()
if (done) break
buffer += decoder.decode(value, { stream: true })
const lines = buffer.split('\n')
buffer = lines.pop() || ''
for (const line of lines) {
const trimmed = line.trim()
if (!trimmed || !trimmed.startsWith('data: ')) continue
const data = trimmed.slice(6)
if (data === '[DONE]') {
if (toolCallsAccumulator.length > 0) {
yield { content: '', isFinished: true, finishReason: 'tool_calls', tool_calls: toolCallsAccumulator }
} else {
yield { content: '', isFinished: true, finishReason: 'stop' }
}
return
}
try {
const parsed: GeminiResponse = JSON.parse(data)
const candidate = parsed.candidates?.[0]
const parts = candidate?.content?.parts
// 处理文本内容
const text = this.extractText(parts)
if (text) {
yield { content: text, isFinished: false }
}
// 处理工具调用
const toolCalls = this.extractToolCalls(parts)
if (toolCalls) {
aiLogger.info('Gemini', '检测到工具调用', { toolCalls: toolCalls.map((tc) => tc.function.name) })
toolCallsAccumulator.push(...toolCalls)
}
// 更新 usage 信息
if (parsed.usageMetadata) {
latestUsage = {
promptTokens: parsed.usageMetadata.promptTokenCount || 0,
completionTokens: parsed.usageMetadata.candidatesTokenCount || 0,
totalTokens: parsed.usageMetadata.totalTokenCount || 0,
}
}
// 检查是否完成
const finishReason = candidate?.finishReason
if (finishReason) {
aiLogger.info('Gemini', '流式响应完成', { finishReason, toolCallsCount: toolCallsAccumulator.length, usage: latestUsage })
if (toolCallsAccumulator.length > 0) {
yield { content: '', isFinished: true, finishReason: 'tool_calls', tool_calls: toolCallsAccumulator, usage: latestUsage }
} else {
let reason: ChatStreamChunk['finishReason'] = 'stop'
if (finishReason === 'MAX_TOKENS') {
reason = 'length'
}
yield { content: '', isFinished: true, finishReason: reason, usage: latestUsage }
}
return
}
} catch (e) {
// 记录解析错误
aiLogger.warn('Gemini', 'SSE 数据解析失败', { data: data.slice(0, 200), error: String(e) })
}
if (chunk) {
yield { content: chunk, isFinished: false }
}
}
// 如果循环正常结束,发送完成信号
if (toolCallsAccumulator.length > 0) {
yield { content: '', isFinished: true, finishReason: 'tool_calls', tool_calls: toolCallsAccumulator, usage: latestUsage }
} else {
yield { content: '', isFinished: true, finishReason: 'stop', usage: latestUsage }
const finishReason = mapFinishReason(await result.finishReason)
const toolCalls = await result.toolCalls
const usage = mapUsage(await result.totalUsage)
yield {
content: '',
isFinished: true,
finishReason,
tool_calls: toolCalls.length > 0 ? mapGeminiToolCalls(toolCalls) : undefined,
usage,
}
} catch (error) {
// 如果是中止错误,正常返回
if (error instanceof Error && error.name === 'AbortError') {
yield { content: '', isFinished: true, finishReason: 'stop' }
return
}
aiLogger.error('Gemini', '流式请求失败', { error: String(error) })
throw error
} finally {
reader.releaseLock()
}
}
async validateApiKey(): Promise<{ success: boolean; error?: string }> {
try {
// 使用 models.list API 验证 API Key
const url = `${this.baseUrl.replace(/\/$/, '')}/v1beta/models?key=${this.apiKey}`
const response = await fetch(url, {
method: 'GET',
const model = this.provider.chat(this.model)
await generateText({
model,
prompt: 'Hi',
maxTokens: 1,
})
if (response.ok) {
return { success: true }
}
// 尝试获取错误详情
const errorText = await response.text()
let errorMessage = `HTTP ${response.status}`
try {
const errorJson = JSON.parse(errorText)
errorMessage = errorJson.error?.message || errorJson.message || errorMessage
} catch {
if (errorText) {
errorMessage = errorText.slice(0, 200)
}
}
return { success: false, error: errorMessage }
return { success: true }
} catch (error) {
const errorMessage = error instanceof Error ? error.message : String(error)
return { success: false, error: errorMessage }
+80 -10
View File
@@ -6,7 +6,7 @@
import * as fs from 'fs'
import * as path from 'path'
import { randomUUID } from 'crypto'
import { getAiDataDir, ensureDir } from '../../paths'
import { getAiDataDir } from '../../paths'
import type {
LLMConfig,
LLMProvider,
@@ -19,8 +19,6 @@ import type {
AIConfigStore,
} from './types'
import { MAX_CONFIG_COUNT } from './types'
import { DeepSeekService, DEEPSEEK_INFO } from './deepseek'
import { QwenService, QWEN_INFO } from './qwen'
import { GeminiService, GEMINI_INFO } from './gemini'
import { OpenAICompatibleService, OPENAI_COMPATIBLE_INFO } from './openai-compatible'
import { aiLogger } from '../logger'
@@ -30,6 +28,31 @@ export * from './types'
// ==================== 新增提供商信息 ====================
/** DeepSeek 提供商信息 */
const DEEPSEEK_INFO: ProviderInfo = {
id: 'deepseek',
name: 'DeepSeek',
description: 'DeepSeek AI 大语言模型',
defaultBaseUrl: 'https://api.deepseek.com/v1',
models: [
{ id: 'deepseek-chat', name: 'DeepSeek Chat', description: '通用对话模型' },
{ id: 'deepseek-coder', name: 'DeepSeek Coder', description: '代码生成模型' },
],
}
/** 通义千问 (Qwen) 提供商信息 */
const QWEN_INFO: ProviderInfo = {
id: 'qwen',
name: '通义千问',
description: '阿里云通义千问大语言模型',
defaultBaseUrl: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
models: [
{ id: 'qwen-turbo', name: 'Qwen Turbo', description: '通义千问超大规模语言模型,速度快' },
{ id: 'qwen-plus', name: 'Qwen Plus', description: '通义千问超大规模语言模型,效果好' },
{ id: 'qwen-max', name: 'Qwen Max', description: '通义千问千亿级别超大规模语言模型' },
],
}
/** MiniMax 提供商信息 */
const MINIMAX_INFO: ProviderInfo = {
id: 'minimax',
@@ -324,6 +347,36 @@ interface ExtendedLLMConfig extends LLMConfig {
disableThinking?: boolean
}
/**
* 不再自动补齐 Base URL,对 DeepSeek/Qwen 的格式做显式校验
*/
function validateProviderBaseUrl(provider: LLMProvider, baseUrl?: string): void {
if (!baseUrl) return
const normalized = baseUrl.replace(/\/+$/, '')
if (provider === 'deepseek') {
if (normalized.endsWith('/chat/completions')) {
throw new Error('DeepSeek Base URL 请填写到 /v1 层级,不要包含 /chat/completions')
}
if (!normalized.endsWith('/v1')) {
throw new Error('DeepSeek Base URL 需要以 /v1 结尾')
}
}
if (provider === 'qwen') {
if (normalized.endsWith('/chat/completions')) {
throw new Error('通义千问 Base URL 请填写到 /v1 层级,不要包含 /chat/completions')
}
if (!normalized.endsWith('/v1')) {
throw new Error('通义千问 Base URL 需要以 /v1 结尾')
}
if (normalized.includes('dashscope.aliyuncs.com') && !normalized.includes('/compatible-mode/')) {
throw new Error('通义千问 Base URL 需要包含 /compatible-mode/v1')
}
}
}
/**
* 创建 LLM 服务实例
*/
@@ -331,22 +384,39 @@ export function createLLMService(config: ExtendedLLMConfig): ILLMService {
// 获取提供商的默认 baseUrl
const providerInfo = getProviderInfo(config.provider)
const baseUrl = config.baseUrl || providerInfo?.defaultBaseUrl
// 未显式指定时使用提供商的首个模型作为默认模型
const resolvedModel = config.model || providerInfo?.models?.[0]?.id
// 不自动补齐,发现不合法直接抛错给用户
validateProviderBaseUrl(config.provider, baseUrl)
switch (config.provider) {
case 'deepseek':
return new DeepSeekService(config.apiKey, config.model, config.baseUrl)
case 'qwen':
return new QwenService(config.apiKey, config.model, config.baseUrl)
case 'gemini':
return new GeminiService(config.apiKey, config.model, config.baseUrl)
return new GeminiService(config.apiKey, resolvedModel, config.baseUrl)
// 新增的官方API都使用 OpenAI 兼容格式
case 'deepseek':
case 'qwen':
case 'minimax':
case 'glm':
case 'kimi':
case 'doubao':
return new OpenAICompatibleService(config.apiKey, config.model, baseUrl)
// DeepSeek/Qwen 走 OpenAI 兼容实现时,禁用本地思考注入
return new OpenAICompatibleService(
config.apiKey,
resolvedModel,
baseUrl,
config.provider === 'deepseek' || config.provider === 'qwen' ? false : undefined,
config.provider,
providerInfo?.models
)
case 'openai-compatible':
return new OpenAICompatibleService(config.apiKey, config.model, config.baseUrl, config.disableThinking)
return new OpenAICompatibleService(
config.apiKey,
resolvedModel,
config.baseUrl,
config.disableThinking,
config.provider,
providerInfo?.models
)
default:
throw new Error(`Unknown LLM provider: ${config.provider}`)
}
+223 -321
View File
@@ -3,6 +3,9 @@
* 支持任何兼容 OpenAI API 格式的服务(如 Ollama、LocalAI、vLLM 等)
*/
import { createOpenAI } from '@ai-sdk/openai'
import { generateText, streamText } from 'ai'
import type { ModelMessage, ToolSet, TypedToolCall } from 'ai'
import type {
ILLMService,
LLMProvider,
@@ -10,12 +13,14 @@ import type {
ChatOptions,
ChatResponse,
ChatStreamChunk,
ProviderInfo,
ToolCall,
ProviderInfo,
} from './types'
import { aiLogger } from '../logger'
import { buildModelMessages, buildToolSet, mapFinishReason, mapToolCalls, mapUsage } from './sdkUtils'
const DEFAULT_BASE_URL = 'http://localhost:11434/v1'
const DEFAULT_THOUGHT_SIGNATURE = 'context_engineering_is_the_way_to_go'
export const OPENAI_COMPATIBLE_INFO: ProviderInfo = {
id: 'openai-compatible',
@@ -29,364 +34,267 @@ export const OPENAI_COMPATIBLE_INFO: ProviderInfo = {
],
}
/**
* 统一处理 baseUrl:去掉尾部斜杠和多余路径
*/
function normalizeBaseUrl(baseUrl?: string): string {
let processed = baseUrl || DEFAULT_BASE_URL
processed = processed.replace(/\/+$/, '')
if (processed.endsWith('/chat/completions')) {
processed = processed.slice(0, -'/chat/completions'.length)
}
return processed
}
/**
* MiniMax 流式返回可能是累计文本,这里按前缀增量去重
*/
function dedupeCumulativeStreamChunk(
chunk: string,
previousText: string
): { delta: string; nextText: string } {
if (!previousText) {
return { delta: chunk, nextText: chunk }
}
if (chunk.startsWith(previousText)) {
return { delta: chunk.slice(previousText.length), nextText: chunk }
}
if (previousText.startsWith(chunk)) {
// 偶发回退或重复帧,保持已输出内容
return { delta: '', nextText: previousText }
}
// 无法判定为累计时,退化为增量追加
return { delta: chunk, nextText: previousText + chunk }
}
/**
* 包装 fetch:注入思考开关和 thought_signature
*/
function createCompatFetch(disableThinking: boolean): typeof fetch {
return async (input, init) => {
if (!init?.body || typeof init.body !== 'string') {
return fetch(input, init)
}
let parsedBody: Record<string, unknown> | null = null
try {
parsedBody = JSON.parse(init.body) as Record<string, unknown>
} catch {
return fetch(input, init)
}
if (!parsedBody) {
return fetch(input, init)
}
let changed = false
if (Array.isArray(parsedBody.messages)) {
const messages = parsedBody.messages as Array<Record<string, unknown>>
// 为 Gemini 兼容后端补充 thought_signature
for (const message of messages) {
if (
message &&
typeof message === 'object' &&
(message as { role?: string }).role === 'assistant' &&
Array.isArray((message as { tool_calls?: unknown[] }).tool_calls)
) {
const toolCalls = (message as { tool_calls: Array<Record<string, unknown>> }).tool_calls
for (const toolCall of toolCalls) {
const typedCall = toolCall as Record<string, unknown> & {
thought_signature?: string
thoughtSignature?: string
}
if (!typedCall.thought_signature && !typedCall.thoughtSignature) {
typedCall.thought_signature = DEFAULT_THOUGHT_SIGNATURE
changed = true
}
}
}
}
// 禁用思考模式(用于本地模型)
if (disableThinking) {
const chatTemplate = parsedBody.chat_template_kwargs
if (!chatTemplate || typeof chatTemplate !== 'object') {
parsedBody.chat_template_kwargs = { enable_thinking: false }
changed = true
} else if (!(chatTemplate as { enable_thinking?: boolean }).enable_thinking) {
parsedBody.chat_template_kwargs = {
...(chatTemplate as Record<string, unknown>),
enable_thinking: false,
}
changed = true
}
}
}
if (!changed) {
return fetch(input, init)
}
const nextInit: RequestInit = {
...init,
body: JSON.stringify(parsedBody),
}
return fetch(input, nextInit)
}
}
export class OpenAICompatibleService implements ILLMService {
private apiKey: string
private baseUrl: string
private model: string
private disableThinking: boolean
private providerId: LLMProvider
private models: ProviderInfo['models']
private defaultModel: string
private provider = createOpenAI()
constructor(apiKey: string, model?: string, baseUrl?: string, disableThinking?: boolean) {
this.apiKey = apiKey || 'sk-no-key-required' // 本地服务可能不需要 API Key
// 智能处理 baseUrl:如果用户已经包含 /chat/completions,则去掉它
let processedBaseUrl = baseUrl || DEFAULT_BASE_URL
processedBaseUrl = processedBaseUrl.replace(/\/+$/, '') // 去掉尾部斜杠
if (processedBaseUrl.endsWith('/chat/completions')) {
processedBaseUrl = processedBaseUrl.slice(0, -'/chat/completions'.length)
}
this.baseUrl = processedBaseUrl
this.model = model || 'llama3.2'
this.disableThinking = disableThinking ?? true // 默认禁用思考模式
}
constructor(
apiKey: string,
model?: string,
baseUrl?: string,
disableThinking?: boolean,
providerId?: LLMProvider,
models?: ProviderInfo['models']
) {
const normalizedBaseUrl = normalizeBaseUrl(baseUrl)
const resolvedApiKey = apiKey || 'sk-no-key-required'
const resolvedDisableThinking = disableThinking ?? true
const resolvedProviderId = providerId ?? 'openai-compatible'
const resolvedModels = models && models.length > 0 ? models : OPENAI_COMPATIBLE_INFO.models
const defaultModel = resolvedModels[0]?.id || 'llama3.2'
const resolvedModel = model || defaultModel
/**
* 设置 Bearer Token 认证头
*/
private setAuthHeaders(headers: Record<string, string>): void {
if (this.apiKey && this.apiKey !== 'sk-no-key-required') {
headers['Authorization'] = `Bearer ${this.apiKey}`
}
this.apiKey = resolvedApiKey
this.baseUrl = normalizedBaseUrl
this.providerId = resolvedProviderId
this.models = resolvedModels
this.defaultModel = defaultModel
this.model = resolvedModel
this.provider = createOpenAI({
apiKey: resolvedApiKey,
baseURL: normalizedBaseUrl,
name: 'openai-compatible',
fetch: createCompatFetch(resolvedDisableThinking),
})
}
getProvider(): LLMProvider {
return 'openai-compatible'
return this.providerId
}
getModels(): string[] {
return OPENAI_COMPATIBLE_INFO.models.map((m) => m.id)
return this.models.map((m) => m.id)
}
getDefaultModel(): string {
return 'llama3.2'
return this.defaultModel
}
// 统一处理消息映射,保持与旧实现一致
private buildMessages(messages: ChatMessage[]): ModelMessage[] {
return buildModelMessages(messages)
}
// 统一映射工具调用结果
private mapToolCalls(toolCalls: TypedToolCall<ToolSet>[]): ToolCall[] {
return mapToolCalls(toolCalls)
}
// 仅 MiniMax 需要累计去重,避免其他模型缓存全文
private shouldTrackStreamText(): boolean {
return this.providerId === 'minimax'
}
// MiniMax 流式返回可能是累计文本,按前缀增量去重
private getStreamChunkDelta(chunk: string, previousText: string): { delta: string; nextText: string } {
if (this.providerId !== 'minimax') {
return { delta: chunk, nextText: previousText + chunk }
}
return dedupeCumulativeStreamChunk(chunk, previousText)
}
async chat(messages: ChatMessage[], options?: ChatOptions): Promise<ChatResponse> {
const requestBody: Record<string, unknown> = {
model: this.model,
messages: messages.map((m) => {
const msg: Record<string, unknown> = { role: m.role, content: m.content }
if (m.role === 'tool' && m.tool_call_id) {
msg.tool_call_id = m.tool_call_id
}
if (m.role === 'assistant' && m.tool_calls) {
// 确保 thoughtSignature 被传递(Gemini 3+ 通过 OpenAI 兼容 API 需要)
msg.tool_calls = m.tool_calls.map((tc) => ({
...tc,
// 如果没有签名,使用虚拟签名(用于 Vertex AI/Gemini 后端)
thought_signature: tc.thoughtSignature || 'context_engineering_is_the_way_to_go',
}))
}
return msg
}),
const model = this.provider.chat(this.model)
const toolSet = buildToolSet(options?.tools)
const result = await generateText({
model,
messages: this.buildMessages(messages),
tools: toolSet,
temperature: options?.temperature ?? 0.7,
max_tokens: options?.maxTokens ?? 2048,
stream: false,
}
if (options?.tools && options.tools.length > 0) {
requestBody.tools = options.tools
}
// 禁用思考模式(用于 Qwen3、DeepSeek-R1 等本地模型)
if (this.disableThinking) {
requestBody.chat_template_kwargs = { enable_thinking: false }
}
const headers: Record<string, string> = {
'Content-Type': 'application/json',
}
this.setAuthHeaders(headers)
const response = await fetch(`${this.baseUrl}/chat/completions`, {
method: 'POST',
headers,
body: JSON.stringify(requestBody),
signal: options?.abortSignal,
maxTokens: options?.maxTokens ?? 2048,
abortSignal: options?.abortSignal,
})
if (!response.ok) {
const error = await response.text()
throw new Error(`OpenAI Compatible API error: ${response.status} - ${error}`)
}
const data = await response.json()
const choice = data.choices?.[0]
const message = choice?.message
let finishReason: ChatResponse['finishReason'] = 'error'
if (choice?.finish_reason === 'stop') {
finishReason = 'stop'
} else if (choice?.finish_reason === 'length') {
finishReason = 'length'
} else if (choice?.finish_reason === 'tool_calls') {
finishReason = 'tool_calls'
}
let toolCalls: ToolCall[] | undefined
if (message?.tool_calls && Array.isArray(message.tool_calls)) {
toolCalls = message.tool_calls.map((tc: Record<string, unknown>) => ({
id: tc.id as string,
type: 'function' as const,
function: {
name: (tc.function as Record<string, unknown>)?.name as string,
arguments: (tc.function as Record<string, unknown>)?.arguments as string,
},
// 提取 thoughtSignatureGemini 3+ 通过 OpenAI 兼容 API 可能返回此字段)
thoughtSignature: (tc.thought_signature || tc.thoughtSignature) as string | undefined,
}))
}
const toolCalls = result.toolCalls.length > 0 ? this.mapToolCalls(result.toolCalls) : undefined
return {
content: message?.content || '',
finishReason,
content: result.text,
finishReason: mapFinishReason(result.finishReason),
tool_calls: toolCalls,
usage: data.usage
? {
promptTokens: data.usage.prompt_tokens,
completionTokens: data.usage.completion_tokens,
totalTokens: data.usage.total_tokens,
}
: undefined,
usage: mapUsage(result.usage),
}
}
async *chatStream(messages: ChatMessage[], options?: ChatOptions): AsyncGenerator<ChatStreamChunk> {
const requestBody: Record<string, unknown> = {
model: this.model,
messages: messages.map((m) => {
const msg: Record<string, unknown> = { role: m.role, content: m.content }
if (m.role === 'tool' && m.tool_call_id) {
msg.tool_call_id = m.tool_call_id
}
if (m.role === 'assistant' && m.tool_calls) {
// 确保 thoughtSignature 被传递(Gemini 3+ 通过 OpenAI 兼容 API 需要)
msg.tool_calls = m.tool_calls.map((tc) => ({
...tc,
// 如果没有签名,使用虚拟签名(用于 Vertex AI/Gemini 后端)
thought_signature: tc.thoughtSignature || 'context_engineering_is_the_way_to_go',
}))
}
return msg
}),
const model = this.provider.chat(this.model)
const toolSet = buildToolSet(options?.tools)
const shouldTrack = this.shouldTrackStreamText()
let streamedText = ''
const result = streamText({
model,
messages: this.buildMessages(messages),
tools: toolSet,
temperature: options?.temperature ?? 0.7,
max_tokens: options?.maxTokens ?? 2048,
stream: true,
// 启用流式响应中的 usage 统计(OpenAI API 兼容)
stream_options: { include_usage: true },
}
if (options?.tools && options.tools.length > 0) {
requestBody.tools = options.tools
}
// 禁用思考模式(用于 Qwen3、DeepSeek-R1 等本地模型)
if (this.disableThinking) {
requestBody.chat_template_kwargs = { enable_thinking: false }
}
const headers: Record<string, string> = {
'Content-Type': 'application/json',
}
this.setAuthHeaders(headers)
const response = await fetch(`${this.baseUrl}/chat/completions`, {
method: 'POST',
headers,
body: JSON.stringify(requestBody),
signal: options?.abortSignal,
maxTokens: options?.maxTokens ?? 2048,
abortSignal: options?.abortSignal,
})
if (!response.ok) {
const error = await response.text()
throw new Error(`OpenAI Compatible API error: ${response.status} - ${error}`)
}
const reader = response.body?.getReader()
if (!reader) {
throw new Error('Failed to get response reader')
}
const decoder = new TextDecoder()
let buffer = ''
const toolCallsAccumulator: Map<number, { id: string; name: string; arguments: string }> = new Map()
let totalChunks = 0
let totalContent = ''
try {
while (true) {
// 检查是否已中止
for await (const chunk of result.textStream) {
if (options?.abortSignal?.aborted) {
yield { content: '', isFinished: true, finishReason: 'stop' }
return
}
const { done, value } = await reader.read()
if (done) break
buffer += decoder.decode(value, { stream: true })
const lines = buffer.split('\n')
buffer = lines.pop() || ''
for (const line of lines) {
const trimmed = line.trim()
if (!trimmed || !trimmed.startsWith('data: ')) continue
const data = trimmed.slice(6)
if (data === '[DONE]') {
if (toolCallsAccumulator.size > 0) {
const toolCalls: ToolCall[] = Array.from(toolCallsAccumulator.values()).map((tc) => ({
id: tc.id,
type: 'function' as const,
function: {
name: tc.name,
arguments: tc.arguments,
},
// 传递 thoughtSignature(如果存在)
thoughtSignature: tc.thoughtSignature,
}))
yield { content: '', isFinished: true, finishReason: 'tool_calls', tool_calls: toolCalls }
} else {
yield { content: '', isFinished: true, finishReason: 'stop' }
}
return
if (chunk) {
const { delta, nextText } = this.getStreamChunkDelta(chunk, streamedText)
if (shouldTrack) {
streamedText = nextText
}
try {
const parsed = JSON.parse(data)
const delta = parsed.choices?.[0]?.delta
const finishReason = parsed.choices?.[0]?.finish_reason
// 调试:如果有 delta 但没有 content,记录其他可能的内容字段(只写日志文件,不输出控制台)
if (delta && !delta.content && !delta.tool_calls && !finishReason) {
const deltaKeys = Object.keys(delta)
if (deltaKeys.length > 0 && !deltaKeys.every((k) => ['role', 'name', 'audio_content'].includes(k))) {
aiLogger.debug('OpenAI-Compatible', '检测到未处理的 delta 字段', { deltaKeys, delta })
}
}
if (delta?.content) {
totalChunks++
totalContent += delta.content
yield {
content: delta.content,
isFinished: false,
}
}
if (delta?.tool_calls && Array.isArray(delta.tool_calls)) {
for (const tc of delta.tool_calls) {
const index = tc.index ?? 0
const existing = toolCallsAccumulator.get(index)
if (existing) {
if (tc.function?.arguments) {
existing.arguments += tc.function.arguments
}
// 更新 thoughtSignature(如果存在)
if (tc.thought_signature || tc.thoughtSignature) {
existing.thoughtSignature = tc.thought_signature || tc.thoughtSignature
}
} else {
toolCallsAccumulator.set(index, {
id: tc.id || '',
name: tc.function?.name || '',
arguments: tc.function?.arguments || '',
// 提取 thoughtSignatureGemini 3+ 通过 OpenAI 兼容 API 可能返回此字段)
// 如果 API 不返回,使用 Gemini 文档提供的虚拟签名绕过验证
thoughtSignature: tc.thought_signature || tc.thoughtSignature || 'context_engineering_is_the_way_to_go',
})
}
}
}
if (finishReason) {
let reason: ChatStreamChunk['finishReason'] = 'error'
if (finishReason === 'stop') {
reason = 'stop'
} else if (finishReason === 'length') {
reason = 'length'
} else if (finishReason === 'tool_calls') {
reason = 'tool_calls'
}
// 解析 usage 信息
const usage = parsed.usage
? {
promptTokens: parsed.usage.prompt_tokens,
completionTokens: parsed.usage.completion_tokens,
totalTokens: parsed.usage.total_tokens,
}
: undefined
if (toolCallsAccumulator.size > 0) {
const toolCalls: ToolCall[] = Array.from(toolCallsAccumulator.values()).map((tc) => ({
id: tc.id,
type: 'function' as const,
function: {
name: tc.name,
arguments: tc.arguments,
},
// 传递 thoughtSignature(如果存在)
thoughtSignature: tc.thoughtSignature,
}))
yield { content: '', isFinished: true, finishReason: reason, tool_calls: toolCalls, usage }
} else {
yield { content: '', isFinished: true, finishReason: reason, usage }
}
return
}
} catch {
// 忽略解析错误,继续处理下一行
if (delta) {
yield { content: delta, isFinished: false }
}
}
}
// 流读取完成后的处理(如果没有收到 [DONE] 或 finish_reason
// 这种情况可能发生在某些 API 不发送标准结束标记时
aiLogger.info('OpenAI-Compatible', '流循环结束,执行兜底处理', {
totalChunks,
totalContentLength: totalContent.length,
toolCallsCount: toolCallsAccumulator.size,
bufferRemaining: buffer.length,
})
const finishReason = mapFinishReason(await result.finishReason)
const toolCalls = await result.toolCalls
const usage = mapUsage(await result.totalUsage)
// 如果有累积的 tool_calls,发送它们
if (toolCallsAccumulator.size > 0) {
const toolCalls: ToolCall[] = Array.from(toolCallsAccumulator.values()).map((tc) => ({
id: tc.id,
type: 'function' as const,
function: {
name: tc.name,
arguments: tc.arguments,
},
// 传递 thoughtSignature(如果存在)
thoughtSignature: tc.thoughtSignature,
}))
yield { content: '', isFinished: true, finishReason: 'tool_calls', tool_calls: toolCalls }
} else {
// 没有 tool_calls,发送普通完成信号
yield { content: '', isFinished: true, finishReason: 'stop' }
yield {
content: '',
isFinished: true,
finishReason,
tool_calls: toolCalls.length > 0 ? this.mapToolCalls(toolCalls) : undefined,
usage,
}
} catch (error) {
// 如果是中止错误,正常返回
if (error instanceof Error && error.name === 'AbortError') {
yield { content: '', isFinished: true, finishReason: 'stop' }
return
}
aiLogger.error('OpenAI-Compatible', '流处理异常', { error: String(error) })
// 使用 providerId 作为日志前缀,便于区分兼容服务来源
aiLogger.error(this.providerId, '流式请求失败', { error: String(error) })
throw error
} finally {
reader.releaseLock()
}
}
@@ -395,12 +303,11 @@ export class OpenAICompatibleService implements ILLMService {
const headers: Record<string, string> = {
'Content-Type': 'application/json',
}
this.setAuthHeaders(headers)
if (this.apiKey && this.apiKey !== 'sk-no-key-required') {
headers['Authorization'] = `Bearer ${this.apiKey}`
}
const url = `${this.baseUrl}/chat/completions`
// 发送一个简单的测试请求来验证连接和认证
const response = await fetch(url, {
const response = await fetch(`${this.baseUrl}/chat/completions`, {
method: 'POST',
headers,
body: JSON.stringify({
@@ -410,16 +317,14 @@ export class OpenAICompatibleService implements ILLMService {
}),
})
// 200 表示成功,401/403 表示认证失败,其他状态可能是参数问题但服务可达
if (response.ok) {
return { success: true }
}
// 尝试获取错误详情
const errorText = await response.text()
let errorMessage = `HTTP ${response.status}`
try {
const errorJson = JSON.parse(errorText)
const errorJson = JSON.parse(errorText) as { error?: { message?: string }; message?: string }
errorMessage = errorJson.error?.message || errorJson.message || errorMessage
} catch {
if (errorText) {
@@ -427,13 +332,10 @@ export class OpenAICompatibleService implements ILLMService {
}
}
// 认证失败
if (response.status === 401 || response.status === 403) {
return { success: false, error: errorMessage }
}
// 其他错误(如 400 参数错误)但服务可达,认为验证通过
// 因为这说明认证成功了,只是请求参数有问题
return { success: true }
} catch (error) {
const errorMessage = error instanceof Error ? error.message : String(error)
-319
View File
@@ -1,319 +0,0 @@
/**
* 通义千问 (Qwen) LLM Provider
* 使用阿里云 DashScope 兼容 OpenAI 格式的 API,支持 Function Calling
*/
import type {
ILLMService,
LLMProvider,
ChatMessage,
ChatOptions,
ChatResponse,
ChatStreamChunk,
ProviderInfo,
ToolCall,
} from './types'
const DEFAULT_BASE_URL = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
const MODELS = [
{ id: 'qwen-turbo', name: 'Qwen Turbo', description: '通义千问超大规模语言模型,速度快' },
{ id: 'qwen-plus', name: 'Qwen Plus', description: '通义千问超大规模语言模型,效果好' },
{ id: 'qwen-max', name: 'Qwen Max', description: '通义千问千亿级别超大规模语言模型' },
]
export const QWEN_INFO: ProviderInfo = {
id: 'qwen',
name: '通义千问',
description: '阿里云通义千问大语言模型',
defaultBaseUrl: DEFAULT_BASE_URL,
models: MODELS,
}
export class QwenService implements ILLMService {
private apiKey: string
private baseUrl: string
private model: string
constructor(apiKey: string, model?: string, baseUrl?: string) {
this.apiKey = apiKey
this.baseUrl = baseUrl || DEFAULT_BASE_URL
this.model = model || 'qwen-turbo'
}
getProvider(): LLMProvider {
return 'qwen'
}
getModels(): string[] {
return MODELS.map((m) => m.id)
}
getDefaultModel(): string {
return 'qwen-turbo'
}
async chat(messages: ChatMessage[], options?: ChatOptions): Promise<ChatResponse> {
// 构建请求体
const requestBody: Record<string, unknown> = {
model: this.model,
messages: messages.map((m) => {
const msg: Record<string, unknown> = { role: m.role, content: m.content }
if (m.role === 'tool' && m.tool_call_id) {
msg.tool_call_id = m.tool_call_id
}
if (m.role === 'assistant' && m.tool_calls) {
msg.tool_calls = m.tool_calls
}
return msg
}),
temperature: options?.temperature ?? 0.7,
max_tokens: options?.maxTokens ?? 2048,
stream: false,
}
if (options?.tools && options.tools.length > 0) {
requestBody.tools = options.tools
}
const response = await fetch(`${this.baseUrl}/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${this.apiKey}`,
},
body: JSON.stringify(requestBody),
signal: options?.abortSignal,
})
if (!response.ok) {
const error = await response.text()
throw new Error(`Qwen API error: ${response.status} - ${error}`)
}
const data = await response.json()
const choice = data.choices?.[0]
const message = choice?.message
// 解析 finish_reason
let finishReason: ChatResponse['finishReason'] = 'error'
if (choice?.finish_reason === 'stop') {
finishReason = 'stop'
} else if (choice?.finish_reason === 'length') {
finishReason = 'length'
} else if (choice?.finish_reason === 'tool_calls') {
finishReason = 'tool_calls'
}
// 解析 tool_calls
let toolCalls: ToolCall[] | undefined
if (message?.tool_calls && Array.isArray(message.tool_calls)) {
toolCalls = message.tool_calls.map((tc: Record<string, unknown>) => ({
id: tc.id as string,
type: 'function' as const,
function: {
name: (tc.function as Record<string, unknown>)?.name as string,
arguments: (tc.function as Record<string, unknown>)?.arguments as string,
},
}))
}
return {
content: message?.content || '',
finishReason,
tool_calls: toolCalls,
usage: data.usage
? {
promptTokens: data.usage.prompt_tokens,
completionTokens: data.usage.completion_tokens,
totalTokens: data.usage.total_tokens,
}
: undefined,
}
}
async *chatStream(messages: ChatMessage[], options?: ChatOptions): AsyncGenerator<ChatStreamChunk> {
// 构建请求体
const requestBody: Record<string, unknown> = {
model: this.model,
messages: messages.map((m) => {
const msg: Record<string, unknown> = { role: m.role, content: m.content }
if (m.role === 'tool' && m.tool_call_id) {
msg.tool_call_id = m.tool_call_id
}
if (m.role === 'assistant' && m.tool_calls) {
msg.tool_calls = m.tool_calls
}
return msg
}),
temperature: options?.temperature ?? 0.7,
max_tokens: options?.maxTokens ?? 2048,
stream: true,
// 启用流式响应中的 usage 统计
stream_options: { include_usage: true },
}
if (options?.tools && options.tools.length > 0) {
requestBody.tools = options.tools
}
const response = await fetch(`${this.baseUrl}/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${this.apiKey}`,
},
body: JSON.stringify(requestBody),
signal: options?.abortSignal,
})
if (!response.ok) {
const error = await response.text()
throw new Error(`Qwen API error: ${response.status} - ${error}`)
}
const reader = response.body?.getReader()
if (!reader) {
throw new Error('Failed to get response reader')
}
const decoder = new TextDecoder()
let buffer = ''
const toolCallsAccumulator: Map<number, { id: string; name: string; arguments: string }> = new Map()
try {
while (true) {
// 检查是否已中止
if (options?.abortSignal?.aborted) {
yield { content: '', isFinished: true, finishReason: 'stop' }
return
}
const { done, value } = await reader.read()
if (done) break
buffer += decoder.decode(value, { stream: true })
const lines = buffer.split('\n')
buffer = lines.pop() || ''
for (const line of lines) {
const trimmed = line.trim()
if (!trimmed || !trimmed.startsWith('data: ')) continue
const data = trimmed.slice(6)
if (data === '[DONE]') {
if (toolCallsAccumulator.size > 0) {
const toolCalls: ToolCall[] = Array.from(toolCallsAccumulator.values()).map((tc) => ({
id: tc.id,
type: 'function' as const,
function: { name: tc.name, arguments: tc.arguments },
}))
yield { content: '', isFinished: true, finishReason: 'tool_calls', tool_calls: toolCalls }
} else {
yield { content: '', isFinished: true, finishReason: 'stop' }
}
return
}
try {
const parsed = JSON.parse(data)
const delta = parsed.choices?.[0]?.delta
const finishReason = parsed.choices?.[0]?.finish_reason
if (delta?.content) {
yield { content: delta.content, isFinished: false }
}
// 处理流式 tool_calls
if (delta?.tool_calls && Array.isArray(delta.tool_calls)) {
for (const tc of delta.tool_calls) {
const index = tc.index ?? 0
const existing = toolCallsAccumulator.get(index)
if (existing) {
if (tc.function?.arguments) {
existing.arguments += tc.function.arguments
}
} else {
toolCallsAccumulator.set(index, {
id: tc.id || '',
name: tc.function?.name || '',
arguments: tc.function?.arguments || '',
})
}
}
}
if (finishReason) {
let reason: ChatStreamChunk['finishReason'] = 'error'
if (finishReason === 'stop') reason = 'stop'
else if (finishReason === 'length') reason = 'length'
else if (finishReason === 'tool_calls') reason = 'tool_calls'
// 解析 usage 信息
const usage = parsed.usage
? {
promptTokens: parsed.usage.prompt_tokens,
completionTokens: parsed.usage.completion_tokens,
totalTokens: parsed.usage.total_tokens,
}
: undefined
if (toolCallsAccumulator.size > 0) {
const toolCalls: ToolCall[] = Array.from(toolCallsAccumulator.values()).map((tc) => ({
id: tc.id,
type: 'function' as const,
function: { name: tc.name, arguments: tc.arguments },
}))
yield { content: '', isFinished: true, finishReason: reason, tool_calls: toolCalls, usage }
} else {
yield { content: '', isFinished: true, finishReason: reason, usage }
}
return
}
} catch {
// 忽略解析错误,继续处理下一行
}
}
}
} catch (error) {
// 如果是中止错误,正常返回
if (error instanceof Error && error.name === 'AbortError') {
yield { content: '', isFinished: true, finishReason: 'stop' }
return
}
throw error
} finally {
reader.releaseLock()
}
}
async validateApiKey(): Promise<{ success: boolean; error?: string }> {
try {
// 发送一个简单请求验证 API Key
const response = await fetch(`${this.baseUrl}/models`, {
method: 'GET',
headers: {
Authorization: `Bearer ${this.apiKey}`,
},
})
if (response.ok) {
return { success: true }
}
// 尝试获取错误详情
const errorText = await response.text()
let errorMessage = `HTTP ${response.status}`
try {
const errorJson = JSON.parse(errorText)
errorMessage = errorJson.error?.message || errorJson.message || errorMessage
} catch {
if (errorText) {
errorMessage = errorText.slice(0, 200)
}
}
return { success: false, error: errorMessage }
} catch (error) {
const errorMessage = error instanceof Error ? error.message : String(error)
return { success: false, error: errorMessage }
}
}
}
+158
View File
@@ -0,0 +1,158 @@
/**
* AI SDK 适配通用工具
*/
import { jsonSchema } from 'ai'
import type {
ContentPart,
FinishReason,
LanguageModelUsage,
ModelMessage,
ToolSet,
TypedToolCall,
JSONValue,
} from 'ai'
import type { ChatMessage, ChatResponse, ToolCall, ToolDefinition } from './types'
const UNKNOWN_TOOL_NAME = 'unknown_tool'
/**
* 将 OpenAI 风格的工具定义转换为 AI SDK ToolSet
*/
export function buildToolSet(tools?: ToolDefinition[]): ToolSet | undefined {
if (!tools || tools.length === 0) return undefined
const toolSet: ToolSet = {}
for (const tool of tools) {
toolSet[tool.function.name] = {
description: tool.function.description,
inputSchema: jsonSchema(tool.function.parameters),
}
}
return toolSet
}
/**
* 解析 JSON 字符串,失败返回 null
*/
export function safeParseJson(value: string): JSONValue | null {
try {
return JSON.parse(value) as JSONValue
} catch {
return null
}
}
/**
* 将 AI SDK finishReason 映射为现有的 LLM finishReason
*/
export function mapFinishReason(reason: FinishReason): ChatResponse['finishReason'] {
if (reason === 'stop') return 'stop'
if (reason === 'length') return 'length'
if (reason === 'tool-calls') return 'tool_calls'
return 'error'
}
/**
* 将 AI SDK usage 映射为现有的 usage 结构
*/
export function mapUsage(usage?: LanguageModelUsage): ChatResponse['usage'] | undefined {
if (!usage) return undefined
const promptTokens = usage.inputTokens
const completionTokens = usage.outputTokens
const totalTokens = usage.totalTokens
if (promptTokens == null && completionTokens == null && totalTokens == null) {
return undefined
}
return {
promptTokens: promptTokens ?? 0,
completionTokens: completionTokens ?? 0,
totalTokens: totalTokens ?? 0,
}
}
/**
* 将 AI SDK toolCalls 映射为现有 ToolCall 结构
*/
export function mapToolCalls(toolCalls: TypedToolCall<ToolSet>[]): ToolCall[] {
return toolCalls.map((tc) => ({
id: tc.toolCallId,
type: 'function' as const,
function: {
name: tc.toolName,
arguments: JSON.stringify(tc.input ?? {}),
},
}))
}
/**
* 将现有 ChatMessage 转为 AI SDK 的 ModelMessage
*/
export function buildModelMessages(messages: ChatMessage[]): ModelMessage[] {
const toolCallNameMap = new Map<string, string>()
// 先建立 tool_call_id 到 toolName 的映射,供 tool 消息使用
for (const message of messages) {
if (message.role !== 'assistant' || !message.tool_calls) continue
for (const toolCall of message.tool_calls) {
toolCallNameMap.set(toolCall.id, toolCall.function.name)
}
}
return messages.map((message) => {
if (message.role === 'assistant') {
if (message.tool_calls && message.tool_calls.length > 0) {
const contentParts: Array<ContentPart<ToolSet>> = []
if (message.content) {
contentParts.push({ type: 'text', text: message.content })
}
for (const toolCall of message.tool_calls) {
const input = safeParseJson(toolCall.function.arguments || '{}') ?? {}
contentParts.push({
type: 'tool-call',
toolCallId: toolCall.id,
toolName: toolCall.function.name,
input,
})
}
return {
role: 'assistant',
content: contentParts,
}
}
return { role: 'assistant', content: message.content }
}
if (message.role === 'tool') {
const toolCallId = message.tool_call_id || ''
const toolName = toolCallNameMap.get(toolCallId) || UNKNOWN_TOOL_NAME
const parsed = safeParseJson(message.content)
return {
role: 'tool',
content: [
{
type: 'tool-result',
toolCallId,
toolName,
// 工具结果只允许 JSON/文本,保持与旧实现一致
output: parsed ? { type: 'json', value: parsed } : { type: 'text', value: message.content },
},
],
}
}
return {
role: message.role,
content: message.content,
}
})
}
+99 -3
View File
@@ -11,6 +11,98 @@ import type { IpcContext } from './types'
// 用于跟踪活跃的 Agent 请求,支持中止操作
const activeAgentRequests = new Map<string, AbortController>()
/**
* 格式化 AI 报错信息,输出更友好的提示
*/
function formatAIError(error: unknown): string {
const candidates: unknown[] = []
if (error) {
candidates.push(error)
}
const errorObj = error as {
lastError?: unknown
errors?: unknown[]
}
if (errorObj?.lastError) {
candidates.push(errorObj.lastError)
}
if (Array.isArray(errorObj?.errors)) {
candidates.push(...errorObj.errors)
}
let rawMessage = ''
let statusCode: number | undefined
let retrySeconds: number | undefined
for (const candidate of candidates) {
if (!candidate || typeof candidate !== 'object') {
if (!rawMessage && typeof candidate === 'string') {
rawMessage = candidate
}
continue
}
const record = candidate as Record<string, unknown>
if (typeof record.statusCode === 'number') {
statusCode = record.statusCode
}
if (!rawMessage && typeof record.message === 'string') {
rawMessage = record.message
}
if (!rawMessage && record.data && typeof record.data === 'object') {
const data = record.data as { error?: { message?: string } }
if (data.error?.message) {
rawMessage = data.error.message
}
}
if (record.responseBody && typeof record.responseBody === 'string') {
const responseBody = record.responseBody
try {
const parsed = JSON.parse(responseBody) as { error?: { message?: string } }
if (!rawMessage && parsed.error?.message) {
rawMessage = parsed.error.message
}
} catch {
if (!rawMessage) {
rawMessage = responseBody
}
}
}
if (rawMessage) {
const retryMatch = rawMessage.match(/retry in ([0-9.]+)s/i)
if (retryMatch) {
retrySeconds = Math.ceil(Number(retryMatch[1]))
}
}
}
const fallbackMessage = rawMessage || String(error)
const lowerMessage = fallbackMessage.toLowerCase()
if (statusCode === 429 || lowerMessage.includes('quota') || lowerMessage.includes('resource_exhausted')) {
return retrySeconds
? `Gemini 配额已用尽,请等待 ${retrySeconds} 秒后重试,或更换/升级配额。`
: 'Gemini 配额已用尽,请稍后重试或更换/升级配额。'
}
if (statusCode === 503 || lowerMessage.includes('overloaded') || lowerMessage.includes('unavailable')) {
return 'Gemini 模型繁忙,请稍后重试。'
}
if (fallbackMessage.length > 300) {
return `${fallbackMessage.slice(0, 300)}...`
}
return fallbackMessage
}
export function registerAIHandlers({ win }: IpcContext): void {
console.log('[IPC] Registering AI handlers...')
@@ -464,11 +556,15 @@ export function registerAIHandlers({ win }: IpcContext): void {
aiLogger.info('IPC', `Agent 请求已中止: ${requestId}`)
return
}
aiLogger.error('IPC', `Agent 执行出错: ${requestId}`, { error: String(error) })
const friendlyError = formatAIError(error)
aiLogger.error('IPC', `Agent 执行出错: ${requestId}`, {
error: String(error),
friendlyError,
})
// 发送错误 chunk
win.webContents.send('agent:streamChunk', {
requestId,
chunk: { type: 'error', error: String(error), isFinished: true },
chunk: { type: 'error', error: friendlyError, isFinished: true },
})
// 发送完成事件(带错误信息),确保前端 promise 能 resolve
win.webContents.send('agent:complete', {
@@ -478,7 +574,7 @@ export function registerAIHandlers({ win }: IpcContext): void {
toolsUsed: [],
toolRounds: 0,
totalUsage: { promptTokens: 0, completionTokens: 0, totalTokens: 0 },
error: String(error),
error: friendlyError,
},
})
} finally {