mirror of
https://github.com/hellodigua/ChatLab.git
synced 2026-05-06 21:16:10 +08:00
320 lines
10 KiB
TypeScript
320 lines
10 KiB
TypeScript
/**
|
|
* 通义千问 (Qwen) LLM Provider
|
|
* 使用阿里云 DashScope 兼容 OpenAI 格式的 API,支持 Function Calling
|
|
*/
|
|
|
|
import type {
|
|
ILLMService,
|
|
LLMProvider,
|
|
ChatMessage,
|
|
ChatOptions,
|
|
ChatResponse,
|
|
ChatStreamChunk,
|
|
ProviderInfo,
|
|
ToolCall,
|
|
} from './types'
|
|
|
|
const DEFAULT_BASE_URL = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
|
|
|
|
const MODELS = [
|
|
{ id: 'qwen-turbo', name: 'Qwen Turbo', description: '通义千问超大规模语言模型,速度快' },
|
|
{ id: 'qwen-plus', name: 'Qwen Plus', description: '通义千问超大规模语言模型,效果好' },
|
|
{ id: 'qwen-max', name: 'Qwen Max', description: '通义千问千亿级别超大规模语言模型' },
|
|
]
|
|
|
|
export const QWEN_INFO: ProviderInfo = {
|
|
id: 'qwen',
|
|
name: '通义千问',
|
|
description: '阿里云通义千问大语言模型',
|
|
defaultBaseUrl: DEFAULT_BASE_URL,
|
|
models: MODELS,
|
|
}
|
|
|
|
export class QwenService implements ILLMService {
|
|
private apiKey: string
|
|
private baseUrl: string
|
|
private model: string
|
|
|
|
constructor(apiKey: string, model?: string, baseUrl?: string) {
|
|
this.apiKey = apiKey
|
|
this.baseUrl = baseUrl || DEFAULT_BASE_URL
|
|
this.model = model || 'qwen-turbo'
|
|
}
|
|
|
|
getProvider(): LLMProvider {
|
|
return 'qwen'
|
|
}
|
|
|
|
getModels(): string[] {
|
|
return MODELS.map((m) => m.id)
|
|
}
|
|
|
|
getDefaultModel(): string {
|
|
return 'qwen-turbo'
|
|
}
|
|
|
|
async chat(messages: ChatMessage[], options?: ChatOptions): Promise<ChatResponse> {
|
|
// 构建请求体
|
|
const requestBody: Record<string, unknown> = {
|
|
model: this.model,
|
|
messages: messages.map((m) => {
|
|
const msg: Record<string, unknown> = { role: m.role, content: m.content }
|
|
if (m.role === 'tool' && m.tool_call_id) {
|
|
msg.tool_call_id = m.tool_call_id
|
|
}
|
|
if (m.role === 'assistant' && m.tool_calls) {
|
|
msg.tool_calls = m.tool_calls
|
|
}
|
|
return msg
|
|
}),
|
|
temperature: options?.temperature ?? 0.7,
|
|
max_tokens: options?.maxTokens ?? 2048,
|
|
stream: false,
|
|
}
|
|
|
|
if (options?.tools && options.tools.length > 0) {
|
|
requestBody.tools = options.tools
|
|
}
|
|
|
|
const response = await fetch(`${this.baseUrl}/chat/completions`, {
|
|
method: 'POST',
|
|
headers: {
|
|
'Content-Type': 'application/json',
|
|
Authorization: `Bearer ${this.apiKey}`,
|
|
},
|
|
body: JSON.stringify(requestBody),
|
|
signal: options?.abortSignal,
|
|
})
|
|
|
|
if (!response.ok) {
|
|
const error = await response.text()
|
|
throw new Error(`Qwen API error: ${response.status} - ${error}`)
|
|
}
|
|
|
|
const data = await response.json()
|
|
const choice = data.choices?.[0]
|
|
const message = choice?.message
|
|
|
|
// 解析 finish_reason
|
|
let finishReason: ChatResponse['finishReason'] = 'error'
|
|
if (choice?.finish_reason === 'stop') {
|
|
finishReason = 'stop'
|
|
} else if (choice?.finish_reason === 'length') {
|
|
finishReason = 'length'
|
|
} else if (choice?.finish_reason === 'tool_calls') {
|
|
finishReason = 'tool_calls'
|
|
}
|
|
|
|
// 解析 tool_calls
|
|
let toolCalls: ToolCall[] | undefined
|
|
if (message?.tool_calls && Array.isArray(message.tool_calls)) {
|
|
toolCalls = message.tool_calls.map((tc: Record<string, unknown>) => ({
|
|
id: tc.id as string,
|
|
type: 'function' as const,
|
|
function: {
|
|
name: (tc.function as Record<string, unknown>)?.name as string,
|
|
arguments: (tc.function as Record<string, unknown>)?.arguments as string,
|
|
},
|
|
}))
|
|
}
|
|
|
|
return {
|
|
content: message?.content || '',
|
|
finishReason,
|
|
tool_calls: toolCalls,
|
|
usage: data.usage
|
|
? {
|
|
promptTokens: data.usage.prompt_tokens,
|
|
completionTokens: data.usage.completion_tokens,
|
|
totalTokens: data.usage.total_tokens,
|
|
}
|
|
: undefined,
|
|
}
|
|
}
|
|
|
|
async *chatStream(messages: ChatMessage[], options?: ChatOptions): AsyncGenerator<ChatStreamChunk> {
|
|
// 构建请求体
|
|
const requestBody: Record<string, unknown> = {
|
|
model: this.model,
|
|
messages: messages.map((m) => {
|
|
const msg: Record<string, unknown> = { role: m.role, content: m.content }
|
|
if (m.role === 'tool' && m.tool_call_id) {
|
|
msg.tool_call_id = m.tool_call_id
|
|
}
|
|
if (m.role === 'assistant' && m.tool_calls) {
|
|
msg.tool_calls = m.tool_calls
|
|
}
|
|
return msg
|
|
}),
|
|
temperature: options?.temperature ?? 0.7,
|
|
max_tokens: options?.maxTokens ?? 2048,
|
|
stream: true,
|
|
// 启用流式响应中的 usage 统计
|
|
stream_options: { include_usage: true },
|
|
}
|
|
|
|
if (options?.tools && options.tools.length > 0) {
|
|
requestBody.tools = options.tools
|
|
}
|
|
|
|
const response = await fetch(`${this.baseUrl}/chat/completions`, {
|
|
method: 'POST',
|
|
headers: {
|
|
'Content-Type': 'application/json',
|
|
Authorization: `Bearer ${this.apiKey}`,
|
|
},
|
|
body: JSON.stringify(requestBody),
|
|
signal: options?.abortSignal,
|
|
})
|
|
|
|
if (!response.ok) {
|
|
const error = await response.text()
|
|
throw new Error(`Qwen API error: ${response.status} - ${error}`)
|
|
}
|
|
|
|
const reader = response.body?.getReader()
|
|
if (!reader) {
|
|
throw new Error('Failed to get response reader')
|
|
}
|
|
|
|
const decoder = new TextDecoder()
|
|
let buffer = ''
|
|
const toolCallsAccumulator: Map<number, { id: string; name: string; arguments: string }> = new Map()
|
|
|
|
try {
|
|
while (true) {
|
|
// 检查是否已中止
|
|
if (options?.abortSignal?.aborted) {
|
|
yield { content: '', isFinished: true, finishReason: 'stop' }
|
|
return
|
|
}
|
|
|
|
const { done, value } = await reader.read()
|
|
if (done) break
|
|
|
|
buffer += decoder.decode(value, { stream: true })
|
|
const lines = buffer.split('\n')
|
|
buffer = lines.pop() || ''
|
|
|
|
for (const line of lines) {
|
|
const trimmed = line.trim()
|
|
if (!trimmed || !trimmed.startsWith('data: ')) continue
|
|
|
|
const data = trimmed.slice(6)
|
|
if (data === '[DONE]') {
|
|
if (toolCallsAccumulator.size > 0) {
|
|
const toolCalls: ToolCall[] = Array.from(toolCallsAccumulator.values()).map((tc) => ({
|
|
id: tc.id,
|
|
type: 'function' as const,
|
|
function: { name: tc.name, arguments: tc.arguments },
|
|
}))
|
|
yield { content: '', isFinished: true, finishReason: 'tool_calls', tool_calls: toolCalls }
|
|
} else {
|
|
yield { content: '', isFinished: true, finishReason: 'stop' }
|
|
}
|
|
return
|
|
}
|
|
|
|
try {
|
|
const parsed = JSON.parse(data)
|
|
const delta = parsed.choices?.[0]?.delta
|
|
const finishReason = parsed.choices?.[0]?.finish_reason
|
|
|
|
if (delta?.content) {
|
|
yield { content: delta.content, isFinished: false }
|
|
}
|
|
|
|
// 处理流式 tool_calls
|
|
if (delta?.tool_calls && Array.isArray(delta.tool_calls)) {
|
|
for (const tc of delta.tool_calls) {
|
|
const index = tc.index ?? 0
|
|
const existing = toolCallsAccumulator.get(index)
|
|
if (existing) {
|
|
if (tc.function?.arguments) {
|
|
existing.arguments += tc.function.arguments
|
|
}
|
|
} else {
|
|
toolCallsAccumulator.set(index, {
|
|
id: tc.id || '',
|
|
name: tc.function?.name || '',
|
|
arguments: tc.function?.arguments || '',
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
if (finishReason) {
|
|
let reason: ChatStreamChunk['finishReason'] = 'error'
|
|
if (finishReason === 'stop') reason = 'stop'
|
|
else if (finishReason === 'length') reason = 'length'
|
|
else if (finishReason === 'tool_calls') reason = 'tool_calls'
|
|
|
|
// 解析 usage 信息
|
|
const usage = parsed.usage
|
|
? {
|
|
promptTokens: parsed.usage.prompt_tokens,
|
|
completionTokens: parsed.usage.completion_tokens,
|
|
totalTokens: parsed.usage.total_tokens,
|
|
}
|
|
: undefined
|
|
|
|
if (toolCallsAccumulator.size > 0) {
|
|
const toolCalls: ToolCall[] = Array.from(toolCallsAccumulator.values()).map((tc) => ({
|
|
id: tc.id,
|
|
type: 'function' as const,
|
|
function: { name: tc.name, arguments: tc.arguments },
|
|
}))
|
|
yield { content: '', isFinished: true, finishReason: reason, tool_calls: toolCalls, usage }
|
|
} else {
|
|
yield { content: '', isFinished: true, finishReason: reason, usage }
|
|
}
|
|
return
|
|
}
|
|
} catch {
|
|
// 忽略解析错误,继续处理下一行
|
|
}
|
|
}
|
|
}
|
|
} catch (error) {
|
|
// 如果是中止错误,正常返回
|
|
if (error instanceof Error && error.name === 'AbortError') {
|
|
yield { content: '', isFinished: true, finishReason: 'stop' }
|
|
return
|
|
}
|
|
throw error
|
|
} finally {
|
|
reader.releaseLock()
|
|
}
|
|
}
|
|
|
|
async validateApiKey(): Promise<{ success: boolean; error?: string }> {
|
|
try {
|
|
// 发送一个简单请求验证 API Key
|
|
const response = await fetch(`${this.baseUrl}/models`, {
|
|
method: 'GET',
|
|
headers: {
|
|
Authorization: `Bearer ${this.apiKey}`,
|
|
},
|
|
})
|
|
if (response.ok) {
|
|
return { success: true }
|
|
}
|
|
// 尝试获取错误详情
|
|
const errorText = await response.text()
|
|
let errorMessage = `HTTP ${response.status}`
|
|
try {
|
|
const errorJson = JSON.parse(errorText)
|
|
errorMessage = errorJson.error?.message || errorJson.message || errorMessage
|
|
} catch {
|
|
if (errorText) {
|
|
errorMessage = errorText.slice(0, 200)
|
|
}
|
|
}
|
|
return { success: false, error: errorMessage }
|
|
} catch (error) {
|
|
const errorMessage = error instanceof Error ? error.message : String(error)
|
|
return { success: false, error: errorMessage }
|
|
}
|
|
}
|
|
}
|