ctr-ll4 / src /MLM /training_scripts /train_with_trainer.py
sanjin7's picture
Upload src/ with huggingface_hub
cea4a4b
from transformers import AutoModelForMaskedLM, AutoTokenizer, TrainingArguments, Trainer
from datasets import Dataset, DatasetDict
from transformers import DataCollatorForLanguageModeling
from src.MLM.datasets.preprocess_dataset import preprocess_dataset
from src.MLM.training_scripts.utils import get_new_model_name
def train_with_trainer(
model_checkpoint: str,
tokenizer: AutoTokenizer,
dataset: DatasetDict,
model_name: str | None = None,
data_collator=None,
num_epochs: int = 3,
):
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
model_name = get_new_model_name(model_checkpoint=model_checkpoint, model_name=model_name)
dataset = preprocess_dataset(dataset=dataset, tokenizer=tokenizer)
if data_collator is None:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
training_args = TrainingArguments(
model_name,
evaluation_strategy="epoch",
learning_rate=2e-5,
weight_decay=0.01,
push_to_hub=True,
report_to="wandb",
run_name=model_name,
num_train_epochs=num_epochs,
save_total_limit=1,
save_strategy="epoch",
)
print(f"device: {training_args.device}")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["val"],
data_collator=data_collator,
)
trainer.train()