| import sys | |
| sys.path.append("..") | |
| import gradio | |
| import torch, torchaudio | |
| import numpy as np | |
| from transformers import ( | |
| Wav2Vec2ForPreTraining, | |
| Wav2Vec2CTCTokenizer, | |
| Wav2Vec2FeatureExtractor, | |
| ) | |
| from finetuning.wav2vec2 import SpeechRecognizer | |
| def load_model(ckpt_path: str): | |
| model_name = "nguyenvulebinh/wav2vec2-base-vietnamese-250h" | |
| wav2vec2 = Wav2Vec2ForPreTraining.from_pretrained(model_name) | |
| tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_name) | |
| feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name) | |
| model = SpeechRecognizer.load_from_checkpoint( | |
| ckpt_path, | |
| wav2vec2=wav2vec2, | |
| tokenizer=tokenizer, | |
| feature_extractor=feature_extractor, | |
| map_location='cpu' | |
| ) | |
| return model | |
| model = load_model("checkpoints/last.ckpt") | |
| model.eval() | |
| def transcribe(audio): | |
| sample_rate, waveform = audio | |
| if len(waveform.shape) == 2: | |
| waveform = waveform[:, 0] | |
| waveform = torch.from_numpy(waveform).float().unsqueeze_(0) | |
| waveform = torchaudio.functional.resample(waveform, sample_rate, 16_000) | |
| transcript = model.predict(waveform)[0] | |
| return transcript | |
| gradio.Interface(fn=transcribe, inputs=gradio.Audio(source="microphone", type="numpy"), outputs="textbox").launch() |