| from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel | |
| from mamba_ssm.models.config_mamba import MambaConfig | |
| import torch | |
| from util import Config, GetDevice, GetNumParams | |
| class Model: | |
| def __init__(self, config: Config): | |
| self.__dict__ = dict(config.__dict__) | |
| self.model = MambaLMHeadModel(MambaConfig(**self.params.__dict__)).to(GetDevice()) | |
| self.log() | |
| def log(self): | |
| model_size, rounded_model_size = GetNumParams(self.model) | |
| print(f"Model has {model_size} ({rounded_model_size}) parameters") | |
| print(f"Model's embedding size is {self.params.vocab_size}") | |
| def parameters(self): return self.model.parameters() | |
| def unfreeze(self): self.model.train() | |
| def freeze(self): self.model.eval() | |
| def compute_loss(self, input_ids, labels=None, criterion=None): | |
| lm_logits = self.model(input_ids).logits | |
| labels = input_ids.to(GetDevice()) | |
| shift_logits = lm_logits[:, :-1, :].contiguous() | |
| labels = labels[:, 1:].contiguous() | |
| loss_fct = criterion or torch.nn.CrossEntropyLoss() | |
| lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) | |
| return lm_loss | |
| def generate_text(self, seed_text, num_predict): | |
| max_len = num_predict + len(seed_text) | |
| with torch.no_grad(): | |
| encoded_ids = self.tokenizer.encode(seed_text) | |
| input_ids = torch.tensor(encoded_ids).unsqueeze(0).to(GetDevice()) | |
| output = self.model.generate(input_ids, max_length=max_len) | |
| logits = output[0].tolist() | |
| text = self.tokenizer.decode(logits) | |
| return text | |
| def save_pretrained(self, path='./'): | |
| self.model.save_pretrained(path) |