File size: 21,973 Bytes
f0743f4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 | const { sleep } = require('@librechat/agents');
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
Constants,
StepTypes,
ContentTypes,
ToolCallTypes,
MessageContentTypes,
AssistantStreamEvents,
} = require('librechat-data-provider');
const { retrieveAndProcessFile } = require('~/server/services/Files/process');
const { processRequiredActions } = require('~/server/services/ToolService');
const { processMessages } = require('~/server/services/Threads');
const { createOnProgress } = require('~/server/utils');
/**
* Implements the StreamRunManager functionality for managing the streaming
* and processing of run steps, messages, and tool calls within a thread.
* @implements {StreamRunManager}
*/
class StreamRunManager {
constructor(fields) {
this.index = 0;
/** @type {Map<string, RunStep>} */
this.steps = new Map();
/** @type {Map<string, number} */
this.mappedOrder = new Map();
/** @type {Map<string, StepToolCall} */
this.orderedRunSteps = new Map();
/** @type {Set<string>} */
this.processedFileIds = new Set();
/** @type {Map<string, (delta: ToolCallDelta | string) => Promise<void>} */
this.progressCallbacks = new Map();
/** @type {Run | null} */
this.run = null;
/** @type {ServerRequest} */
this.req = fields.req;
/** @type {Express.Response} */
this.res = fields.res;
/** @type {OpenAI} */
this.openai = fields.openai;
/** @type {string} */
this.apiKey = this.openai.apiKey;
/** @type {string} */
this.parentMessageId = fields.parentMessageId;
/** @type {string} */
this.thread_id = fields.thread_id;
/** @type {RunCreateAndStreamParams} */
this.initialRunBody = fields.runBody;
/**
* @type {Object.<AssistantStreamEvents, (event: AssistantStreamEvent) => Promise<void>>}
*/
this.clientHandlers = fields.handlers ?? {};
/** @type {OpenAIRequestOptions} */
this.streamOptions = fields.streamOptions ?? {};
/** @type {Partial<TMessage>} */
this.finalMessage = fields.responseMessage ?? {};
/** @type {ThreadMessage[]} */
this.messages = [];
/** @type {string} */
this.text = '';
/** @type {string} */
this.intermediateText = '';
/** @type {Set<string>} */
this.attachedFileIds = fields.attachedFileIds;
/** @type {undefined | Promise<ChatCompletion>} */
this.visionPromise = fields.visionPromise;
/** @type {number} */
this.streamRate = fields.streamRate ?? Constants.DEFAULT_STREAM_RATE;
/**
* @type {Object.<AssistantStreamEvents, (event: AssistantStreamEvent) => Promise<void>>}
*/
this.handlers = {
[AssistantStreamEvents.ThreadCreated]: this.handleThreadCreated,
[AssistantStreamEvents.ThreadRunCreated]: this.handleRunEvent,
[AssistantStreamEvents.ThreadRunQueued]: this.handleRunEvent,
[AssistantStreamEvents.ThreadRunInProgress]: this.handleRunEvent,
[AssistantStreamEvents.ThreadRunRequiresAction]: this.handleRunEvent,
[AssistantStreamEvents.ThreadRunCompleted]: this.handleRunEvent,
[AssistantStreamEvents.ThreadRunFailed]: this.handleRunEvent,
[AssistantStreamEvents.ThreadRunCancelling]: this.handleRunEvent,
[AssistantStreamEvents.ThreadRunCancelled]: this.handleRunEvent,
[AssistantStreamEvents.ThreadRunExpired]: this.handleRunEvent,
[AssistantStreamEvents.ThreadRunStepCreated]: this.handleRunStepEvent,
[AssistantStreamEvents.ThreadRunStepInProgress]: this.handleRunStepEvent,
[AssistantStreamEvents.ThreadRunStepCompleted]: this.handleRunStepEvent,
[AssistantStreamEvents.ThreadRunStepFailed]: this.handleRunStepEvent,
[AssistantStreamEvents.ThreadRunStepCancelled]: this.handleRunStepEvent,
[AssistantStreamEvents.ThreadRunStepExpired]: this.handleRunStepEvent,
[AssistantStreamEvents.ThreadRunStepDelta]: this.handleRunStepDeltaEvent,
[AssistantStreamEvents.ThreadMessageCreated]: this.handleMessageEvent,
[AssistantStreamEvents.ThreadMessageInProgress]: this.handleMessageEvent,
[AssistantStreamEvents.ThreadMessageCompleted]: this.handleMessageEvent,
[AssistantStreamEvents.ThreadMessageIncomplete]: this.handleMessageEvent,
[AssistantStreamEvents.ThreadMessageDelta]: this.handleMessageDeltaEvent,
[AssistantStreamEvents.ErrorEvent]: this.handleErrorEvent,
};
}
/**
*
* Sends the content data to the client via SSE.
*
* @param {StreamContentData} data
* @returns {Promise<void>}
*/
async addContentData(data) {
const { type, index, edited } = data;
/** @type {ContentPart} */
const contentPart = data[type];
this.finalMessage.content[index] = { type, [type]: contentPart };
if (type === ContentTypes.TEXT && !edited) {
this.text += contentPart.value;
return;
}
const contentData = {
index,
type,
[type]: contentPart,
thread_id: this.thread_id,
messageId: this.finalMessage.messageId,
conversationId: this.finalMessage.conversationId,
};
sendEvent(this.res, contentData);
}
/* <------------------ Misc. Helpers ------------------> */
/** Returns the latest intermediate text
* @returns {string}
*/
getText() {
return this.intermediateText;
}
/** Returns the current, intermediate message
* @returns {TMessage}
*/
getIntermediateMessage() {
return {
conversationId: this.finalMessage.conversationId,
messageId: this.finalMessage.messageId,
parentMessageId: this.parentMessageId,
model: this.req.body.assistant_id,
endpoint: this.req.body.endpoint,
isCreatedByUser: false,
user: this.req.user.id,
text: this.getText(),
sender: 'Assistant',
unfinished: true,
error: false,
};
}
/* <------------------ Main Event Handlers ------------------> */
/**
* Run the assistant and handle the events.
* @param {Object} params -
* The parameters for running the assistant.
* @param {string} params.thread_id - The thread id.
* @param {RunCreateAndStreamParams} params.body - The body of the run.
* @returns {Promise<void>}
*/
async runAssistant({ thread_id, body }) {
const streamRun = this.openai.beta.threads.runs.createAndStream(
thread_id,
body,
this.streamOptions,
);
for await (const event of streamRun) {
await this.handleEvent(event);
}
}
/**
* Handle the event.
* @param {AssistantStreamEvent} event - The stream event object.
* @returns {Promise<void>}
*/
async handleEvent(event) {
const handler = this.handlers[event.event];
const clientHandler = this.clientHandlers[event.event];
if (clientHandler) {
await clientHandler.call(this, event);
}
if (handler) {
await handler.call(this, event);
} else {
logger.warn(`Unhandled event type: ${event.event}`);
}
}
/**
* Handle thread.created event
* @param {ThreadCreated} event -
* The thread.created event object.
*/
async handleThreadCreated(event) {
logger.debug('Thread created:', event.data);
}
/**
* Handle Run Events
* @param {ThreadRunCreated | ThreadRunQueued | ThreadRunInProgress | ThreadRunRequiresAction | ThreadRunCompleted | ThreadRunFailed | ThreadRunCancelling | ThreadRunCancelled | ThreadRunExpired} event -
* The run event object.
*/
async handleRunEvent(event) {
this.run = event.data;
logger.debug('Run event:', this.run);
if (event.event === AssistantStreamEvents.ThreadRunRequiresAction) {
await this.onRunRequiresAction(event);
} else if (event.event === AssistantStreamEvents.ThreadRunCompleted) {
logger.debug('Run completed:', this.run);
}
}
/**
* Handle Run Step Events
* @param {ThreadRunStepCreated | ThreadRunStepInProgress | ThreadRunStepCompleted | ThreadRunStepFailed | ThreadRunStepCancelled | ThreadRunStepExpired} event -
* The run step event object.
*/
async handleRunStepEvent(event) {
logger.debug('Run step event:', event.data);
const step = event.data;
this.steps.set(step.id, step);
if (event.event === AssistantStreamEvents.ThreadRunStepCreated) {
this.onRunStepCreated(event);
} else if (event.event === AssistantStreamEvents.ThreadRunStepCompleted) {
this.onRunStepCompleted(event);
}
}
/* <------------------ Delta Events ------------------> */
/** @param {CodeImageOutput} */
async handleCodeImageOutput(output) {
if (this.processedFileIds.has(output.image?.file_id)) {
return;
}
const { file_id } = output.image;
const file = await retrieveAndProcessFile({
openai: this.openai,
client: this,
file_id,
basename: `${file_id}.png`,
});
const prelimImage = file;
// check if every key has a value before adding to content
const prelimImageKeys = Object.keys(prelimImage);
const validImageFile = prelimImageKeys.every((key) => prelimImage[key]);
if (!validImageFile) {
return;
}
const index = this.getStepIndex(file_id);
const image_file = {
[ContentTypes.IMAGE_FILE]: prelimImage,
type: ContentTypes.IMAGE_FILE,
index,
};
this.addContentData(image_file);
this.processedFileIds.add(file_id);
}
/**
* Create Tool Call Stream
* @param {number} index - The index of the tool call.
* @param {StepToolCall} toolCall -
* The current tool call object.
*/
createToolCallStream(index, toolCall) {
/** @type {StepToolCall} */
const state = toolCall;
const type = state.type;
const data = state[type];
/** @param {ToolCallDelta} */
const deltaHandler = async (delta) => {
for (const key in delta) {
if (!Object.prototype.hasOwnProperty.call(data, key)) {
logger.warn(`Unhandled tool call key "${key}", delta: `, delta);
continue;
}
if (Array.isArray(delta[key])) {
if (!Array.isArray(data[key])) {
data[key] = [];
}
for (const d of delta[key]) {
if (typeof d === 'object' && !Object.prototype.hasOwnProperty.call(d, 'index')) {
logger.warn("Expected an object with an 'index' for array updates but got:", d);
continue;
}
const imageOutput = type === ToolCallTypes.CODE_INTERPRETER && d?.type === 'image';
if (imageOutput) {
await this.handleCodeImageOutput(d);
continue;
}
const { index, ...updateData } = d;
// Ensure the data at index is an object or undefined before assigning
if (typeof data[key][index] !== 'object' || data[key][index] === null) {
data[key][index] = {};
}
// Merge the updateData into data[key][index]
for (const updateKey in updateData) {
data[key][index][updateKey] = updateData[updateKey];
}
}
} else if (typeof delta[key] === 'string' && typeof data[key] === 'string') {
// Concatenate strings
// data[key] += delta[key];
} else if (
typeof delta[key] === 'object' &&
delta[key] !== null &&
!Array.isArray(delta[key])
) {
// Merge objects
data[key] = { ...data[key], ...delta[key] };
} else {
// Directly set the value for other types
data[key] = delta[key];
}
state[type] = data;
this.addContentData({
[ContentTypes.TOOL_CALL]: toolCall,
type: ContentTypes.TOOL_CALL,
index,
});
await sleep(this.streamRate);
}
};
return deltaHandler;
}
/**
* @param {string} stepId -
* @param {StepToolCall} toolCall -
*
*/
handleNewToolCall(stepId, toolCall) {
const stepKey = this.generateToolCallKey(stepId, toolCall);
const index = this.getStepIndex(stepKey);
this.getStepIndex(toolCall.id, index);
toolCall.progress = 0.01;
this.orderedRunSteps.set(index, toolCall);
const progressCallback = this.createToolCallStream(index, toolCall);
this.progressCallbacks.set(stepKey, progressCallback);
this.addContentData({
[ContentTypes.TOOL_CALL]: toolCall,
type: ContentTypes.TOOL_CALL,
index,
});
}
/**
* Handle Completed Tool Call
* @param {string} stepId - The id of the step the tool_call is part of.
* @param {StepToolCall} toolCall - The tool call object.
*
*/
handleCompletedToolCall(stepId, toolCall) {
if (toolCall.type === ToolCallTypes.FUNCTION) {
return;
}
const stepKey = this.generateToolCallKey(stepId, toolCall);
const index = this.getStepIndex(stepKey);
toolCall.progress = 1;
this.orderedRunSteps.set(index, toolCall);
this.addContentData({
[ContentTypes.TOOL_CALL]: toolCall,
type: ContentTypes.TOOL_CALL,
index,
});
}
/**
* Handle Run Step Delta Event
* @param {ThreadRunStepDelta} event -
* The run step delta event object.
*/
async handleRunStepDeltaEvent(event) {
const { delta, id: stepId } = event.data;
if (!delta.step_details) {
logger.warn('Undefined or unhandled run step delta:', delta);
return;
}
/** @type {{ tool_calls: Array<ToolCallDeltaObject> }} */
const { tool_calls } = delta.step_details;
if (!tool_calls) {
logger.warn('Unhandled run step details', delta.step_details);
return;
}
for (const toolCall of tool_calls) {
const stepKey = this.generateToolCallKey(stepId, toolCall);
if (!this.mappedOrder.has(stepKey)) {
this.handleNewToolCall(stepId, toolCall);
continue;
}
const toolCallDelta = toolCall[toolCall.type];
const progressCallback = this.progressCallbacks.get(stepKey);
progressCallback(toolCallDelta);
}
}
/**
* Handle Message Delta Event
* @param {ThreadMessageDelta} event -
* The Message Delta event object.
*/
async handleMessageDeltaEvent(event) {
const message = event.data;
const onProgress = this.progressCallbacks.get(message.id);
const content = message.delta.content?.[0];
if (content && content.type === MessageContentTypes.TEXT) {
this.intermediateText += content.text.value;
onProgress(content.text.value);
await sleep(this.streamRate);
}
}
/**
* Handle Error Event
* @param {ErrorEvent} event -
* The Error event object.
*/
async handleErrorEvent(event) {
logger.error('Error event:', event.data);
}
/* <------------------ Misc. Helpers ------------------> */
/**
* Gets the step index for a given step key, creating a new index if it doesn't exist.
* @param {string} stepKey -
* The access key for the step. Either a message.id, tool_call key, or file_id.
* @param {number | undefined} [overrideIndex] - An override index to use an alternative stepKey.
* This is necessary due to the toolCall Id being unavailable in delta stream events.
* @returns {number | undefined} index - The index of the step; `undefined` if invalid key or using overrideIndex.
*/
getStepIndex(stepKey, overrideIndex) {
if (!stepKey) {
return;
}
if (!isNaN(overrideIndex)) {
this.mappedOrder.set(stepKey, overrideIndex);
return;
}
let index = this.mappedOrder.get(stepKey);
if (index === undefined) {
index = this.index;
this.mappedOrder.set(stepKey, this.index);
this.index++;
}
return index;
}
/**
* Generate Tool Call Key
* @param {string} stepId - The id of the step the tool_call is part of.
* @param {StepToolCall} toolCall - The tool call object.
* @returns {string} key - The generated key for the tool call.
*/
generateToolCallKey(stepId, toolCall) {
return `${stepId}_tool_call_${toolCall.index}_${toolCall.type}`;
}
/**
* Check Missing Outputs
* @param {ToolOutput[]} tool_outputs - The tool outputs.
* @param {RequiredAction[]} actions - The required actions.
* @returns {ToolOutput[]} completeOutputs - The complete outputs.
*/
checkMissingOutputs(tool_outputs = [], actions = []) {
const missingOutputs = [];
const MISSING_OUTPUT_MESSAGE =
'The tool failed to produce an output. The tool may not be currently available or experienced an unhandled error.';
const outputIds = new Set();
const validatedOutputs = tool_outputs.map((output) => {
if (!output) {
logger.warn('Tool output is undefined');
return;
}
outputIds.add(output.tool_call_id);
if (!output.output) {
logger.warn(`Tool output exists but has no output property (ID: ${output.tool_call_id})`);
return {
...output,
output: MISSING_OUTPUT_MESSAGE,
};
}
return output;
});
for (const item of actions) {
const { tool, toolCallId, run_id, thread_id } = item;
const outputExists = outputIds.has(toolCallId);
if (!outputExists) {
logger.warn(
`The "${tool}" tool (ID: ${toolCallId}) failed to produce an output. run_id: ${run_id} thread_id: ${thread_id}`,
);
missingOutputs.push({
tool_call_id: toolCallId,
output: MISSING_OUTPUT_MESSAGE,
});
}
}
return [...validatedOutputs, ...missingOutputs];
}
/* <------------------ Run Event handlers ------------------> */
/**
* Handle Run Events Requiring Action
* @param {ThreadRunRequiresAction} event -
* The run event object requiring action.
*/
async onRunRequiresAction(event) {
const run = event.data;
const { submit_tool_outputs } = run.required_action;
const actions = submit_tool_outputs.tool_calls.map((item) => {
const functionCall = item.function;
const args = JSON.parse(functionCall.arguments);
return {
tool: functionCall.name,
toolInput: args,
toolCallId: item.id,
run_id: run.id,
thread_id: this.thread_id,
};
});
const { tool_outputs: preliminaryOutputs } = await processRequiredActions(this, actions);
const tool_outputs = this.checkMissingOutputs(preliminaryOutputs, actions);
/** @type {AssistantStream | undefined} */
let toolRun;
try {
toolRun = this.openai.beta.threads.runs.submitToolOutputsStream(
run.id,
{
thread_id: run.thread_id,
tool_outputs,
stream: true,
},
this.streamOptions,
);
} catch (error) {
logger.error('Error submitting tool outputs:', error);
throw error;
}
for await (const event of toolRun) {
await this.handleEvent(event);
}
}
/* <------------------ RunStep Event handlers ------------------> */
/**
* Handle Run Step Created Events
* @param {ThreadRunStepCreated} event -
* The created run step event object.
*/
async onRunStepCreated(event) {
const step = event.data;
const isMessage = step.type === StepTypes.MESSAGE_CREATION;
if (isMessage) {
/** @type {MessageCreationStepDetails} */
const { message_creation } = step.step_details;
const stepKey = message_creation.message_id;
const index = this.getStepIndex(stepKey);
this.orderedRunSteps.set(index, message_creation);
const { onProgress: progressCallback } = createOnProgress();
const onProgress = progressCallback({
index,
res: this.res,
messageId: this.finalMessage.messageId,
conversationId: this.finalMessage.conversationId,
thread_id: this.thread_id,
type: ContentTypes.TEXT,
});
this.progressCallbacks.set(stepKey, onProgress);
this.orderedRunSteps.set(index, step);
return;
}
if (step.type !== StepTypes.TOOL_CALLS) {
logger.warn('Unhandled step creation type:', step.type);
return;
}
/** @type {{ tool_calls: StepToolCall[] }} */
const { tool_calls } = step.step_details;
for (const toolCall of tool_calls) {
this.handleNewToolCall(step.id, toolCall);
}
}
/**
* Handle Run Step Completed Events
* @param {ThreadRunStepCompleted} event -
* The completed run step event object.
*/
async onRunStepCompleted(event) {
const step = event.data;
const isMessage = step.type === StepTypes.MESSAGE_CREATION;
if (isMessage) {
logger.debug('RunStep Message completion: to be handled by Message Event.', step);
return;
}
/** @type {{ tool_calls: StepToolCall[] }} */
const { tool_calls } = step.step_details;
for (let i = 0; i < tool_calls.length; i++) {
const toolCall = tool_calls[i];
toolCall.index = i;
this.handleCompletedToolCall(step.id, toolCall);
}
}
/* <------------------ Message Event handlers ------------------> */
/**
* Handle Message Event
* @param {ThreadMessageCreated | ThreadMessageInProgress | ThreadMessageCompleted | ThreadMessageIncomplete} event -
* The Message event object.
*/
async handleMessageEvent(event) {
if (event.event === AssistantStreamEvents.ThreadMessageCompleted) {
await this.messageCompleted(event);
}
}
/**
* Handle Message Completed Events
* @param {ThreadMessageCompleted} event -
* The Completed Message event object.
*/
async messageCompleted(event) {
const message = event.data;
const result = await processMessages({
openai: this.openai,
client: this,
messages: [message],
});
const index = this.mappedOrder.get(message.id);
this.addContentData({
[ContentTypes.TEXT]: { value: result.text },
type: ContentTypes.TEXT,
edited: result.edited,
index,
});
this.messages.push(message);
}
}
module.exports = StreamRunManager;
|