| | import { useCallback } from 'react' |
| |
|
| | import { APIRoutes } from '@/api/routes' |
| |
|
| | import useChatActions from '@/hooks/useChatActions' |
| | import { useStore } from '../store' |
| | import { RunEvent, RunResponseContent, type RunResponse } from '@/types/os' |
| | import { constructEndpointUrl } from '@/lib/constructEndpointUrl' |
| | import useAIResponseStream from './useAIResponseStream' |
| | import { ToolCall } from '@/types/os' |
| | import { useQueryState } from 'nuqs' |
| | import { getJsonMarkdown } from '@/lib/utils' |
| |
|
| | const useAIChatStreamHandler = () => { |
| | const setMessages = useStore((state) => state.setMessages) |
| | const { addMessage, focusChatInput } = useChatActions() |
| | const [agentId] = useQueryState('agent') |
| | const [teamId] = useQueryState('team') |
| | const [sessionId, setSessionId] = useQueryState('session') |
| | const selectedEndpoint = useStore((state) => state.selectedEndpoint) |
| | const mode = useStore((state) => state.mode) |
| | const setStreamingErrorMessage = useStore( |
| | (state) => state.setStreamingErrorMessage |
| | ) |
| | const setIsStreaming = useStore((state) => state.setIsStreaming) |
| | const setSessionsData = useStore((state) => state.setSessionsData) |
| | const { streamResponse } = useAIResponseStream() |
| |
|
| | const updateMessagesWithErrorState = useCallback(() => { |
| | setMessages((prevMessages) => { |
| | const newMessages = [...prevMessages] |
| | const lastMessage = newMessages[newMessages.length - 1] |
| | if (lastMessage && lastMessage.role === 'agent') { |
| | lastMessage.streamingError = true |
| | } |
| | return newMessages |
| | }) |
| | }, [setMessages]) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | const processToolCall = useCallback( |
| | (toolCall: ToolCall, prevToolCalls: ToolCall[] = []) => { |
| | const toolCallId = |
| | toolCall.tool_call_id || `${toolCall.tool_name}-${toolCall.created_at}` |
| |
|
| | const existingToolCallIndex = prevToolCalls.findIndex( |
| | (tc) => |
| | (tc.tool_call_id && tc.tool_call_id === toolCall.tool_call_id) || |
| | (!tc.tool_call_id && |
| | toolCall.tool_name && |
| | toolCall.created_at && |
| | `${tc.tool_name}-${tc.created_at}` === toolCallId) |
| | ) |
| | if (existingToolCallIndex >= 0) { |
| | const updatedToolCalls = [...prevToolCalls] |
| | updatedToolCalls[existingToolCallIndex] = { |
| | ...updatedToolCalls[existingToolCallIndex], |
| | ...toolCall |
| | } |
| | return updatedToolCalls |
| | } else { |
| | return [...prevToolCalls, toolCall] |
| | } |
| | }, |
| | [] |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | const processChunkToolCalls = useCallback( |
| | ( |
| | chunk: RunResponseContent | RunResponse, |
| | existingToolCalls: ToolCall[] = [] |
| | ) => { |
| | let updatedToolCalls = [...existingToolCalls] |
| | |
| | if (chunk.tool) { |
| | updatedToolCalls = processToolCall(chunk.tool, updatedToolCalls) |
| | } |
| | |
| | if (chunk.tools && chunk.tools.length > 0) { |
| | for (const toolCall of chunk.tools) { |
| | updatedToolCalls = processToolCall(toolCall, updatedToolCalls) |
| | } |
| | } |
| |
|
| | return updatedToolCalls |
| | }, |
| | [processToolCall] |
| | ) |
| |
|
| | const handleStreamResponse = useCallback( |
| | async (input: string | FormData) => { |
| | setIsStreaming(true) |
| |
|
| | const formData = input instanceof FormData ? input : new FormData() |
| | if (typeof input === 'string') { |
| | formData.append('message', input) |
| | } |
| |
|
| | setMessages((prevMessages) => { |
| | if (prevMessages.length >= 2) { |
| | const lastMessage = prevMessages[prevMessages.length - 1] |
| | const secondLastMessage = prevMessages[prevMessages.length - 2] |
| | if ( |
| | lastMessage.role === 'agent' && |
| | lastMessage.streamingError && |
| | secondLastMessage.role === 'user' |
| | ) { |
| | return prevMessages.slice(0, -2) |
| | } |
| | } |
| | return prevMessages |
| | }) |
| |
|
| | addMessage({ |
| | role: 'user', |
| | content: formData.get('message') as string, |
| | created_at: Math.floor(Date.now() / 1000) |
| | }) |
| |
|
| | addMessage({ |
| | role: 'agent', |
| | content: '', |
| | tool_calls: [], |
| | streamingError: false, |
| | created_at: Math.floor(Date.now() / 1000) + 1 |
| | }) |
| |
|
| | let lastContent = '' |
| | let newSessionId = sessionId |
| | try { |
| | const endpointUrl = constructEndpointUrl(selectedEndpoint) |
| |
|
| | let RunUrl: string | null = null |
| |
|
| | if (mode === 'team' && teamId) { |
| | RunUrl = APIRoutes.TeamRun(endpointUrl, teamId) |
| | } else if (mode === 'agent' && agentId) { |
| | RunUrl = APIRoutes.AgentRun(endpointUrl).replace( |
| | '{agent_id}', |
| | agentId |
| | ) |
| | } |
| |
|
| | if (!RunUrl) { |
| | updateMessagesWithErrorState() |
| | setStreamingErrorMessage('Please select an agent or team first.') |
| | setIsStreaming(false) |
| | return |
| | } |
| |
|
| | formData.append('stream', 'true') |
| | formData.append('session_id', sessionId ?? '') |
| |
|
| | await streamResponse({ |
| | apiUrl: RunUrl, |
| | requestBody: formData, |
| | onChunk: (chunk: RunResponse) => { |
| | if ( |
| | chunk.event === RunEvent.RunStarted || |
| | chunk.event === RunEvent.TeamRunStarted || |
| | chunk.event === RunEvent.ReasoningStarted || |
| | chunk.event === RunEvent.TeamReasoningStarted |
| | ) { |
| | newSessionId = chunk.session_id as string |
| | setSessionId(chunk.session_id as string) |
| | if ( |
| | (!sessionId || sessionId !== chunk.session_id) && |
| | chunk.session_id |
| | ) { |
| | const sessionData = { |
| | session_id: chunk.session_id as string, |
| | session_name: formData.get('message') as string, |
| | created_at: chunk.created_at |
| | } |
| | setSessionsData((prevSessionsData) => { |
| | const sessionExists = prevSessionsData?.some( |
| | (session) => session.session_id === chunk.session_id |
| | ) |
| | if (sessionExists) { |
| | return prevSessionsData |
| | } |
| | return [sessionData, ...(prevSessionsData ?? [])] |
| | }) |
| | } |
| | } else if ( |
| | chunk.event === RunEvent.ToolCallStarted || |
| | chunk.event === RunEvent.TeamToolCallStarted || |
| | chunk.event === RunEvent.ToolCallCompleted || |
| | chunk.event === RunEvent.TeamToolCallCompleted |
| | ) { |
| | setMessages((prevMessages) => { |
| | const newMessages = [...prevMessages] |
| | const lastMessage = newMessages[newMessages.length - 1] |
| | if (lastMessage && lastMessage.role === 'agent') { |
| | lastMessage.tool_calls = processChunkToolCalls( |
| | chunk, |
| | lastMessage.tool_calls |
| | ) |
| | } |
| | return newMessages |
| | }) |
| | } else if ( |
| | chunk.event === RunEvent.RunContent || |
| | chunk.event === RunEvent.TeamRunContent |
| | ) { |
| | setMessages((prevMessages) => { |
| | const newMessages = [...prevMessages] |
| | const lastMessage = newMessages[newMessages.length - 1] |
| | if ( |
| | lastMessage && |
| | lastMessage.role === 'agent' && |
| | typeof chunk.content === 'string' |
| | ) { |
| | const uniqueContent = chunk.content.replace(lastContent, '') |
| | lastMessage.content += uniqueContent |
| | lastContent = chunk.content |
| |
|
| | |
| | lastMessage.tool_calls = processChunkToolCalls( |
| | chunk, |
| | lastMessage.tool_calls |
| | ) |
| | if (chunk.extra_data?.reasoning_steps) { |
| | lastMessage.extra_data = { |
| | ...lastMessage.extra_data, |
| | reasoning_steps: chunk.extra_data.reasoning_steps |
| | } |
| | } |
| |
|
| | if (chunk.extra_data?.references) { |
| | lastMessage.extra_data = { |
| | ...lastMessage.extra_data, |
| | references: chunk.extra_data.references |
| | } |
| | } |
| |
|
| | lastMessage.created_at = |
| | chunk.created_at ?? lastMessage.created_at |
| | if (chunk.images) { |
| | lastMessage.images = chunk.images |
| | } |
| | if (chunk.videos) { |
| | lastMessage.videos = chunk.videos |
| | } |
| | if (chunk.audio) { |
| | lastMessage.audio = chunk.audio |
| | } |
| | } else if ( |
| | lastMessage && |
| | lastMessage.role === 'agent' && |
| | typeof chunk?.content !== 'string' && |
| | chunk.content !== null |
| | ) { |
| | const jsonBlock = getJsonMarkdown(chunk?.content) |
| |
|
| | lastMessage.content += jsonBlock |
| | lastContent = jsonBlock |
| | } else if ( |
| | chunk.response_audio?.transcript && |
| | typeof chunk.response_audio?.transcript === 'string' |
| | ) { |
| | const transcript = chunk.response_audio.transcript |
| | lastMessage.response_audio = { |
| | ...lastMessage.response_audio, |
| | transcript: |
| | lastMessage.response_audio?.transcript + transcript |
| | } |
| | } |
| | return newMessages |
| | }) |
| | } else if ( |
| | chunk.event === RunEvent.ReasoningStep || |
| | chunk.event === RunEvent.TeamReasoningStep |
| | ) { |
| | setMessages((prevMessages) => { |
| | const newMessages = [...prevMessages] |
| | const lastMessage = newMessages[newMessages.length - 1] |
| | if (lastMessage && lastMessage.role === 'agent') { |
| | const existingSteps = |
| | lastMessage.extra_data?.reasoning_steps ?? [] |
| | const incomingSteps = chunk.extra_data?.reasoning_steps ?? [] |
| | lastMessage.extra_data = { |
| | ...lastMessage.extra_data, |
| | reasoning_steps: [...existingSteps, ...incomingSteps] |
| | } |
| | } |
| | return newMessages |
| | }) |
| | } else if ( |
| | chunk.event === RunEvent.ReasoningCompleted || |
| | chunk.event === RunEvent.TeamReasoningCompleted |
| | ) { |
| | setMessages((prevMessages) => { |
| | const newMessages = [...prevMessages] |
| | const lastMessage = newMessages[newMessages.length - 1] |
| | if (lastMessage && lastMessage.role === 'agent') { |
| | if (chunk.extra_data?.reasoning_steps) { |
| | lastMessage.extra_data = { |
| | ...lastMessage.extra_data, |
| | reasoning_steps: chunk.extra_data.reasoning_steps |
| | } |
| | } |
| | } |
| | return newMessages |
| | }) |
| | } else if ( |
| | chunk.event === RunEvent.RunError || |
| | chunk.event === RunEvent.TeamRunError || |
| | chunk.event === RunEvent.TeamRunCancelled |
| | ) { |
| | updateMessagesWithErrorState() |
| | const errorContent = |
| | (chunk.content as string) || |
| | (chunk.event === RunEvent.TeamRunCancelled |
| | ? 'Run cancelled' |
| | : 'Error during run') |
| | setStreamingErrorMessage(errorContent) |
| | if (newSessionId) { |
| | setSessionsData( |
| | (prevSessionsData) => |
| | prevSessionsData?.filter( |
| | (session) => session.session_id !== newSessionId |
| | ) ?? null |
| | ) |
| | } |
| | } else if ( |
| | chunk.event === RunEvent.UpdatingMemory || |
| | chunk.event === RunEvent.TeamMemoryUpdateStarted || |
| | chunk.event === RunEvent.TeamMemoryUpdateCompleted |
| | ) { |
| | |
| | } else if ( |
| | chunk.event === RunEvent.RunCompleted || |
| | chunk.event === RunEvent.TeamRunCompleted |
| | ) { |
| | setMessages((prevMessages) => { |
| | const newMessages = prevMessages.map((message, index) => { |
| | if ( |
| | index === prevMessages.length - 1 && |
| | message.role === 'agent' |
| | ) { |
| | let updatedContent: string |
| | if (typeof chunk.content === 'string') { |
| | updatedContent = chunk.content |
| | } else { |
| | try { |
| | updatedContent = JSON.stringify(chunk.content) |
| | } catch { |
| | updatedContent = 'Error parsing response' |
| | } |
| | } |
| | return { |
| | ...message, |
| | content: updatedContent, |
| | tool_calls: processChunkToolCalls( |
| | chunk, |
| | message.tool_calls |
| | ), |
| | images: chunk.images ?? message.images, |
| | videos: chunk.videos ?? message.videos, |
| | response_audio: chunk.response_audio, |
| | created_at: chunk.created_at ?? message.created_at, |
| | extra_data: { |
| | reasoning_steps: |
| | chunk.extra_data?.reasoning_steps ?? |
| | message.extra_data?.reasoning_steps, |
| | references: |
| | chunk.extra_data?.references ?? |
| | message.extra_data?.references |
| | } |
| | } |
| | } |
| | return message |
| | }) |
| | return newMessages |
| | }) |
| | } |
| | }, |
| | onError: (error) => { |
| | updateMessagesWithErrorState() |
| | setStreamingErrorMessage(error.message) |
| | if (newSessionId) { |
| | setSessionsData( |
| | (prevSessionsData) => |
| | prevSessionsData?.filter( |
| | (session) => session.session_id !== newSessionId |
| | ) ?? null |
| | ) |
| | } |
| | }, |
| | onComplete: () => {} |
| | }) |
| | } catch (error) { |
| | updateMessagesWithErrorState() |
| | setStreamingErrorMessage( |
| | error instanceof Error ? error.message : String(error) |
| | ) |
| | if (newSessionId) { |
| | setSessionsData( |
| | (prevSessionsData) => |
| | prevSessionsData?.filter( |
| | (session) => session.session_id !== newSessionId |
| | ) ?? null |
| | ) |
| | } |
| | } finally { |
| | focusChatInput() |
| | setIsStreaming(false) |
| | } |
| | }, |
| | [ |
| | setMessages, |
| | addMessage, |
| | updateMessagesWithErrorState, |
| | selectedEndpoint, |
| | streamResponse, |
| | agentId, |
| | teamId, |
| | mode, |
| | setStreamingErrorMessage, |
| | setIsStreaming, |
| | focusChatInput, |
| | setSessionsData, |
| | sessionId, |
| | setSessionId, |
| | processChunkToolCalls |
| | ] |
| | ) |
| |
|
| | return { handleStreamResponse } |
| | } |
| |
|
| | export default useAIChatStreamHandler |
| |
|