from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel from mamba_ssm.models.config_mamba import MambaConfig import torch from util import Config, GetDevice, GetNumParams class Model: def __init__(self, config: Config): self.__dict__ = dict(config.__dict__) self.model = MambaLMHeadModel(MambaConfig(**self.params.__dict__)).to(GetDevice()) self.log() def log(self): model_size, rounded_model_size = GetNumParams(self.model) print(f"Model has {model_size} ({rounded_model_size}) parameters") print(f"Model's embedding size is {self.params.vocab_size}") def parameters(self): return self.model.parameters() def unfreeze(self): self.model.train() def freeze(self): self.model.eval() def compute_loss(self, input_ids, labels=None, criterion=None): lm_logits = self.model(input_ids).logits labels = input_ids.to(GetDevice()) shift_logits = lm_logits[:, :-1, :].contiguous() labels = labels[:, 1:].contiguous() loss_fct = criterion or torch.nn.CrossEntropyLoss() lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) return lm_loss def generate_text(self, seed_text, num_predict): max_len = num_predict + len(seed_text) with torch.no_grad(): encoded_ids = self.tokenizer.encode(seed_text) input_ids = torch.tensor(encoded_ids).unsqueeze(0).to(GetDevice()) output = self.model.generate(input_ids, max_length=max_len) logits = output[0].tolist() text = self.tokenizer.decode(logits) return text def save_pretrained(self, path='./'): self.model.save_pretrained(path)