| from transformers import AutoModelForMaskedLM, AutoTokenizer, TrainingArguments, Trainer | |
| from datasets import Dataset, DatasetDict | |
| from transformers import DataCollatorForLanguageModeling | |
| from src.MLM.datasets.preprocess_dataset import preprocess_dataset | |
| from src.MLM.training_scripts.utils import get_new_model_name | |
| def train_with_trainer( | |
| model_checkpoint: str, | |
| tokenizer: AutoTokenizer, | |
| dataset: DatasetDict, | |
| model_name: str | None = None, | |
| data_collator=None, | |
| num_epochs: int = 3, | |
| ): | |
| model = AutoModelForMaskedLM.from_pretrained(model_checkpoint) | |
| model_name = get_new_model_name(model_checkpoint=model_checkpoint, model_name=model_name) | |
| dataset = preprocess_dataset(dataset=dataset, tokenizer=tokenizer) | |
| if data_collator is None: | |
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15) | |
| training_args = TrainingArguments( | |
| model_name, | |
| evaluation_strategy="epoch", | |
| learning_rate=2e-5, | |
| weight_decay=0.01, | |
| push_to_hub=True, | |
| report_to="wandb", | |
| run_name=model_name, | |
| num_train_epochs=num_epochs, | |
| save_total_limit=1, | |
| save_strategy="epoch", | |
| ) | |
| print(f"device: {training_args.device}") | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset["train"], | |
| eval_dataset=dataset["val"], | |
| data_collator=data_collator, | |
| ) | |
| trainer.train() | |