| | const { google } = require('googleapis'); |
| | const { sleep } = require('@librechat/agents'); |
| | const { logger } = require('@librechat/data-schemas'); |
| | const { getModelMaxTokens } = require('@librechat/api'); |
| | const { concat } = require('@langchain/core/utils/stream'); |
| | const { ChatVertexAI } = require('@langchain/google-vertexai'); |
| | const { Tokenizer, getSafetySettings } = require('@librechat/api'); |
| | const { ChatGoogleGenerativeAI } = require('@langchain/google-genai'); |
| | const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai'); |
| | const { HumanMessage, SystemMessage } = require('@langchain/core/messages'); |
| | const { |
| | googleGenConfigSchema, |
| | validateVisionModel, |
| | getResponseSender, |
| | endpointSettings, |
| | parseTextParts, |
| | EModelEndpoint, |
| | googleSettings, |
| | ContentTypes, |
| | VisionModes, |
| | ErrorTypes, |
| | Constants, |
| | AuthKeys, |
| | } = require('librechat-data-provider'); |
| | const { encodeAndFormat } = require('~/server/services/Files/images'); |
| | const { spendTokens } = require('~/models/spendTokens'); |
| | const { |
| | formatMessage, |
| | createContextHandlers, |
| | titleInstruction, |
| | truncateText, |
| | } = require('./prompts'); |
| | const BaseClient = require('./BaseClient'); |
| |
|
| | const loc = process.env.GOOGLE_LOC || 'us-central1'; |
| | const publisher = 'google'; |
| | const endpointPrefix = |
| | loc === 'global' ? 'aiplatform.googleapis.com' : `${loc}-aiplatform.googleapis.com`; |
| |
|
| | const settings = endpointSettings[EModelEndpoint.google]; |
| | const EXCLUDED_GENAI_MODELS = /gemini-(?:1\.0|1-0|pro)/; |
| |
|
| | class GoogleClient extends BaseClient { |
| | constructor(credentials, options = {}) { |
| | super('apiKey', options); |
| | let creds = {}; |
| |
|
| | if (typeof credentials === 'string') { |
| | creds = JSON.parse(credentials); |
| | } else if (credentials) { |
| | creds = credentials; |
| | } |
| |
|
| | const serviceKey = creds[AuthKeys.GOOGLE_SERVICE_KEY] ?? {}; |
| | this.serviceKey = |
| | serviceKey && typeof serviceKey === 'string' ? JSON.parse(serviceKey) : (serviceKey ?? {}); |
| | |
| | this.project_id = this.serviceKey.project_id; |
| | this.client_email = this.serviceKey.client_email; |
| | this.private_key = this.serviceKey.private_key; |
| | this.access_token = null; |
| |
|
| | this.apiKey = creds[AuthKeys.GOOGLE_API_KEY]; |
| |
|
| | this.reverseProxyUrl = options.reverseProxyUrl; |
| |
|
| | this.authHeader = options.authHeader; |
| |
|
| | |
| | this.usage; |
| | |
| | |
| | this.inputTokensKey = 'input_tokens'; |
| | |
| | |
| | this.outputTokensKey = 'output_tokens'; |
| | this.visionMode = VisionModes.generative; |
| | |
| | this.systemMessage; |
| | if (options.skipSetOptions) { |
| | return; |
| | } |
| | this.setOptions(options); |
| | } |
| |
|
| | |
| | constructUrl() { |
| | return `https://${endpointPrefix}/v1/projects/${this.project_id}/locations/${loc}/publishers/${publisher}/models/${this.modelOptions.model}:serverStreamingPredict`; |
| | } |
| |
|
| | async getClient() { |
| | const scopes = ['https://www.googleapis.com/auth/cloud-platform']; |
| | const jwtClient = new google.auth.JWT(this.client_email, null, this.private_key, scopes); |
| |
|
| | jwtClient.authorize((err) => { |
| | if (err) { |
| | logger.error('jwtClient failed to authorize', err); |
| | throw err; |
| | } |
| | }); |
| |
|
| | return jwtClient; |
| | } |
| |
|
| | async getAccessToken() { |
| | const scopes = ['https://www.googleapis.com/auth/cloud-platform']; |
| | const jwtClient = new google.auth.JWT(this.client_email, null, this.private_key, scopes); |
| |
|
| | return new Promise((resolve, reject) => { |
| | jwtClient.authorize((err, tokens) => { |
| | if (err) { |
| | logger.error('jwtClient failed to authorize', err); |
| | reject(err); |
| | } else { |
| | resolve(tokens.access_token); |
| | } |
| | }); |
| | }); |
| | } |
| |
|
| | |
| | setOptions(options) { |
| | if (this.options && !this.options.replaceOptions) { |
| | |
| | this.options.modelOptions = { |
| | ...this.options.modelOptions, |
| | ...options.modelOptions, |
| | }; |
| | delete options.modelOptions; |
| | |
| | this.options = { |
| | ...this.options, |
| | ...options, |
| | }; |
| | } else { |
| | this.options = options; |
| | } |
| |
|
| | this.modelOptions = this.options.modelOptions || {}; |
| |
|
| | this.options.attachments?.then((attachments) => this.checkVisionRequest(attachments)); |
| |
|
| | |
| | this.isGenerativeModel = /gemini|learnlm|gemma/.test(this.modelOptions.model); |
| |
|
| | this.maxContextTokens = |
| | this.options.maxContextTokens ?? |
| | getModelMaxTokens(this.modelOptions.model, EModelEndpoint.google); |
| |
|
| | |
| | |
| | this.maxResponseTokens = this.modelOptions.maxOutputTokens || settings.maxOutputTokens.default; |
| |
|
| | if (this.maxContextTokens > 32000) { |
| | this.maxContextTokens = this.maxContextTokens - this.maxResponseTokens; |
| | } |
| |
|
| | this.maxPromptTokens = |
| | this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens; |
| |
|
| | if (this.maxPromptTokens + this.maxResponseTokens > this.maxContextTokens) { |
| | throw new Error( |
| | `maxPromptTokens + maxOutputTokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${ |
| | this.maxPromptTokens + this.maxResponseTokens |
| | }) must be less than or equal to maxContextTokens (${this.maxContextTokens})`, |
| | ); |
| | } |
| |
|
| | |
| | this.modelOptions.thinkingConfig = { |
| | thinkingBudget: |
| | (this.modelOptions.thinking ?? googleSettings.thinking.default) |
| | ? this.modelOptions.thinkingBudget |
| | : 0, |
| | }; |
| | delete this.modelOptions.thinking; |
| | delete this.modelOptions.thinkingBudget; |
| |
|
| | this.sender = |
| | this.options.sender ?? |
| | getResponseSender({ |
| | model: this.modelOptions.model, |
| | endpoint: EModelEndpoint.google, |
| | modelLabel: this.options.modelLabel, |
| | }); |
| |
|
| | this.userLabel = this.options.userLabel || 'User'; |
| | this.modelLabel = this.options.modelLabel || 'Assistant'; |
| |
|
| | if (this.options.reverseProxyUrl) { |
| | this.completionsUrl = this.options.reverseProxyUrl; |
| | } else { |
| | this.completionsUrl = this.constructUrl(); |
| | } |
| |
|
| | let promptPrefix = (this.options.promptPrefix ?? '').trim(); |
| | if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) { |
| | promptPrefix = `${promptPrefix ?? ''}\n${this.options.artifactsPrompt}`.trim(); |
| | } |
| | this.systemMessage = promptPrefix; |
| | this.initializeClient(); |
| | return this; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | checkVisionRequest(attachments) { |
| | |
| | this.defaultVisionModel = |
| | this.options.visionModel ?? |
| | (!EXCLUDED_GENAI_MODELS.test(this.modelOptions.model) |
| | ? this.modelOptions.model |
| | : 'gemini-pro-vision'); |
| | const availableModels = this.options.modelsConfig?.[EModelEndpoint.google]; |
| | this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels }); |
| |
|
| | if ( |
| | attachments && |
| | attachments.some((file) => file?.type && file?.type?.includes('image')) && |
| | availableModels?.includes(this.defaultVisionModel) && |
| | !this.isVisionModel |
| | ) { |
| | this.modelOptions.model = this.defaultVisionModel; |
| | this.isVisionModel = true; |
| | } |
| |
|
| | if (this.isVisionModel && !attachments && this.modelOptions.model.includes('gemini-pro')) { |
| | this.modelOptions.model = 'gemini-pro'; |
| | this.isVisionModel = false; |
| | } |
| | } |
| |
|
| | formatMessages() { |
| | return ((message) => { |
| | const msg = { |
| | author: message?.author ?? (message.isCreatedByUser ? this.userLabel : this.modelLabel), |
| | content: message?.content ?? message.text, |
| | }; |
| |
|
| | if (!message.image_urls?.length) { |
| | return msg; |
| | } |
| |
|
| | msg.content = ( |
| | !Array.isArray(msg.content) |
| | ? [ |
| | { |
| | type: ContentTypes.TEXT, |
| | [ContentTypes.TEXT]: msg.content, |
| | }, |
| | ] |
| | : msg.content |
| | ).concat(message.image_urls); |
| |
|
| | return msg; |
| | }).bind(this); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | async formatGenerativeMessages(messages) { |
| | const formattedMessages = []; |
| | const attachments = await this.options.attachments; |
| | const latestMessage = { ...messages[messages.length - 1] }; |
| | const files = await this.addImageURLs(latestMessage, attachments, VisionModes.generative); |
| | this.options.attachments = files; |
| | messages[messages.length - 1] = latestMessage; |
| |
|
| | for (const _message of messages) { |
| | const role = _message.isCreatedByUser ? this.userLabel : this.modelLabel; |
| | const parts = []; |
| | parts.push({ text: _message.text }); |
| | if (!_message.image_urls?.length) { |
| | formattedMessages.push({ role, parts }); |
| | continue; |
| | } |
| |
|
| | for (const images of _message.image_urls) { |
| | if (images.inlineData) { |
| | parts.push({ inlineData: images.inlineData }); |
| | } |
| | } |
| |
|
| | formattedMessages.push({ role, parts }); |
| | } |
| |
|
| | return formattedMessages; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | async addImageURLs(message, attachments, mode = '') { |
| | const { files, image_urls } = await encodeAndFormat( |
| | this.options.req, |
| | attachments, |
| | { |
| | endpoint: EModelEndpoint.google, |
| | }, |
| | mode, |
| | ); |
| | message.image_urls = image_urls.length ? image_urls : undefined; |
| | return files; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | async buildAugmentedPrompt(messages = []) { |
| | const attachments = await this.options.attachments; |
| | const latestMessage = { ...messages[messages.length - 1] }; |
| | this.contextHandlers = createContextHandlers(this.options.req, latestMessage.text); |
| |
|
| | if (this.contextHandlers) { |
| | for (const file of attachments) { |
| | if (file.embedded) { |
| | this.contextHandlers?.processFile(file); |
| | continue; |
| | } |
| | if (file.metadata?.fileIdentifier) { |
| | continue; |
| | } |
| | } |
| |
|
| | this.augmentedPrompt = await this.contextHandlers.createContext(); |
| | this.systemMessage = this.augmentedPrompt + this.systemMessage; |
| | } |
| | } |
| |
|
| | async buildVisionMessages(messages = [], parentMessageId) { |
| | const attachments = await this.options.attachments; |
| | const latestMessage = { ...messages[messages.length - 1] }; |
| | await this.buildAugmentedPrompt(messages); |
| |
|
| | const { prompt } = await this.buildMessagesPrompt(messages, parentMessageId); |
| |
|
| | const files = await this.addImageURLs(latestMessage, attachments); |
| |
|
| | this.options.attachments = files; |
| |
|
| | latestMessage.text = prompt; |
| |
|
| | const payload = { |
| | instances: [ |
| | { |
| | messages: [new HumanMessage(formatMessage({ message: latestMessage }))], |
| | }, |
| | ], |
| | }; |
| | return { prompt: payload }; |
| | } |
| |
|
| | |
| | async buildGenerativeMessages(messages = []) { |
| | this.userLabel = 'user'; |
| | this.modelLabel = 'model'; |
| | const promises = []; |
| | promises.push(await this.formatGenerativeMessages(messages)); |
| | promises.push(this.buildAugmentedPrompt(messages)); |
| | const [formattedMessages] = await Promise.all(promises); |
| | return { prompt: formattedMessages }; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | async buildMessages(_messages = [], parentMessageId) { |
| | if (!this.isGenerativeModel && !this.project_id) { |
| | throw new Error('[GoogleClient] PaLM 2 and Codey models are no longer supported.'); |
| | } |
| |
|
| | if (this.systemMessage) { |
| | const instructionsTokenCount = this.getTokenCount(this.systemMessage); |
| |
|
| | this.maxContextTokens = this.maxContextTokens - instructionsTokenCount; |
| | if (this.maxContextTokens < 0) { |
| | const info = `${instructionsTokenCount} / ${this.maxContextTokens}`; |
| | const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`; |
| | logger.warn(`Instructions token count exceeds max context (${info}).`); |
| | throw new Error(errorMessage); |
| | } |
| | } |
| |
|
| | for (let i = 0; i < _messages.length; i++) { |
| | const message = _messages[i]; |
| | if (!message.tokenCount) { |
| | _messages[i].tokenCount = this.getTokenCountForMessage({ |
| | role: message.isCreatedByUser ? 'user' : 'assistant', |
| | content: message.content ?? message.text, |
| | }); |
| | } |
| | } |
| |
|
| | const { |
| | payload: messages, |
| | tokenCountMap, |
| | promptTokens, |
| | } = await this.handleContextStrategy({ |
| | orderedMessages: _messages, |
| | formattedMessages: _messages, |
| | }); |
| |
|
| | if (!this.project_id && !EXCLUDED_GENAI_MODELS.test(this.modelOptions.model)) { |
| | const result = await this.buildGenerativeMessages(messages); |
| | result.tokenCountMap = tokenCountMap; |
| | result.promptTokens = promptTokens; |
| | return result; |
| | } |
| |
|
| | if (this.options.attachments && this.isGenerativeModel) { |
| | const result = this.buildVisionMessages(messages, parentMessageId); |
| | result.tokenCountMap = tokenCountMap; |
| | result.promptTokens = promptTokens; |
| | return result; |
| | } |
| |
|
| | let payload = { |
| | instances: [ |
| | { |
| | messages: messages |
| | .map(this.formatMessages()) |
| | .map((msg) => ({ ...msg, role: msg.author === 'User' ? 'user' : 'assistant' })) |
| | .map((message) => formatMessage({ message, langChain: true })), |
| | }, |
| | ], |
| | }; |
| |
|
| | if (this.systemMessage) { |
| | payload.instances[0].context = this.systemMessage; |
| | } |
| |
|
| | logger.debug('[GoogleClient] buildMessages', payload); |
| | return { prompt: payload, tokenCountMap, promptTokens }; |
| | } |
| |
|
| | async buildMessagesPrompt(messages, parentMessageId) { |
| | const orderedMessages = this.constructor.getMessagesForConversation({ |
| | messages, |
| | parentMessageId, |
| | }); |
| |
|
| | logger.debug('[GoogleClient]', { |
| | orderedMessages, |
| | parentMessageId, |
| | }); |
| |
|
| | const formattedMessages = orderedMessages.map(this.formatMessages()); |
| |
|
| | let lastAuthor = ''; |
| | let groupedMessages = []; |
| |
|
| | for (let message of formattedMessages) { |
| | |
| | if (lastAuthor !== message.author) { |
| | groupedMessages.push({ |
| | author: message.author, |
| | content: [message.content], |
| | }); |
| | lastAuthor = message.author; |
| | |
| | } else { |
| | groupedMessages[groupedMessages.length - 1].content.push(message.content); |
| | } |
| | } |
| |
|
| | let identityPrefix = ''; |
| | if (this.options.userLabel) { |
| | identityPrefix = `\nHuman's name: ${this.options.userLabel}`; |
| | } |
| |
|
| | if (this.options.modelLabel) { |
| | identityPrefix = `${identityPrefix}\nYou are ${this.options.modelLabel}`; |
| | } |
| |
|
| | let promptPrefix = (this.systemMessage ?? '').trim(); |
| |
|
| | if (identityPrefix) { |
| | promptPrefix = `${identityPrefix}${promptPrefix}`; |
| | } |
| |
|
| | |
| | let isEdited = lastAuthor === this.modelLabel; |
| | const promptSuffix = isEdited ? '' : `${promptPrefix}\n\n${this.modelLabel}:\n`; |
| | let currentTokenCount = isEdited |
| | ? this.getTokenCount(promptPrefix) |
| | : this.getTokenCount(promptSuffix); |
| |
|
| | let promptBody = ''; |
| | const maxTokenCount = this.maxPromptTokens; |
| |
|
| | const context = []; |
| |
|
| | |
| | |
| | |
| | |
| | const nextMessage = { |
| | remove: false, |
| | tokenCount: 0, |
| | messageString: '', |
| | }; |
| |
|
| | const buildPromptBody = async () => { |
| | if (currentTokenCount < maxTokenCount && groupedMessages.length > 0) { |
| | const message = groupedMessages.pop(); |
| | const isCreatedByUser = message.author === this.userLabel; |
| | |
| | const messagePrefix = |
| | isCreatedByUser || !isEdited |
| | ? `\n\n${message.author}:` |
| | : `${promptPrefix}\n\n${message.author}:`; |
| | const messageString = `${messagePrefix}\n${message.content}\n`; |
| | let newPromptBody = `${messageString}${promptBody}`; |
| |
|
| | context.unshift(message); |
| |
|
| | const tokenCountForMessage = this.getTokenCount(messageString); |
| | const newTokenCount = currentTokenCount + tokenCountForMessage; |
| |
|
| | if (!isCreatedByUser) { |
| | nextMessage.messageString = messageString; |
| | nextMessage.tokenCount = tokenCountForMessage; |
| | } |
| |
|
| | if (newTokenCount > maxTokenCount) { |
| | if (!promptBody) { |
| | |
| | throw new Error( |
| | `Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`, |
| | ); |
| | } |
| |
|
| | |
| | |
| | if (isCreatedByUser) { |
| | nextMessage.remove = true; |
| | } |
| |
|
| | return false; |
| | } |
| | promptBody = newPromptBody; |
| | currentTokenCount = newTokenCount; |
| |
|
| | |
| | if (isEdited) { |
| | isEdited = false; |
| | } |
| |
|
| | |
| | await new Promise((resolve) => setImmediate(resolve)); |
| | return buildPromptBody(); |
| | } |
| | return true; |
| | }; |
| |
|
| | await buildPromptBody(); |
| |
|
| | if (nextMessage.remove) { |
| | promptBody = promptBody.replace(nextMessage.messageString, ''); |
| | currentTokenCount -= nextMessage.tokenCount; |
| | context.shift(); |
| | } |
| |
|
| | let prompt = `${promptBody}${promptSuffix}`.trim(); |
| |
|
| | |
| | currentTokenCount += 2; |
| |
|
| | |
| | this.modelOptions.maxOutputTokens = Math.min( |
| | this.maxContextTokens - currentTokenCount, |
| | this.maxResponseTokens, |
| | ); |
| |
|
| | return { prompt, context }; |
| | } |
| |
|
| | createLLM(clientOptions) { |
| | const model = clientOptions.modelName ?? clientOptions.model; |
| | clientOptions.location = loc; |
| | clientOptions.endpoint = endpointPrefix; |
| |
|
| | let requestOptions = null; |
| | if (this.reverseProxyUrl) { |
| | requestOptions = { |
| | baseUrl: this.reverseProxyUrl, |
| | }; |
| |
|
| | if (this.authHeader) { |
| | requestOptions.customHeaders = { |
| | Authorization: `Bearer ${this.apiKey}`, |
| | }; |
| | } |
| | } |
| |
|
| | if (this.project_id != null) { |
| | logger.debug('Creating VertexAI client'); |
| | this.visionMode = undefined; |
| | clientOptions.streaming = true; |
| | const client = new ChatVertexAI(clientOptions); |
| | client.temperature = clientOptions.temperature; |
| | client.topP = clientOptions.topP; |
| | client.topK = clientOptions.topK; |
| | client.topLogprobs = clientOptions.topLogprobs; |
| | client.frequencyPenalty = clientOptions.frequencyPenalty; |
| | client.presencePenalty = clientOptions.presencePenalty; |
| | client.maxOutputTokens = clientOptions.maxOutputTokens; |
| | return client; |
| | } else if (!EXCLUDED_GENAI_MODELS.test(model)) { |
| | logger.debug('Creating GenAI client'); |
| | return new GenAI(this.apiKey).getGenerativeModel({ model }, requestOptions); |
| | } |
| |
|
| | logger.debug('Creating Chat Google Generative AI client'); |
| | return new ChatGoogleGenerativeAI({ ...clientOptions, apiKey: this.apiKey }); |
| | } |
| |
|
| | initializeClient() { |
| | let clientOptions = { ...this.modelOptions }; |
| |
|
| | if (this.project_id) { |
| | clientOptions['authOptions'] = { |
| | credentials: { |
| | ...this.serviceKey, |
| | }, |
| | projectId: this.project_id, |
| | }; |
| | } |
| |
|
| | if (this.isGenerativeModel && !this.project_id) { |
| | clientOptions.modelName = clientOptions.model; |
| | delete clientOptions.model; |
| | } |
| |
|
| | this.client = this.createLLM(clientOptions); |
| | return this.client; |
| | } |
| |
|
| | async getCompletion(_payload, options = {}) { |
| | const { onProgress, abortController } = options; |
| | const safetySettings = getSafetySettings(this.modelOptions.model); |
| | const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE; |
| | const modelName = this.modelOptions.modelName ?? this.modelOptions.model ?? ''; |
| |
|
| | let reply = ''; |
| | |
| | let error; |
| | try { |
| | if (!EXCLUDED_GENAI_MODELS.test(modelName) && !this.project_id) { |
| | |
| | const client = this.client; |
| | |
| | const requestOptions = { |
| | safetySettings, |
| | contents: _payload, |
| | generationConfig: googleGenConfigSchema.parse(this.modelOptions), |
| | }; |
| |
|
| | const promptPrefix = (this.systemMessage ?? '').trim(); |
| | if (promptPrefix.length) { |
| | requestOptions.systemInstruction = { |
| | parts: [ |
| | { |
| | text: promptPrefix, |
| | }, |
| | ], |
| | }; |
| | } |
| |
|
| | const delay = modelName.includes('flash') ? 8 : 15; |
| | |
| | let usageMetadata; |
| |
|
| | abortController.signal.addEventListener( |
| | 'abort', |
| | () => { |
| | logger.warn('[GoogleClient] Request was aborted', abortController.signal.reason); |
| | }, |
| | { once: true }, |
| | ); |
| |
|
| | const result = await client.generateContentStream(requestOptions, { |
| | signal: abortController.signal, |
| | }); |
| | for await (const chunk of result.stream) { |
| | usageMetadata = !usageMetadata |
| | ? chunk?.usageMetadata |
| | : Object.assign(usageMetadata, chunk?.usageMetadata); |
| | const chunkText = chunk.text(); |
| | await this.generateTextStream(chunkText, onProgress, { |
| | delay, |
| | }); |
| | reply += chunkText; |
| | await sleep(streamRate); |
| | } |
| |
|
| | if (usageMetadata) { |
| | this.usage = { |
| | input_tokens: usageMetadata.promptTokenCount, |
| | output_tokens: usageMetadata.candidatesTokenCount, |
| | }; |
| | } |
| |
|
| | return reply; |
| | } |
| |
|
| | const { instances } = _payload; |
| | const { messages: messages, context } = instances?.[0] ?? {}; |
| |
|
| | if (!this.isVisionModel && context && messages?.length > 0) { |
| | messages.unshift(new SystemMessage(context)); |
| | } |
| |
|
| | |
| | let usageMetadata; |
| | |
| | const client = this.client; |
| | const stream = await client.stream(messages, { |
| | signal: abortController.signal, |
| | streamUsage: true, |
| | safetySettings, |
| | }); |
| |
|
| | let delay = this.options.streamRate || 8; |
| |
|
| | if (!this.options.streamRate) { |
| | if (this.isGenerativeModel) { |
| | delay = 15; |
| | } |
| | if (modelName.includes('flash')) { |
| | delay = 5; |
| | } |
| | } |
| |
|
| | for await (const chunk of stream) { |
| | if (chunk?.usage_metadata) { |
| | const metadata = chunk.usage_metadata; |
| | for (const key in metadata) { |
| | if (Number.isNaN(metadata[key])) { |
| | delete metadata[key]; |
| | } |
| | } |
| |
|
| | usageMetadata = !usageMetadata ? metadata : concat(usageMetadata, metadata); |
| | } |
| |
|
| | const chunkText = chunk?.content ?? ''; |
| | await this.generateTextStream(chunkText, onProgress, { |
| | delay, |
| | }); |
| | reply += chunkText; |
| | } |
| |
|
| | if (usageMetadata) { |
| | this.usage = usageMetadata; |
| | } |
| | } catch (e) { |
| | error = e; |
| | logger.error('[GoogleClient] There was an issue generating the completion', e); |
| | } |
| |
|
| | if (error != null && reply === '') { |
| | const errorMessage = `{ "type": "${ErrorTypes.GoogleError}", "info": "${ |
| | error.message ?? 'The Google provider failed to generate content, please contact the Admin.' |
| | }" }`; |
| | throw new Error(errorMessage); |
| | } |
| | return reply; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | getStreamUsage() { |
| | return this.usage; |
| | } |
| |
|
| | getMessageMapMethod() { |
| | |
| | |
| | |
| | return (msg) => { |
| | if (msg.text != null && msg.text && msg.text.startsWith(':::thinking')) { |
| | msg.text = msg.text.replace(/:::thinking.*?:::/gs, '').trim(); |
| | } else if (msg.content != null) { |
| | msg.text = parseTextParts(msg.content, true); |
| | delete msg.content; |
| | } |
| |
|
| | return msg; |
| | }; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | calculateCurrentTokenCount({ tokenCountMap, currentMessageId, usage }) { |
| | const originalEstimate = tokenCountMap[currentMessageId] || 0; |
| |
|
| | if (!usage || typeof usage.input_tokens !== 'number') { |
| | return originalEstimate; |
| | } |
| |
|
| | tokenCountMap[currentMessageId] = 0; |
| | const totalTokensFromMap = Object.values(tokenCountMap).reduce((sum, count) => { |
| | const numCount = Number(count); |
| | return sum + (isNaN(numCount) ? 0 : numCount); |
| | }, 0); |
| | const totalInputTokens = usage.input_tokens ?? 0; |
| | const currentMessageTokens = totalInputTokens - totalTokensFromMap; |
| | return currentMessageTokens > 0 ? currentMessageTokens : originalEstimate; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | async recordTokenUsage({ promptTokens, completionTokens, model, context = 'message' }) { |
| | await spendTokens( |
| | { |
| | context, |
| | user: this.user ?? this.options.req?.user?.id, |
| | conversationId: this.conversationId, |
| | model: model ?? this.modelOptions.model, |
| | endpointTokenConfig: this.options.endpointTokenConfig, |
| | }, |
| | { promptTokens, completionTokens }, |
| | ); |
| | } |
| |
|
| | |
| | |
| | |
| | async titleChatCompletion(_payload, options = {}) { |
| | let reply = ''; |
| | const { abortController } = options; |
| |
|
| | const model = |
| | this.options.titleModel ?? this.modelOptions.modelName ?? this.modelOptions.model ?? ''; |
| | const safetySettings = getSafetySettings(model); |
| | if (!EXCLUDED_GENAI_MODELS.test(model) && !this.project_id) { |
| | logger.debug('Identified titling model as GenAI version'); |
| | |
| | const client = this.client; |
| | const requestOptions = { |
| | contents: _payload, |
| | safetySettings, |
| | generationConfig: { |
| | temperature: 0.5, |
| | }, |
| | }; |
| |
|
| | const result = await client.generateContent(requestOptions); |
| | reply = result.response?.text(); |
| | return reply; |
| | } else { |
| | const { instances } = _payload; |
| | const { messages } = instances?.[0] ?? {}; |
| | const titleResponse = await this.client.invoke(messages, { |
| | signal: abortController.signal, |
| | timeout: 7000, |
| | safetySettings, |
| | }); |
| |
|
| | if (titleResponse.usage_metadata) { |
| | await this.recordTokenUsage({ |
| | model, |
| | promptTokens: titleResponse.usage_metadata.input_tokens, |
| | completionTokens: titleResponse.usage_metadata.output_tokens, |
| | context: 'title', |
| | }); |
| | } |
| |
|
| | reply = titleResponse.content; |
| | return reply; |
| | } |
| | } |
| |
|
| | async titleConvo({ text, responseText = '' }) { |
| | let title = 'New Chat'; |
| | const convo = `||>User: |
| | "${truncateText(text)}" |
| | ||>Response: |
| | "${JSON.stringify(truncateText(responseText))}"`; |
| |
|
| | let { prompt: payload } = await this.buildMessages([ |
| | { |
| | text: `Please generate ${titleInstruction} |
| | |
| | ${convo} |
| | |
| | ||>Title:`, |
| | isCreatedByUser: true, |
| | author: this.userLabel, |
| | }, |
| | ]); |
| |
|
| | try { |
| | this.initializeClient(); |
| | title = await this.titleChatCompletion(payload, { |
| | abortController: new AbortController(), |
| | onProgress: () => {}, |
| | }); |
| | } catch (e) { |
| | logger.error('[GoogleClient] There was an issue generating the title', e); |
| | } |
| | logger.debug(`Title response: ${title}`); |
| | return title; |
| | } |
| |
|
| | getSaveOptions() { |
| | return { |
| | endpointType: null, |
| | artifacts: this.options.artifacts, |
| | promptPrefix: this.options.promptPrefix, |
| | maxContextTokens: this.options.maxContextTokens, |
| | modelLabel: this.options.modelLabel, |
| | iconURL: this.options.iconURL, |
| | greeting: this.options.greeting, |
| | spec: this.options.spec, |
| | ...this.modelOptions, |
| | }; |
| | } |
| |
|
| | getBuildMessagesOptions() { |
| | |
| | } |
| |
|
| | async sendCompletion(payload, opts = {}) { |
| | let reply = ''; |
| | reply = await this.getCompletion(payload, opts); |
| | return reply.trim(); |
| | } |
| |
|
| | getEncoding() { |
| | return 'cl100k_base'; |
| | } |
| |
|
| | async getVertexTokenCount(text) { |
| | |
| | const client = this.client ?? this.initializeClient(); |
| | const connection = client.connection; |
| | const gAuthClient = connection.client; |
| | const tokenEndpoint = `https://${connection._endpoint}/${connection.apiVersion}/projects/${this.project_id}/locations/${connection._location}/publishers/google/models/${connection.model}/:countTokens`; |
| | const result = await gAuthClient.request({ |
| | url: tokenEndpoint, |
| | method: 'POST', |
| | data: { |
| | contents: [{ role: 'user', parts: [{ text }] }], |
| | }, |
| | }); |
| | return result; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | getTokenCount(text) { |
| | const encoding = this.getEncoding(); |
| | return Tokenizer.getTokenCount(text, encoding); |
| | } |
| | } |
| |
|
| | module.exports = GoogleClient; |
| |
|