| | import { useCallback, useRef } from 'react'; |
| | import { |
| | Constants, |
| | StepTypes, |
| | ContentTypes, |
| | ToolCallTypes, |
| | getNonEmptyValue, |
| | } from 'librechat-data-provider'; |
| | import type { |
| | Agents, |
| | TMessage, |
| | PartMetadata, |
| | EventSubmission, |
| | TMessageContentParts, |
| | } from 'librechat-data-provider'; |
| | import type { SetterOrUpdater } from 'recoil'; |
| | import type { AnnounceOptions } from '~/common'; |
| | import { MESSAGE_UPDATE_INTERVAL } from '~/common'; |
| |
|
| | type TUseStepHandler = { |
| | announcePolite: (options: AnnounceOptions) => void; |
| | setMessages: (messages: TMessage[]) => void; |
| | getMessages: () => TMessage[] | undefined; |
| | setIsSubmitting: SetterOrUpdater<boolean>; |
| | lastAnnouncementTimeRef: React.MutableRefObject<number>; |
| | }; |
| |
|
| | type TStepEvent = { |
| | event: string; |
| | data: |
| | | Agents.MessageDeltaEvent |
| | | Agents.AgentUpdate |
| | | Agents.RunStep |
| | | Agents.ToolEndEvent |
| | | { |
| | runId?: string; |
| | message: string; |
| | }; |
| | }; |
| |
|
| | type MessageDeltaUpdate = { type: ContentTypes.TEXT; text: string; tool_call_ids?: string[] }; |
| |
|
| | type ReasoningDeltaUpdate = { type: ContentTypes.THINK; think: string }; |
| |
|
| | type AllContentTypes = |
| | | ContentTypes.TEXT |
| | | ContentTypes.THINK |
| | | ContentTypes.TOOL_CALL |
| | | ContentTypes.IMAGE_FILE |
| | | ContentTypes.IMAGE_URL |
| | | ContentTypes.ERROR; |
| |
|
| | export default function useStepHandler({ |
| | setMessages, |
| | getMessages, |
| | setIsSubmitting, |
| | announcePolite, |
| | lastAnnouncementTimeRef, |
| | }: TUseStepHandler) { |
| | const toolCallIdMap = useRef(new Map<string, string | undefined>()); |
| | const messageMap = useRef(new Map<string, TMessage>()); |
| | const stepMap = useRef(new Map<string, Agents.RunStep>()); |
| |
|
| | const calculateContentIndex = ( |
| | baseIndex: number, |
| | initialContent: TMessageContentParts[], |
| | incomingContentType: string, |
| | existingContent?: TMessageContentParts[], |
| | ): number => { |
| | |
| | if ( |
| | initialContent.length > 0 && |
| | (incomingContentType === ContentTypes.TEXT || incomingContentType === ContentTypes.THINK) |
| | ) { |
| | const targetIndex = baseIndex + initialContent.length - 1; |
| | const existingType = existingContent?.[targetIndex]?.type; |
| | if (existingType === incomingContentType) { |
| | return targetIndex; |
| | } |
| | } |
| | return baseIndex + initialContent.length; |
| | }; |
| |
|
| | const updateContent = ( |
| | message: TMessage, |
| | index: number, |
| | contentPart: Agents.MessageContentComplex, |
| | finalUpdate = false, |
| | ) => { |
| | const contentType = contentPart.type ?? ''; |
| | if (!contentType) { |
| | console.warn('No content type found in content part'); |
| | return message; |
| | } |
| |
|
| | const updatedContent = [...(message.content || [])] as Array< |
| | Partial<TMessageContentParts> | undefined |
| | >; |
| | if (!updatedContent[index]) { |
| | updatedContent[index] = { type: contentPart.type as AllContentTypes }; |
| | } |
| | |
| | const existingType = (updatedContent[index]?.type as string | undefined) ?? ''; |
| | if (existingType && !contentType.startsWith(existingType)) { |
| | console.warn('Content type mismatch'); |
| | return message; |
| | } |
| |
|
| | if ( |
| | contentType.startsWith(ContentTypes.TEXT) && |
| | ContentTypes.TEXT in contentPart && |
| | typeof contentPart.text === 'string' |
| | ) { |
| | const currentContent = updatedContent[index] as MessageDeltaUpdate; |
| | const update: MessageDeltaUpdate = { |
| | type: ContentTypes.TEXT, |
| | text: (currentContent.text || '') + contentPart.text, |
| | }; |
| |
|
| | if (contentPart.tool_call_ids != null) { |
| | update.tool_call_ids = contentPart.tool_call_ids; |
| | } |
| | updatedContent[index] = update; |
| | } else if ( |
| | contentType.startsWith(ContentTypes.AGENT_UPDATE) && |
| | ContentTypes.AGENT_UPDATE in contentPart && |
| | contentPart.agent_update |
| | ) { |
| | const update: Agents.AgentUpdate = { |
| | type: ContentTypes.AGENT_UPDATE, |
| | agent_update: contentPart.agent_update, |
| | }; |
| |
|
| | updatedContent[index] = update; |
| | } else if ( |
| | contentType.startsWith(ContentTypes.THINK) && |
| | ContentTypes.THINK in contentPart && |
| | typeof contentPart.think === 'string' |
| | ) { |
| | const currentContent = updatedContent[index] as ReasoningDeltaUpdate; |
| | const update: ReasoningDeltaUpdate = { |
| | type: ContentTypes.THINK, |
| | think: (currentContent.think || '') + contentPart.think, |
| | }; |
| |
|
| | updatedContent[index] = update; |
| | } else if (contentType === ContentTypes.IMAGE_URL && 'image_url' in contentPart) { |
| | const currentContent = updatedContent[index] as { |
| | type: ContentTypes.IMAGE_URL; |
| | image_url: string; |
| | }; |
| | updatedContent[index] = { |
| | ...currentContent, |
| | }; |
| | } else if (contentType === ContentTypes.TOOL_CALL && 'tool_call' in contentPart) { |
| | const existingContent = updatedContent[index] as Agents.ToolCallContent | undefined; |
| | const existingToolCall = existingContent?.tool_call; |
| | const toolCallArgs = (contentPart.tool_call as Agents.ToolCall).args; |
| | |
| | let args = |
| | finalUpdate || |
| | typeof existingToolCall?.args === 'object' || |
| | typeof toolCallArgs === 'object' |
| | ? contentPart.tool_call.args |
| | : (existingToolCall?.args ?? '') + (toolCallArgs ?? ''); |
| | |
| | if (finalUpdate && args == null && existingToolCall?.args != null) { |
| | args = existingToolCall.args; |
| | } |
| |
|
| | const id = getNonEmptyValue([contentPart.tool_call.id, existingToolCall?.id]) ?? ''; |
| | const name = getNonEmptyValue([contentPart.tool_call.name, existingToolCall?.name]) ?? ''; |
| |
|
| | const newToolCall: Agents.ToolCall & PartMetadata = { |
| | id, |
| | name, |
| | args, |
| | type: ToolCallTypes.TOOL_CALL, |
| | auth: contentPart.tool_call.auth, |
| | expires_at: contentPart.tool_call.expires_at, |
| | }; |
| |
|
| | if (finalUpdate) { |
| | newToolCall.progress = 1; |
| | newToolCall.output = contentPart.tool_call.output; |
| | } |
| |
|
| | updatedContent[index] = { |
| | type: ContentTypes.TOOL_CALL, |
| | tool_call: newToolCall, |
| | }; |
| | } |
| |
|
| | return { ...message, content: updatedContent as TMessageContentParts[] }; |
| | }; |
| |
|
| | const stepHandler = useCallback( |
| | ({ event, data }: TStepEvent, submission: EventSubmission) => { |
| | const messages = getMessages() || []; |
| | const { userMessage } = submission; |
| | setIsSubmitting(true); |
| | let parentMessageId = userMessage.messageId; |
| |
|
| | const currentTime = Date.now(); |
| | if (currentTime - lastAnnouncementTimeRef.current > MESSAGE_UPDATE_INTERVAL) { |
| | announcePolite({ message: 'composing', isStatus: true }); |
| | lastAnnouncementTimeRef.current = currentTime; |
| | } |
| |
|
| | let initialContent: TMessageContentParts[] = []; |
| | if (submission?.editedContent != null) { |
| | initialContent = submission?.initialResponse?.content ?? initialContent; |
| | } |
| |
|
| | if (event === 'on_run_step') { |
| | const runStep = data as Agents.RunStep; |
| | let responseMessageId = runStep.runId ?? ''; |
| | if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) { |
| | responseMessageId = submission?.initialResponse?.messageId ?? ''; |
| | parentMessageId = submission?.initialResponse?.parentMessageId ?? ''; |
| | } |
| | if (!responseMessageId) { |
| | console.warn('No message id found in run step event'); |
| | return; |
| | } |
| |
|
| | stepMap.current.set(runStep.id, runStep); |
| | let response = messageMap.current.get(responseMessageId); |
| |
|
| | if (!response) { |
| | const responseMessage = messages[messages.length - 1] as TMessage; |
| |
|
| | response = { |
| | ...responseMessage, |
| | parentMessageId, |
| | conversationId: userMessage.conversationId, |
| | messageId: responseMessageId, |
| | content: initialContent, |
| | }; |
| |
|
| | messageMap.current.set(responseMessageId, response); |
| | setMessages([...messages.slice(0, -1), response]); |
| | } |
| |
|
| | |
| | if (runStep.stepDetails.type === StepTypes.TOOL_CALLS) { |
| | let updatedResponse = { ...response }; |
| | (runStep.stepDetails.tool_calls as Agents.ToolCall[]).forEach((toolCall) => { |
| | const toolCallId = toolCall.id ?? ''; |
| | if ('id' in toolCall && toolCallId) { |
| | toolCallIdMap.current.set(runStep.id, toolCallId); |
| | } |
| |
|
| | const contentPart: Agents.MessageContentComplex = { |
| | type: ContentTypes.TOOL_CALL, |
| | tool_call: { |
| | name: toolCall.name ?? '', |
| | args: toolCall.args, |
| | id: toolCallId, |
| | }, |
| | }; |
| |
|
| | |
| | const currentIndex = runStep.index + initialContent.length; |
| | updatedResponse = updateContent(updatedResponse, currentIndex, contentPart); |
| | }); |
| |
|
| | messageMap.current.set(responseMessageId, updatedResponse); |
| | const updatedMessages = messages.map((msg) => |
| | msg.messageId === responseMessageId ? updatedResponse : msg, |
| | ); |
| |
|
| | setMessages(updatedMessages); |
| | } |
| | } else if (event === 'on_agent_update') { |
| | const { agent_update } = data as Agents.AgentUpdate; |
| | let responseMessageId = agent_update.runId || ''; |
| | if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) { |
| | responseMessageId = submission?.initialResponse?.messageId ?? ''; |
| | parentMessageId = submission?.initialResponse?.parentMessageId ?? ''; |
| | } |
| | if (!responseMessageId) { |
| | console.warn('No message id found in agent update event'); |
| | return; |
| | } |
| |
|
| | const response = messageMap.current.get(responseMessageId); |
| | if (response) { |
| | |
| | const currentIndex = agent_update.index + initialContent.length; |
| | const updatedResponse = updateContent(response, currentIndex, data); |
| | messageMap.current.set(responseMessageId, updatedResponse); |
| | const currentMessages = getMessages() || []; |
| | setMessages([...currentMessages.slice(0, -1), updatedResponse]); |
| | } |
| | } else if (event === 'on_message_delta') { |
| | const messageDelta = data as Agents.MessageDeltaEvent; |
| | const runStep = stepMap.current.get(messageDelta.id); |
| | let responseMessageId = runStep?.runId ?? ''; |
| | if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) { |
| | responseMessageId = submission?.initialResponse?.messageId ?? ''; |
| | parentMessageId = submission?.initialResponse?.parentMessageId ?? ''; |
| | } |
| |
|
| | if (!runStep || !responseMessageId) { |
| | console.warn('No run step or runId found for message delta event'); |
| | return; |
| | } |
| |
|
| | const response = messageMap.current.get(responseMessageId); |
| | if (response && messageDelta.delta.content) { |
| | const contentPart = Array.isArray(messageDelta.delta.content) |
| | ? messageDelta.delta.content[0] |
| | : messageDelta.delta.content; |
| |
|
| | if (contentPart == null) { |
| | return; |
| | } |
| |
|
| | const currentIndex = calculateContentIndex( |
| | runStep.index, |
| | initialContent, |
| | contentPart.type || '', |
| | response.content, |
| | ); |
| | const updatedResponse = updateContent(response, currentIndex, contentPart); |
| |
|
| | messageMap.current.set(responseMessageId, updatedResponse); |
| | const currentMessages = getMessages() || []; |
| | setMessages([...currentMessages.slice(0, -1), updatedResponse]); |
| | } |
| | } else if (event === 'on_reasoning_delta') { |
| | const reasoningDelta = data as Agents.ReasoningDeltaEvent; |
| | const runStep = stepMap.current.get(reasoningDelta.id); |
| | let responseMessageId = runStep?.runId ?? ''; |
| | if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) { |
| | responseMessageId = submission?.initialResponse?.messageId ?? ''; |
| | parentMessageId = submission?.initialResponse?.parentMessageId ?? ''; |
| | } |
| |
|
| | if (!runStep || !responseMessageId) { |
| | console.warn('No run step or runId found for reasoning delta event'); |
| | return; |
| | } |
| |
|
| | const response = messageMap.current.get(responseMessageId); |
| | if (response && reasoningDelta.delta.content != null) { |
| | const contentPart = Array.isArray(reasoningDelta.delta.content) |
| | ? reasoningDelta.delta.content[0] |
| | : reasoningDelta.delta.content; |
| |
|
| | if (contentPart == null) { |
| | return; |
| | } |
| |
|
| | const currentIndex = calculateContentIndex( |
| | runStep.index, |
| | initialContent, |
| | contentPart.type || '', |
| | response.content, |
| | ); |
| | const updatedResponse = updateContent(response, currentIndex, contentPart); |
| |
|
| | messageMap.current.set(responseMessageId, updatedResponse); |
| | const currentMessages = getMessages() || []; |
| | setMessages([...currentMessages.slice(0, -1), updatedResponse]); |
| | } |
| | } else if (event === 'on_run_step_delta') { |
| | const runStepDelta = data as Agents.RunStepDeltaEvent; |
| | const runStep = stepMap.current.get(runStepDelta.id); |
| | let responseMessageId = runStep?.runId ?? ''; |
| | if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) { |
| | responseMessageId = submission?.initialResponse?.messageId ?? ''; |
| | parentMessageId = submission?.initialResponse?.parentMessageId ?? ''; |
| | } |
| |
|
| | if (!runStep || !responseMessageId) { |
| | console.warn('No run step or runId found for run step delta event'); |
| | return; |
| | } |
| |
|
| | const response = messageMap.current.get(responseMessageId); |
| | if ( |
| | response && |
| | runStepDelta.delta.type === StepTypes.TOOL_CALLS && |
| | runStepDelta.delta.tool_calls |
| | ) { |
| | let updatedResponse = { ...response }; |
| |
|
| | runStepDelta.delta.tool_calls.forEach((toolCallDelta) => { |
| | const toolCallId = toolCallIdMap.current.get(runStepDelta.id) ?? ''; |
| |
|
| | const contentPart: Agents.MessageContentComplex = { |
| | type: ContentTypes.TOOL_CALL, |
| | tool_call: { |
| | name: toolCallDelta.name ?? '', |
| | args: toolCallDelta.args ?? '', |
| | id: toolCallId, |
| | }, |
| | }; |
| |
|
| | if (runStepDelta.delta.auth != null) { |
| | contentPart.tool_call.auth = runStepDelta.delta.auth; |
| | contentPart.tool_call.expires_at = runStepDelta.delta.expires_at; |
| | } |
| |
|
| | |
| | const currentIndex = runStep.index + initialContent.length; |
| | updatedResponse = updateContent(updatedResponse, currentIndex, contentPart); |
| | }); |
| |
|
| | messageMap.current.set(responseMessageId, updatedResponse); |
| | const updatedMessages = messages.map((msg) => |
| | msg.messageId === responseMessageId ? updatedResponse : msg, |
| | ); |
| |
|
| | setMessages(updatedMessages); |
| | } |
| | } else if (event === 'on_run_step_completed') { |
| | const { result } = data as unknown as { result: Agents.ToolEndEvent }; |
| |
|
| | const { id: stepId } = result; |
| |
|
| | const runStep = stepMap.current.get(stepId); |
| | let responseMessageId = runStep?.runId ?? ''; |
| | if (responseMessageId === Constants.USE_PRELIM_RESPONSE_MESSAGE_ID) { |
| | responseMessageId = submission?.initialResponse?.messageId ?? ''; |
| | parentMessageId = submission?.initialResponse?.parentMessageId ?? ''; |
| | } |
| |
|
| | if (!runStep || !responseMessageId) { |
| | console.warn('No run step or runId found for completed tool call event'); |
| | return; |
| | } |
| |
|
| | const response = messageMap.current.get(responseMessageId); |
| | if (response) { |
| | let updatedResponse = { ...response }; |
| |
|
| | const contentPart: Agents.MessageContentComplex = { |
| | type: ContentTypes.TOOL_CALL, |
| | tool_call: result.tool_call, |
| | }; |
| |
|
| | |
| | const currentIndex = runStep.index + initialContent.length; |
| | updatedResponse = updateContent(updatedResponse, currentIndex, contentPart, true); |
| |
|
| | messageMap.current.set(responseMessageId, updatedResponse); |
| | const updatedMessages = messages.map((msg) => |
| | msg.messageId === responseMessageId ? updatedResponse : msg, |
| | ); |
| |
|
| | setMessages(updatedMessages); |
| | } |
| | } |
| |
|
| | return () => { |
| | toolCallIdMap.current.clear(); |
| | messageMap.current.clear(); |
| | stepMap.current.clear(); |
| | }; |
| | }, |
| | [getMessages, setIsSubmitting, lastAnnouncementTimeRef, announcePolite, setMessages], |
| | ); |
| |
|
| | const clearStepMaps = useCallback(() => { |
| | toolCallIdMap.current.clear(); |
| | messageMap.current.clear(); |
| | stepMap.current.clear(); |
| | }, []); |
| | return { stepHandler, clearStepMaps }; |
| | } |
| |
|