mamba / model.py
flpelerin's picture
Update 2 files
0049910
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
import torch
from util import Config, GetDevice, GetNumParams
class Model:
def __init__(self, config: Config):
self.__dict__ = dict(config.__dict__)
self.model = MambaLMHeadModel(MambaConfig(**self.params.__dict__)).to(GetDevice())
self.log()
def log(self):
model_size, rounded_model_size = GetNumParams(self.model)
print(f"Model has {model_size} ({rounded_model_size}) parameters")
print(f"Model's embedding size is {self.params.vocab_size}")
def parameters(self): return self.model.parameters()
def unfreeze(self): self.model.train()
def freeze(self): self.model.eval()
def compute_loss(self, input_ids, labels=None, criterion=None):
lm_logits = self.model(input_ids).logits
labels = input_ids.to(GetDevice())
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = criterion or torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
return lm_loss
def generate_text(self, seed_text, num_predict):
max_len = num_predict + len(seed_text)
with torch.no_grad():
encoded_ids = self.tokenizer.encode(seed_text)
input_ids = torch.tensor(encoded_ids).unsqueeze(0).to(GetDevice())
output = self.model.generate(input_ids, max_length=max_len)
logits = output[0].tolist()
text = self.tokenizer.decode(logits)
return text
def save_pretrained(self, path='./'):
self.model.save_pretrained(path)