feat: AI支持配置Gemini

This commit is contained in:
digua
2025-12-21 00:01:55 +08:00
parent af11f39f36
commit 78a328fb19
6 changed files with 530 additions and 14 deletions
+519
View File
@@ -0,0 +1,519 @@
/**
* Google Gemini LLM Provider
* 使用 Gemini REST API 格式,支持 Function Calling
*/
import type {
ILLMService,
LLMProvider,
ChatMessage,
ChatOptions,
ChatResponse,
ChatStreamChunk,
ProviderInfo,
ToolCall,
ToolDefinition,
} from './types'
import { aiLogger } from '../logger'
const DEFAULT_BASE_URL = 'https://generativelanguage.googleapis.com'
const MODELS = [
{ id: 'gemini-3-flash-preview', name: 'Gemini 3 Flash Preview', description: '高速预览版' },
{ id: 'gemini-3-pro-preview', name: 'Gemini 3 Pro Preview', description: '专业预览版' },
]
export const GEMINI_INFO: ProviderInfo = {
id: 'gemini',
name: 'Gemini',
description: 'Google Gemini 大语言模型',
defaultBaseUrl: DEFAULT_BASE_URL,
models: MODELS,
}
// ==================== Gemini API 类型定义 ====================
/** Gemini 消息 part(支持多种类型) */
interface GeminiPart {
text?: string
functionCall?: {
name: string
args: Record<string, unknown>
}
functionResponse?: {
name: string
response: unknown
}
/** Gemini 3+ 模型的思考签名 */
thoughtSignature?: string
}
/** Gemini 消息内容 */
interface GeminiContent {
role: 'user' | 'model'
parts: GeminiPart[]
}
/** Gemini 函数声明(对应 OpenAI 的 ToolDefinition */
interface GeminiFunctionDeclaration {
name: string
description: string
parameters: {
type: string
properties: Record<string, unknown>
required?: string[]
}
}
/** Gemini 请求体 */
interface GeminiRequest {
contents: GeminiContent[]
generationConfig?: {
temperature?: number
maxOutputTokens?: number
}
systemInstruction?: {
parts: Array<{ text: string }>
}
tools?: Array<{
functionDeclarations: GeminiFunctionDeclaration[]
}>
}
/** Gemini 响应候选项 */
interface GeminiCandidate {
content?: {
parts?: GeminiPart[]
role?: string
}
finishReason?: string
}
/** Gemini API 响应 */
interface GeminiResponse {
candidates?: GeminiCandidate[]
usageMetadata?: {
promptTokenCount?: number
candidatesTokenCount?: number
totalTokenCount?: number
}
}
// ==================== GeminiService 类 ====================
export class GeminiService implements ILLMService {
private apiKey: string
private baseUrl: string
private model: string
constructor(apiKey: string, model?: string, baseUrl?: string) {
this.apiKey = apiKey
this.baseUrl = baseUrl || DEFAULT_BASE_URL
this.model = model || 'gemini-3-flash-preview'
}
getProvider(): LLMProvider {
return 'gemini'
}
getModels(): string[] {
return MODELS.map((m) => m.id)
}
getDefaultModel(): string {
return 'gemini-3-flash-preview'
}
/**
* 将 OpenAI 格式的 tools 转换为 Gemini 格式
*/
private convertTools(tools?: ToolDefinition[]): Array<{ functionDeclarations: GeminiFunctionDeclaration[] }> | undefined {
if (!tools || tools.length === 0) return undefined
const functionDeclarations: GeminiFunctionDeclaration[] = tools.map((tool) => ({
name: tool.function.name,
description: tool.function.description,
parameters: tool.function.parameters,
}))
return [{ functionDeclarations }]
}
/**
* 将 OpenAI 格式消息转换为 Gemini 格式
*/
private convertMessages(messages: ChatMessage[]): {
contents: GeminiContent[]
systemInstruction?: { parts: Array<{ text: string }> }
} {
const contents: GeminiContent[] = []
let systemInstruction: { parts: Array<{ text: string }> } | undefined
for (const msg of messages) {
if (msg.role === 'system') {
// Gemini 使用 systemInstruction 处理系统提示
systemInstruction = {
parts: [{ text: msg.content }],
}
} else if (msg.role === 'user') {
contents.push({
role: 'user',
parts: [{ text: msg.content }],
})
} else if (msg.role === 'assistant') {
// 处理 assistant 消息(可能包含 tool_calls
if (msg.tool_calls && msg.tool_calls.length > 0) {
// 有工具调用的 assistant 消息
const parts: GeminiPart[] = []
if (msg.content) {
parts.push({ text: msg.content })
}
for (const tc of msg.tool_calls) {
const part: GeminiPart = {
functionCall: {
name: tc.function.name,
args: JSON.parse(tc.function.arguments),
},
}
// Gemini 3+ 需要包含 thoughtSignature
if (tc.thoughtSignature) {
part.thoughtSignature = tc.thoughtSignature
}
parts.push(part)
}
contents.push({ role: 'model', parts })
} else {
// 普通文本消息
contents.push({
role: 'model',
parts: [{ text: msg.content }],
})
}
} else if (msg.role === 'tool') {
// 工具结果消息 - 在 Gemini 中作为 user 角色的 functionResponse
// 注意:需要从消息内容解析工具名称和结果
// tool_call_id 格式通常是 "call_xxx",我们需要从上下文获取工具名
// 这里简化处理:假设内容是 JSON 格式的结果
try {
const result = JSON.parse(msg.content)
// 尝试从上一条 assistant 消息中找到对应的 tool_call
// 由于 Gemini 需要 name,我们从 tool_call_id 推断或使用默认值
contents.push({
role: 'user',
parts: [
{
functionResponse: {
name: msg.tool_call_id?.replace('call_', '') || 'unknown',
response: result,
},
},
],
})
} catch {
// 如果不是 JSON,直接作为文本结果
contents.push({
role: 'user',
parts: [
{
functionResponse: {
name: msg.tool_call_id?.replace('call_', '') || 'unknown',
response: { result: msg.content },
},
},
],
})
}
}
}
return { contents, systemInstruction }
}
/**
* 构建 API URL
*/
private buildUrl(stream: boolean): string {
const action = stream ? 'streamGenerateContent' : 'generateContent'
const base = this.baseUrl.replace(/\/$/, '')
return `${base}/v1beta/models/${this.model}:${action}?key=${this.apiKey}`
}
/**
* 从 Gemini parts 中提取工具调用
*/
private extractToolCalls(parts?: GeminiPart[]): ToolCall[] | undefined {
if (!parts) return undefined
const toolCalls: ToolCall[] = []
for (const part of parts) {
if (part.functionCall) {
toolCalls.push({
id: `call_${part.functionCall.name}_${Date.now()}`,
type: 'function',
function: {
name: part.functionCall.name,
arguments: JSON.stringify(part.functionCall.args),
},
// 保存 Gemini 3+ 的思考签名
thoughtSignature: part.thoughtSignature,
})
}
}
return toolCalls.length > 0 ? toolCalls : undefined
}
/**
* 从 Gemini parts 中提取文本内容
*/
private extractText(parts?: GeminiPart[]): string {
if (!parts) return ''
return parts
.filter((p) => p.text)
.map((p) => p.text)
.join('')
}
async chat(messages: ChatMessage[], options?: ChatOptions): Promise<ChatResponse> {
const { contents, systemInstruction } = this.convertMessages(messages)
const requestBody: GeminiRequest = {
contents,
generationConfig: {
temperature: options?.temperature ?? 0.7,
maxOutputTokens: options?.maxTokens ?? 2048,
},
}
if (systemInstruction) {
requestBody.systemInstruction = systemInstruction
}
// 添加工具定义
const geminiTools = this.convertTools(options?.tools)
if (geminiTools) {
requestBody.tools = geminiTools
}
const response = await fetch(this.buildUrl(false), {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(requestBody),
signal: options?.abortSignal,
})
if (!response.ok) {
const error = await response.text()
throw new Error(`Gemini API error: ${response.status} - ${error}`)
}
const data: GeminiResponse = await response.json()
const candidate = data.candidates?.[0]
const parts = candidate?.content?.parts
const content = this.extractText(parts)
const toolCalls = this.extractToolCalls(parts)
// 解析 finish_reason
let finishReason: ChatResponse['finishReason'] = 'error'
const reason = candidate?.finishReason
if (reason === 'STOP') {
finishReason = toolCalls ? 'tool_calls' : 'stop'
} else if (reason === 'MAX_TOKENS') {
finishReason = 'length'
}
return {
content,
finishReason,
tool_calls: toolCalls,
usage: data.usageMetadata
? {
promptTokens: data.usageMetadata.promptTokenCount || 0,
completionTokens: data.usageMetadata.candidatesTokenCount || 0,
totalTokens: data.usageMetadata.totalTokenCount || 0,
}
: undefined,
}
}
async *chatStream(messages: ChatMessage[], options?: ChatOptions): AsyncGenerator<ChatStreamChunk> {
const { contents, systemInstruction } = this.convertMessages(messages)
const requestBody: GeminiRequest = {
contents,
generationConfig: {
temperature: options?.temperature ?? 0.7,
maxOutputTokens: options?.maxTokens ?? 2048,
},
}
if (systemInstruction) {
requestBody.systemInstruction = systemInstruction
}
// 添加工具定义
const geminiTools = this.convertTools(options?.tools)
if (geminiTools) {
requestBody.tools = geminiTools
}
// Gemini 流式需要添加 alt=sse 参数
const url = this.buildUrl(true) + '&alt=sse'
aiLogger.info('Gemini', '开始流式请求', {
url: url.replace(/key=[^&]+/, 'key=***'),
model: this.model,
messagesCount: contents.length,
hasSystemInstruction: !!systemInstruction,
hasTools: !!geminiTools,
})
const response = await fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(requestBody),
signal: options?.abortSignal,
})
if (!response.ok) {
const error = await response.text()
aiLogger.error('Gemini', 'API 请求失败', { status: response.status, error: error.slice(0, 500) })
throw new Error(`Gemini API error: ${response.status} - ${error}`)
}
aiLogger.info('Gemini', 'API 响应成功,开始读取流')
const reader = response.body?.getReader()
if (!reader) {
throw new Error('Failed to get response reader')
}
const decoder = new TextDecoder()
let buffer = ''
// 用于累积工具调用(可能跨多个 chunk)
const toolCallsAccumulator: ToolCall[] = []
try {
while (true) {
// 检查是否已中止
if (options?.abortSignal?.aborted) {
yield { content: '', isFinished: true, finishReason: 'stop' }
return
}
const { done, value } = await reader.read()
if (done) break
buffer += decoder.decode(value, { stream: true })
const lines = buffer.split('\n')
buffer = lines.pop() || ''
for (const line of lines) {
const trimmed = line.trim()
if (!trimmed || !trimmed.startsWith('data: ')) continue
const data = trimmed.slice(6)
if (data === '[DONE]') {
if (toolCallsAccumulator.length > 0) {
yield { content: '', isFinished: true, finishReason: 'tool_calls', tool_calls: toolCallsAccumulator }
} else {
yield { content: '', isFinished: true, finishReason: 'stop' }
}
return
}
try {
const parsed: GeminiResponse = JSON.parse(data)
const candidate = parsed.candidates?.[0]
const parts = candidate?.content?.parts
// 处理文本内容
const text = this.extractText(parts)
if (text) {
yield { content: text, isFinished: false }
}
// 处理工具调用
const toolCalls = this.extractToolCalls(parts)
if (toolCalls) {
aiLogger.info('Gemini', '检测到工具调用', { toolCalls: toolCalls.map((tc) => tc.function.name) })
toolCallsAccumulator.push(...toolCalls)
}
// 检查是否完成
const finishReason = candidate?.finishReason
if (finishReason) {
aiLogger.info('Gemini', '流式响应完成', { finishReason, toolCallsCount: toolCallsAccumulator.length })
if (toolCallsAccumulator.length > 0) {
yield { content: '', isFinished: true, finishReason: 'tool_calls', tool_calls: toolCallsAccumulator }
} else {
let reason: ChatStreamChunk['finishReason'] = 'stop'
if (finishReason === 'MAX_TOKENS') {
reason = 'length'
}
yield { content: '', isFinished: true, finishReason: reason }
}
return
}
} catch (e) {
// 记录解析错误
aiLogger.warn('Gemini', 'SSE 数据解析失败', { data: data.slice(0, 200), error: String(e) })
}
}
}
// 如果循环正常结束,发送完成信号
if (toolCallsAccumulator.length > 0) {
yield { content: '', isFinished: true, finishReason: 'tool_calls', tool_calls: toolCallsAccumulator }
} else {
yield { content: '', isFinished: true, finishReason: 'stop' }
}
} catch (error) {
// 如果是中止错误,正常返回
if (error instanceof Error && error.name === 'AbortError') {
yield { content: '', isFinished: true, finishReason: 'stop' }
return
}
throw error
} finally {
reader.releaseLock()
}
}
async validateApiKey(): Promise<{ success: boolean; error?: string }> {
try {
// 使用 models.list API 验证 API Key
const url = `${this.baseUrl.replace(/\/$/, '')}/v1beta/models?key=${this.apiKey}`
const response = await fetch(url, {
method: 'GET',
})
if (response.ok) {
return { success: true }
}
// 尝试获取错误详情
const errorText = await response.text()
let errorMessage = `HTTP ${response.status}`
try {
const errorJson = JSON.parse(errorText)
errorMessage = errorJson.error?.message || errorJson.message || errorMessage
} catch {
if (errorText) {
errorMessage = errorText.slice(0, 200)
}
}
return { success: false, error: errorMessage }
} catch (error) {
const errorMessage = error instanceof Error ? error.message : String(error)
return { success: false, error: errorMessage }
}
}
}
+4
View File
@@ -21,6 +21,7 @@ import type {
import { MAX_CONFIG_COUNT } from './types'
import { DeepSeekService, DEEPSEEK_INFO } from './deepseek'
import { QwenService, QWEN_INFO } from './qwen'
import { GeminiService, GEMINI_INFO } from './gemini'
import { OpenAICompatibleService, OPENAI_COMPATIBLE_INFO } from './openai-compatible'
import { aiLogger } from '../logger'
@@ -72,6 +73,7 @@ const KIMI_INFO: ProviderInfo = {
export const PROVIDERS: ProviderInfo[] = [
DEEPSEEK_INFO,
QWEN_INFO,
GEMINI_INFO,
MINIMAX_INFO,
GLM_INFO,
KIMI_INFO,
@@ -325,6 +327,8 @@ export function createLLMService(config: ExtendedLLMConfig): ILLMService {
return new DeepSeekService(config.apiKey, config.model, config.baseUrl)
case 'qwen':
return new QwenService(config.apiKey, config.model, config.baseUrl)
case 'gemini':
return new GeminiService(config.apiKey, config.model, config.baseUrl)
// 新增的官方API都使用 OpenAI 兼容格式
case 'minimax':
case 'glm':
+3 -1
View File
@@ -5,7 +5,7 @@
/**
* 支持的 LLM 提供商
*/
export type LLMProvider = 'deepseek' | 'qwen' | 'minimax' | 'glm' | 'kimi' | 'openai-compatible'
export type LLMProvider = 'deepseek' | 'qwen' | 'minimax' | 'glm' | 'kimi' | 'gemini' | 'openai-compatible'
/**
* LLM 配置
@@ -105,6 +105,8 @@ export interface ToolCall {
name: string
arguments: string // JSON 字符串
}
/** Gemini 3+ 模型需要的思考签名(用于工具调用验证) */
thoughtSignature?: string
}
/**
+3
View File
@@ -42,6 +42,9 @@ export function getAvailableYears(sessionId: string): number[] {
* 获取成员活跃度排行
*/
export function getMemberActivity(sessionId: string, filter?: TimeFilter): any[] {
// 先确保数据库有 avatar 字段(兼容旧数据库)
ensureAvatarColumn(sessionId)
const db = openDatabase(sessionId)
if (!db) return []