| import torch | |
| from util import Config | |
| class Trainer: | |
| def __init__(self, config: Config): | |
| self.__dict__ = dict(config.__dict__) | |
| def log(self, loss: float): | |
| print(f"Epoch: {self.epoch} / {self.num_epochs}\t\tBatch: {self.batch} / {self.num_batches}\t\tLoss: {round(loss, 4)}") | |
| args = {'epoch': self.epoch, 'batch': self.batch, 'loss': loss} | |
| self.wandb(args) | |
| if self.inference.frequency != 0: | |
| if self.batch % self.inference.frequency == 0: | |
| print(f'{self.model.generate_text(self.inference.seed_text, self.inference.n_predict)}') | |
| def train(self, batches): | |
| self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) | |
| self.model.unfreeze() | |
| for self.epoch in range(self.num_epochs): | |
| for self.batch in range(self.num_batches): | |
| ids = batches[self.batch] | |
| loss = self.model.compute_loss(ids) | |
| self.optimizer.zero_grad() | |
| loss.backward() | |
| self.optimizer.step() | |
| self.log(loss.item()) | |