|
|
import os |
|
|
import torch |
|
|
import torchaudio |
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException, Request |
|
|
from fastapi.responses import JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from transformers import ( |
|
|
WhisperProcessor, |
|
|
WhisperForConditionalGeneration, |
|
|
AutoTokenizer, |
|
|
AutoModelForSeq2SeqLM, |
|
|
pipeline |
|
|
) |
|
|
from torch.quantization import quantize_dynamic |
|
|
import logging |
|
|
import ffmpeg |
|
|
import tempfile |
|
|
import datetime |
|
|
|
|
|
|
|
|
os.environ["HF_HOME"] = "/tmp/huggingface" |
|
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers" |
|
|
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub" |
|
|
os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface/datasets" |
|
|
os.environ["XDG_CACHE_HOME"] = "/tmp/huggingface" |
|
|
|
|
|
for key in ["HF_HOME", "TRANSFORMERS_CACHE", "HF_HUB_CACHE", "HF_DATASETS_CACHE", "XDG_CACHE_HOME"]: |
|
|
path = os.environ.get(key) |
|
|
if path: |
|
|
os.makedirs(path, exist_ok=True) |
|
|
|
|
|
|
|
|
logging.getLogger("transformers").setLevel(logging.ERROR) |
|
|
logging.getLogger("urllib3").setLevel(logging.ERROR) |
|
|
logging.getLogger("huggingface_hub").setLevel(logging.ERROR) |
|
|
|
|
|
app = FastAPI() |
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
def load_whisper_model(model_size="small", save_dir="/tmp/models_cache/whisper"): |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
model_name = f"openai/whisper-{model_size}" |
|
|
processor = WhisperProcessor.from_pretrained(model_name, cache_dir=save_dir) |
|
|
model = WhisperForConditionalGeneration.from_pretrained(model_name, cache_dir=save_dir) |
|
|
model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) |
|
|
model.to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
return processor, model |
|
|
|
|
|
|
|
|
|
|
|
def load_grammar_model(save_dir="/tmp/models_cache/grammar_corrector"): |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
model_name = "prithivida/grammar_error_correcter_v1" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=save_dir) |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=save_dir) |
|
|
model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) |
|
|
grammar_pipeline = pipeline( |
|
|
"text2text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
device=0 if torch.cuda.is_available() else -1 |
|
|
) |
|
|
return grammar_pipeline |
|
|
|
|
|
|
|
|
def load_audio(audio_path): |
|
|
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_wav: |
|
|
tmp_wav_path = tmp_wav.name |
|
|
try: |
|
|
( |
|
|
ffmpeg |
|
|
.input(audio_path) |
|
|
.output(tmp_wav_path, format='wav', ac=1, ar='16k') |
|
|
.overwrite_output() |
|
|
.run(quiet=True) |
|
|
) |
|
|
waveform, sample_rate = torchaudio.load(tmp_wav_path) |
|
|
if sample_rate != 16000: |
|
|
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) |
|
|
waveform = resampler(waveform) |
|
|
return waveform.squeeze().numpy(), 16000 |
|
|
finally: |
|
|
if os.path.exists(tmp_wav_path): |
|
|
os.remove(tmp_wav_path) |
|
|
|
|
|
|
|
|
def transcribe_audio(audio_file, processor, model): |
|
|
audio, _ = load_audio(audio_file) |
|
|
input_features = processor(audio, sampling_rate=16000, return_tensors="pt").input_features |
|
|
input_features = input_features.to(model.device) |
|
|
with torch.no_grad(): |
|
|
generated_ids = model.generate(input_features) |
|
|
return processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
|
|
def transcribe_long_audio(audio_file, processor, model, chunk_length_s=30): |
|
|
audio, sample_rate = load_audio(audio_file) |
|
|
audio_length_s = len(audio) / sample_rate |
|
|
if audio_length_s <= chunk_length_s: |
|
|
return transcribe_audio(audio_file, processor, model) |
|
|
|
|
|
chunk_size = int(chunk_length_s * sample_rate) |
|
|
transcription_chunks = [] |
|
|
|
|
|
for i in range(0, len(audio), chunk_size): |
|
|
chunk = audio[i:i + chunk_size] |
|
|
if len(chunk) < 0.5 * chunk_size: |
|
|
continue |
|
|
inputs = processor(chunk, sampling_rate=16000, return_tensors="pt") |
|
|
input_features = inputs.input_features.to(model.device) |
|
|
with torch.no_grad(): |
|
|
generated_ids = model.generate(input_features) |
|
|
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
transcription_chunks.append(text) |
|
|
|
|
|
return " ".join(transcription_chunks) |
|
|
|
|
|
|
|
|
def correct_grammar(text, grammar_pipeline): |
|
|
sentences = [s.strip() for s in text.split('.') if s.strip()] |
|
|
results = grammar_pipeline(sentences, batch_size=4) |
|
|
return '. '.join([r['generated_text'] for r in results]) |
|
|
|
|
|
|
|
|
processor, whisper_model = load_whisper_model("small") |
|
|
grammar_pipeline = load_grammar_model() |
|
|
|
|
|
|
|
|
def warm_up_models(): |
|
|
dummy_audio = torch.zeros(1, 80, 3000).to(whisper_model.device) |
|
|
with torch.no_grad(): |
|
|
whisper_model.generate(dummy_audio) |
|
|
_ = correct_grammar("This is a warm up test.", grammar_pipeline) |
|
|
|
|
|
warm_up_models() |
|
|
|
|
|
|
|
|
@app.api_route("/", methods=["GET", "HEAD"]) |
|
|
async def index(request: Request): |
|
|
return JSONResponse({ |
|
|
"status": "ok", |
|
|
"message": "Server is alive", |
|
|
"timestamp": datetime.datetime.utcnow().isoformat() + "Z" |
|
|
}) |
|
|
|
|
|
@app.post('/transcribe') |
|
|
async def transcribe(audio: UploadFile = File(...)): |
|
|
if not audio: |
|
|
raise HTTPException(status_code=400, detail="No audio file provided.") |
|
|
|
|
|
os.makedirs("/tmp/temp_audio", exist_ok=True) |
|
|
audio_path = f"/tmp/temp_audio/{audio.filename}" |
|
|
|
|
|
|
|
|
try: |
|
|
with open(audio_path, "wb") as f: |
|
|
content = await audio.read() |
|
|
f.write(content) |
|
|
|
|
|
transcription = transcribe_long_audio(audio_path, processor, whisper_model) |
|
|
corrected_text = correct_grammar(transcription, grammar_pipeline) |
|
|
|
|
|
return JSONResponse({ |
|
|
"raw_transcription": transcription, |
|
|
"corrected_transcription": corrected_text |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
finally: |
|
|
try: |
|
|
if os.path.exists(audio_path): |
|
|
os.remove(audio_path) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
import uvicorn |
|
|
|
|
|
uvicorn.run("main:app", host="0.0.0.0", port=7860, log_level="info") |