# Copyright (2024) Earth Species Project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import hashlib import logging import os from collections import OrderedDict from pathlib import Path from typing import Literal, Union import torch import torch.nn as nn import torch.nn.functional as F from huggingface_hub import PyTorchModelHubMixin from peft import LoraConfig, TaskType, get_peft_model from torch.nn import CrossEntropyLoss from torch.nn.utils.rnn import pad_sequence from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList from NatureLM.checkpoint_utils import save_model_checkpoint from NatureLM.config import BeatsConfig, ModelConfig, save_config_as_yaml from NatureLM.utils import universal_torch_load from .beats.BEATs import BEATs, BEATsConfig from .Qformer import BertConfig, BertLMHeadModel from .utils import StoppingCriteriaSub torch.backends.cuda.matmul.allow_tf32 = True auth_token = os.getenv("llama", None) class AudioEncodingCache: """LRU cache for audio encoding with content-based hashing.""" def __init__(self, capacity: int = 100): self.capacity = capacity self.cache = OrderedDict() self.hits = 0 self.misses = 0 def _compute_hash( self, raw_wav: torch.Tensor, audio_padding_mask: torch.Tensor | None = None ) -> str: """Compute a hash key from the audio tensor and padding mask.""" # Use a sample of the tensor for efficiency (first, middle, last portions) B, L = raw_wav.shape sample_size = min(1000, L) # Sample 1000 points or entire length if smaller # Sample from beginning, middle, and end indices = torch.cat( [ torch.arange(min(sample_size // 3, L)), torch.arange(L // 2, min(L // 2 + sample_size // 3, L)), torch.arange(max(0, L - sample_size // 3), L), ] ) sampled_wav = raw_wav[:, indices].cpu().numpy().tobytes() # Create hash from audio data, shape, and padding mask presence hash_obj = hashlib.sha256(sampled_wav) hash_obj.update(str(raw_wav.shape).encode()) hash_obj.update(str(raw_wav.dtype).encode()) if audio_padding_mask is not None: mask_sample = audio_padding_mask[:, indices].cpu().numpy().tobytes() hash_obj.update(mask_sample) hash_obj.update(str(audio_padding_mask.shape).encode()) else: hash_obj.update(b"no_mask") return hash_obj.hexdigest() def get(self, raw_wav: torch.Tensor, audio_padding_mask: torch.Tensor = None): """Retrieve cached encoding if available.""" key = self._compute_hash(raw_wav, audio_padding_mask) if key in self.cache: self.hits += 1 # Move to end (most recently used) self.cache.move_to_end(key) return self.cache[key] self.misses += 1 return None def put(self, raw_wav: torch.Tensor, audio_padding_mask: torch.Tensor, value: tuple): """Store encoding in cache (on CPU to save GPU memory).""" key = self._compute_hash(raw_wav, audio_padding_mask) # Move tensors to CPU for storage audio_embeds, audio_atts = value cached_value = (audio_embeds.cpu(), audio_atts.cpu()) # Add to cache self.cache[key] = cached_value self.cache.move_to_end(key) # Evict oldest if over capacity if len(self.cache) > self.capacity: self.cache.popitem(last=False) def clear(self): """Clear the cache.""" self.cache.clear() self.hits = 0 self.misses = 0 def get_stats(self): """Get cache statistics.""" total = self.hits + self.misses hit_rate = self.hits / total if total > 0 else 0 return { "hits": self.hits, "misses": self.misses, "hit_rate": hit_rate, "size": len(self.cache), "capacity": self.capacity, } class NatureLM(nn.Module, PyTorchModelHubMixin): def __init__( self, *, llama_path: Path, beats_path: Path | os.PathLike | None = None, beats_cfg: BeatsConfig, freeze_beats: bool = True, use_audio_Qformer: bool = True, max_pooling: bool = False, num_audio_query_token: int = 1, freeze_audio_QFormer: bool = False, window_level_Qformer: bool = True, second_per_window: float = 0.333333, second_stride: float = 0.333333, downsample_factor: int = 4, audio_llama_proj_model: Path | os.PathLike | None = None, freeze_audio_llama_proj: bool = False, lora: bool = True, lora_rank: int = 8, lora_alpha: int = 32, lora_dropout: float = 0.1, flash_attn: Literal["eager", "flash_attention_2"] = "eager", prompt_template: str = "", max_txt_len: int = 128, end_sym: str = "", device: str = "cuda", audio_encoding_cache_size: int = 100, ): super().__init__() self.audio_encoding_cache = ( AudioEncodingCache(capacity=audio_encoding_cache_size) if audio_encoding_cache_size > 0 else None ) self.beats_path = beats_path self.beats_cfg = beats_cfg self.use_audio_Qformer = use_audio_Qformer self.max_pooling = max_pooling self.window_level_Qformer = window_level_Qformer self.second_per_window = second_per_window self.second_stride = second_stride self.downsample_factor = downsample_factor self.lora = lora self.max_txt_len = max_txt_len self.end_sym = end_sym self.prompt_template = prompt_template self.flash_attn = flash_attn logging.info(f"Llama path: {llama_path}") logging.info("Loading Llama Tokenizer") self.llama_tokenizer = AutoTokenizer.from_pretrained( llama_path, use_fast=False, use_auth_token=auth_token ) self.llama_tokenizer.add_special_tokens({"pad_token": "[PAD]"}) self.llama_tokenizer.padding_side = "right" logging.info("Loading Llama Model") if device == "cpu": self.llama_model = AutoModelForCausalLM.from_pretrained( llama_path, torch_dtype=torch.float32, attn_implementation="eager", device_map="cpu", ) # An issue with tiny-llama is that pad_token_id was set to -1, but # model.save_pretrained checks generation configs and does not allow -1 as # pad_token_id self.llama_model.generation_config.pad_token_id = self.llama_tokenizer.pad_token_id else: self.llama_model = AutoModelForCausalLM.from_pretrained( llama_path, torch_dtype=torch.bfloat16, attn_implementation=flash_attn, ) self.llama_model.resize_token_embeddings(len(self.llama_tokenizer)) if self.lora: for param in self.llama_model.parameters(): param.requires_grad = False logging.info("Loading LLaMA Done") self.llama_embed_tokens = self.llama_model.model.embed_tokens if self.lora: logging.info("Setting up LoRA for llama model") self.peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], ) self.llama_model = get_peft_model(self.llama_model, self.peft_config) self.llama_embed_tokens = self.llama_model.model.model.embed_tokens self.llama_model.print_trainable_parameters() logging.info("LoRA Training") logging.info("Loading BEATs Model") self.beats = BEATs(cfg=BEATsConfig(dict(self.beats_cfg))) if self.beats_path: beats_ckpt = universal_torch_load( self.beats_path, cache_mode="none", map_location="cpu" ) self.beats.load_state_dict(beats_ckpt["model"]) self.ln_audio = nn.LayerNorm(self.beats.cfg.encoder_embed_dim) if freeze_beats: for param in self.beats.parameters(): param.requires_grad = False self.beats.eval() logging.info("freeze BEATs") if self.use_audio_Qformer: self.audio_Qformer, self.audio_query_tokens = self.init_audio_Qformer( num_query_token=num_audio_query_token, audio_width=self.beats.cfg.encoder_embed_dim, ) self.audio_Qformer.bert.embeddings.word_embeddings = None self.audio_Qformer.bert.embeddings.position_embeddings = None for layer in self.audio_Qformer.bert.encoder.layer: layer.output = None layer.intermediate = None self.audio_Qformer.cls = None if freeze_audio_QFormer: for param in self.audio_Qformer.parameters(): param.requires_grad = False self.audio_Qformer.eval() self.audio_query_tokens.requires_grad = False logging.info("freeze audio QFormer") logging.info("Loading audio LLAMA proj") self.audio_llama_proj = nn.Linear( self.audio_Qformer.config.hidden_size, self.llama_model.config.hidden_size, ) if audio_llama_proj_model: logging.info(f"Loading audio LLAMA proj from {audio_llama_proj_model}") # audio_llama_proj_weight = torch.load(audio_llama_proj_model, map_location="cpu") audio_llama_proj_weight = universal_torch_load( audio_llama_proj_model, cache_mode="use", map_location="cpu" ) self.load_state_dict(audio_llama_proj_weight["model"], strict=False) if freeze_audio_llama_proj: for param in self.audio_llama_proj.parameters(): param.requires_grad = False self.audio_llama_proj.eval() logging.info("freeze audio LLAMA proj") elif self.max_pooling: hidden_size = ( 768 if self.aves else 768 if self.htsat else 1024 if self.aves_large else self.beats.cfg.encoder_embed_dim ) self.audio_llama_proj = nn.Linear( hidden_size, self.llama_model.config.hidden_size ) # Single embedding, just project to LLM. elif self.htsat: self.audio_llama_proj = nn.Linear( 512, self.llama_model.config.hidden_size ) # Single embedding, just project to LLM. else: # feel free to add other aligners here raise NotImplementedError("Have to use audio qformer") self.config: ModelConfig = None # set this in from_config @classmethod def from_config(cls, config: ModelConfig): model = cls( llama_path=config.llama_path, beats_path=config.beats_path, freeze_beats=config.freeze_beats, use_audio_Qformer=config.use_audio_Qformer, max_pooling=config.max_pooling, num_audio_query_token=config.num_audio_query_token, freeze_audio_QFormer=config.freeze_audio_QFormer, window_level_Qformer=config.window_level_Qformer, second_per_window=config.second_per_window, second_stride=config.second_stride, downsample_factor=config.downsample_factor, audio_llama_proj_model=config.audio_llama_proj_model, freeze_audio_llama_proj=config.freeze_audio_llama_proj, lora=config.lora, lora_rank=config.lora_rank, lora_alpha=config.lora_alpha, lora_dropout=config.lora_dropout, prompt_template=config.prompt_template, max_txt_len=config.max_txt_len, end_sym=config.end_sym, flash_attn=config.flash_attn, device=config.device, ) model.config = config ckpt_path = config.ckpt if ckpt_path: logging.info(f"⏳ Load NatureLM ckpt from: {ckpt_path}") ckpt = universal_torch_load(ckpt_path, cache_mode="use", map_location="cpu") model.load_state_dict(ckpt["model"], strict=False) logging.info("✅ Finished loading from ckpt") return model def _save_to_local( self, output_dir: Union[str, os.PathLike], use_distributed: bool = False, drop_untrained_params: bool = False, ) -> None: output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Save the config config_path = output_dir / "model_config.yaml" save_config_as_yaml(self.config, config_path) # Save the model model_path = output_dir / "model.pt" save_model_checkpoint( self, model_path, drop_untrained_params=drop_untrained_params, use_distributed=use_distributed, ) # Save the tokenizer and llama model tokenizer_path = output_dir / "llama" self.llama_tokenizer.save_pretrained(tokenizer_path) self.llama_model.save_pretrained(tokenizer_path) # Save the audio model if self.beats_path: beats_path = output_dir / "beats.pt" save_model_checkpoint( self.beats, beats_path, drop_untrained_params=drop_untrained_params, cfg=self.beats_cfg, ) # Save the audio projection audio_llama_proj_path = output_dir / "audio_llama_proj.pt" save_model_checkpoint( self.audio_llama_proj, audio_llama_proj_path, drop_untrained_params=drop_untrained_params, ) @staticmethod def init_audio_Qformer(num_query_token, audio_width, num_hidden_layers=2): encoder_config = BertConfig.from_pretrained("bert-base-uncased") encoder_config.num_hidden_layers = num_hidden_layers encoder_config.encoder_width = audio_width # insert cross-attention layer every other block encoder_config.add_cross_attention = True encoder_config.cross_attention_freq = 1 encoder_config.query_length = num_query_token Qformer = BertLMHeadModel(config=encoder_config) query_tokens = nn.Parameter(torch.zeros(1, num_query_token, encoder_config.hidden_size)) query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) return Qformer, query_tokens @property def device(self): return list(self.parameters())[0].device def _encode_auditory_feature(self, audio_embeds, audio_pad_mask): if self.max_pooling: # Max Pooling logic to reduce sequence length # Apply 1D Max Pooling along the time dimension audio_embeds = F.max_pool1d( audio_embeds.transpose(1, 2), kernel_size=self.downsample_factor, stride=self.downsample_factor, ).transpose(1, 2) audio_embeds = self.audio_llama_proj(audio_embeds) # print("audio pad mask is", audio_pad_mask) audio_atts = ~audio_pad_mask # Adjust the padding mask using max pooling audio_atts = F.max_pool1d( audio_atts.unsqueeze(1).float(), kernel_size=self.downsample_factor, stride=self.downsample_factor, ).squeeze(1) audio_atts = audio_atts > 0 # print(f"audio pad mask shape after pooling: {audio_atts.shape}") # print("audio pad mask post", audio_atts) elif self.use_audio_Qformer: # Q-Former logic audio_embeds = self.ln_audio(audio_embeds) # Generate attention mask audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to( audio_embeds.device ) if self.window_level_Qformer: B, T, C = audio_embeds.shape # batch, T, Channels kernel = round( 1500 * self.second_per_window / 30.0 ) # 160 ms patches; calculate kernel size stride = round(1500 * self.second_stride / 30.0) # Calculate stride size kernel = (1, kernel) stride = (1, stride) # Transpose and unfold audio embeddings to create overlapping windows audio_embeds_tr = audio_embeds.transpose(1, 2).unsqueeze(2) audio_embeds_overlap = F.unfold( audio_embeds_tr, kernel_size=kernel, dilation=1, padding=0, stride=stride, ) _, _, L = audio_embeds_overlap.shape audio_embeds_overlap = audio_embeds_overlap.view(B, -1, kernel[1], L) audio_embeds_overlap = torch.permute( audio_embeds_overlap, [0, 3, 2, 1] ) # (B, num_windows, kernel_size, C) audio_embeds = audio_embeds_overlap.reshape(-1, kernel[1], C) audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to( audio_embeds.device ) # Q-Former mechanism query_tokens = self.audio_query_tokens.expand(audio_embeds.shape[0], -1, -1) query_output = self.audio_Qformer.bert( query_embeds=query_tokens, encoder_hidden_states=audio_embeds, encoder_attention_mask=audio_atts, return_dict=True, ) audio_embeds = self.audio_llama_proj(query_output.last_hidden_state) if self.window_level_Qformer: audio_embeds = audio_embeds.view(B, -1, audio_embeds.size(2)).contiguous() audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to( audio_embeds.device ) elif self.htsat: # HTSAT processing audio_embeds = self.ln_audio(audio_embeds) audio_embeds = self.audio_llama_proj(audio_embeds).reshape( -1, 30, self.llama_model.config.hidden_size ) audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to( audio_embeds.device ) else: raise NotImplementedError("no audio qformer or max pooling") return audio_embeds, audio_atts def encode_audio(self, raw_wav, audio_padding_mask=None): # Only use cache during inference (not training) if self.audio_encoding_cache is not None and not self.training: cached_result = self.audio_encoding_cache.get(raw_wav, audio_padding_mask) if cached_result is not None: print("#### Audio encoding cache hit ####") # Move cached tensors back to the model's device audio_embeds, audio_atts = cached_result return audio_embeds.to(self.device), audio_atts.to(self.device) # Compute encoding if not cached with torch.autocast(self.device.type, dtype=torch.bfloat16): audio_embeds, audio_pad_mask = self.beats(raw_wav, padding_mask=audio_padding_mask) result = self._encode_auditory_feature( audio_embeds=audio_embeds, audio_pad_mask=audio_pad_mask ) # Store in cache if enabled and in inference mode if self.audio_encoding_cache is not None and not self.training: self.audio_encoding_cache.put(raw_wav, audio_padding_mask, result) return result def clear_audio_embed_cache(self): """Clear the audio encoding cache.""" if self.audio_encoding_cache is not None: self.audio_encoding_cache.clear() def prompt_wrap(self, audio_embeds, audio_atts, prompt: list[str]): """Merge audio embeddings with embeddings of the tokens in the prompt. Args: audio_embeds (list): List of tensors of audio embeddings. audio_atts (list): List of tensors of audio padding masks. prompt (list): List of strings with the prompt for each sample. Each prompt should contain the placeholder(s) "" to indicate where the audio embeddings should be inserted. Returns: tuple: A tuple containing the wrapped audio embeddings and padding masks. """ def interleave_lists(longer: list, shorter: list) -> list: """Interleave two lists where the first list is one element longer. Args: longer (list): The first list with length n. shorter (list): The second list with length n-1. Returns: list: A new list with elements interleaved from longer and shorter. Example: >>> interleave_lists(['a1', 'a2', 'a3'], ['b1', 'b2']) ['a1', 'b1', 'a2', 'b2', 'a3'] """ interleaved_list = [] for i in range(len(shorter)): interleaved_list.append(longer[i]) interleaved_list.append(shorter[i]) interleaved_list.append(longer[-1]) # last element is from longer return interleaved_list device = audio_embeds[0].device wrapped_embeds_list = [] wrapped_atts_list = [] batch_size = len(prompt) for i in range(batch_size): prompt_parts = prompt[i].split("") wrapped_embeds = [] wrapped_atts = [] for part in prompt_parts: tokens = self.llama_tokenizer( part, return_tensors="pt", add_special_tokens=False ).to(device) part_embeds = self.llama_embed_tokens(tokens.input_ids).squeeze(0) part_atts = tokens.attention_mask.squeeze(0) wrapped_embeds.append(part_embeds) wrapped_atts.append(part_atts) # Process each element in the batch to remove padding if self.max_pooling: audio_embeds[i] = list(audio_embeds[i].unbind(0)) audio_atts[i] = list(audio_atts[i].unbind(0)) for j in range(len(audio_embeds[i])): audio_embeds[i][j] = audio_embeds[i][j][audio_atts[i][j]] audio_atts[i][j] = audio_atts[i][j][audio_atts[i][j]] # Interleave wrapped_embeds and audio_embeds using interleave_lists wrapped_embeds = interleave_lists(wrapped_embeds, audio_embeds[i]) wrapped_atts = interleave_lists(wrapped_atts, audio_atts[i]) wrapped_embeds = torch.cat(wrapped_embeds, dim=0) wrapped_atts = torch.cat(wrapped_atts, dim=0) wrapped_embeds_list.append(wrapped_embeds) wrapped_atts_list.append(wrapped_atts) wrapped_embeds = pad_sequence(wrapped_embeds_list, batch_first=True) wrapped_atts = pad_sequence(wrapped_atts_list, batch_first=True) return wrapped_embeds, wrapped_atts def forward(self, samples, verbose=True): # Prepare prompts prompt = samples["prompt"] prompt = [self.prompt_template.format(p) for p in prompt] # Use audio/audio encoder to encode audio/audio raw_wav = samples.get("raw_wav", None) audio_padding_mask = samples.get("padding_mask", None) audio_embeds, audio_atts = self.encode_audio(raw_wav, audio_padding_mask) audio_chunk_sizes = samples["audio_chunk_sizes"] split_audio_embeds = list(torch.split(audio_embeds, audio_chunk_sizes, dim=0)) split_audio_atts = list(torch.split(audio_atts, audio_chunk_sizes, dim=0)) # Wrap audio_embeds with prompts audio_embeds, audio_atts = self.prompt_wrap(split_audio_embeds, split_audio_atts, prompt) # Prepare inputs for LLM text = [t + self.end_sym for t in samples["text"]] to_regress_tokens = self.llama_tokenizer( text, return_tensors="pt", padding="longest", truncation=True, max_length=self.max_txt_len, add_special_tokens=False, ).to(audio_embeds.device) to_regress_embeds = self.llama_embed_tokens(to_regress_tokens.input_ids) # Prepare targets targets = to_regress_tokens.input_ids.masked_fill( to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100 ) batch_size = audio_embeds.size(0) # BOS token embeddings bos_token_id = self.llama_tokenizer.bos_token_id bos = torch.full( (batch_size, 1), bos_token_id, dtype=torch.long, device=audio_embeds.device ) bos_embeds = self.llama_embed_tokens(bos) # Prepare lists to collect per-sample embeddings, attention masks, and targets inputs_embeds_list = [] attention_mask_list = [] targets_list = [] for i in range(batch_size): # Extract non-padded audio embeddings and attention mask audio_embed = audio_embeds[i][audio_atts[i].bool()] audio_att = audio_atts[i][audio_atts[i].bool()] # Extract non-padded text embeddings and attention mask text_embed = to_regress_embeds[i][to_regress_tokens.attention_mask[i].bool()] text_att = to_regress_tokens.attention_mask[i][ to_regress_tokens.attention_mask[i].bool() ] # Extract corresponding targets for the text tokens target = targets[i][to_regress_tokens.attention_mask[i].bool()] # Concatenate embeddings: BOS token, audio embeddings, text embeddings input_embeds = torch.cat([bos_embeds[i], audio_embed, text_embed], dim=0) # Concatenate attention masks: BOS token mask, audio attention mask, text attention mask att_mask = torch.cat( [ torch.ones(1, device=audio_embeds.device, dtype=audio_att.dtype), audio_att, text_att, ], dim=0, ) # Create targets: Ignore index (-100) for BOS and audio tokens, actual targets for text tokens ignore_targets = torch.full( (1 + audio_embed.size(0),), -100, device=audio_embeds.device, dtype=targets.dtype, ) sample_targets = torch.cat([ignore_targets, target], dim=0) # Append to lists inputs_embeds_list.append(input_embeds) attention_mask_list.append(att_mask) targets_list.append(sample_targets) # Pad sequences to the maximum length in the batch inputs_embeds_padded = pad_sequence(inputs_embeds_list, batch_first=True) attention_mask_padded = pad_sequence(attention_mask_list, batch_first=True, padding_value=0) targets_padded = pad_sequence(targets_list, batch_first=True, padding_value=-100) # Now use the padded embeddings, attention masks, and targets in the model with torch.autocast(self.device.type, dtype=torch.bfloat16): outputs = self.llama_model( inputs_embeds=inputs_embeds_padded, attention_mask=attention_mask_padded, return_dict=True, labels=targets_padded, ) loss = outputs.loss # Original batch loss # Compute per-example loss nvocab = self.llama_model.config.vocab_size logits = outputs.logits shift_logits = logits[..., :-1, :].contiguous() shift_labels = targets_padded[..., 1:].contiguous() # Compute loss per token loss_fct_per_example = CrossEntropyLoss(reduction="none") loss_per_token = loss_fct_per_example( shift_logits.view(-1, nvocab), # Flatten to [batch_size * (seq_len-1), vocab_size] shift_labels.view(-1), # Flatten to [batch_size * (seq_len-1)] ) loss_per_token = loss_per_token.view( shift_labels.size() ) # Reshape back to [batch_size, seq_len-1] # Create mask mask = shift_labels != -100 # [batch_size, seq_len-1] # Apply mask to loss_per_token loss_per_token = loss_per_token * mask.float() # Compute per-example loss loss_per_example = loss_per_token.sum(dim=1) / mask.sum(dim=1).clamp(min=1) if verbose: # Calculate predictions predicted_tokens = shift_logits.argmax(dim=-1) # [batch_size, seq_len-1] # Compute per-example correct counts correct_per_sample = ( ((predicted_tokens == shift_labels) & mask).sum(dim=1).float() ) # [batch_size] total_tokens_per_sample = mask.sum(dim=1).float() # [batch_size] # Total correct and total tokens across the batch correct = correct_per_sample.sum() total = total_tokens_per_sample.sum() return { "loss": loss, "correct": correct, "total": total, "per_example_loss": loss_per_example, "correct_per_sample": correct_per_sample, "total_per_sample": total_tokens_per_sample, } return {"loss": loss, "per_example_loss": loss_per_example} def model_merging_scaling(self, merging_alpha, adapter_name="default"): """ Performs model merging with the base model by adjusting the scaling of the LoRA adapters as described in "Model Merging Improves Zero-Shot Generalization in Bioacoustic Foundation Models" (https://arxiv.org/abs/2511.05171). The best value for alpha is task- and dataset-specific, but the paper found alpha values between 0.4 and 0.6 to perform generally well. Args: merging_alpha: The merging_alpha used for interpolation. adapter_name (str): The name of the adapter to rescale when merging. """ # Store original scaling on first call, then always scale relative to original if not hasattr(self, "_original_lora_scaling"): self._original_lora_scaling = {} for name, module in self.llama_model.named_modules(): if hasattr(module, "r") and isinstance(module.r, dict) and adapter_name in module.r: self._original_lora_scaling[name] = module.scaling[adapter_name] for name, module in self.llama_model.named_modules(): if name in self._original_lora_scaling: module.scaling[adapter_name] = merging_alpha * self._original_lora_scaling[name] @torch.inference_mode() def generate(self, samples, generate_cfg, prompts) -> list[str]: merging_alpha = getattr(generate_cfg, "merging_alpha", 1.0) if merging_alpha != 1.0: self.model_merging_scaling(merging_alpha) batch_size = len(prompts) raw_wav = samples["raw_wav"] audio_padding_mask = samples.get("padding_mask", None) audio_embeds, audio_atts = self.encode_audio(raw_wav, audio_padding_mask=audio_padding_mask) split_audio_embeds = list(torch.split(audio_embeds, samples["audio_chunk_sizes"], dim=0)) split_audio_atts = list(torch.split(audio_atts, samples["audio_chunk_sizes"], dim=0)) audio_embeds, audio_atts = self.prompt_wrap(split_audio_embeds, split_audio_atts, prompts) bos = ( torch.ones( [batch_size, 1], dtype=torch.int32, device=audio_embeds.device, ) * self.llama_tokenizer.bos_token_id ) bos_embeds = self.llama_embed_tokens(bos) atts_bos = audio_atts[:, :1] embeds = torch.cat([bos_embeds, audio_embeds], dim=1) attns = torch.cat([atts_bos, audio_atts], dim=1) stop_words_ids = [torch.tensor([2]).to(audio_embeds.device)] stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) with torch.autocast(self.device.type, dtype=torch.bfloat16): outputs = self.llama_model.generate( # TODO: Wrap the llama_model with outlines https://outlines-dev.github.io/outlines/reference/models/transformers/ inputs_embeds=embeds.bfloat16(), max_new_tokens=generate_cfg.max_new_tokens, stopping_criteria=stopping_criteria, num_beams=generate_cfg.num_beams, do_sample=generate_cfg.do_sample, min_length=generate_cfg.min_length, temperature=generate_cfg.temperature, # top_p=generate_cfg.get("top_p", 0.9), repetition_penalty=generate_cfg.repetition_penalty, length_penalty=generate_cfg.length_penalty, attention_mask=attns.bfloat16(), # prefix_allowed_tokens_fn=prefix_tokens_fn # logits_processor=None # constraints=[constraint] if constraint is not None else None ) text = self.llama_tokenizer.batch_decode(outputs, skip_special_tokens=True) return text