AbdoIR commited on
Commit
a462c9a
·
verified ·
1 Parent(s): 6e0af0a

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +155 -0
main.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ from flask import Flask, request, jsonify
5
+ from flask_cors import CORS
6
+ from transformers import (
7
+ WhisperProcessor,
8
+ WhisperForConditionalGeneration,
9
+ AutoTokenizer,
10
+ AutoModelForSeq2SeqLM,
11
+ pipeline
12
+ )
13
+ from huggingface_hub import snapshot_download
14
+ from torch.quantization import quantize_dynamic
15
+ import logging
16
+ import ffmpeg
17
+ import tempfile
18
+
19
+ # Silence all transformers and huggingface logging
20
+ logging.getLogger("transformers").setLevel(logging.ERROR)
21
+ logging.getLogger("urllib3").setLevel(logging.ERROR)
22
+ logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
23
+
24
+ app = Flask(__name__)
25
+ CORS(app)
26
+
27
+ # ========== Load Whisper Model (quantized + small) ==========
28
+ def load_whisper_model(model_size="small", save_dir="/tmp/saved_models/whisper"):
29
+ os.makedirs(save_dir, exist_ok=True)
30
+ model_name = f"openai/whisper-{model_size}"
31
+ processor = WhisperProcessor.from_pretrained(model_name, cache_dir=save_dir)
32
+ model = WhisperForConditionalGeneration.from_pretrained(model_name, cache_dir=save_dir)
33
+ model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
34
+ model.to("cuda" if torch.cuda.is_available() else "cpu")
35
+ return processor, model
36
+
37
+ # ========== Load Grammar Correction Model (quantized) ==========
38
+ def load_grammar_model(save_dir="/tmp/saved_models/grammar_corrector"):
39
+ os.makedirs(save_dir, exist_ok=True)
40
+ model_name = "prithivida/grammar_error_correcter_v1"
41
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=save_dir)
42
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=save_dir)
43
+ model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
44
+ grammar_pipeline = pipeline(
45
+ "text2text-generation",
46
+ model=model,
47
+ tokenizer=tokenizer,
48
+ device=0 if torch.cuda.is_available() else -1
49
+ )
50
+ return grammar_pipeline
51
+
52
+ # ========== Optimized Audio Loader ==========
53
+ def load_audio(audio_path):
54
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_wav:
55
+ tmp_wav_path = tmp_wav.name
56
+ try:
57
+ (
58
+ ffmpeg
59
+ .input(audio_path)
60
+ .output(tmp_wav_path, format='wav', ac=1, ar='16k')
61
+ .overwrite_output()
62
+ .run(quiet=True)
63
+ )
64
+ waveform, sample_rate = torchaudio.load(tmp_wav_path)
65
+ if sample_rate != 16000:
66
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
67
+ waveform = resampler(waveform)
68
+ return waveform.squeeze().numpy(), 16000
69
+ finally:
70
+ if os.path.exists(tmp_wav_path):
71
+ os.remove(tmp_wav_path)
72
+
73
+ # ========== Audio Transcription ==========
74
+ def transcribe_audio(audio_file, processor, model):
75
+ audio, _ = load_audio(audio_file)
76
+ input_features = processor(audio, sampling_rate=16000, return_tensors="pt").input_features
77
+ input_features = input_features.to(model.device)
78
+ with torch.no_grad():
79
+ generated_ids = model.generate(input_features)
80
+ return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
81
+
82
+ def transcribe_long_audio(audio_file, processor, model, chunk_length_s=30):
83
+ audio, sample_rate = load_audio(audio_file)
84
+ audio_length_s = len(audio) / sample_rate
85
+ if audio_length_s <= chunk_length_s:
86
+ return transcribe_audio(audio_file, processor, model)
87
+
88
+ chunk_size = int(chunk_length_s * sample_rate)
89
+ transcription_chunks = []
90
+
91
+ for i in range(0, len(audio), chunk_size):
92
+ chunk = audio[i:i + chunk_size]
93
+ if len(chunk) < 0.5 * chunk_size:
94
+ continue
95
+ inputs = processor(chunk, sampling_rate=16000, return_tensors="pt")
96
+ input_features = inputs.input_features.to(model.device)
97
+ with torch.no_grad():
98
+ generated_ids = model.generate(input_features)
99
+ text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
100
+ transcription_chunks.append(text)
101
+
102
+ return " ".join(transcription_chunks)
103
+
104
+ # ========== Grammar Correction ==========
105
+ def correct_grammar(text, grammar_pipeline):
106
+ sentences = [s.strip() for s in text.split('.') if s.strip()]
107
+ results = grammar_pipeline(sentences, batch_size=4)
108
+ return '. '.join([r['generated_text'] for r in results])
109
+
110
+ # ========== Initialize Models ==========
111
+ processor, whisper_model = load_whisper_model("small")
112
+ grammar_pipeline = load_grammar_model()
113
+
114
+ # ========== Warm-Up Models ==========
115
+ def warm_up_models():
116
+ dummy_audio = torch.zeros(1, 80, 3000).to(whisper_model.device)
117
+ with torch.no_grad():
118
+ whisper_model.generate(dummy_audio)
119
+ _ = correct_grammar("This is a warm up test.", grammar_pipeline)
120
+
121
+ warm_up_models()
122
+
123
+ # ========== Flask Route ==========
124
+ @app.route('/transcribe', methods=['POST'])
125
+ def transcribe():
126
+ if 'audio' not in request.files:
127
+ return jsonify({"error": "No audio file provided."}), 400
128
+
129
+ audio_file = request.files['audio']
130
+ os.makedirs("/tmp/temp_audio", exist_ok=True)
131
+ audio_path = f"/tmp/temp_audio/{audio_file.filename}"
132
+ audio_file.save(audio_path)
133
+
134
+ try:
135
+ transcription = transcribe_long_audio(audio_path, processor, whisper_model)
136
+ corrected_text = correct_grammar(transcription, grammar_pipeline)
137
+
138
+ return jsonify({
139
+ "raw_transcription": transcription,
140
+ "corrected_transcription": corrected_text
141
+ })
142
+
143
+ except Exception as e:
144
+ return jsonify({"error": str(e)}), 500
145
+
146
+ finally:
147
+ if os.path.exists(audio_path):
148
+ os.remove(audio_path)
149
+
150
+ # ========== Run App ==========
151
+ if __name__ == '__main__':
152
+ import logging
153
+ log = logging.getLogger('werkzeug')
154
+ log.setLevel(logging.WARNING)
155
+ app.run(host="0.0.0.0", debug=False, port=7860)