Spaces:
Runtime error
Runtime error
| import torch | |
| import os | |
| import logging | |
| import torch.nn.functional as F | |
| from slam_llm.models.slam_model import ( | |
| slam_model, | |
| setup_tokenizer, | |
| setup_encoder, | |
| setup_encoder_projector, | |
| setup_llm, | |
| ) | |
| from slam_llm.utils.train_utils import print_model_size | |
| from typing import List, Optional | |
| from slam_llm.utils.metric import compute_accuracy | |
| from transformers import T5ForConditionalGeneration | |
| from tqdm import tqdm | |
| from utils.tts_adapter_utils import setup_tts_adapter | |
| from utils.codec_utils import setup_codec | |
| from utils.trick_utils import partial_freeze_weights, train_embedding_layer_only | |
| from utils.snac_utils import layershift | |
| logger = logging.getLogger(__name__) | |
| def model_factory(train_config, model_config, ckpt_path, **kwargs): | |
| # return necessary components for training | |
| tokenizer = setup_tokenizer(train_config, model_config, **kwargs) | |
| if train_config.task_type == "s2s" or train_config.task_type == "asr": | |
| encoder = setup_encoder(train_config, model_config, **kwargs) | |
| elif train_config.task_type == "tts": | |
| encoder = None | |
| else: | |
| raise NotImplementedError | |
| # llm | |
| llm = setup_llm(train_config, model_config, **kwargs) | |
| # projector | |
| if encoder is not None: | |
| encoder_projector = setup_encoder_projector( | |
| train_config, model_config, **kwargs | |
| ) | |
| else: | |
| encoder_projector = None | |
| codec_decoder = None | |
| if model_config.codec_decode: | |
| codec_decoder = setup_codec(train_config, model_config, **kwargs) | |
| tts_adapter = None | |
| if model_config.tts_adapter: | |
| adapter_config = model_config.tts_adapter_config | |
| tts_adapter = setup_tts_adapter(adapter_config, model_config, **kwargs) | |
| model = slam_model_s2s( | |
| encoder, | |
| llm, | |
| encoder_projector, | |
| tokenizer, | |
| tts_adapter, | |
| codec_decoder, | |
| train_config, | |
| model_config, | |
| **kwargs, | |
| ) | |
| if ckpt_path is not None: | |
| logger.info("loading other parts from: {}".format(ckpt_path)) | |
| ckpt_dict = torch.load(ckpt_path, map_location="cpu") | |
| model.load_state_dict(ckpt_dict, strict=False) | |
| if train_config.train_audio_embed_only: | |
| partial_freeze_weights(model, model_config.vocab_config.padded_text_vocabsize, model_config.vocab_config.total_vocabsize) | |
| if train_config.train_embed_only: | |
| train_embedding_layer_only(model) | |
| print_model_size( | |
| model, | |
| train_config, | |
| ( | |
| int(os.environ["RANK"]) | |
| if train_config.enable_fsdp or train_config.enable_ddp | |
| else 0 | |
| ), | |
| ) | |
| return model, tokenizer | |
| class slam_model_s2s(slam_model): | |
| def __init__( | |
| self, | |
| encoder, | |
| llm, | |
| encoder_projector, | |
| tokenizer, | |
| tts_adapter, | |
| codec_decoder, | |
| train_config, | |
| model_config, | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| encoder, | |
| llm, | |
| encoder_projector, | |
| tokenizer, | |
| train_config, | |
| model_config, | |
| **kwargs, | |
| ) | |
| # resize llm embedding layer | |
| self.original_vocabsize = self.llm.lm_head.weight.size(0) | |
| if self.model_config.vocab_config.total_vocabsize != self.original_vocabsize: | |
| self.llm.resize_token_embeddings(self.model_config.vocab_config.total_vocabsize) | |
| if int(os.environ.get("RANK", "0")) == 0: | |
| logger.info("Resize llm embedding layer's vocab size to {}".format(self.model_config.vocab_config.total_vocabsize)) | |
| self.codec_decoder = codec_decoder | |
| self.tts_adapter = tts_adapter | |
| self.code_layer = self.model_config.vocab_config.code_layer | |
| def forward(self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| **kwargs, | |
| ): | |
| audio_mel = kwargs.get("audio_mel", None) | |
| audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper | |
| audio = kwargs.get("audio", None) | |
| audio_mask = kwargs.get("audio_mask", None) | |
| modality_mask = kwargs.get("modality_mask", None) | |
| encoder_outs = None | |
| if audio_mel is not None or audio is not None: | |
| if self.train_config.freeze_encoder: # freeze encoder | |
| self.encoder.eval() | |
| if self.model_config.encoder_name == "whisper": | |
| encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1)) # bs*seq*dim | |
| if self.model_config.encoder_name == "wavlm": | |
| encoder_outs = self.encoder.extract_features(audio, 1 - audio_mask) #(FIX:MZY): 1-audio_mask is needed for wavlm as the padding mask | |
| if self.model_config.encoder_name == "hubert": | |
| results = self.encoder(source = audio, padding_mask = 1-audio_mask) | |
| if self.model_config.encoder_type == "pretrain": | |
| encoder_outs, audio_mel_post_mask = results["x"], results["padding_mask"] | |
| if self.model_config.encoder_type == "finetune": | |
| encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"] | |
| encoder_outs = encoder_outs.transpose(0, 1) | |
| if self.encoder is None: | |
| encoder_outs = audio_mel if audio_mel is not None else audio | |
| if self.model_config.encoder_projector == "q-former": | |
| encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) | |
| if self.model_config.encoder_projector == "linear": | |
| encoder_outs = self.encoder_projector(encoder_outs) | |
| if self.model_config.encoder_projector == "cov1d-linear": | |
| encoder_outs = self.encoder_projector(encoder_outs) | |
| if input_ids is not None: | |
| input_ids[input_ids == -1] = 0 # [btz, 8, seq_length] | |
| if isinstance(self.llm, T5ForConditionalGeneration): | |
| inputs_embeds = self.llm.shared(input_ids) | |
| else: | |
| if hasattr(self.llm.model, "embed_tokens"): | |
| inputs_embeds = self.llm.model.embed_tokens(input_ids) # [btz, 8, seq_length, emb_dim] | |
| elif hasattr(self.llm.model.model, "embed_tokens"): | |
| inputs_embeds = self.llm.model.model.embed_tokens(input_ids) | |
| else: | |
| inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids) | |
| if modality_mask is not None and encoder_outs is not None: | |
| modality_mask = modality_mask.unsqueeze(1).repeat(1, self.code_layer, 1) # [btz, 8, seq_length] | |
| modality_mask_start_indices = (modality_mask == True).float().argmax(dim=2) | |
| modality_lengths = torch.clamp(modality_mask.sum(dim=2), max=encoder_outs.shape[1]).tolist() | |
| encoder_outs_pad = torch.zeros_like(inputs_embeds) | |
| for i in range(encoder_outs.shape[0]): | |
| for j in range(self.code_layer): | |
| start_idx = modality_mask_start_indices[i, j].item() | |
| length = modality_lengths[i][j] | |
| encoder_outs_pad[i, j, start_idx:start_idx+length] = encoder_outs[i, :length] | |
| inputs_embeds[:, :self.code_layer, :, :] = encoder_outs_pad[:, :self.code_layer, :, :] + inputs_embeds[:, :self.code_layer, :, :] * (~modality_mask[:, :, :, None]) | |
| inputs_embeds = torch.mean(inputs_embeds, dim=1) # [btz, seq_length, emb_dim], average over the 8 layers | |
| if kwargs.get("inference_mode", False): | |
| return inputs_embeds, attention_mask | |
| text_labels = labels[:,self.code_layer] if labels is not None else None | |
| audio_labels = labels[:, :self.code_layer] if labels is not None else None | |
| model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=text_labels) # here we use the text token layer as the target label | |
| # parrallel generation | |
| # TODO: add tts adapter forward | |
| x_ori = model_outputs.logits | |
| text_vocab_size = self.model_config.vocab_config.padded_text_vocabsize | |
| audio_vocab_size = self.model_config.vocab_config.padded_audio_vocabsize | |
| xt = x_ori[..., :text_vocab_size] | |
| xa = [] | |
| for i in range(self.code_layer): | |
| xa.append(x_ori[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)]) | |
| loss_recorder = [] | |
| total_loss, loss_recorder = self.compute_parallel_loss(xt, text_labels, xa, audio_labels) | |
| model_outputs.loss = total_loss | |
| text_acc = -1 | |
| audio_acc = [-1 for _ in range(self.code_layer)] | |
| if self.metric: | |
| with torch.no_grad(): | |
| preds = torch.argmax(xt, -1) | |
| text_acc = compute_accuracy(preds.detach()[:, :-1], text_labels.detach()[:, 1:], ignore_label=-100) | |
| preds_audio = [torch.argmax(xa[i], -1) for i in range(self.code_layer)] | |
| audio_acc = [compute_accuracy(preds_audio[i].detach()[:, :-1], audio_labels[:, i, 1:], ignore_label=-100) for i in range(self.code_layer)] | |
| # metrics = {"text_acc": text_acc, "audio_acc": audio_acc, "layer_loss": loss_recorder} | |
| return model_outputs, text_acc, audio_acc, loss_recorder | |
| def compute_parallel_loss(self, xt, text_labels, xa, audio_labels): | |
| """ | |
| Compute the parallel loss for text and audio layers. | |
| """ | |
| text_vocab_size = self.model_config.vocab_config.padded_text_vocabsize | |
| audio_vocab_size = self.model_config.vocab_config.padded_audio_vocabsize | |
| layer_loss = [0 for _ in range(self.code_layer+1) ] | |
| if text_labels is not None: | |
| # text_loss = F.cross_entropy(xt.reshape(-1, text_vocab_size), text_labels.reshape(-1), ignore_index=-100) | |
| text_loss = F.cross_entropy(xt[:, :-1, :].reshape(-1, text_vocab_size), text_labels[:, 1:].reshape(-1), ignore_index=-100) | |
| layer_loss[self.code_layer] = text_loss | |
| else: | |
| text_loss = 0 | |
| total_audio_loss = 0 | |
| single_audio_loss = 0 | |
| for i in range(self.code_layer): | |
| if audio_labels[:,i] is not None: | |
| # audio_loss += F.cross_entropy(xa[i].reshape(-1, audio_vocab_size), audio_labels[:,i].reshape(-1), ignore_index=-100) | |
| single_audio_loss = F.cross_entropy(xa[i][:, :-1, :].reshape(-1, audio_vocab_size), audio_labels[:, i, 1:].reshape(-1), ignore_index=-100) | |
| layer_loss[i] = single_audio_loss | |
| total_audio_loss += single_audio_loss | |
| total_loss = (text_loss + total_audio_loss) / (self.code_layer+1) | |
| return total_loss, layer_loss | |
| def generate(self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| **kwargs, | |
| ): | |
| kwargs["inference_mode"] = True | |
| inputs_embeds, attention_mask = self.forward( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| labels=labels, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| **kwargs, | |
| ) | |
| generated_ids = [[] for _ in range((self.code_layer+1))] | |
| current_input_text = None | |
| current_audio_tokens = [None for _ in range(self.code_layer)] | |
| # input_pos = torch.arange(input_ids.size(-1), device=input_ids.device).unsqueeze(0) | |
| past_key_values = None | |
| text_vocab_size = self.model_config.vocab_config.padded_text_vocabsize | |
| audio_vocab_size = self.model_config.vocab_config.padded_audio_vocabsize | |
| max_new_tokens = kwargs.get("max_new_tokens", 360) | |
| repetition_penalty = kwargs.get("repetition_penalty", 1.0) | |
| decode_text_only = kwargs.get("decode_text_only", False) | |
| pad_t = self.model_config.vocab_config.pad_t | |
| pad_a = self.model_config.vocab_config.pad_a | |
| eot = self.model_config.vocab_config.eot | |
| eoa = self.model_config.vocab_config.eoa | |
| text_end = False # Track whether text generation has ended | |
| audio_end = False # Track whether audio generation has ended | |
| # NOTE: currently, we only support greedy decoding and sampling for parallel generation, no beam search | |
| for step in tqdm(range(max_new_tokens), desc="Generating"): | |
| if current_input_text is not None: | |
| audio_tokens = torch.cat([layershift(current_audio_tokens[i], i).unsqueeze(1) for i in range(self.code_layer)], dim=1) | |
| combined_input_ids = torch.cat([audio_tokens, current_input_text.unsqueeze(1)], dim=1) | |
| inputs_embeds = self.llm.model.embed_tokens(combined_input_ids) | |
| inputs_embeds = torch.mean(inputs_embeds, dim=1).unsqueeze(1) | |
| outputs = self.llm( | |
| inputs_embeds=inputs_embeds, # [btz, seq_len / 1, emb_dim] | |
| attention_mask=attention_mask, # single sample, no need for attention mask | |
| past_key_values=past_key_values, | |
| # position_ids=input_pos, | |
| use_cache=True, | |
| ) | |
| logits = outputs.logits | |
| past_key_values = outputs.past_key_values # Update past_key_values for the next step | |
| # Split logits into text and audio layers based on vocab size | |
| xt_logits = logits[..., :text_vocab_size] | |
| xa_logits = [logits[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)] for i in range(self.code_layer)] | |
| # Apply repetition penalty to the logits | |
| if repetition_penalty != 1.0: | |
| xt_logits = self.repetition_penalty(xt_logits, generated_ids[self.code_layer], repetition_penalty) | |
| for i in range(self.code_layer): | |
| xa_logits[i] = self.repetition_penalty(xa_logits[i], generated_ids[i], repetition_penalty) | |
| if not text_end: | |
| next_token_text = self.sample_next_token(xt_logits[:, -1, :], **kwargs) | |
| else: | |
| next_token_text = torch.tensor([pad_t], device=input_ids.device) | |
| next_tokens_audio = [] | |
| for i in range(self.code_layer): | |
| if not audio_end and not decode_text_only: | |
| next_token_audio = self.sample_next_token(xa_logits[i][:, -1, :], **kwargs) | |
| else: | |
| next_token_audio = torch.full((input_ids.size(0),), pad_a, device=input_ids.device) | |
| next_tokens_audio.append(next_token_audio) | |
| if next_tokens_audio[-1] == eoa or decode_text_only: | |
| audio_end = True | |
| if next_token_text == eot: | |
| text_end = True | |
| # Update input_ids for the next step | |
| current_input_text = next_token_text | |
| for i in range(self.code_layer): | |
| current_audio_tokens[i] = next_tokens_audio[i] | |
| # if input_pos.size(-1) > 1: | |
| # input_pos = torch.tensor(input_pos.size(-1), device=input_ids.device).unsqueeze(0) | |
| # else: | |
| # input_pos = input_pos.add_(1) | |
| attention_mask = torch.cat([attention_mask, torch.ones((input_ids.size(0), 1), device=input_ids.device)], dim=1) | |
| if audio_end and text_end: | |
| break | |
| # Append generated tokens to the list | |
| for i in range(self.code_layer): | |
| generated_ids[i].append(next_tokens_audio[i].clone().tolist()[0]) # Audio layers | |
| generated_ids[self.code_layer].append(next_token_text.clone().tolist()[0]) # Text layer | |
| # Concatenate the generated tokens to form the complete sequence | |
| text_tokens = generated_ids[-1] | |
| generated_ids[-1] = text_tokens[: text_tokens.index(eot)] if eot in text_tokens else text_tokens | |
| generated_ids = [torch.tensor(layer) for layer in generated_ids] | |
| return generated_ids | |
| def sample_next_token(self, logits, **kwargs): | |
| """ | |
| Generate the next token based on the model output logits. | |
| Supports both greedy decoding, top-k sampling, and top-p (nucleus) sampling. | |
| """ | |
| do_sample = kwargs.get("do_sample", False) | |
| temperature = kwargs.get("temperature", 1.0) | |
| top_k = kwargs.get("top_k", 50) | |
| top_p = kwargs.get("top_p", 1.0) | |
| num_samples = kwargs.get("num_samples", 1) | |
| # Adjust logits with temperature | |
| logits = logits.squeeze(0) | |
| logits = logits / temperature | |
| # Top-k filtering | |
| if top_k > 0: | |
| top_k = min(top_k, logits.size(-1)) # Make sure top_k is within the vocab size | |
| values, indices = torch.topk(logits, top_k) | |
| logits[logits < values[..., [-1]]] = -float('Inf') # Filter tokens not in top_k | |
| # Top-p filtering (nucleus sampling) | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| # Remove tokens with cumulative probability above the threshold | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
| logits[indices_to_remove] = -float('Inf') | |
| if do_sample: | |
| # Perform sampling | |
| return torch.multinomial(F.softmax(logits, dim=-1), num_samples=num_samples) | |
| else: | |
| # Greedy decoding (argmax) | |
| return torch.argmax(logits, dim=-1, keepdim=True) | |
| def repetition_penalty(self, logits, generated_ids, repetition_penalty): | |
| """ | |
| Apply repetition penalty to the logits. | |
| """ | |
| for token_id in set(generated_ids): | |
| if logits[0, -1, token_id] < 0: | |
| logits[0, -1, token_id] *= repetition_penalty | |
| else: | |
| logits[0, -1, token_id] /= repetition_penalty | |
| return logits |