import spaces import uuid import warnings import traceback import numpy as np from pathlib import Path from typing import Optional from collections import Counter import gradio as gr import torch import torchaudio import soundfile as sf import matplotlib.pyplot as plt from NatureLM.config import Config from NatureLM.models.NatureLM import NatureLM from NatureLM.infer import Pipeline from data_store import upload_data warnings.filterwarnings("ignore") SAMPLE_RATE = 16000 # Default sample rate for NatureLM-audio DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu" MIN_AUDIO_DURATION: float = 0.5 # seconds MAX_HISTORY_TURNS = ( 3 # Maximum number of conversation turns to include in context (user + assistant pairs) ) # Load model at startup if CUDA is available print(f"Device: {DEVICE}") model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio") model = model.eval().to(DEVICE) model = Pipeline(model) def check_audio_duration_greater(audio_path: str) -> bool: """Check the duration of the audio file.""" info = sf.info(audio_path) duration = info.duration # info.num_frames / info.sample_rate if not duration >= MIN_AUDIO_DURATION: raise gr.Error(f"Audio duration must be at least {MIN_AUDIO_DURATION} seconds.") def get_spectrogram(audio: torch.Tensor) -> plt.Figure: """Generate a spectrogram from the audio tensor.""" spectrogram = torchaudio.transforms.Spectrogram(n_fft=1024)(audio) spectrogram = spectrogram.numpy()[0].squeeze() # Convert to matplotlib figure with imshow fig, ax = plt.subplots(figsize=(13, 5)) ax.imshow(np.log(spectrogram + 1e-4), aspect="auto", origin="lower", cmap="viridis") ax.set_title("Spectrogram") ax.set_xlabel("Time") # Set x ticks to reflect 0 to audio duration seconds if audio.dim() > 1: duration = audio.size(1) / SAMPLE_RATE else: duration = audio.size(0) / SAMPLE_RATE ax.set_xticks([0, spectrogram.shape[1]]) ax.set_xticklabels(["0s", f"{duration:.2f}s"]) ax.set_ylabel("Frequency") # Set y ticks to reflect 0 to nyquist frequency (sample_rate/2) nyquist_freq = SAMPLE_RATE / 2 ax.set_yticks( [ 0, spectrogram.shape[0] // 4, spectrogram.shape[0] // 2, 3 * spectrogram.shape[0] // 4, spectrogram.shape[0] - 1, ] ) ax.set_yticklabels( [ "0 Hz", f"{nyquist_freq / 4:.0f} Hz", f"{nyquist_freq / 2:.0f} Hz", f"{3 * nyquist_freq / 4:.0f} Hz", f"{nyquist_freq:.0f} Hz", ] ) fig.tight_layout() return fig def take_majority_vote(results: list[list[dict]]) -> list[str]: """For each audio file, take the majority vote of the labels across all windows""" outputs = [] for result in results: predictions = [window["prediction"] for window in result] if not predictions: continue # Count occurrences of each label counts = Counter(predictions) # Find the most common label most_common_label, _ = counts.most_common(1)[0] outputs.append(most_common_label) return outputs @spaces.GPU def prompt_lm( audios: list[str], queries: list[str] | str, window_length_seconds: float = 10.0, hop_length_seconds: float = 10.0, ) -> list[str]: """Generate response using the model Args: audios (list[str]): List of audio file paths queries (list[str] | str): Query or list of queries to process window_length_seconds (float): Length of the window for processing audio hop_length_seconds (float): Hop length for processing audio Returns: list[str]: List of generated responses for each audio-query pair """ if model is None: return "❌ Model not loaded. Please check the model configuration." with torch.amp.autocast(device_type="cuda", dtype=torch.float16): results: list[list[dict]] = model( audios, queries, window_length_seconds=window_length_seconds, hop_length_seconds=hop_length_seconds, input_sample_rate=None, ) return results def make_spectrogram_figure(audio_input: str) -> list[dict]: # Load audio with torchaudio and compute spectrogram if not audio_input: # Return an empty figure if no audio input is provided return get_spectrogram(torch.zeros(1, SAMPLE_RATE)) # Check if file exists and is accessible try: if not Path(audio_input).exists(): print(f"Audio file does not exist: {audio_input}") return get_spectrogram(torch.zeros(1, SAMPLE_RATE)) if not Path(audio_input).is_file(): print(f"Path is not a valid file: {audio_input}") return get_spectrogram(torch.zeros(1, SAMPLE_RATE)) audio_tensor, sample_rate = torchaudio.load(audio_input) spectrogram_fig = get_spectrogram(audio_tensor) return spectrogram_fig except Exception as e: print(f"Error loading audio file {audio_input}: {e}") # Return an empty spectrogram on error return get_spectrogram(torch.zeros(1, SAMPLE_RATE)) def add_user_query(chatbot_history: list[dict], chat_input: str) -> list[dict]: """Add user message to chat and get model response""" # Validate input if not chat_input.strip(): return chatbot_history chatbot_history.append({"role": "user", "content": chat_input.strip()}) return chatbot_history def send_data_to_hub(chatbot_history: list[dict], audio: str, session_id: str): """Upload data to hub""" if not chatbot_history or len(chatbot_history) < 2: return user_text = chatbot_history[-2]["content"] model_response = chatbot_history[-1]["content"] upload_data(audio, user_text, model_response, session_id) def get_response(chatbot_history: list[dict], audio_input: str) -> list[dict]: """Generate response from the model based on user input and audio file with conversation history""" try: # Warn if conversation is getting long num_turns = len(chatbot_history) if num_turns > MAX_HISTORY_TURNS * 2: # Each turn = user + assistant message gr.Warning( "⚠️ Long conversations may affect response quality. Consider starting a new conversation with the Clear button." ) # Build conversation context from history conversation_context = [] for message in chatbot_history: if message["role"] == "user": conversation_context.append(f"User: {message['content']}") elif message["role"] == "assistant": conversation_context.append(f"Assistant: {message['content']}") # Get the last user message last_user_message = "" for message in reversed(chatbot_history): if message["role"] == "user": last_user_message = message["content"] break # Format the full prompt with conversation history if len(conversation_context) > 2: # More than just the current query # Include previous turns (limit to last MAX_HISTORY_TURNS exchanges) # recent_context = conversation_context[ # -(MAX_HISTORY_TURNS + 1) : -1 # ] # Exclude current message recent_context = conversation_context full_prompt = ( "Previous conversation:\n" + "\n".join(recent_context) + "\n\nCurrent question: " + last_user_message ) else: full_prompt = last_user_message print("\nFull prompt with history:", full_prompt) response = prompt_lm( audios=[audio_input], queries=[full_prompt.strip()], window_length_seconds=100_000, hop_length_seconds=100_000, ) # get first item if isinstance(response, list) and len(response) > 0: response = response[0][0]["prediction"] print("Model response:", response) else: response = "No response generated." except Exception as e: print(f"Error generating response: {e}") traceback.print_exc() response = "Error generating response. Please try again." # Add model response to chat history chatbot_history.append({"role": "assistant", "content": response}) return chatbot_history def main( assets_dir: Path, ): # Check if assets directory exists, if not create a placeholder if not assets_dir.exists(): print(f"Warning: Assets directory {assets_dir} does not exist") assets_dir.mkdir(exist_ok=True) # Create placeholder audio files if they don't exist laz_audio = assets_dir / "Lazuli_Bunting_yell-YELLLAZB20160625SM303143.mp3" frog_audio = assets_dir / "nri-GreenTreeFrogEvergladesNP.mp3" robin_audio = assets_dir / "yell-YELLAMRO20160506SM3.mp3" whale_audio = assets_dir / "Humpback Whale - Megaptera novaeangliae.wav" crow_audio = assets_dir / "American Crow - Corvus brachyrhynchos.mp3" examples = { "Identifying Focal Species (Lazuli Bunting)": [ str(laz_audio), "What is the common name for the focal species in the audio?", ], "Caption the audio (Green Tree Frog)": [ str(frog_audio), "Caption the audio, using the common name for any animal species.", ], "Caption the audio (American Robin)": [ str(robin_audio), "Caption the audio, using the scientific name for any animal species.", ], "Identifying Focal Species (Megaptera novaeangliae)": [ str(whale_audio), "What is the scientific name for the focal species in the audio?", ], "Speaker Count (American Crow)": [ str(crow_audio), "How many individuals are vocalizing in this audio?", ], "Caption the audio (Humpback Whale)": [str(whale_audio), "Caption the audio."], } gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"]) with gr.Blocks( title="NatureLM-audio", theme=gr.themes.Base(primary_hue="blue", font=[gr.themes.GoogleFont("Noto Sans")]), ) as app: with gr.Row(): gr.HTML("""
ESP Logo

