import gradio as gr import torch import os import warnings import sys import os fix_import=f"{os.getcwd()}/server" sys.path.append(fix_import) from inference.audio_chunker import AudioChunker from inference.audio_sentence_alignment import AudioAlignment from inference.mms_model_pipeline import MMSModel from media_transcription_processor import MediaTranscriptionProcessor from subtitle import make_subtitle from lang_dict import lang_code import download_models # warnings.filterwarnings("ignore", category=UserWarning, module="torchaudio") warnings.filterwarnings( "ignore", message=".*torchaudio.functional._alignment.forced_align.*", category=UserWarning ) # ---- Setup Model Globals ---- _model_loaded = False _model_loading = False # ---- Initialize model ---- def load_model(model_name="omniASR_LLM_1B"): """Load MMS model on startup - only once.""" global _model_loaded, _model_loading if _model_loaded or _model_loading: return _model_loading = True print(f"š Loading {model_name} model...") AudioChunker() AudioAlignment() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") MMSModel(model_card=model_name, device=device) _model_loaded = True _model_loading = False print("ā Model loaded successfully.") # ---- Transcription function ---- def media_transcription(file_path, lang_code="eng_Latn"): """Perform transcription + subtitle generation.""" with open(file_path, "rb") as f: media_bytes = f.read() processor = MediaTranscriptionProcessor( media_bytes=media_bytes, filename=file_path, language_with_script=lang_code ) processor.convert_media() processor.transcribe_full_pipeline() results = processor.get_results() transcription = results['transcription'] word_level_timestamps = [ {"word": s['text'], "start": s['start'], "end": s['end']} for s in results.get('aligned_segments', []) ] if word_level_timestamps: sentence_srt, word_level_srt, shorts_srt = make_subtitle(word_level_timestamps, file_path) return transcription, sentence_srt, word_level_srt, shorts_srt else: return transcription,None,None,None def transcribe_interface(audio, selected_lang): """Main Gradio wrapper.""" if audio is None: return "Please upload or record audio.", None, None, None # Save uploaded/recorded audio file_path = audio find_lang_code = lang_code[selected_lang] # print(f"š Transcribing {file_path} in {selected_lang} ({find_lang_code})...") try: transcription, sentence_srt, word_level_srt, shorts_srt = media_transcription(file_path, find_lang_code) return transcription, sentence_srt, word_level_srt, shorts_srt except Exception as e: return f"ā Error: {e}", None, None, None def ui(): lang_list = list(lang_code.keys()) custom_css = """.gradio-container { font-family: 'SF Pro Display', -apple-system, BlinkMacSystemFont, sans-serif; }""" with gr.Blocks(theme=gr.themes.Soft(),css=custom_css) as demo: gr.HTML("""
Converting the official facebook/omniasr-transcriptions Flask application into a Gradio App. Running omniASR_LLM_300M on CPU.
š Run on Google Colab