Spaces:
Running
Running
| import torch | |
| import torchaudio | |
| import gradio as gr | |
| from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq | |
| from datetime import date | |
| # device setup | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # load model + processor | |
| model_name = "ibm-granite/granite-speech-3.3-8b" | |
| processor = AutoProcessor.from_pretrained(model_name) | |
| tokenizer = processor.tokenizer | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| model_name, device_map=device, torch_dtype=torch.bfloat16 | |
| ) | |
| def transcribe(audio_file, user_prompt): | |
| # load wav file | |
| wav, sr = torchaudio.load(audio_file, normalize=True) | |
| if wav.shape[0] != 1 or sr != 16000: | |
| # resample + convert to mono if needed | |
| wav = torch.mean(wav, dim=0, keepdim=True) # mono | |
| wav = torchaudio.functional.resample(wav, sr, 16000) | |
| sr = 16000 | |
| today_str = date.today().strftime("%B %d, %Y") | |
| system_prompt = ( | |
| "Knowledge Cutoff Date: April 2024.\n" | |
| f"Today's Date: {today_str}.\n" | |
| "You are Granite, developed by IBM. You are a helpful AI assistant." | |
| ) | |
| chat = [ | |
| dict(role="system", content=system_prompt), | |
| dict(role="user", content=f"<|audio|>{user_prompt}"), | |
| ] | |
| prompt = tokenizer.apply_chat_template( | |
| chat, tokenize=False, add_generation_prompt=True) | |
| # run model | |
| model_inputs = processor( | |
| prompt, | |
| wav, | |
| device=device, | |
| return_tensors="pt").to(device) | |
| model_outputs = model.generate( | |
| **model_inputs, | |
| max_new_tokens=512, | |
| do_sample=False, | |
| num_beams=1 | |
| ) | |
| # strip prompt tokens | |
| num_input_tokens = model_inputs["input_ids"].shape[-1] | |
| new_tokens = torch.unsqueeze(model_outputs[0, num_input_tokens:], dim=0) | |
| output_text = tokenizer.batch_decode( | |
| new_tokens, add_special_tokens=False, skip_special_tokens=True | |
| ) | |
| return output_text[0].strip() | |
| # Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Granite 3.3 Speech-to-Text") | |
| gr.Markdown( | |
| "Upload an audio file and Granite Speech 3.3 8b will transcribe it into text." | |
| "You can also edit the prompt below to customize what Granite should do with the audio, like translation." | |
| ) | |
| with gr.Row(): | |
| audio_input = gr.Audio(type="filepath", | |
| label="Upload Audio (16kHz mono preferred)") | |
| output_text = gr.Textbox(label="Transcription", lines=5) | |
| user_prompt = gr.Textbox( | |
| label="User Prompt", | |
| value="Can you transcribe the speech into a written format?", | |
| lines=2 | |
| ) | |
| transcribe_btn = gr.Button("Transcribe") | |
| transcribe_btn.click( | |
| fn=transcribe, | |
| inputs=[ | |
| audio_input, | |
| user_prompt], | |
| outputs=output_text) | |
| demo.launch() | |