File size: 6,830 Bytes
a462c9a 6be1daa 745c522 a462c9a 6be1daa a462c9a 48d8acc a703706 11712e0 c0bd5dc a703706 a462c9a 745c522 a462c9a 48d8acc 11712e0 a462c9a 11712e0 a462c9a 11712e0 a462c9a 11712e0 a462c9a 11712e0 a462c9a 48d8acc a462c9a d82dc46 745c522 a462c9a 745c522 a462c9a 745c522 a462c9a 745c522 a462c9a 745c522 a462c9a 745c522 a462c9a 745c522 a462c9a 745c522 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import os
import torch
import torchaudio
from fastapi import FastAPI, UploadFile, File, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import (
WhisperProcessor,
WhisperForConditionalGeneration,
AutoTokenizer,
AutoModelForSeq2SeqLM,
pipeline
)
from torch.quantization import quantize_dynamic
import logging
import ffmpeg
import tempfile
import datetime
# ========== Force HF cache to /tmp ==========
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface/datasets"
os.environ["XDG_CACHE_HOME"] = "/tmp/huggingface"
for key in ["HF_HOME", "TRANSFORMERS_CACHE", "HF_HUB_CACHE", "HF_DATASETS_CACHE", "XDG_CACHE_HOME"]:
path = os.environ.get(key)
if path:
os.makedirs(path, exist_ok=True)
# Silence all transformers and huggingface logging
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("urllib3").setLevel(logging.ERROR)
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ========== Load Whisper Model (quantized) ==========
def load_whisper_model(model_size="small", save_dir="/tmp/models_cache/whisper"):
os.makedirs(save_dir, exist_ok=True)
model_name = f"openai/whisper-{model_size}"
processor = WhisperProcessor.from_pretrained(model_name, cache_dir=save_dir)
model = WhisperForConditionalGeneration.from_pretrained(model_name, cache_dir=save_dir)
model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
model.to("cuda" if torch.cuda.is_available() else "cpu")
return processor, model
# ========== Load Grammar Correction Model (quantized) ==========
def load_grammar_model(save_dir="/tmp/models_cache/grammar_corrector"):
os.makedirs(save_dir, exist_ok=True)
model_name = "prithivida/grammar_error_correcter_v1"
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=save_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=save_dir)
model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
grammar_pipeline = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1
)
return grammar_pipeline
# ========== Optimized Audio Loader ==========
def load_audio(audio_path):
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_wav:
tmp_wav_path = tmp_wav.name
try:
(
ffmpeg
.input(audio_path)
.output(tmp_wav_path, format='wav', ac=1, ar='16k')
.overwrite_output()
.run(quiet=True)
)
waveform, sample_rate = torchaudio.load(tmp_wav_path)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
return waveform.squeeze().numpy(), 16000
finally:
if os.path.exists(tmp_wav_path):
os.remove(tmp_wav_path)
# ========== Audio Transcription ==========
def transcribe_audio(audio_file, processor, model):
audio, _ = load_audio(audio_file)
input_features = processor(audio, sampling_rate=16000, return_tensors="pt").input_features
input_features = input_features.to(model.device)
with torch.no_grad():
generated_ids = model.generate(input_features)
return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
def transcribe_long_audio(audio_file, processor, model, chunk_length_s=30):
audio, sample_rate = load_audio(audio_file)
audio_length_s = len(audio) / sample_rate
if audio_length_s <= chunk_length_s:
return transcribe_audio(audio_file, processor, model)
chunk_size = int(chunk_length_s * sample_rate)
transcription_chunks = []
for i in range(0, len(audio), chunk_size):
chunk = audio[i:i + chunk_size]
if len(chunk) < 0.5 * chunk_size:
continue
inputs = processor(chunk, sampling_rate=16000, return_tensors="pt")
input_features = inputs.input_features.to(model.device)
with torch.no_grad():
generated_ids = model.generate(input_features)
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
transcription_chunks.append(text)
return " ".join(transcription_chunks)
# ========== Grammar Correction ==========
def correct_grammar(text, grammar_pipeline):
sentences = [s.strip() for s in text.split('.') if s.strip()]
results = grammar_pipeline(sentences, batch_size=4)
return '. '.join([r['generated_text'] for r in results])
# ========== Initialize Models ==========
processor, whisper_model = load_whisper_model("small")
grammar_pipeline = load_grammar_model()
# ========== Warm-Up Models ==========
def warm_up_models():
dummy_audio = torch.zeros(1, 80, 3000).to(whisper_model.device)
with torch.no_grad():
whisper_model.generate(dummy_audio)
_ = correct_grammar("This is a warm up test.", grammar_pipeline)
warm_up_models()
# ========== Routes ==========
@app.api_route("/", methods=["GET", "HEAD"])
async def index(request: Request):
return JSONResponse({
"status": "ok",
"message": "Server is alive",
"timestamp": datetime.datetime.utcnow().isoformat() + "Z"
})
@app.post('/transcribe')
async def transcribe(audio: UploadFile = File(...)):
if not audio:
raise HTTPException(status_code=400, detail="No audio file provided.")
os.makedirs("/tmp/temp_audio", exist_ok=True)
audio_path = f"/tmp/temp_audio/{audio.filename}"
# Save uploaded file
try:
with open(audio_path, "wb") as f:
content = await audio.read()
f.write(content)
transcription = transcribe_long_audio(audio_path, processor, whisper_model)
corrected_text = correct_grammar(transcription, grammar_pipeline)
return JSONResponse({
"raw_transcription": transcription,
"corrected_transcription": corrected_text
})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
if os.path.exists(audio_path):
os.remove(audio_path)
except Exception:
pass
# ========== Run App ==========
if __name__ == '__main__':
# Run with Uvicorn for FastAPI
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=7860, log_level="info") |