NatureLM-audioBETA

""") with gr.Tabs(): with gr.Tab("Analyze Audio"): session_id = gr.State(str(uuid.uuid4())) # uploaded_audio = gr.State() # Status indicator # status_text = gr.Textbox( # value=model_manager.get_status(), # label="Model Status", # interactive=False, # visible=True, # ) with gr.Column(visible=True) as onboarding_message: gr.HTML( """ """, padding=False, ) with gr.Column(visible=True) as upload_section: audio_input = gr.Audio( type="filepath", container=True, interactive=True, sources=["upload"], ) # check that audio duration is greater than MIN_AUDIO_DURATION # raise audio_input.change( fn=check_audio_duration_greater, inputs=[audio_input], outputs=[], ) with gr.Accordion( label="Toggle Spectrogram", open=False, visible=False ) as spectrogram: plotter = gr.Plot( get_spectrogram(torch.zeros(1, SAMPLE_RATE)), label="Spectrogram", visible=False, elem_id="spectrogram-plot", ) with gr.Column(visible=False) as tasks: task_dropdown = gr.Dropdown( [ "What are the common names for the species in the audio, if any?", "Caption the audio, using the scientific name for any animal species.", "Caption the audio, using the common name for any animal species.", "What is the scientific name for the focal species in the audio?", "What is the common name for the focal species in the audio?", "What is the family of the focal species in the audio?", "What is the genus of the focal species in the audio?", "What is the taxonomic name of the focal species in the audio?", "What call types are heard from the focal species in the audio?", "What is the life stage of the focal species in the audio?", ], label="Pre-Loaded Tasks", info="Select a task, or write your own prompt below.", allow_custom_value=False, value=None, ) with gr.Group(visible=False) as chat: chatbot = gr.Chatbot( elem_id="chatbot", height=250, type="messages", label="Chat", render_markdown=False, group_consecutive_messages=False, feedback_options=[ "like", "dislike", "wrong species", "incorrect response", "other", ], resizeable=True, ) with gr.Column() as text: chat_input = gr.Textbox( placeholder="Type your message and press Enter to send", type="text", lines=1, show_label=False, submit_btn="Send", container=True, autofocus=False, elem_id="chat-input", ) with gr.Column() as examples_section: gr.Examples( list(examples.values()), [audio_input, chat_input], [audio_input, chat_input], example_labels=list(examples.keys()), examples_per_page=20, ) def validate_and_submit(chatbot_history, chat_input): if not chat_input or not chat_input.strip(): gr.Warning("Please enter a question or message before sending.") return chatbot_history, chat_input updated_history = add_user_query(chatbot_history, chat_input) return updated_history, "" clear_button = gr.ClearButton( components=[chatbot, chat_input, audio_input, plotter], visible=False, ) # if task_dropdown is selected, set chat_input to that value def set_query(task): if task: return gr.update(value=task) return gr.update(value="") task_dropdown.select( fn=set_query, inputs=[task_dropdown], outputs=[chat_input], ) def start_chat_interface(audio_path): return ( gr.update(visible=False), # hide onboarding message gr.update(visible=True), # show upload section gr.update(visible=True), # show spectrogram gr.update(visible=True), # show tasks gr.update(visible=True), # show chat box gr.update(visible=True), # show plotter ) # When audio added, set spectrogram audio_input.change( fn=start_chat_interface, inputs=[audio_input], outputs=[ onboarding_message, upload_section, spectrogram, tasks, chat, plotter, ], ).then( fn=make_spectrogram_figure, inputs=[audio_input], outputs=[plotter], ) # When submit clicked first: # 1. Validate and add user query to chat history # 2. Get response from model # 3. Clear the chat input box # 4. Show clear button chat_input.submit( validate_and_submit, inputs=[chatbot, chat_input], outputs=[chatbot, chat_input], ).then( get_response, inputs=[chatbot, audio_input], outputs=[chatbot], ).then( lambda: gr.update(visible=True), # Show clear button None, [clear_button], ).then( send_data_to_hub, [chatbot, audio_input, session_id], None, ) clear_button.click(lambda: gr.ClearButton(visible=False), None, [clear_button]) with gr.Tab("Sample Library"): with gr.Row(): with gr.Column(): gr.Markdown("### Download Sample Audio") gr.Markdown( """Feel free to explore these sample audio files. To download, click the button in the top-right corner of each audio file. You can also find a large collection of publicly available animal sounds on [Xenocanto](https://xeno-canto.org/explore/taxonomy) and [Watkins Marine Mammal Sound Database](https://whoicf2.whoi.edu/science/B/whalesounds/index.cfm).""" ) samples = [ ( "assets/Lazuli_Bunting_yell-YELLLAZB20160625SM303143.m4a", "Lazuli Bunting", ), ( "assets/nri-GreenTreeFrogEvergladesNP.mp3", "Green Tree Frog", ), ( "assets/American Crow - Corvus brachyrhynchos.mp3", "American Crow", ), ( "assets/Gray Wolf - Canis lupus italicus.m4a", "Gray Wolf", ), ( "assets/Humpback Whale - Megaptera novaeangliae.wav", "Humpback Whale", ), ("assets/Walrus - Odobenus rosmarus.wav", "Walrus"), ] for row_i in range(0, len(samples), 3): with gr.Row(): for filepath, label in samples[row_i : row_i + 3]: with gr.Column(): gr.Audio( filepath, label=label, type="filepath", show_download_button=True, ) with gr.Tab("πŸ’‘ Help"): gr.HTML("""

Getting Started

  1. Upload your audio or click on a pre-loaded example. Drag and drop your audio file containing animal vocalizations, or click on an example.
  2. Trim your audio (if needed) by clicking the scissors icon on the bottom right of the audio panel. Try to keep your audio to 10 seconds or less.
  3. View the Spectrogram (optional). You can easily view/hide the spectrogram of your audio for closer analysis.
  4. Select a task or write your own. Select an option from pre-loaded tasks. This will auto-fill the text box with a prompt, so all you have to do is hit Send. Or, type a custom prompt directly into the chat.
  5. Send and Analyze Audio. Press "Send" or type Enter to begin processing your audio. Ask follow-up questions or press "Clear" to start a new conversation.

Tips

Prompting Best Practices
Audio Files

Learn More

""") app.css = """ #chat-input textarea { background: white; flex: 1; } #chat-input .submit-button { padding: 10px; margin: 2px 6px; align-self: center; } #spectrogram-plot { padding: 12px; margin: 12px; } .banner { background: white; border: 1px solid #e5e7eb; border-radius: 8px; padding: 16px 20px; display: flex; align-items: center; justify-content: space-between; margin-bottom: 16px; margin-left: 0; margin-right: 0; box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); } .banner .banner-header { font-size: 16px; font-weight: 600; color: #374151; margin-bottom: 4px; } .banner .banner-text { style="font-size: 14px; color: #6b7280; line-height: 1.4; } .link-btn { padding: 6px 12px; border-radius: 6px; font-size: 13px; font-weight: 500; cursor: pointer; border: none; background: #3b82f6; color: white; text-decoration: none; display: inline-block; transition: background 0.2s ease; } .link-btn:hover { background: #2563eb; } .guide-section { margin-bottom: 32px; border-radius: 8px; padding: 14px; border: 1px solid #e5e7eb; } .guide-section h3 { margin-top: 4px; margin-bottom: 16px; border-bottom: 1px solid #e5e7eb; padding-bottom: 12px; } .guide-section h4 { color: #1f2937; margin-top: 4px; } @media (prefers-color-scheme: dark) { #chat-input { background: #1e1e1e; } #chat-input textarea { background: #1e1e1e; color: white; } .banner { background: #1e1e1e; color: white; } .banner .banner-header { color: white; } } """ return app # Create and launch the app app = main( assets_dir=Path("assets"), ) if __name__ == "__main__": app.launch()