mamba / trainer.cli.py
flpelerin's picture
Update 2 files
2b1c712
from argparse import ArgumentParser
from util import ConfigParser
from logger import Wandb
from trainer import Trainer
from dataset import Dataset
from tokenizer import Tokenizer
from model import Model
from logger import Wandb
from export import ExportAll
parser = ArgumentParser(
prog='Trainer implementation, using Pytorch',
description=''
)
if __name__ == '__main__':
parser.add_argument('-p', '--config_path')
args = parser.parse_args()
config = ConfigParser(args.config_path).config
dataset = Dataset(config.dataset)
tokenizer = Tokenizer()
tokenizer.train(dataset.text, max_length=config.tokenizer.max_length)
ids = tokenizer.c_encode(dataset.text)
config.model.tokenizer = tokenizer
config.model.params.vocab_size = tokenizer.vocab_size
batches, num_batches = dataset.batch(ids)
config.trainer.num_batches = num_batches
model = Model(config.model)
wandb = Wandb(config.wandb)
config.trainer.model = model
config.trainer.wandb = wandb
trainer = Trainer(config.trainer)
trainer.train(batches)
model.save_pretrained()
tokenizer.to_file('tokenizer.bin')
ExportAll()