|
|
from fastapi import FastAPI, File, UploadFile, Form |
|
|
from fastapi.responses import JSONResponse |
|
|
import gradio as gr |
|
|
import whisperx |
|
|
import torch |
|
|
import tempfile |
|
|
import os |
|
|
import uvicorn |
|
|
from threading import Thread |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
compute_type = "float16" if device == "cuda" else "int8" |
|
|
|
|
|
print(f"π Device: {device}, Compute: {compute_type}") |
|
|
|
|
|
|
|
|
app = FastAPI(title="WhisperX Alignment API") |
|
|
|
|
|
def process_audio(audio_path: str, language: str = "en"): |
|
|
"""Core alignment logic""" |
|
|
try: |
|
|
print(f"π Processing {audio_path} ({language})...") |
|
|
|
|
|
|
|
|
model = whisperx.load_model("base", device=device, compute_type=compute_type) |
|
|
|
|
|
|
|
|
result = model.transcribe(audio_path, language=language) |
|
|
|
|
|
|
|
|
align_model, metadata = whisperx.load_align_model(language_code=language, device=device) |
|
|
aligned = whisperx.align(result["segments"], align_model, metadata, audio_path, device=device) |
|
|
|
|
|
|
|
|
word_segments = [] |
|
|
for segment in aligned["segments"]: |
|
|
for word in segment.get("words", []): |
|
|
word_segments.append({ |
|
|
"word": word["word"].strip(), |
|
|
"start": round(word["start"], 2), |
|
|
"end": round(word["end"], 2) |
|
|
}) |
|
|
|
|
|
duration = aligned["segments"][-1]["end"] if aligned["segments"] else 0 |
|
|
|
|
|
return { |
|
|
"word_segments": word_segments, |
|
|
"duration": round(duration, 2), |
|
|
"word_count": len(word_segments), |
|
|
"language": language, |
|
|
"device": device |
|
|
} |
|
|
except Exception as e: |
|
|
print(f"β Error: {e}") |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
@app.post("/align") |
|
|
async def align_audio_api( |
|
|
audio_file: UploadFile = File(...), |
|
|
language: str = Form("en") |
|
|
): |
|
|
"""REST API endpoint for audio alignment""" |
|
|
temp_path = None |
|
|
try: |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp: |
|
|
content = await audio_file.read() |
|
|
tmp.write(content) |
|
|
temp_path = tmp.name |
|
|
|
|
|
|
|
|
result = process_audio(temp_path, language) |
|
|
return JSONResponse(result) |
|
|
|
|
|
finally: |
|
|
if temp_path and os.path.exists(temp_path): |
|
|
os.unlink(temp_path) |
|
|
|
|
|
@app.get("/") |
|
|
def health(): |
|
|
return {"status": "healthy", "device": device} |
|
|
|
|
|
|
|
|
def align_gradio(audio_file, language="en"): |
|
|
"""Gradio UI wrapper""" |
|
|
if not audio_file: |
|
|
return {"error": "No file"} |
|
|
return process_audio(audio_file, language) |
|
|
|
|
|
gradio_app = gr.Interface( |
|
|
fn=align_gradio, |
|
|
inputs=[ |
|
|
gr.Audio(type="filepath", label="Audio"), |
|
|
gr.Textbox(value="en", label="Language") |
|
|
], |
|
|
outputs=gr.JSON(label="Result"), |
|
|
title="π― WhisperX Alignment", |
|
|
description="Upload audio for word-level timestamps" |
|
|
) |
|
|
|
|
|
|
|
|
app = gr.mount_gradio_app(app, gradio_app, path="/") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|
|
|
|
|
|
|