mamba / trainer.py
flpelerin's picture
Update file trainer.py
0ebab57
import torch
from util import Config
class Trainer:
def __init__(self, config: Config):
self.__dict__ = dict(config.__dict__)
def log(self, loss: float):
print(f"Epoch: {self.epoch} / {self.num_epochs}\t\tBatch: {self.batch} / {self.num_batches}\t\tLoss: {round(loss, 4)}")
args = {'epoch': self.epoch, 'batch': self.batch, 'loss': loss}
self.wandb(args)
if self.inference.frequency != 0:
if self.batch % self.inference.frequency == 0:
print(f'{self.model.generate_text(self.inference.seed_text, self.inference.n_predict)}')
def train(self, batches):
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
self.model.unfreeze()
for self.epoch in range(self.num_epochs):
for self.batch in range(self.num_batches):
ids = batches[self.batch]
loss = self.model.compute_loss(ids)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.log(loss.item())