File size: 1,475 Bytes
cea4a4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from transformers import AutoModelForMaskedLM, AutoTokenizer, TrainingArguments, Trainer
from datasets import Dataset, DatasetDict
from transformers import DataCollatorForLanguageModeling
from src.MLM.datasets.preprocess_dataset import preprocess_dataset
from src.MLM.training_scripts.utils import get_new_model_name
def train_with_trainer(
model_checkpoint: str,
tokenizer: AutoTokenizer,
dataset: DatasetDict,
model_name: str | None = None,
data_collator=None,
num_epochs: int = 3,
):
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
model_name = get_new_model_name(model_checkpoint=model_checkpoint, model_name=model_name)
dataset = preprocess_dataset(dataset=dataset, tokenizer=tokenizer)
if data_collator is None:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
training_args = TrainingArguments(
model_name,
evaluation_strategy="epoch",
learning_rate=2e-5,
weight_decay=0.01,
push_to_hub=True,
report_to="wandb",
run_name=model_name,
num_train_epochs=num_epochs,
save_total_limit=1,
save_strategy="epoch",
)
print(f"device: {training_args.device}")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["val"],
data_collator=data_collator,
)
trainer.train()
|