Spaces:
Paused
Paused
| """ | |
| Main FastAPI application integrating all components with Hugging Face Inference Endpoint. | |
| """ | |
| import gradio as gr | |
| import fastapi | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import HTMLResponse, FileResponse, JSONResponse | |
| from fastapi import FastAPI, Request, Form, UploadFile, File | |
| import os | |
| import time | |
| import logging | |
| import json | |
| import shutil | |
| import uvicorn | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Any | |
| import io | |
| import numpy as np | |
| from scipy.io.wavfile import write | |
| # Import our modules | |
| from local_llm import run_llm, run_llm_with_memory, clear_memory, get_memory_sessions, get_model_info, test_endpoint | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Create the FastAPI app | |
| app = FastAPI(title="AGI Telecom POC") | |
| # Create static directory if it doesn't exist | |
| static_dir = Path("static") | |
| static_dir.mkdir(exist_ok=True) | |
| # Copy index.html from templates to static if it doesn't exist | |
| html_template = Path("templates/index.html") | |
| static_html = static_dir / "index.html" | |
| if html_template.exists() and not static_html.exists(): | |
| shutil.copy(html_template, static_html) | |
| # Mount static files | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # Helper functions for mock implementations | |
| def mock_transcribe(audio_bytes): | |
| """Mock function to simulate speech-to-text.""" | |
| logger.info("Transcribing audio...") | |
| time.sleep(0.5) # Simulate processing time | |
| return "This is a mock transcription of the audio." | |
| def mock_synthesize_speech(text): | |
| """Mock function to simulate text-to-speech.""" | |
| logger.info("Synthesizing speech...") | |
| time.sleep(0.5) # Simulate processing time | |
| # Create a dummy audio file | |
| sample_rate = 22050 | |
| duration = 2 # seconds | |
| t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) | |
| audio = np.sin(2 * np.pi * 440 * t) * 0.3 | |
| output_file = "temp_audio.wav" | |
| write(output_file, sample_rate, audio.astype(np.float32)) | |
| with open(output_file, "rb") as f: | |
| audio_bytes = f.read() | |
| return audio_bytes | |
| # Routes for the API | |
| async def root(): | |
| """Serve the main UI.""" | |
| return FileResponse("static/index.html") | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| endpoint_status = test_endpoint() | |
| return { | |
| "status": "ok", | |
| "endpoint": endpoint_status | |
| } | |
| async def transcribe(file: UploadFile = File(...)): | |
| """Transcribe audio to text.""" | |
| try: | |
| audio_bytes = await file.read() | |
| text = mock_transcribe(audio_bytes) | |
| return {"transcription": text} | |
| except Exception as e: | |
| logger.error(f"Transcription error: {str(e)}") | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": f"Failed to transcribe audio: {str(e)}"} | |
| ) | |
| async def query_agent(input_text: str = Form(...), session_id: str = Form("default")): | |
| """Process a text query with the agent.""" | |
| try: | |
| response = run_llm_with_memory(input_text, session_id=session_id) | |
| logger.info(f"Query: {input_text[:30]}... Response: {response[:30]}...") | |
| return {"response": response} | |
| except Exception as e: | |
| logger.error(f"Query error: {str(e)}") | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": f"Failed to process query: {str(e)}"} | |
| ) | |
| async def speak(text: str = Form(...)): | |
| """Convert text to speech.""" | |
| try: | |
| audio_bytes = mock_synthesize_speech(text) | |
| return FileResponse( | |
| "temp_audio.wav", | |
| media_type="audio/wav", | |
| filename="response.wav" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Speech synthesis error: {str(e)}") | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": f"Failed to synthesize speech: {str(e)}"} | |
| ) | |
| async def create_session(): | |
| """Create a new session.""" | |
| import uuid | |
| session_id = str(uuid.uuid4()) | |
| clear_memory(session_id) | |
| return {"session_id": session_id} | |
| async def delete_session(session_id: str): | |
| """Delete a session.""" | |
| success = clear_memory(session_id) | |
| if success: | |
| return {"message": f"Session {session_id} cleared"} | |
| else: | |
| return JSONResponse( | |
| status_code=404, | |
| content={"error": f"Session {session_id} not found"} | |
| ) | |
| async def list_sessions(): | |
| """List all active sessions.""" | |
| return {"sessions": get_memory_sessions()} | |
| async def model_info(): | |
| """Get information about the model.""" | |
| return get_model_info() | |
| async def complete_flow( | |
| request: Request, | |
| audio_file: UploadFile = File(None), | |
| text_input: str = Form(None), | |
| session_id: str = Form("default") | |
| ): | |
| """ | |
| Complete flow: audio to text to agent to speech. | |
| """ | |
| try: | |
| # If audio file provided, transcribe it | |
| if audio_file: | |
| audio_bytes = await audio_file.read() | |
| text_input = mock_transcribe(audio_bytes) | |
| logger.info(f"Transcribed input: {text_input[:30]}...") | |
| # Process with agent | |
| if not text_input: | |
| return JSONResponse( | |
| status_code=400, | |
| content={"error": "No input provided"} | |
| ) | |
| response = run_llm_with_memory(text_input, session_id=session_id) | |
| logger.info(f"Agent response: {response[:30]}...") | |
| # Synthesize speech | |
| audio_bytes = mock_synthesize_speech(response) | |
| # Save audio to a temporary file | |
| import tempfile | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
| temp_file.write(audio_bytes) | |
| temp_file.close() | |
| # Generate URL for audio | |
| host = request.headers.get("host", "localhost") | |
| scheme = request.headers.get("x-forwarded-proto", "http") | |
| audio_url = f"{scheme}://{host}/audio/{os.path.basename(temp_file.name)}" | |
| return { | |
| "input": text_input, | |
| "response": response, | |
| "audio_url": audio_url | |
| } | |
| except Exception as e: | |
| logger.error(f"Complete flow error: {str(e)}") | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": f"Failed to process: {str(e)}"} | |
| ) | |
| async def get_audio(filename: str): | |
| """ | |
| Serve temporary audio files. | |
| """ | |
| try: | |
| # Ensure filename only contains safe characters | |
| import re | |
| if not re.match(r'^[a-zA-Z0-9_.-]+$', filename): | |
| return JSONResponse( | |
| status_code=400, | |
| content={"error": "Invalid filename"} | |
| ) | |
| temp_dir = tempfile.gettempdir() | |
| file_path = os.path.join(temp_dir, filename) | |
| if not os.path.exists(file_path): | |
| return JSONResponse( | |
| status_code=404, | |
| content={"error": "File not found"} | |
| ) | |
| return FileResponse( | |
| file_path, | |
| media_type="audio/wav", | |
| filename=filename | |
| ) | |
| except Exception as e: | |
| logger.error(f"Audio serving error: {str(e)}") | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": f"Failed to serve audio: {str(e)}"} | |
| ) | |
| # Gradio interface | |
| with gr.Blocks(title="AGI Telecom POC", css="footer {visibility: hidden}") as interface: | |
| gr.Markdown("# AGI Telecom POC Demo") | |
| gr.Markdown("This is a demonstration of the AGI Telecom Proof of Concept using a Hugging Face Inference Endpoint.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Input components | |
| audio_input = gr.Audio(label="Voice Input", type="filepath") | |
| text_input = gr.Textbox(label="Text Input", placeholder="Type your message here...", lines=2) | |
| # Session management | |
| session_id = gr.Textbox(label="Session ID", value="default") | |
| new_session_btn = gr.Button("New Session") | |
| # Action buttons | |
| with gr.Row(): | |
| transcribe_btn = gr.Button("Transcribe Audio") | |
| query_btn = gr.Button("Send Query") | |
| speak_btn = gr.Button("Speak Response") | |
| with gr.Column(): | |
| # Output components | |
| transcription_output = gr.Textbox(label="Transcription", lines=2) | |
| response_output = gr.Textbox(label="Agent Response", lines=5) | |
| audio_output = gr.Audio(label="Voice Response", autoplay=True) | |
| # Status and info | |
| status_output = gr.Textbox(label="Status", value="Ready") | |
| endpoint_status = gr.Textbox(label="Endpoint Status", value="Checking endpoint connection...") | |
| # Link components with functions | |
| def update_session(): | |
| import uuid | |
| new_id = str(uuid.uuid4()) | |
| clear_memory(new_id) | |
| status = f"Created new session: {new_id}" | |
| return new_id, status | |
| new_session_btn.click( | |
| update_session, | |
| outputs=[session_id, status_output] | |
| ) | |
| def process_audio(audio_path, session): | |
| if not audio_path: | |
| return "No audio provided", "", None, "Error: No audio input" | |
| try: | |
| with open(audio_path, "rb") as f: | |
| audio_bytes = f.read() | |
| # Transcribe | |
| text = mock_transcribe(audio_bytes) | |
| # Get response | |
| response = run_llm_with_memory(text, session) | |
| # Synthesize | |
| audio_bytes = mock_synthesize_speech(response) | |
| temp_file = "temp_response.wav" | |
| with open(temp_file, "wb") as f: | |
| f.write(audio_bytes) | |
| return text, response, temp_file, "Processed successfully" | |
| except Exception as e: | |
| logger.error(f"Error: {str(e)}") | |
| return "", "", None, f"Error: {str(e)}" | |
| transcribe_btn.click( | |
| lambda audio_path: mock_transcribe(open(audio_path, "rb").read()) if audio_path else "No audio provided", | |
| inputs=[audio_input], | |
| outputs=[transcription_output] | |
| ) | |
| query_btn.click( | |
| lambda text, session: run_llm_with_memory(text, session), | |
| inputs=[text_input, session_id], | |
| outputs=[response_output] | |
| ) | |
| speak_btn.click( | |
| lambda text: "temp_response.wav" if mock_synthesize_speech(text) else None, | |
| inputs=[response_output], | |
| outputs=[audio_output] | |
| ) | |
| # Full process | |
| audio_input.change( | |
| process_audio, | |
| inputs=[audio_input, session_id], | |
| outputs=[transcription_output, response_output, audio_output, status_output] | |
| ) | |
| # Check endpoint on load | |
| def check_endpoint(): | |
| status = test_endpoint() | |
| if status["status"] == "connected": | |
| return f"✅ Connected to endpoint: {status['message']}" | |
| else: | |
| return f"❌ Error connecting to endpoint: {status['message']}" | |
| gr.on_load(lambda: gr.update(value=check_endpoint()), outputs=endpoint_status) | |
| # Mount Gradio app | |
| app = gr.mount_gradio_app(app, interface, path="/gradio") | |
| # Run the app | |
| if __name__ == "__main__": | |
| # Check if running on HF Spaces | |
| if os.environ.get("SPACE_ID"): | |
| # Running on HF Spaces - use their port | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |
| else: | |
| # Running locally | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |