File size: 3,252 Bytes
8c9cdca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import JSONResponse
import gradio as gr
import whisperx
import torch
import tempfile
import os
import uvicorn
from threading import Thread

# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"

print(f"πŸš€ Device: {device}, Compute: {compute_type}")

# Create FastAPI app
app = FastAPI(title="WhisperX Alignment API")

def process_audio(audio_path: str, language: str = "en"):
    """Core alignment logic"""
    try:
        print(f"πŸ“ Processing {audio_path} ({language})...")
        
        # Load model
        model = whisperx.load_model("base", device=device, compute_type=compute_type)
        
        # Transcribe
        result = model.transcribe(audio_path, language=language)
        
        # Align
        align_model, metadata = whisperx.load_align_model(language_code=language, device=device)
        aligned = whisperx.align(result["segments"], align_model, metadata, audio_path, device=device)
        
        # Extract word segments
        word_segments = []
        for segment in aligned["segments"]:
            for word in segment.get("words", []):
                word_segments.append({
                    "word": word["word"].strip(),
                    "start": round(word["start"], 2),
                    "end": round(word["end"], 2)
                })
        
        duration = aligned["segments"][-1]["end"] if aligned["segments"] else 0
        
        return {
            "word_segments": word_segments,
            "duration": round(duration, 2),
            "word_count": len(word_segments),
            "language": language,
            "device": device
        }
    except Exception as e:
        print(f"❌ Error: {e}")
        return {"error": str(e)}

# FastAPI endpoint
@app.post("/align")
async def align_audio_api(
    audio_file: UploadFile = File(...),
    language: str = Form("en")
):
    """REST API endpoint for audio alignment"""
    temp_path = None
    try:
        # Save temp file
        with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp:
            content = await audio_file.read()
            tmp.write(content)
            temp_path = tmp.name
        
        # Process
        result = process_audio(temp_path, language)
        return JSONResponse(result)
    
    finally:
        if temp_path and os.path.exists(temp_path):
            os.unlink(temp_path)

@app.get("/")
def health():
    return {"status": "healthy", "device": device}

# Gradio interface
def align_gradio(audio_file, language="en"):
    """Gradio UI wrapper"""
    if not audio_file:
        return {"error": "No file"}
    return process_audio(audio_file, language)

gradio_app = gr.Interface(
    fn=align_gradio,
    inputs=[
        gr.Audio(type="filepath", label="Audio"),
        gr.Textbox(value="en", label="Language")
    ],
    outputs=gr.JSON(label="Result"),
    title="🎯 WhisperX Alignment",
    description="Upload audio for word-level timestamps"
)

# Mount Gradio to FastAPI
app = gr.mount_gradio_app(app, gradio_app, path="/")

# Launch
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)