import gradio as gr import time import logging import torch from sys import platform from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor from transformers.utils import is_flash_attn_2_available from languages import get_language_names from subtitle_manager import Subtitle logging.basicConfig(level=logging.INFO) last_model = None pipe = None # Utility function to save text to file def write_file(output_file, subtitle): with open(output_file, 'w', encoding='utf-8') as f: f.write(subtitle) # Create the Whisper pipeline def create_pipe(model_name, flash): device = "cuda:0" if torch.cuda.is_available() else "mps" if platform=="darwin" else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model = AutoModelForSpeechSeq2Seq.from_pretrained( model_name, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="flash_attention_2" if flash and is_flash_attn_2_available() else "sdpa" ) model.to(device) processor = AutoProcessor.from_pretrained(model_name) return pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch_dtype, device=device ) # Main transcription function def transcribe_webui(model_name, language_name, url_data, uploaded_files, audio_file, task, flash, chunk_length_s, batch_size, progress=gr.Progress()): global last_model, pipe progress(0, desc="Initializing...") # Load or reload model if needed if last_model != model_name or pipe is None: logging.info("Loading model...") torch.cuda.empty_cache() pipe = create_pipe(model_name, flash) last_model = model_name # Prepare subtitle generators srt_sub = Subtitle("srt") vtt_sub = Subtitle("vtt") txt_sub = Subtitle("txt") # Collect all input files files = [] if uploaded_files: files += uploaded_files if url_data: files.append(url_data) if audio_file: files.append(audio_file) logging.info(f"Processing files: {files}") # Prepare Whisper generation options generate_kwargs = {} if language_name != "Automatic Detection" and not model_name.endswith(".en"): generate_kwargs["language"] = language_name if not model_name.endswith(".en"): generate_kwargs["task"] = task files_out = [] final_vtt, final_txt = "", "" for file in progress.tqdm(files, desc="Transcribing..."): start_time = time.time() outputs = pipe( file, chunk_length_s=int(chunk_length_s), batch_size=int(batch_size), generate_kwargs=generate_kwargs, return_timestamps=True, ) logging.info(f"Transcription completed in {time.time() - start_time:.2f} sec for {file}") file_name = file.split('/')[-1] srt = srt_sub.get_subtitle(outputs["chunks"]) vtt = vtt_sub.get_subtitle(outputs["chunks"]) txt = txt_sub.get_subtitle(outputs["chunks"]) write_file(file_name + ".srt", srt) write_file(file_name + ".vtt", vtt) write_file(file_name + ".txt", txt) files_out += [file_name + ".srt", file_name + ".vtt", file_name + ".txt"] final_vtt, final_txt = vtt, txt progress(1, desc="Completed!") return files_out, final_vtt, final_txt # Gradio Interface whisper_models = [ "openai/whisper-tiny", "openai/whisper-tiny.en", "openai/whisper-base", "openai/whisper-base.en", "openai/whisper-small", "openai/whisper-small.en", "distil-whisper/distil-small.en", "openai/whisper-medium", "openai/whisper-medium.en", "distil-whisper/distil-medium.en", "openai/whisper-large", "openai/whisper-large-v1", "openai/whisper-large-v2", "distil-whisper/distil-large-v2", "openai/whisper-large-v3", "distil-whisper/distil-large-v3", "xaviviro/whisper-large-v3-catalan-finetuned-v2", "antony66/whisper-large-v3-russian", "openai/whisper-large-v3-turbo", "efficient-speech/lite-whisper-large-v3-turbo", "distil-whisper/distil-large-v3.5", ] with gr.Blocks(title="Insanely Fast Whisper") as demo: gr.Markdown("## Insanely Fast Whisper\nTranscribe audio on-device using Whisper and Transformers!") with gr.Row(): model_input = gr.Dropdown(choices=whisper_models, value="distil-whisper/distil-large-v2", label="Model") language_input = gr.Dropdown(choices=["Automatic Detection"] + sorted(get_language_names()), value="Automatic Detection", label="Language") url_input = gr.Text(label="URL (optional)") files_input = gr.File(label="Upload Files", file_types=None, file_count="multiple") audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Microphone / Upload") task_input = gr.Dropdown(choices=["transcribe", "translate"], value="transcribe", label="Task") flash_input = gr.Checkbox(label="Use Flash Attention", value=False) chunk_input = gr.Number(label="Chunk Length (s)", value=30) batch_input = gr.Number(label="Batch Size", value=24) output_files = gr.File(label="Download Files", file_types=None, file_count="multiple") output_text = gr.Text(label="Transcription") output_segments = gr.Text(label="Segments") submit_btn = gr.Button("Transcribe") submit_btn.click( fn=transcribe_webui, inputs=[model_input, language_input, url_input, files_input, audio_input, task_input, flash_input, chunk_input, batch_input], outputs=[output_files, output_text, output_segments] ) if __name__ == "__main__": demo.launch(share=True)