| import sys | |
| sys.path.append("..") | |
| from argparse import ArgumentParser | |
| import os, string | |
| from transformers import ( | |
| Wav2Vec2ForPreTraining, | |
| Wav2Vec2CTCTokenizer, | |
| Wav2Vec2FeatureExtractor, | |
| ) | |
| from pytorch_lightning import seed_everything | |
| from pytorch_lightning import Trainer | |
| from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor | |
| from pytorch_lightning.loggers import WandbLogger | |
| from src.datamodule import VLSP2020TarDataset | |
| from src.datamodule.vlsp2020 import get_dataloader | |
| from finetuning.wav2vec2 import SpeechRecognizer | |
| def remove_punctuation(text: str): | |
| return text.translate(str.maketrans("", "", string.punctuation)).lower() | |
| def prepare_dataloader(data_dir, batch_size, num_workers): | |
| train_dataset = VLSP2020TarDataset( | |
| os.path.join(data_dir, "vlsp2020_train_set.tar") | |
| ).load() | |
| val_dataset = VLSP2020TarDataset( | |
| os.path.join(data_dir, "vlsp2020_val_set.tar") | |
| ).load() | |
| train_dataloader = get_dataloader( | |
| train_dataset, | |
| return_transcript=True, | |
| target_transform=remove_punctuation, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| ) | |
| val_dataloader = get_dataloader( | |
| val_dataset, | |
| return_transcript=True, | |
| target_transform=remove_punctuation, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| ) | |
| return train_dataloader, val_dataloader | |
| def prepare_model(adam_config: dict, tristate_scheduler_config: dict): | |
| model_name = "nguyenvulebinh/wav2vec2-base-vietnamese-250h" | |
| wav2vec2 = Wav2Vec2ForPreTraining.from_pretrained(model_name) | |
| tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_name) | |
| feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name) | |
| model = SpeechRecognizer( | |
| wav2vec2, tokenizer, feature_extractor, adam_config, tristate_scheduler_config | |
| ) | |
| return model | |
| def main(): | |
| parser = ArgumentParser() | |
| parser.add_argument("--batch_size", type=int, default=2) | |
| parser.add_argument("--num_workers", type=int, default=0) | |
| parser.add_argument("--classifier_lr", type=float, default=1e-4) | |
| parser.add_argument("--wav2vec2_lr", type=float, default=1e-5) | |
| parser.add_argument("--max_epochs", type=int, default=10) | |
| parser.add_argument("--accelerator", type=str, default="gpu") | |
| parser.add_argument("--weight_decay", type=float, default=0.0) | |
| parser.add_argument("--warmup_steps", type=float, default=0.1) | |
| parser.add_argument("--constant_steps", type=float, default=0.4) | |
| parser.add_argument("--scheduler_factor", type=float, default=1e-3) | |
| parser.add_argument("--data_dir", type=str, default="data") | |
| parser.add_argument("--ckpt_dir", type=str, default="ckpt") | |
| parser.add_argument("--ckpt_path", type=str, default=None) | |
| parser.add_argument("--detect_anomaly", type=bool, default=False) | |
| parser.add_argument("--grad_clip", type=float, default=None) | |
| parser.add_argument("--wandb_id", type=str, default=None) | |
| args = parser.parse_args() | |
| print(args) | |
| train_loader, val_loader = prepare_dataloader( | |
| args.data_dir, args.batch_size, args.num_workers | |
| ) | |
| total_steps = args.max_epochs * 42_000 // args.batch_size | |
| warmup_steps = int(total_steps * args.warmup_steps) | |
| constant_steps = int(total_steps * args.constant_steps) | |
| model = prepare_model( | |
| { | |
| "wav2vec2_lr": args.wav2vec2_lr, | |
| "classifier_lr": args.classifier_lr, | |
| "weight_decay": args.weight_decay, | |
| }, | |
| { | |
| "warmup_steps": warmup_steps, | |
| "constant_steps": constant_steps, | |
| "total_steps": total_steps, | |
| "factor": args.scheduler_factor, | |
| }, | |
| ) | |
| trainer = Trainer( | |
| accelerator=args.accelerator, | |
| callbacks=[ | |
| ModelCheckpoint( | |
| args.ckpt_dir, | |
| monitor="val/wer", | |
| mode="min", | |
| save_top_k=1, | |
| save_last=True, | |
| ), | |
| LearningRateMonitor(logging_interval="step"), | |
| ], | |
| logger=WandbLogger(project="Wav2Vec2", id=args.wandb_id), | |
| max_epochs=args.max_epochs, | |
| detect_anomaly=args.detect_anomaly, | |
| gradient_clip_val=args.grad_clip, | |
| ) | |
| trainer.fit(model, train_loader, val_loader) | |
| if __name__ == "__main__": | |
| seed_everything(188) | |
| main() | |