|
|
from torch.utils.data import DataLoader |
|
|
from pytorch_lightning.loggers import WandbLogger |
|
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
|
from pytorch_lightning import Trainer |
|
|
import pandas as pd |
|
|
from loguru import logger |
|
|
from dotenv import load_dotenv |
|
|
import torch |
|
|
from pytorch_lightning.callbacks import LearningRateMonitor |
|
|
|
|
|
from src.regression.datasets import FullModelDatasetTorch |
|
|
from src.regression.PL import * |
|
|
|
|
|
|
|
|
def train_full_model_PL( |
|
|
train: pd.DataFrame, |
|
|
test: pd.DataFrame, |
|
|
artifact_path: str | None = None, |
|
|
resume: bool | str = "must", |
|
|
run_id: str | None = None, |
|
|
run_name: str = "sanity", |
|
|
model_class=FullModelPL, |
|
|
max_epochs: int = 2, |
|
|
layer_norm: bool = False, |
|
|
): |
|
|
|
|
|
torch.set_default_dtype(torch.float32) |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
nontext_features = ["aov"] |
|
|
|
|
|
train = train[train.aov.notna()].reset_index(drop=True) |
|
|
test = test[test.aov.notna()].reset_index(drop=True) |
|
|
|
|
|
if run_name == "sanity": |
|
|
resume = False |
|
|
run_id = None |
|
|
train = train.loc[0:16, :] |
|
|
test = test.loc[0:16] |
|
|
|
|
|
|
|
|
train_dataset = FullModelDatasetTorch(df=train, nontext_features=nontext_features) |
|
|
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8) |
|
|
|
|
|
test_dataset = FullModelDatasetTorch(df=test, nontext_features=nontext_features) |
|
|
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=8) |
|
|
|
|
|
wandb_logger = WandbLogger( |
|
|
project="transformers", |
|
|
entity="sanjin_juric_fot", |
|
|
log_model=True, |
|
|
reinit=True, |
|
|
resume=resume, |
|
|
id=run_id, |
|
|
name=run_name, |
|
|
) |
|
|
|
|
|
|
|
|
if artifact_path is not None: |
|
|
artifact = wandb_logger.use_artifact(artifact_path) |
|
|
artifact_dir = artifact.download() |
|
|
litmodel = model_class.load_from_checkpoint(artifact_dir + "/" + "model.ckpt").to("mps") |
|
|
|
|
|
logger.debug("logged from checkpoint") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
litmodel = model_class( |
|
|
model_name="bert-base-uncased", |
|
|
nontext_features=nontext_features, |
|
|
layer_norm=layer_norm, |
|
|
).to("mps") |
|
|
|
|
|
checkpoint_callback = ModelCheckpoint(monitor="val_loss", mode="min") |
|
|
lr_monitor = LearningRateMonitor(logging_interval="epoch") |
|
|
trainer = Trainer( |
|
|
accelerator="mps", |
|
|
devices=1, |
|
|
logger=wandb_logger, |
|
|
log_every_n_steps=2, |
|
|
max_epochs=max_epochs, |
|
|
callbacks=[checkpoint_callback, lr_monitor], |
|
|
) |
|
|
|
|
|
logger.debug("training...") |
|
|
trainer.fit( |
|
|
model=litmodel, |
|
|
train_dataloaders=train_dataloader, |
|
|
val_dataloaders=test_dataloader, |
|
|
) |
|
|
|