from fastapi import FastAPI, File, UploadFile, Form from fastapi.responses import JSONResponse import gradio as gr import whisperx import torch import tempfile import os import uvicorn from threading import Thread # Device setup device = "cuda" if torch.cuda.is_available() else "cpu" compute_type = "float16" if device == "cuda" else "int8" print(f"🚀 Device: {device}, Compute: {compute_type}") # Create FastAPI app app = FastAPI(title="WhisperX Alignment API") def process_audio(audio_path: str, language: str = "en"): """Core alignment logic""" try: print(f"📝 Processing {audio_path} ({language})...") # Load model model = whisperx.load_model("base", device=device, compute_type=compute_type) # Transcribe result = model.transcribe(audio_path, language=language) # Align align_model, metadata = whisperx.load_align_model(language_code=language, device=device) aligned = whisperx.align(result["segments"], align_model, metadata, audio_path, device=device) # Extract word segments word_segments = [] for segment in aligned["segments"]: for word in segment.get("words", []): word_segments.append({ "word": word["word"].strip(), "start": round(word["start"], 2), "end": round(word["end"], 2) }) duration = aligned["segments"][-1]["end"] if aligned["segments"] else 0 return { "word_segments": word_segments, "duration": round(duration, 2), "word_count": len(word_segments), "language": language, "device": device } except Exception as e: print(f"❌ Error: {e}") return {"error": str(e)} # FastAPI endpoint @app.post("/align") async def align_audio_api( audio_file: UploadFile = File(...), language: str = Form("en") ): """REST API endpoint for audio alignment""" temp_path = None try: # Save temp file with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp: content = await audio_file.read() tmp.write(content) temp_path = tmp.name # Process result = process_audio(temp_path, language) return JSONResponse(result) finally: if temp_path and os.path.exists(temp_path): os.unlink(temp_path) @app.get("/") def health(): return {"status": "healthy", "device": device} # Gradio interface def align_gradio(audio_file, language="en"): """Gradio UI wrapper""" if not audio_file: return {"error": "No file"} return process_audio(audio_file, language) gradio_app = gr.Interface( fn=align_gradio, inputs=[ gr.Audio(type="filepath", label="Audio"), gr.Textbox(value="en", label="Language") ], outputs=gr.JSON(label="Result"), title="🎯 WhisperX Alignment", description="Upload audio for word-level timestamps" ) # Mount Gradio to FastAPI app = gr.mount_gradio_app(app, gradio_app, path="/") # Launch if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)