from torch.utils.data import DataLoader from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning import Trainer import pandas as pd from loguru import logger from dotenv import load_dotenv import torch from pytorch_lightning.callbacks import LearningRateMonitor from src.regression.datasets import FullModelDatasetTorch from src.regression.PL import * def train_full_model_PL( train: pd.DataFrame, test: pd.DataFrame, artifact_path: str | None = None, resume: bool | str = "must", run_id: str | None = None, run_name: str = "sanity", model_class=FullModelPL, max_epochs: int = 2, layer_norm: bool = False, ): torch.set_default_dtype(torch.float32) load_dotenv() nontext_features = ["aov"] train = train[train.aov.notna()].reset_index(drop=True) test = test[test.aov.notna()].reset_index(drop=True) if run_name == "sanity": resume = False run_id = None train = train.loc[0:16, :] test = test.loc[0:16] # initializing dataset, dataloader and nn.module model train_dataset = FullModelDatasetTorch(df=train, nontext_features=nontext_features) train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8) test_dataset = FullModelDatasetTorch(df=test, nontext_features=nontext_features) test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=8) wandb_logger = WandbLogger( project="transformers", entity="sanjin_juric_fot", log_model=True, reinit=True, resume=resume, id=run_id, name=run_name, ) # here lightning comes into play if artifact_path is not None: artifact = wandb_logger.use_artifact(artifact_path) artifact_dir = artifact.download() litmodel = model_class.load_from_checkpoint(artifact_dir + "/" + "model.ckpt").to("mps") logger.debug("logged from checkpoint") # for name, layer in litmodel.named_modules(): # if isinstance(layer, nn.Linear) and name == "linear2": # break # layer_dict = {"linear2": layer} # litmodel = LitAdModelLHS( # nontext_features=nontext_features, layer_dict=layer_dict # ) else: litmodel = model_class( model_name="bert-base-uncased", nontext_features=nontext_features, layer_norm=layer_norm, ).to("mps") checkpoint_callback = ModelCheckpoint(monitor="val_loss", mode="min") lr_monitor = LearningRateMonitor(logging_interval="epoch") trainer = Trainer( accelerator="mps", devices=1, logger=wandb_logger, log_every_n_steps=2, max_epochs=max_epochs, callbacks=[checkpoint_callback, lr_monitor], ) # trainer = Trainer(logger=wandb_logger, log_every_n_steps=2, max_epochs=2, callbacks=[checkpoint_callback]) logger.debug("training...") trainer.fit( model=litmodel, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader, )