Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from shared import RealtimeSpeakerDiarization | |
| import numpy as np | |
| import uvicorn | |
| import logging | |
| import asyncio | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI() | |
| # Add CORS middleware for browser compatibility | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize the diarization system | |
| logger.info("Initializing diarization system...") | |
| diart = RealtimeSpeakerDiarization() | |
| success = diart.initialize_models() | |
| logger.info(f"Models initialized: {success}") | |
| if success: | |
| diart.start_recording() | |
| # Track active WebSocket connections | |
| active_connections = set() | |
| # Periodic status update function | |
| async def send_conversation_updates(): | |
| """Periodically send conversation updates to all connected clients""" | |
| while True: | |
| if active_connections: | |
| try: | |
| # Get current conversation HTML | |
| conversation_html = diart.get_formatted_conversation() | |
| # Send to all active connections | |
| for ws in active_connections.copy(): | |
| try: | |
| await ws.send_text(conversation_html) | |
| except Exception as e: | |
| logger.error(f"Error sending to WebSocket: {e}") | |
| active_connections.discard(ws) | |
| except Exception as e: | |
| logger.error(f"Error in conversation update: {e}") | |
| # Wait before sending next update | |
| await asyncio.sleep(0.5) # 500ms update interval | |
| async def startup_event(): | |
| """Start background tasks when the app starts""" | |
| asyncio.create_task(send_conversation_updates()) | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "system_running": diart.is_running, | |
| "active_connections": len(active_connections) | |
| } | |
| async def ws_inference(ws: WebSocket): | |
| """WebSocket endpoint for real-time audio processing""" | |
| await ws.accept() | |
| active_connections.add(ws) | |
| logger.info(f"WebSocket connection established. Total connections: {len(active_connections)}") | |
| try: | |
| # Send initial conversation state | |
| conversation_html = diart.get_formatted_conversation() | |
| await ws.send_text(conversation_html) | |
| # Process incoming audio chunks | |
| async for chunk in ws.iter_bytes(): | |
| try: | |
| # Process raw audio bytes | |
| if chunk: | |
| # Process audio data - this updates the internal conversation state | |
| diart.process_audio_chunk(chunk) | |
| except Exception as e: | |
| logger.error(f"Error processing audio chunk: {e}") | |
| except WebSocketDisconnect: | |
| logger.info("WebSocket disconnected") | |
| except Exception as e: | |
| logger.error(f"WebSocket error: {e}") | |
| finally: | |
| active_connections.discard(ws) | |
| logger.info(f"WebSocket connection closed. Remaining connections: {len(active_connections)}") | |
| async def get_conversation(): | |
| """Get the current conversation as HTML""" | |
| return {"conversation": diart.get_formatted_conversation()} | |
| async def get_status(): | |
| """Get system status information""" | |
| return {"status": diart.get_status_info()} | |
| async def update_settings(threshold: float, max_speakers: int): | |
| """Update speaker detection settings""" | |
| result = diart.update_settings(threshold, max_speakers) | |
| return {"result": result} | |
| async def clear_conversation(): | |
| """Clear the conversation""" | |
| result = diart.clear_conversation() | |
| return {"result": result} | |
| # Import UI module to mount the Gradio app | |
| try: | |
| import ui | |
| ui.mount_ui(app) | |
| logger.info("Gradio UI mounted successfully") | |
| except ImportError: | |
| logger.warning("UI module not found, running in API-only mode") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |