Spaces:
Running
Running
| """WebSocket connection manager for real-time communication.""" | |
| import logging | |
| from typing import Any | |
| from fastapi import WebSocket | |
| logger = logging.getLogger(__name__) | |
| class ConnectionManager: | |
| """Manages WebSocket connections for multiple sessions.""" | |
| def __init__(self) -> None: | |
| # session_id -> WebSocket | |
| self.active_connections: dict[str, WebSocket] = {} | |
| async def connect(self, websocket: WebSocket, session_id: str) -> None: | |
| """Accept a WebSocket connection and register it.""" | |
| logger.info(f"Attempting to accept WebSocket for session {session_id}") | |
| await websocket.accept() | |
| self.active_connections[session_id] = websocket | |
| logger.info(f"WebSocket connected and registered for session {session_id}") | |
| def disconnect(self, session_id: str) -> None: | |
| """Remove a WebSocket connection.""" | |
| if session_id in self.active_connections: | |
| del self.active_connections[session_id] | |
| logger.info(f"WebSocket disconnected for session {session_id}") | |
| async def send_event( | |
| self, session_id: str, event_type: str, data: dict[str, Any] | None = None | |
| ) -> None: | |
| """Send an event to a specific session's WebSocket.""" | |
| if session_id not in self.active_connections: | |
| logger.warning(f"No active connection for session {session_id}") | |
| return | |
| message = {"event_type": event_type} | |
| if data is not None: | |
| message["data"] = data | |
| try: | |
| await self.active_connections[session_id].send_json(message) | |
| except Exception as e: | |
| logger.error(f"Error sending to session {session_id}: {e}") | |
| self.disconnect(session_id) | |
| async def broadcast( | |
| self, event_type: str, data: dict[str, Any] | None = None | |
| ) -> None: | |
| """Broadcast an event to all connected sessions.""" | |
| for session_id in list(self.active_connections.keys()): | |
| await self.send_event(session_id, event_type, data) | |
| def is_connected(self, session_id: str) -> bool: | |
| """Check if a session has an active WebSocket connection.""" | |
| return session_id in self.active_connections | |
| # Global connection manager instance | |
| manager = ConnectionManager() | |