File size: 2,940 Bytes
cea4a4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import wandb
from torch import nn
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning import Trainer
import pandas as pd
from loguru import logger
from dotenv import load_dotenv
import torch
from src.regression.datasets import DecoderDatasetTorch
from src.regression.datasets import regression_dataset
from src.regression.PL import *
load_dotenv()
def train_decoder_PL(
train: pd.DataFrame,
test: pd.DataFrame,
artifact_path: str | None = None,
resume: bool | str = "must",
run_id: str | None = None,
run_name: str = "sanity",
model_class=DecoderPL,
max_epochs: int = 2,
layer_norm: bool = True,
embedding_column: str = "my_full_mean_embedding",
device: str = "mps",
*args,
**kwargs
):
torch.set_default_dtype(torch.float32)
train = train[train.aov.notna()].reset_index(drop=True)
test = test[test.aov.notna()].reset_index(drop=True)
if run_name == "sanity":
resume = False
run_id = None
max_epochs = 2
train = train.loc[0:16, :]
test = test.loc[0:16]
# initializing dataset, dataloader and nn.module model
train_dataset = DecoderDatasetTorch(df=train, embedding_column=embedding_column)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8)
test_dataset = DecoderDatasetTorch(df=test, embedding_column=embedding_column)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=8)
wandb_logger = WandbLogger(
project="transformers",
entity="sanjin_juric_fot",
log_model=True,
reinit=True,
resume=resume,
id=run_id,
name=run_name,
)
# here lightning comes into play
if artifact_path is not None:
artifact = wandb_logger.use_artifact(artifact_path)
artifact_dir = artifact.download()
litmodel = model_class.load_from_checkpoint(artifact_dir + "/" + "model.ckpt").to(device)
logger.debug("logged from checkpoint")
torch.multiprocessing.set_sharing_strategy("file_system")
else:
litmodel = model_class(input_dim=len(train.at[0, embedding_column]), layer_norm=layer_norm, *args, **kwargs).to(
device
)
checkpoint_callback = ModelCheckpoint(monitor="val_loss", mode="min")
lr_monitor = LearningRateMonitor(logging_interval="epoch")
trainer = Trainer(
accelerator=str(device),
devices=1,
logger=wandb_logger,
log_every_n_steps=2,
max_epochs=max_epochs,
callbacks=[checkpoint_callback, lr_monitor],
)
logger.debug("training...")
trainer.fit(
model=litmodel,
train_dataloaders=train_dataloader,
val_dataloaders=test_dataloader,
)
|