| import sys | |
| sys.path.append(".") | |
| from src.config import model as conf | |
| from src.model import Wav2Vec2PretrainingModule | |
| from src.datamodule import WebDatasetConverter, VLSP2020ForPretrainingDataModule | |
| from pytorch_lightning import Trainer | |
| from pytorch_lightning.callbacks import ModelCheckpoint | |
| if __name__ == "__main__": | |
| model = Wav2Vec2PretrainingModule(conf.wav2vec2_pretraining) | |
| dts = WebDatasetConverter(conf.dataset.path).get_dataset() | |
| dtm = VLSP2020ForPretrainingDataModule(dts, **conf.dataset) | |
| trainer = Trainer( | |
| callbacks=[ | |
| ModelCheckpoint( | |
| monitor="val/loss", | |
| dirpath=conf["checkpoint_dir"], | |
| ) | |
| ], | |
| gradient_clip_val=1.0, | |
| accelerator="gpu" | |
| ) | |
| trainer.fit(model, dtm) | |