from transformers import AutoModelForMaskedLM, AutoTokenizer, TrainingArguments, Trainer from datasets import Dataset, DatasetDict from transformers import DataCollatorForLanguageModeling from src.MLM.datasets.preprocess_dataset import preprocess_dataset from src.MLM.training_scripts.utils import get_new_model_name def train_with_trainer( model_checkpoint: str, tokenizer: AutoTokenizer, dataset: DatasetDict, model_name: str | None = None, data_collator=None, num_epochs: int = 3, ): model = AutoModelForMaskedLM.from_pretrained(model_checkpoint) model_name = get_new_model_name(model_checkpoint=model_checkpoint, model_name=model_name) dataset = preprocess_dataset(dataset=dataset, tokenizer=tokenizer) if data_collator is None: data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15) training_args = TrainingArguments( model_name, evaluation_strategy="epoch", learning_rate=2e-5, weight_decay=0.01, push_to_hub=True, report_to="wandb", run_name=model_name, num_train_epochs=num_epochs, save_total_limit=1, save_strategy="epoch", ) print(f"device: {training_args.device}") trainer = Trainer( model=model, args=training_args, train_dataset=dataset["train"], eval_dataset=dataset["val"], data_collator=data_collator, ) trainer.train()