Spaces:
Running
on
Zero
Running
on
Zero
| import re | |
| import tempfile | |
| from collections import Counter | |
| from pathlib import Path | |
| from typing import Literal, Optional | |
| import gradio as gr | |
| import torch | |
| from NatureLM.config import Config | |
| from NatureLM.models.NatureLM import NatureLM | |
| from NatureLM.utils import generate_sample_batches, prepare_sample_waveforms | |
| import spaces | |
| class ModelManager: | |
| """Manages model loading and state""" | |
| def __init__(self): | |
| self.model: Optional[NatureLM] = None | |
| self.config: Optional[Config] = None | |
| self.is_loaded = False | |
| self.is_loading = False | |
| self.load_failed = False | |
| def check_availability(self) -> tuple[bool, str]: | |
| """Check if the model is available for download""" | |
| try: | |
| from huggingface_hub import model_info | |
| info = model_info("EarthSpeciesProject/NatureLM-audio") | |
| return True, "Model is available" | |
| except Exception as e: | |
| return False, f"Model not available: {str(e)}" | |
| def reset_state(self): | |
| """Reset the model loading state to allow retrying after a failure""" | |
| self.model = None | |
| self.is_loaded = False | |
| self.is_loading = False | |
| self.load_failed = False | |
| return self.get_status() | |
| def get_status(self) -> str: | |
| """Get the current model loading status""" | |
| if self.is_loaded: | |
| return "β Model loaded and ready" | |
| elif self.is_loading: | |
| return "π Loading model... Please wait" | |
| elif self.load_failed: | |
| return "β Model failed to load. Please check the configuration." | |
| else: | |
| return "β³ Ready to load model on first use" | |
| def load_model(self) -> Optional[NatureLM]: | |
| """Load the model if needed""" | |
| if self.is_loaded: | |
| return self.model | |
| if self.is_loading or self.load_failed: | |
| return None | |
| try: | |
| self.is_loading = True | |
| print("Loading model...") | |
| # Check if model is available first | |
| available, message = self.check_availability() | |
| if not available: | |
| raise Exception(f"Model not available: {message}") | |
| model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio") | |
| model.to("cuda") | |
| model.eval() | |
| self.model = model | |
| self.is_loaded = True | |
| self.is_loading = False | |
| print("Model loaded successfully!") | |
| return model | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| self.is_loading = False | |
| self.load_failed = True | |
| return None | |
| # Global model manager instance | |
| model_manager = ModelManager() | |
| def prompt_lm(audios: list[str], messages: list[dict[str, str]]) -> str: | |
| """Generate response using the model""" | |
| model = model_manager.load_model() | |
| if model is None: | |
| if model_manager.is_loading: | |
| return "π Loading model... This may take a few minutes on first use. Please try again in a moment." | |
| elif model_manager.load_failed: | |
| return "β Model failed to load. This could be due to:\nβ’ No internet connection\nβ’ Insufficient disk space\nβ’ Model repository access issues\n\nPlease check your connection and try again using the retry button." | |
| else: | |
| return "Demo mode: Model not loaded. Please check the model configuration." | |
| cuda_enabled = torch.cuda.is_available() | |
| samples = prepare_sample_waveforms(audios, cuda_enabled) | |
| prompt_text = model.llama_tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ).removeprefix(model.llama_tokenizer.bos_token) | |
| prompt_text = re.sub( | |
| r"<\|start_header_id\|>system<\|end_header_id\|>\n\nCutting Knowledge Date: [^\n]+\nToday Date: [^\n]+\n\n<\|eot_id\|>", | |
| "", | |
| prompt_text, | |
| ) | |
| prompt_text = re.sub("\\n", r"\\n", prompt_text) | |
| print(f"{prompt_text=}") | |
| with torch.cuda.amp.autocast(dtype=torch.float16): | |
| llm_answer = model.generate(samples, model_manager.config.generate, prompts=[prompt_text]) | |
| return llm_answer[0] | |
| def _multimodal_textbox_factory(): | |
| return gr.MultimodalTextbox( | |
| value=None, | |
| interactive=True, | |
| sources="microphone", | |
| placeholder="Enter message...", | |
| show_label=False, | |
| autofocus=True, | |
| submit_btn="Send" | |
| ) | |
| def user_message(content): | |
| return {"role": "user", "content": content} | |
| def add_message(history, message): | |
| for x in message["files"]: | |
| history.append(user_message({"path": x})) | |
| if message["text"]: | |
| history.append(user_message(message["text"])) | |
| return history, _multimodal_textbox_factory() | |
| def combine_model_inputs(msgs: list[dict[str, str]]) -> dict[str, list[str]]: | |
| messages = [] | |
| files = [] | |
| for msg in msgs: | |
| print(msg, messages, files) | |
| match msg: | |
| case {"content": (path,)}: | |
| messages.append({"role": msg["role"], "content": "<Audio><AudioHere></Audio> "}) | |
| files.append(path) | |
| case _: | |
| messages.append(msg) | |
| # Join consecutive messages from the same role | |
| joined_messages = [] | |
| for msg in messages: | |
| if joined_messages and joined_messages[-1]["role"] == msg["role"]: | |
| joined_messages[-1]["content"] += msg["content"] | |
| else: | |
| joined_messages.append(msg) | |
| return {"messages": joined_messages, "files": files} | |
| def bot_response(history: list): | |
| print(type(history)) | |
| combined_inputs = combine_model_inputs(history) | |
| response = prompt_lm(combined_inputs["files"], combined_inputs["messages"]) | |
| history.append({"role": "assistant", "content": response}) | |
| return history | |
| def _chat_tab(examples): | |
| # Status indicator | |
| status_text = gr.Textbox( | |
| value=model_manager.get_status(), | |
| label="Model Status", | |
| interactive=False, | |
| visible=True | |
| ) | |
| chatbot = gr.Chatbot( | |
| label="Chat", | |
| elem_id="chatbot", | |
| bubble_full_width=False, | |
| type="messages", | |
| render_markdown=False, | |
| resizeable=True | |
| ) | |
| chat_input = _multimodal_textbox_factory() | |
| send_all = gr.Button("Send all", elem_id="send-all") | |
| clear_button = gr.ClearButton(components=[chatbot, chat_input], visible=False) | |
| chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input]) | |
| bot_msg = send_all.click( | |
| bot_response, | |
| [chatbot], | |
| [chatbot], | |
| api_name="bot_response", | |
| ) | |
| # Update status after bot response | |
| bot_msg.then(lambda: model_manager.get_status(), None, [status_text]) | |
| bot_msg.then(lambda: gr.ClearButton(visible=True), None, [clear_button]) | |
| clear_button.click(lambda: gr.ClearButton(visible=False), None, [clear_button]) | |
| gr.Examples( | |
| list(examples.values()), | |
| chatbot, | |
| chatbot, | |
| example_labels=list(examples.keys()), | |
| examples_per_page=20, | |
| ) | |
| def summarize_batch_results(results): | |
| summary = Counter(results) | |
| summary_str = "\n".join(f"{k}: {v}" for k, v in summary.most_common()) | |
| return summary_str | |
| def run_batch_inference(files, task, progress=gr.Progress()) -> str: | |
| model = model_manager.load_model() | |
| if model is None: | |
| if model_manager.is_loading: | |
| return "π Loading model... This may take a few minutes on first use. Please try again in a moment." | |
| elif model_manager.load_failed: | |
| return "β Model failed to load. This could be due to:\nβ’ No internet connection\nβ’ Insufficient disk space\nβ’ Model repository access issues\n\nPlease check your connection and try again." | |
| else: | |
| return "Demo mode: Model not loaded. Please check the model configuration." | |
| outputs = [] | |
| prompt = "<Audio><AudioHere></Audio> " + task | |
| for file in progress.tqdm(files): | |
| outputs.append(prompt_lm([file], [{"role": "user", "content": prompt}])) | |
| batch_summary: str = summarize_batch_results(outputs) | |
| report = f"Batch summary:\n{batch_summary}\n\n" | |
| return report | |
| def multi_extension_glob_mask(mask_base, *extensions): | |
| mask_ext = ["[{}]".format("".join(set(c))) for c in zip(*extensions)] | |
| if not mask_ext or len(set(len(e) for e in extensions)) > 1: | |
| mask_ext.append("*") | |
| return mask_base + "".join(mask_ext) | |
| def _batch_tab(file_selection: Literal["upload", "explorer"] = "upload"): | |
| if file_selection == "explorer": | |
| files = gr.FileExplorer( | |
| glob=multi_extension_glob_mask("**.", "mp3", "flac", "wav"), | |
| label="Select audio files", | |
| file_count="multiple", | |
| ) | |
| elif file_selection == "upload": | |
| files = gr.Files(label="Uploaded files", file_types=["audio"], height=300) | |
| task = gr.Textbox(label="Task", placeholder="Enter task...", show_label=True) | |
| process_btn = gr.Button("Process") | |
| output = gr.TextArea() | |
| process_btn.click( | |
| run_batch_inference, | |
| [files, task], | |
| [output], | |
| ) | |
| def to_raven_format(outputs: dict[int, str], chunk_len: int = 10) -> str: | |
| def get_line(row, start, end, annotation): | |
| return f"{row}\tSpectrogram 1\t1\t{start}\t{end}\t0\t8000\t{annotation}" | |
| raven_output = ["Selection\tView\tChannel\tBegin Time (s)\tEnd Time (s)\tLow Freq (Hz)\tHigh Freq (Hz)\tAnnotation"] | |
| current_offset = 0 | |
| last_label = "" | |
| row = 1 | |
| for offset, label in sorted(outputs.items()): | |
| if label != last_label and last_label: | |
| raven_output.append(get_line(row, current_offset, offset, last_label)) | |
| current_offset = offset | |
| row += 1 | |
| if not last_label: | |
| current_offset = offset | |
| if label != "None": | |
| last_label = label | |
| else: | |
| last_label = "" | |
| if last_label: | |
| raven_output.append(get_line(row, current_offset, current_offset + chunk_len, last_label)) | |
| return "\n".join(raven_output) | |
| def _run_long_recording_inference(file, task, chunk_len: int = 10, hop_len: int = 5, progress=gr.Progress()): | |
| # Check if model is loading | |
| if model_manager.is_loading: | |
| return "π Loading model... This may take a few minutes on first use. Please try again in a moment.", None | |
| # Check if model failed to load | |
| if model_manager.load_failed: | |
| return "β Model failed to load. This could be due to:\nβ’ No internet connection\nβ’ Insufficient disk space\nβ’ Model repository access issues\n\nPlease refresh the page to try again.", None | |
| model = model_manager.load_model() | |
| if model is None: | |
| return "Demo mode: Model not loaded. Please check the model configuration.", None | |
| cuda_enabled = torch.cuda.is_available() | |
| outputs = {} | |
| offset = 0 | |
| prompt = f"<Audio><AudioHere></Audio> {task}" | |
| prompt = model_manager.config.model.prompt_template.format(prompt) | |
| for batch in progress.tqdm(generate_sample_batches(file, cuda_enabled, chunk_len=chunk_len, hop_len=hop_len)): | |
| prompt_strs = [prompt] * len(batch["audio_chunk_sizes"]) | |
| with torch.cuda.amp.autocast(dtype=torch.float16): | |
| llm_answers = model.generate(batch, model_manager.config.generate, prompts=prompt_strs) | |
| for answer in llm_answers: | |
| outputs[offset] = answer | |
| offset += hop_len | |
| report = f"Number of chunks: {len(outputs)}\n\n" | |
| for offset in sorted(outputs.keys()): | |
| report += f"{offset:02d}s:\t{outputs[offset]}\n" | |
| raven_output = to_raven_format(outputs, chunk_len=chunk_len) | |
| with tempfile.NamedTemporaryFile(mode="w", prefix="raven-", suffix=".txt", delete=False) as f: | |
| f.write(raven_output) | |
| raven_file = f.name | |
| return report, raven_file | |
| def _long_recording_tab(): | |
| audio_input = gr.Audio(label="Upload audio file", type="filepath") | |
| task = gr.Dropdown( | |
| [ | |
| "What are the common names for the species in the audio, if any?", | |
| "Caption the audio.", | |
| "Caption the audio, using the scientific name for any animal species.", | |
| "Caption the audio, using the common name for any animal species.", | |
| "What is the scientific name for the focal species in the audio?", | |
| "What is the common name for the focal species in the audio?", | |
| "What is the family of the focal species in the audio?", | |
| "What is the genus of the focal species in the audio?", | |
| "What is the taxonomic name of the focal species in the audio?", | |
| "What call types are heard from the focal species in the audio?", | |
| "What is the life stage of the focal species in the audio?", | |
| ], | |
| label="Tasks", | |
| allow_custom_value=True, | |
| ) | |
| with gr.Accordion("Advanced options", open=False): | |
| hop_len = gr.Slider(1, 10, 5, label="Hop length (seconds)", step=1) | |
| chunk_len = gr.Slider(1, 10, 10, label="Chunk length (seconds)", step=1) | |
| process_btn = gr.Button("Process") | |
| output = gr.TextArea() | |
| download_raven = gr.DownloadButton("Download Raven file") | |
| process_btn.click( | |
| _run_long_recording_inference, | |
| [audio_input, task, chunk_len, hop_len], | |
| [output, download_raven], | |
| ) | |
| def main( | |
| assets_dir: Path, | |
| cfg_path: str | Path, | |
| options: list[str] = [], | |
| device: str = "cuda", | |
| ): | |
| # Load configuration | |
| try: | |
| cfg = Config.from_sources(yaml_file=cfg_path, cli_args=options) | |
| model_manager.config = cfg | |
| print("Configuration loaded successfully") | |
| except Exception as e: | |
| print(f"Warning: Could not load config: {e}") | |
| print("Running in demo mode") | |
| model_manager.config = None | |
| # Check if assets directory exists, if not create a placeholder | |
| if not assets_dir.exists(): | |
| print(f"Warning: Assets directory {assets_dir} does not exist") | |
| assets_dir.mkdir(exist_ok=True) | |
| # Create placeholder audio files if they don't exist | |
| laz_audio = assets_dir / "Lazuli_Bunting_yell-YELLLAZB20160625SM303143.mp3" | |
| frog_audio = assets_dir / "nri-GreenTreeFrogEvergladesNP.mp3" | |
| robin_audio = assets_dir / "yell-YELLAMRO20160506SM3.mp3" | |
| vireo_audio = assets_dir / "yell-YELLWarblingVireoMammoth20150614T29ms.mp3" | |
| examples = { | |
| "Caption the audio (Lazuli Bunting)": [ | |
| [ | |
| user_message({"path": str(laz_audio)}), | |
| user_message("Caption the audio."), | |
| ] | |
| ], | |
| "Caption the audio (Green Tree Frog)": [ | |
| [ | |
| user_message({"path": str(frog_audio)}), | |
| user_message("Caption the audio, using the common name for any animal species."), | |
| ] | |
| ], | |
| "Caption the audio (American Robin)": [ | |
| [ | |
| user_message({"path": str(robin_audio)}), | |
| user_message("Caption the audio."), | |
| ] | |
| ], | |
| "Caption the audio (Warbling Vireo)": [ | |
| [ | |
| user_message({"path": str(vireo_audio)}), | |
| user_message("Caption the audio."), | |
| ] | |
| ], | |
| } | |
| with gr.Blocks(title="NatureLM-audio", theme=gr.themes.Base(primary_hue="blue", font=[gr.themes.GoogleFont("Noto Sans")])) as app: | |
| header = gr.HTML(""" | |
| <div style="display: flex; align-items: center; gap: 12px;"><h2 style="margin: 0;">NatureLM-audio<span style="font-size: 0.55em; color: #28a745; background: #e6f4ea; padding: 2px 6px; border-radius: 4px; margin-left: 8px; display: inline-block; vertical-align: top;">BETA</span></h2></div> | |
| """) | |
| with gr.Tabs(): | |
| with gr.Tab("Analyze Audio"): | |
| uploaded_audio = gr.State() | |
| with gr.Column(visible=True) as onboarding_message: | |
| gr.HTML(""" | |
| <div style=" | |
| background: transparent; | |
| border: 1px solid #e5e7eb; | |
| border-radius: 8px; | |
| padding: 16px 20px; | |
| display: flex; | |
| align-items: center; | |
| justify-content: space-between; | |
| margin-bottom: 16px; | |
| margin-left: 0; | |
| margin-right: 0; | |
| box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); | |
| "> | |
| <div style="display: flex; padding: 0px; align-items: center; flex: 1;"> | |
| <div style="font-size: 20px; margin-right: 12px;">π</div> | |
| <div style="flex: 1;"> | |
| <div style="font-size: 16px; font-weight: 600; color: #374151; margin-bottom: 4px;">Welcome to NatureLM-audio!</div> | |
| <div style="font-size: 14px; color: #6b7280; line-height: 1.4;">Upload your first audio file below or try a sample from our library.</div> | |
| </div> | |
| </div> | |
| <a href="https://www.earthspecies.org/blog" target="_blank" style=" | |
| padding: 6px 12px; | |
| border-radius: 6px; | |
| font-size: 13px; | |
| font-weight: 500; | |
| cursor: pointer; | |
| border: none; | |
| background: #3b82f6; | |
| color: white; | |
| text-decoration: none; | |
| display: inline-block; | |
| transition: background 0.2s ease; | |
| " | |
| onmouseover="this.style.background='#2563eb';" | |
| onmouseout="this.style.background='#3b82f6';" | |
| >View Tutorial</a> | |
| </div> | |
| """, padding=False) | |
| with gr.Column(visible=True) as upload_section: | |
| audio_input = gr.Audio( | |
| type="filepath", | |
| container=True, | |
| interactive=True, | |
| sources=['upload'] | |
| ) | |
| with gr.Group(visible=False) as chat: | |
| chatbot = gr.Chatbot( | |
| elem_id="chatbot", | |
| type="messages", | |
| render_markdown=False, | |
| feedback_options=["like", "dislike", "wrong species", "incorrect response", "other"], | |
| resizeable=True | |
| ) | |
| chat_input = _multimodal_textbox_factory() | |
| send_all = gr.Button("Send all") | |
| def start_chat_interface(audio_path): | |
| return ( | |
| gr.update(visible=False), # hide onboarding message | |
| gr.update(visible=True), # show upload section | |
| gr.update(visible=True), # show chat box | |
| ) | |
| audio_input.change( | |
| fn=start_chat_interface, | |
| inputs=[audio_input], | |
| outputs=[onboarding_message, upload_section, chat] | |
| ) | |
| chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input]) | |
| send_all.click(bot_response, [chatbot], [chatbot]) | |
| with gr.Tab("Sample Library"): | |
| gr.Markdown("## Sample Library\n\nExplore example audio files below.") | |
| gr.Examples( | |
| list(examples.values()), | |
| chatbot, | |
| chatbot, | |
| example_labels=list(examples.keys()), | |
| examples_per_page=20, | |
| ) | |
| with gr.Tab("π‘ Help"): | |
| gr.Markdown("## User Guide") # to fill out | |
| gr.Markdown("## Share Feedback") # to fill out | |
| gr.Markdown("## FAQs") # to fill out | |
| app.css = """ | |
| .welcome-banner { | |
| background: transparent !important; | |
| border: 1px solid #e5e7eb !important; | |
| border-radius: 8px !important; | |
| padding: 16px 20px !important; | |
| margin-bottom: 16px !important; | |
| box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1) !important; | |
| } | |
| .welcome-banner > div { | |
| background: transparent !important; | |
| } | |
| .welcome-banner button { | |
| margin: 0 4px !important; | |
| } | |
| """ | |
| # Disabling Batch and Long Recording tabs for now | |
| """ with gr.Tab("Batch"): | |
| _batch_tab() | |
| with gr.Tab("Long Recording"): | |
| _long_recording_tab() """ | |
| return app | |
| # Create and launch the app | |
| app = main( | |
| assets_dir=Path("assets"), | |
| cfg_path=Path("configs/inference.yml"), | |
| options=[], | |
| device="cuda", | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() |