ml-agent / backend /websocket.py
tfrere's picture
tfrere HF Staff
Comprehensive frontend/backend overhaul: unified tool system, responsive UI, code quality
e77f678
"""WebSocket connection manager for real-time communication."""
import logging
from typing import Any
from fastapi import WebSocket
logger = logging.getLogger(__name__)
class ConnectionManager:
"""Manages WebSocket connections for multiple sessions."""
def __init__(self) -> None:
# session_id -> WebSocket
self.active_connections: dict[str, WebSocket] = {}
async def connect(self, websocket: WebSocket, session_id: str) -> None:
"""Accept a WebSocket connection and register it."""
logger.info(f"Attempting to accept WebSocket for session {session_id}")
await websocket.accept()
self.active_connections[session_id] = websocket
logger.info(f"WebSocket connected and registered for session {session_id}")
def disconnect(self, session_id: str) -> None:
"""Remove a WebSocket connection."""
if session_id in self.active_connections:
del self.active_connections[session_id]
logger.info(f"WebSocket disconnected for session {session_id}")
async def send_event(
self, session_id: str, event_type: str, data: dict[str, Any] | None = None
) -> None:
"""Send an event to a specific session's WebSocket."""
if session_id not in self.active_connections:
logger.warning(f"No active connection for session {session_id}")
return
message = {"event_type": event_type}
if data is not None:
message["data"] = data
try:
await self.active_connections[session_id].send_json(message)
except Exception as e:
logger.error(f"Error sending to session {session_id}: {e}")
self.disconnect(session_id)
async def broadcast(
self, event_type: str, data: dict[str, Any] | None = None
) -> None:
"""Broadcast an event to all connected sessions."""
for session_id in list(self.active_connections.keys()):
await self.send_event(session_id, event_type, data)
def is_connected(self, session_id: str) -> bool:
"""Check if a session has an active WebSocket connection."""
return session_id in self.active_connections
# Global connection manager instance
manager = ConnectionManager()