File size: 3,158 Bytes
cea4a4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from torch.utils.data import DataLoader
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer
import pandas as pd
from loguru import logger
from dotenv import load_dotenv
import torch
from pytorch_lightning.callbacks import LearningRateMonitor

from src.regression.datasets import FullModelDatasetTorch
from src.regression.PL import *


def train_full_model_PL(
    train: pd.DataFrame,
    test: pd.DataFrame,
    artifact_path: str | None = None,
    resume: bool | str = "must",
    run_id: str | None = None,
    run_name: str = "sanity",
    model_class=FullModelPL,
    max_epochs: int = 2,
    layer_norm: bool = False,
):

    torch.set_default_dtype(torch.float32)

    load_dotenv()

    nontext_features = ["aov"]

    train = train[train.aov.notna()].reset_index(drop=True)
    test = test[test.aov.notna()].reset_index(drop=True)

    if run_name == "sanity":
        resume = False
        run_id = None
        train = train.loc[0:16, :]
        test = test.loc[0:16]

    # initializing dataset, dataloader and nn.module model
    train_dataset = FullModelDatasetTorch(df=train, nontext_features=nontext_features)
    train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8)

    test_dataset = FullModelDatasetTorch(df=test, nontext_features=nontext_features)
    test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=8)

    wandb_logger = WandbLogger(
        project="transformers",
        entity="sanjin_juric_fot",
        log_model=True,
        reinit=True,
        resume=resume,
        id=run_id,
        name=run_name,
    )

    # here lightning comes into play
    if artifact_path is not None:
        artifact = wandb_logger.use_artifact(artifact_path)
        artifact_dir = artifact.download()
        litmodel = model_class.load_from_checkpoint(artifact_dir + "/" + "model.ckpt").to("mps")

        logger.debug("logged from checkpoint")

        # for name, layer in litmodel.named_modules():
        #     if isinstance(layer, nn.Linear) and name == "linear2":
        #         break

        # layer_dict = {"linear2": layer}

        # litmodel = LitAdModelLHS(
        #     nontext_features=nontext_features, layer_dict=layer_dict
        # )

    else:
        litmodel = model_class(
            model_name="bert-base-uncased",
            nontext_features=nontext_features,
            layer_norm=layer_norm,
        ).to("mps")

    checkpoint_callback = ModelCheckpoint(monitor="val_loss", mode="min")
    lr_monitor = LearningRateMonitor(logging_interval="epoch")
    trainer = Trainer(
        accelerator="mps",
        devices=1,
        logger=wandb_logger,
        log_every_n_steps=2,
        max_epochs=max_epochs,
        callbacks=[checkpoint_callback, lr_monitor],
    )
    # trainer = Trainer(logger=wandb_logger, log_every_n_steps=2, max_epochs=2, callbacks=[checkpoint_callback])
    logger.debug("training...")
    trainer.fit(
        model=litmodel,
        train_dataloaders=train_dataloader,
        val_dataloaders=test_dataloader,
    )