from argparse import ArgumentParser from util import ConfigParser from logger import Wandb from trainer import Trainer from dataset import Dataset from tokenizer import Tokenizer from model import Model from logger import Wandb from export import ExportAll parser = ArgumentParser( prog='Trainer implementation, using Pytorch', description='' ) if __name__ == '__main__': parser.add_argument('-p', '--config_path') args = parser.parse_args() config = ConfigParser(args.config_path).config dataset = Dataset(config.dataset) tokenizer = Tokenizer() tokenizer.train(dataset.text, max_length=config.tokenizer.max_length) ids = tokenizer.c_encode(dataset.text) config.model.tokenizer = tokenizer config.model.params.vocab_size = tokenizer.vocab_size batches, num_batches = dataset.batch(ids) config.trainer.num_batches = num_batches model = Model(config.model) wandb = Wandb(config.wandb) config.trainer.model = model config.trainer.wandb = wandb trainer = Trainer(config.trainer) trainer.train(batches) model.save_pretrained() tokenizer.to_file('tokenizer.bin') ExportAll()