| | const { sleep } = require('@librechat/agents'); |
| | const { sendEvent } = require('@librechat/api'); |
| | const { logger } = require('@librechat/data-schemas'); |
| | const { |
| | Constants, |
| | StepTypes, |
| | ContentTypes, |
| | ToolCallTypes, |
| | MessageContentTypes, |
| | AssistantStreamEvents, |
| | } = require('librechat-data-provider'); |
| | const { retrieveAndProcessFile } = require('~/server/services/Files/process'); |
| | const { processRequiredActions } = require('~/server/services/ToolService'); |
| | const { processMessages } = require('~/server/services/Threads'); |
| | const { createOnProgress } = require('~/server/utils'); |
| |
|
| | |
| | |
| | |
| | |
| | |
| | class StreamRunManager { |
| | constructor(fields) { |
| | this.index = 0; |
| | |
| | this.steps = new Map(); |
| |
|
| | |
| | this.mappedOrder = new Map(); |
| | |
| | this.orderedRunSteps = new Map(); |
| | |
| | this.processedFileIds = new Set(); |
| | |
| | this.progressCallbacks = new Map(); |
| | |
| | this.run = null; |
| |
|
| | |
| | this.req = fields.req; |
| | |
| | this.res = fields.res; |
| | |
| | this.openai = fields.openai; |
| | |
| | this.apiKey = this.openai.apiKey; |
| | |
| | this.parentMessageId = fields.parentMessageId; |
| | |
| | this.thread_id = fields.thread_id; |
| | |
| | this.initialRunBody = fields.runBody; |
| | |
| | |
| | |
| | this.clientHandlers = fields.handlers ?? {}; |
| | |
| | this.streamOptions = fields.streamOptions ?? {}; |
| | |
| | this.finalMessage = fields.responseMessage ?? {}; |
| | |
| | this.messages = []; |
| | |
| | this.text = ''; |
| | |
| | this.intermediateText = ''; |
| | |
| | this.attachedFileIds = fields.attachedFileIds; |
| | |
| | this.visionPromise = fields.visionPromise; |
| | |
| | this.streamRate = fields.streamRate ?? Constants.DEFAULT_STREAM_RATE; |
| |
|
| | |
| | |
| | |
| | this.handlers = { |
| | [AssistantStreamEvents.ThreadCreated]: this.handleThreadCreated, |
| | [AssistantStreamEvents.ThreadRunCreated]: this.handleRunEvent, |
| | [AssistantStreamEvents.ThreadRunQueued]: this.handleRunEvent, |
| | [AssistantStreamEvents.ThreadRunInProgress]: this.handleRunEvent, |
| | [AssistantStreamEvents.ThreadRunRequiresAction]: this.handleRunEvent, |
| | [AssistantStreamEvents.ThreadRunCompleted]: this.handleRunEvent, |
| | [AssistantStreamEvents.ThreadRunFailed]: this.handleRunEvent, |
| | [AssistantStreamEvents.ThreadRunCancelling]: this.handleRunEvent, |
| | [AssistantStreamEvents.ThreadRunCancelled]: this.handleRunEvent, |
| | [AssistantStreamEvents.ThreadRunExpired]: this.handleRunEvent, |
| | [AssistantStreamEvents.ThreadRunStepCreated]: this.handleRunStepEvent, |
| | [AssistantStreamEvents.ThreadRunStepInProgress]: this.handleRunStepEvent, |
| | [AssistantStreamEvents.ThreadRunStepCompleted]: this.handleRunStepEvent, |
| | [AssistantStreamEvents.ThreadRunStepFailed]: this.handleRunStepEvent, |
| | [AssistantStreamEvents.ThreadRunStepCancelled]: this.handleRunStepEvent, |
| | [AssistantStreamEvents.ThreadRunStepExpired]: this.handleRunStepEvent, |
| | [AssistantStreamEvents.ThreadRunStepDelta]: this.handleRunStepDeltaEvent, |
| | [AssistantStreamEvents.ThreadMessageCreated]: this.handleMessageEvent, |
| | [AssistantStreamEvents.ThreadMessageInProgress]: this.handleMessageEvent, |
| | [AssistantStreamEvents.ThreadMessageCompleted]: this.handleMessageEvent, |
| | [AssistantStreamEvents.ThreadMessageIncomplete]: this.handleMessageEvent, |
| | [AssistantStreamEvents.ThreadMessageDelta]: this.handleMessageDeltaEvent, |
| | [AssistantStreamEvents.ErrorEvent]: this.handleErrorEvent, |
| | }; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | async addContentData(data) { |
| | const { type, index, edited } = data; |
| | |
| | const contentPart = data[type]; |
| | this.finalMessage.content[index] = { type, [type]: contentPart }; |
| |
|
| | if (type === ContentTypes.TEXT && !edited) { |
| | this.text += contentPart.value; |
| | return; |
| | } |
| |
|
| | const contentData = { |
| | index, |
| | type, |
| | [type]: contentPart, |
| | thread_id: this.thread_id, |
| | messageId: this.finalMessage.messageId, |
| | conversationId: this.finalMessage.conversationId, |
| | }; |
| |
|
| | sendEvent(this.res, contentData); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | getText() { |
| | return this.intermediateText; |
| | } |
| |
|
| | |
| | |
| | |
| | getIntermediateMessage() { |
| | return { |
| | conversationId: this.finalMessage.conversationId, |
| | messageId: this.finalMessage.messageId, |
| | parentMessageId: this.parentMessageId, |
| | model: this.req.body.assistant_id, |
| | endpoint: this.req.body.endpoint, |
| | isCreatedByUser: false, |
| | user: this.req.user.id, |
| | text: this.getText(), |
| | sender: 'Assistant', |
| | unfinished: true, |
| | error: false, |
| | }; |
| | } |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | async runAssistant({ thread_id, body }) { |
| | const streamRun = this.openai.beta.threads.runs.createAndStream( |
| | thread_id, |
| | body, |
| | this.streamOptions, |
| | ); |
| | for await (const event of streamRun) { |
| | await this.handleEvent(event); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | async handleEvent(event) { |
| | const handler = this.handlers[event.event]; |
| | const clientHandler = this.clientHandlers[event.event]; |
| |
|
| | if (clientHandler) { |
| | await clientHandler.call(this, event); |
| | } |
| |
|
| | if (handler) { |
| | await handler.call(this, event); |
| | } else { |
| | logger.warn(`Unhandled event type: ${event.event}`); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | async handleThreadCreated(event) { |
| | logger.debug('Thread created:', event.data); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | async handleRunEvent(event) { |
| | this.run = event.data; |
| | logger.debug('Run event:', this.run); |
| | if (event.event === AssistantStreamEvents.ThreadRunRequiresAction) { |
| | await this.onRunRequiresAction(event); |
| | } else if (event.event === AssistantStreamEvents.ThreadRunCompleted) { |
| | logger.debug('Run completed:', this.run); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | async handleRunStepEvent(event) { |
| | logger.debug('Run step event:', event.data); |
| |
|
| | const step = event.data; |
| | this.steps.set(step.id, step); |
| |
|
| | if (event.event === AssistantStreamEvents.ThreadRunStepCreated) { |
| | this.onRunStepCreated(event); |
| | } else if (event.event === AssistantStreamEvents.ThreadRunStepCompleted) { |
| | this.onRunStepCompleted(event); |
| | } |
| | } |
| |
|
| | |
| |
|
| | |
| | async handleCodeImageOutput(output) { |
| | if (this.processedFileIds.has(output.image?.file_id)) { |
| | return; |
| | } |
| |
|
| | const { file_id } = output.image; |
| | const file = await retrieveAndProcessFile({ |
| | openai: this.openai, |
| | client: this, |
| | file_id, |
| | basename: `${file_id}.png`, |
| | }); |
| |
|
| | const prelimImage = file; |
| |
|
| | |
| | const prelimImageKeys = Object.keys(prelimImage); |
| | const validImageFile = prelimImageKeys.every((key) => prelimImage[key]); |
| |
|
| | if (!validImageFile) { |
| | return; |
| | } |
| |
|
| | const index = this.getStepIndex(file_id); |
| | const image_file = { |
| | [ContentTypes.IMAGE_FILE]: prelimImage, |
| | type: ContentTypes.IMAGE_FILE, |
| | index, |
| | }; |
| | this.addContentData(image_file); |
| | this.processedFileIds.add(file_id); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | createToolCallStream(index, toolCall) { |
| | |
| | const state = toolCall; |
| | const type = state.type; |
| | const data = state[type]; |
| |
|
| | |
| | const deltaHandler = async (delta) => { |
| | for (const key in delta) { |
| | if (!Object.prototype.hasOwnProperty.call(data, key)) { |
| | logger.warn(`Unhandled tool call key "${key}", delta: `, delta); |
| | continue; |
| | } |
| |
|
| | if (Array.isArray(delta[key])) { |
| | if (!Array.isArray(data[key])) { |
| | data[key] = []; |
| | } |
| |
|
| | for (const d of delta[key]) { |
| | if (typeof d === 'object' && !Object.prototype.hasOwnProperty.call(d, 'index')) { |
| | logger.warn("Expected an object with an 'index' for array updates but got:", d); |
| | continue; |
| | } |
| |
|
| | const imageOutput = type === ToolCallTypes.CODE_INTERPRETER && d?.type === 'image'; |
| |
|
| | if (imageOutput) { |
| | await this.handleCodeImageOutput(d); |
| | continue; |
| | } |
| |
|
| | const { index, ...updateData } = d; |
| | |
| | if (typeof data[key][index] !== 'object' || data[key][index] === null) { |
| | data[key][index] = {}; |
| | } |
| | |
| | for (const updateKey in updateData) { |
| | data[key][index][updateKey] = updateData[updateKey]; |
| | } |
| | } |
| | } else if (typeof delta[key] === 'string' && typeof data[key] === 'string') { |
| | |
| | |
| | } else if ( |
| | typeof delta[key] === 'object' && |
| | delta[key] !== null && |
| | !Array.isArray(delta[key]) |
| | ) { |
| | |
| | data[key] = { ...data[key], ...delta[key] }; |
| | } else { |
| | |
| | data[key] = delta[key]; |
| | } |
| |
|
| | state[type] = data; |
| |
|
| | this.addContentData({ |
| | [ContentTypes.TOOL_CALL]: toolCall, |
| | type: ContentTypes.TOOL_CALL, |
| | index, |
| | }); |
| |
|
| | await sleep(this.streamRate); |
| | } |
| | }; |
| |
|
| | return deltaHandler; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | handleNewToolCall(stepId, toolCall) { |
| | const stepKey = this.generateToolCallKey(stepId, toolCall); |
| | const index = this.getStepIndex(stepKey); |
| | this.getStepIndex(toolCall.id, index); |
| | toolCall.progress = 0.01; |
| | this.orderedRunSteps.set(index, toolCall); |
| | const progressCallback = this.createToolCallStream(index, toolCall); |
| | this.progressCallbacks.set(stepKey, progressCallback); |
| |
|
| | this.addContentData({ |
| | [ContentTypes.TOOL_CALL]: toolCall, |
| | type: ContentTypes.TOOL_CALL, |
| | index, |
| | }); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | handleCompletedToolCall(stepId, toolCall) { |
| | if (toolCall.type === ToolCallTypes.FUNCTION) { |
| | return; |
| | } |
| |
|
| | const stepKey = this.generateToolCallKey(stepId, toolCall); |
| | const index = this.getStepIndex(stepKey); |
| | toolCall.progress = 1; |
| | this.orderedRunSteps.set(index, toolCall); |
| | this.addContentData({ |
| | [ContentTypes.TOOL_CALL]: toolCall, |
| | type: ContentTypes.TOOL_CALL, |
| | index, |
| | }); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | async handleRunStepDeltaEvent(event) { |
| | const { delta, id: stepId } = event.data; |
| |
|
| | if (!delta.step_details) { |
| | logger.warn('Undefined or unhandled run step delta:', delta); |
| | return; |
| | } |
| |
|
| | |
| | const { tool_calls } = delta.step_details; |
| |
|
| | if (!tool_calls) { |
| | logger.warn('Unhandled run step details', delta.step_details); |
| | return; |
| | } |
| |
|
| | for (const toolCall of tool_calls) { |
| | const stepKey = this.generateToolCallKey(stepId, toolCall); |
| |
|
| | if (!this.mappedOrder.has(stepKey)) { |
| | this.handleNewToolCall(stepId, toolCall); |
| | continue; |
| | } |
| |
|
| | const toolCallDelta = toolCall[toolCall.type]; |
| | const progressCallback = this.progressCallbacks.get(stepKey); |
| | progressCallback(toolCallDelta); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | async handleMessageDeltaEvent(event) { |
| | const message = event.data; |
| | const onProgress = this.progressCallbacks.get(message.id); |
| | const content = message.delta.content?.[0]; |
| |
|
| | if (content && content.type === MessageContentTypes.TEXT) { |
| | this.intermediateText += content.text.value; |
| | onProgress(content.text.value); |
| | await sleep(this.streamRate); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | async handleErrorEvent(event) { |
| | logger.error('Error event:', event.data); |
| | } |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | getStepIndex(stepKey, overrideIndex) { |
| | if (!stepKey) { |
| | return; |
| | } |
| |
|
| | if (!isNaN(overrideIndex)) { |
| | this.mappedOrder.set(stepKey, overrideIndex); |
| | return; |
| | } |
| |
|
| | let index = this.mappedOrder.get(stepKey); |
| |
|
| | if (index === undefined) { |
| | index = this.index; |
| | this.mappedOrder.set(stepKey, this.index); |
| | this.index++; |
| | } |
| |
|
| | return index; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | generateToolCallKey(stepId, toolCall) { |
| | return `${stepId}_tool_call_${toolCall.index}_${toolCall.type}`; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | checkMissingOutputs(tool_outputs = [], actions = []) { |
| | const missingOutputs = []; |
| | const MISSING_OUTPUT_MESSAGE = |
| | 'The tool failed to produce an output. The tool may not be currently available or experienced an unhandled error.'; |
| | const outputIds = new Set(); |
| | const validatedOutputs = tool_outputs.map((output) => { |
| | if (!output) { |
| | logger.warn('Tool output is undefined'); |
| | return; |
| | } |
| | outputIds.add(output.tool_call_id); |
| | if (!output.output) { |
| | logger.warn(`Tool output exists but has no output property (ID: ${output.tool_call_id})`); |
| | return { |
| | ...output, |
| | output: MISSING_OUTPUT_MESSAGE, |
| | }; |
| | } |
| | return output; |
| | }); |
| |
|
| | for (const item of actions) { |
| | const { tool, toolCallId, run_id, thread_id } = item; |
| | const outputExists = outputIds.has(toolCallId); |
| |
|
| | if (!outputExists) { |
| | logger.warn( |
| | `The "${tool}" tool (ID: ${toolCallId}) failed to produce an output. run_id: ${run_id} thread_id: ${thread_id}`, |
| | ); |
| | missingOutputs.push({ |
| | tool_call_id: toolCallId, |
| | output: MISSING_OUTPUT_MESSAGE, |
| | }); |
| | } |
| | } |
| |
|
| | return [...validatedOutputs, ...missingOutputs]; |
| | } |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | async onRunRequiresAction(event) { |
| | const run = event.data; |
| | const { submit_tool_outputs } = run.required_action; |
| | const actions = submit_tool_outputs.tool_calls.map((item) => { |
| | const functionCall = item.function; |
| | const args = JSON.parse(functionCall.arguments); |
| | return { |
| | tool: functionCall.name, |
| | toolInput: args, |
| | toolCallId: item.id, |
| | run_id: run.id, |
| | thread_id: this.thread_id, |
| | }; |
| | }); |
| |
|
| | const { tool_outputs: preliminaryOutputs } = await processRequiredActions(this, actions); |
| | const tool_outputs = this.checkMissingOutputs(preliminaryOutputs, actions); |
| | |
| | let toolRun; |
| | try { |
| | toolRun = this.openai.beta.threads.runs.submitToolOutputsStream( |
| | run.id, |
| | { |
| | thread_id: run.thread_id, |
| | tool_outputs, |
| | stream: true, |
| | }, |
| | this.streamOptions, |
| | ); |
| | } catch (error) { |
| | logger.error('Error submitting tool outputs:', error); |
| | throw error; |
| | } |
| |
|
| | for await (const event of toolRun) { |
| | await this.handleEvent(event); |
| | } |
| | } |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | async onRunStepCreated(event) { |
| | const step = event.data; |
| | const isMessage = step.type === StepTypes.MESSAGE_CREATION; |
| |
|
| | if (isMessage) { |
| | |
| | const { message_creation } = step.step_details; |
| | const stepKey = message_creation.message_id; |
| | const index = this.getStepIndex(stepKey); |
| | this.orderedRunSteps.set(index, message_creation); |
| |
|
| | const { onProgress: progressCallback } = createOnProgress(); |
| |
|
| | const onProgress = progressCallback({ |
| | index, |
| | res: this.res, |
| | messageId: this.finalMessage.messageId, |
| | conversationId: this.finalMessage.conversationId, |
| | thread_id: this.thread_id, |
| | type: ContentTypes.TEXT, |
| | }); |
| |
|
| | this.progressCallbacks.set(stepKey, onProgress); |
| | this.orderedRunSteps.set(index, step); |
| | return; |
| | } |
| |
|
| | if (step.type !== StepTypes.TOOL_CALLS) { |
| | logger.warn('Unhandled step creation type:', step.type); |
| | return; |
| | } |
| |
|
| | |
| | const { tool_calls } = step.step_details; |
| | for (const toolCall of tool_calls) { |
| | this.handleNewToolCall(step.id, toolCall); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | async onRunStepCompleted(event) { |
| | const step = event.data; |
| | const isMessage = step.type === StepTypes.MESSAGE_CREATION; |
| |
|
| | if (isMessage) { |
| | logger.debug('RunStep Message completion: to be handled by Message Event.', step); |
| | return; |
| | } |
| |
|
| | |
| | const { tool_calls } = step.step_details; |
| | for (let i = 0; i < tool_calls.length; i++) { |
| | const toolCall = tool_calls[i]; |
| | toolCall.index = i; |
| | this.handleCompletedToolCall(step.id, toolCall); |
| | } |
| | } |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | async handleMessageEvent(event) { |
| | if (event.event === AssistantStreamEvents.ThreadMessageCompleted) { |
| | await this.messageCompleted(event); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | async messageCompleted(event) { |
| | const message = event.data; |
| | const result = await processMessages({ |
| | openai: this.openai, |
| | client: this, |
| | messages: [message], |
| | }); |
| | const index = this.mappedOrder.get(message.id); |
| | this.addContentData({ |
| | [ContentTypes.TEXT]: { value: result.text }, |
| | type: ContentTypes.TEXT, |
| | edited: result.edited, |
| | index, |
| | }); |
| | this.messages.push(message); |
| | } |
| | } |
| |
|
| | module.exports = StreamRunManager; |
| |
|