File size: 6,379 Bytes
92e075b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55b0dfe
 
 
 
 
92e075b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55b0dfe
 
 
 
 
 
 
 
 
 
92e075b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55b0dfe
92e075b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import gradio as gr
import torch
import os
import warnings
import sys
import os
fix_import=f"{os.getcwd()}/server"
sys.path.append(fix_import)
from inference.audio_chunker import AudioChunker
from inference.audio_sentence_alignment import AudioAlignment
from inference.mms_model_pipeline import MMSModel
from media_transcription_processor import MediaTranscriptionProcessor
from subtitle import make_subtitle
from lang_dict import lang_code  
import download_models

# warnings.filterwarnings("ignore", category=UserWarning, module="torchaudio")
warnings.filterwarnings(
    "ignore",
    message=".*torchaudio.functional._alignment.forced_align.*",
    category=UserWarning
)


# ---- Setup Model Globals ----
_model_loaded = False
_model_loading = False

# ---- Initialize model ----
def load_model(model_name="omniASR_LLM_1B"):
    """Load MMS model on startup - only once."""
    global _model_loaded, _model_loading
    if _model_loaded or _model_loading:
        return

    _model_loading = True
    print(f"πŸ”„ Loading {model_name} model...")

    AudioChunker()
    AudioAlignment()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    MMSModel(model_card=model_name, device=device)

    _model_loaded = True
    _model_loading = False
    print("βœ… Model loaded successfully.")


# ---- Transcription function ----
def media_transcription(file_path, lang_code="eng_Latn"):
    """Perform transcription + subtitle generation."""
    with open(file_path, "rb") as f:
        media_bytes = f.read()

    processor = MediaTranscriptionProcessor(
        media_bytes=media_bytes,
        filename=file_path,
        language_with_script=lang_code
    )

    processor.convert_media()
    processor.transcribe_full_pipeline()
    results = processor.get_results()

    transcription = results['transcription']
    word_level_timestamps = [
        {"word": s['text'], "start": s['start'], "end": s['end']}
        for s in results.get('aligned_segments', [])
    ]
    if word_level_timestamps:
        sentence_srt, word_level_srt, shorts_srt = make_subtitle(word_level_timestamps, file_path)
        return transcription, sentence_srt, word_level_srt, shorts_srt
    else:
        return transcription,None,None,None



def transcribe_interface(audio, selected_lang):
    """Main Gradio wrapper."""
    if audio is None:
        return "Please upload or record audio.", None, None, None

    # Save uploaded/recorded audio
    file_path = audio
    find_lang_code = lang_code[selected_lang]

    # print(f"πŸŽ™ Transcribing {file_path} in {selected_lang} ({find_lang_code})...")

    try:
        transcription, sentence_srt, word_level_srt, shorts_srt = media_transcription(file_path, find_lang_code)
        return transcription, sentence_srt, word_level_srt, shorts_srt
    except Exception as e:
        return f"❌ Error: {e}", None, None, None



def ui():
    lang_list = list(lang_code.keys())
    custom_css = """.gradio-container { font-family: 'SF Pro Display', -apple-system, BlinkMacSystemFont, sans-serif; }"""
    with gr.Blocks(theme=gr.themes.Soft(),css=custom_css) as demo:
        gr.HTML("""
        <div style="text-align: center; margin: 20px auto; max-width: 800px;">
            <h1 style="font-size: 2.5em; margin-bottom: 10px;">Meta Omnilingual ASR</h1>
            <p style="font-size: 1.2em; color: #333; margin-bottom: 15px;">
                Converting the official 
                <a href="https://huggingface.co/spaces/facebook/omniasr-transcriptions" 
                   target="_blank" 
                   style="color: #1a73e8; font-weight: 600; text-decoration: none;">
                   facebook/omniasr-transcriptions
                </a> 
                Flask application into a Gradio App. Running omniASR_LLM_300M on CPU.
            </p>

            <a href="https://github.com/NeuralFalconYT/omnilingual-asr-colab" target="_blank" style="display: inline-block; padding: 10px 20px; background-color: #4285F4; color: white; border-radius: 6px; text-decoration: none; font-size: 1em;">πŸ˜‡ Run on Google Colab</a>
        </div>
        """)

        with gr.Row():
            with gr.Column():
                audio_input = gr.Audio(sources=[ "microphone","upload"], type="filepath", label="πŸŽ™ Upload or Record Audio")
                language_dropdown = gr.Dropdown(
                                    choices=lang_list,
                                    value=lang_list[0],
                                    label="🌐 Select Language"
                                )
                transcribe_btn = gr.Button("πŸš€ Transcribe")
            with gr.Column():
              transcription_output = gr.Textbox(label="Transcription", lines=8,show_copy_button=True)
              with gr.Accordion("🎬 Subtitle (Not Accurate)", open=False):
                    sentence_srt_out = gr.File(label="Sentence-level Subtitle File")
                    word_srt_out = gr.File(label="Word-level Subtitle File")
                    shorts_srt_out = gr.File(label="Shorts Subtitle File")

        transcribe_btn.click(
            fn=transcribe_interface,
            inputs=[audio_input, language_dropdown],
            outputs=[transcription_output, sentence_srt_out, word_srt_out, shorts_srt_out]
        )

    return demo




import click

@click.command()
@click.option(
    "--debug",
    is_flag=True,
    default=False,
    help="Enable debug mode (shows detailed logs)."
)
@click.option(
    "--share",
    is_flag=True,
    default=False,
    help="Create a public Gradio share link (for Colab or remote usage)."
)
@click.option(
    "--model",
    default="omniASR_LLM_300M",
    type=click.Choice([
        "omniASR_CTC_300M",
        "omniASR_CTC_1B",
        "omniASR_CTC_3B",
        "omniASR_CTC_7B",
        "omniASR_LLM_300M",
        "omniASR_LLM_1B",
        "omniASR_LLM_3B",
        "omniASR_LLM_7B",
        "omniASR_LLM_7B_ZS",
    ]),
    help="Choose the OmniASR model to load."
)
def main(debug, share, model):
# def main(debug=True, share=True,model="omniASR_LLM_1B"):

    """Universal CLI entry point for omniASR transcription UI."""
    print(f"\nπŸš€ Starting omniASR UI with model: {model}")
    # βœ… Load model
    load_model(model)
    # βœ… Launch UI
    demo = ui()
    demo.queue().launch(share=share, debug=debug)

if __name__ == "__main__":
    main()