stt-model / main.py
AbdoIR's picture
Update main.py
6be1daa verified
import os
import torch
import torchaudio
from fastapi import FastAPI, UploadFile, File, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import (
WhisperProcessor,
WhisperForConditionalGeneration,
AutoTokenizer,
AutoModelForSeq2SeqLM,
pipeline
)
from torch.quantization import quantize_dynamic
import logging
import ffmpeg
import tempfile
import datetime
# ========== Force HF cache to /tmp ==========
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface/datasets"
os.environ["XDG_CACHE_HOME"] = "/tmp/huggingface"
for key in ["HF_HOME", "TRANSFORMERS_CACHE", "HF_HUB_CACHE", "HF_DATASETS_CACHE", "XDG_CACHE_HOME"]:
path = os.environ.get(key)
if path:
os.makedirs(path, exist_ok=True)
# Silence all transformers and huggingface logging
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("urllib3").setLevel(logging.ERROR)
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ========== Load Whisper Model (quantized) ==========
def load_whisper_model(model_size="small", save_dir="/tmp/models_cache/whisper"):
os.makedirs(save_dir, exist_ok=True)
model_name = f"openai/whisper-{model_size}"
processor = WhisperProcessor.from_pretrained(model_name, cache_dir=save_dir)
model = WhisperForConditionalGeneration.from_pretrained(model_name, cache_dir=save_dir)
model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
model.to("cuda" if torch.cuda.is_available() else "cpu")
return processor, model
# ========== Load Grammar Correction Model (quantized) ==========
def load_grammar_model(save_dir="/tmp/models_cache/grammar_corrector"):
os.makedirs(save_dir, exist_ok=True)
model_name = "prithivida/grammar_error_correcter_v1"
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=save_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=save_dir)
model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
grammar_pipeline = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1
)
return grammar_pipeline
# ========== Optimized Audio Loader ==========
def load_audio(audio_path):
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_wav:
tmp_wav_path = tmp_wav.name
try:
(
ffmpeg
.input(audio_path)
.output(tmp_wav_path, format='wav', ac=1, ar='16k')
.overwrite_output()
.run(quiet=True)
)
waveform, sample_rate = torchaudio.load(tmp_wav_path)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
return waveform.squeeze().numpy(), 16000
finally:
if os.path.exists(tmp_wav_path):
os.remove(tmp_wav_path)
# ========== Audio Transcription ==========
def transcribe_audio(audio_file, processor, model):
audio, _ = load_audio(audio_file)
input_features = processor(audio, sampling_rate=16000, return_tensors="pt").input_features
input_features = input_features.to(model.device)
with torch.no_grad():
generated_ids = model.generate(input_features)
return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
def transcribe_long_audio(audio_file, processor, model, chunk_length_s=30):
audio, sample_rate = load_audio(audio_file)
audio_length_s = len(audio) / sample_rate
if audio_length_s <= chunk_length_s:
return transcribe_audio(audio_file, processor, model)
chunk_size = int(chunk_length_s * sample_rate)
transcription_chunks = []
for i in range(0, len(audio), chunk_size):
chunk = audio[i:i + chunk_size]
if len(chunk) < 0.5 * chunk_size:
continue
inputs = processor(chunk, sampling_rate=16000, return_tensors="pt")
input_features = inputs.input_features.to(model.device)
with torch.no_grad():
generated_ids = model.generate(input_features)
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
transcription_chunks.append(text)
return " ".join(transcription_chunks)
# ========== Grammar Correction ==========
def correct_grammar(text, grammar_pipeline):
sentences = [s.strip() for s in text.split('.') if s.strip()]
results = grammar_pipeline(sentences, batch_size=4)
return '. '.join([r['generated_text'] for r in results])
# ========== Initialize Models ==========
processor, whisper_model = load_whisper_model("small")
grammar_pipeline = load_grammar_model()
# ========== Warm-Up Models ==========
def warm_up_models():
dummy_audio = torch.zeros(1, 80, 3000).to(whisper_model.device)
with torch.no_grad():
whisper_model.generate(dummy_audio)
_ = correct_grammar("This is a warm up test.", grammar_pipeline)
warm_up_models()
# ========== Routes ==========
@app.api_route("/", methods=["GET", "HEAD"])
async def index(request: Request):
return JSONResponse({
"status": "ok",
"message": "Server is alive",
"timestamp": datetime.datetime.utcnow().isoformat() + "Z"
})
@app.post('/transcribe')
async def transcribe(audio: UploadFile = File(...)):
if not audio:
raise HTTPException(status_code=400, detail="No audio file provided.")
os.makedirs("/tmp/temp_audio", exist_ok=True)
audio_path = f"/tmp/temp_audio/{audio.filename}"
# Save uploaded file
try:
with open(audio_path, "wb") as f:
content = await audio.read()
f.write(content)
transcription = transcribe_long_audio(audio_path, processor, whisper_model)
corrected_text = correct_grammar(transcription, grammar_pipeline)
return JSONResponse({
"raw_transcription": transcription,
"corrected_transcription": corrected_text
})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
if os.path.exists(audio_path):
os.remove(audio_path)
except Exception:
pass
# ========== Run App ==========
if __name__ == '__main__':
# Run with Uvicorn for FastAPI
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=7860, log_level="info")