ctr-ll4 / src /regression /PL /DecoderPL.py
sanjin7's picture
Upload src/ with huggingface_hub
cea4a4b
import emoji
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from loguru import logger
from torch import nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchmetrics import R2Score
from src.utils import get_sentiment
from src.utils.neural_networks import set_layer
from config import DEVICE
torch.set_default_dtype(torch.float32)
class DecoderPL(pl.LightningModule):
def __init__(
self,
input_dim: int = 774,
layer_norm: bool = True,
layer_dict: dict = {},
device=DEVICE,
T_max: int = 10,
start_lr: float = 5 * 1e-4,
):
super().__init__()
# layers
self.linear1 = set_layer(
layer_dict=layer_dict,
name="linear1",
alternative=nn.Linear(in_features=input_dim, out_features=512),
)
self.linear2 = set_layer(
layer_dict=layer_dict,
name="linear2",
alternative=nn.Linear(in_features=512, out_features=264),
)
self.linear3 = set_layer(
layer_dict=layer_dict,
name="linear3",
alternative=nn.Linear(in_features=264, out_features=64),
)
self.linear4 = set_layer(
layer_dict=layer_dict,
name="linear4",
alternative=nn.Linear(in_features=64, out_features=1),
)
self.activation = nn.LeakyReLU(negative_slope=0.01)
if not layer_norm:
self.layers = [
self.linear1,
self.activation,
self.linear2,
self.activation,
self.linear3,
self.activation,
self.linear4,
]
else:
self.layernorm1 = nn.LayerNorm(normalized_shape=(1, self.linear1.out_features))
self.layernorm2 = nn.LayerNorm(normalized_shape=(1, self.linear2.out_features))
self.layernorm3 = nn.LayerNorm(normalized_shape=(1, self.linear3.out_features))
self.layers = [
self.linear1,
self.layernorm1,
self.activation,
self.linear2,
self.layernorm2,
self.activation,
self.linear3,
self.layernorm3,
self.activation,
self.linear4,
]
# initialize weights
[self.initialize_weights(layer) for layer in self.layers]
# optimizer and scheduler
self.optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=start_lr)
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=T_max)
# else
self.save_hyperparameters(ignore=["model"])
self.MSE = nn.MSELoss()
self.R2 = R2Score()
def initialize_weights(self, module):
if isinstance(module, nn.Linear):
logger.debug("linear weights initialized")
torch.nn.init.xavier_uniform_(module.weight)
module.bias.data.fill_(0.01)
def forward(self, x: torch.Tensor):
if x.dim() == 2:
x = x.unsqueeze(dim=1)
for layer in self.layers:
x = layer(x)
x = x.squeeze()
if x.dim() == 0:
x = x.unsqueeze(dim=0)
return x.to(torch.float32)
def training_step(self, batch):
loss_and_metrics = self._get_loss(batch, get_metrics=True)
pred = loss_and_metrics["pred"]
act = loss_and_metrics["act"]
loss = loss_and_metrics["loss"]
self.log("train_loss", loss, on_epoch=True, on_step=False, prog_bar=True, logger=True)
return {"loss": loss, "pred": pred, "act": act}
def configure_optimizers(self):
optimizer = self.optimizer
scheduler = self.scheduler
return dict(optimizer=optimizer, lr_scheduler=scheduler)
def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
logger.debug(scheduler)
if metric is None:
scheduler.step()
else:
scheduler.step(metric)
def validation_step(self, batch, batch_idx):
"""used for logging metrics"""
loss_and_metrics = self._get_loss(batch, get_metrics=True)
loss = loss_and_metrics["loss"]
# Log loss and metric
self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True)
def training_epoch_end(self, training_step_outputs):
training_step_outputs = list(training_step_outputs)
training_step_outputs.pop()
output_dict = {k: [dic[k] for dic in training_step_outputs] for k in training_step_outputs[0]}
pred = torch.stack(output_dict["pred"])
act = torch.stack(output_dict["act"])
loss = torch.sub(pred, act)
loss_sq = torch.square(loss)
TSS = float(torch.var(act, unbiased=False))
RSS = float(torch.mean(loss_sq))
R2 = 1 - RSS / TSS
self.log("train_R2", R2, prog_bar=True, logger=True)
def _get_loss(self, batch, get_metrics: bool = False):
"""convenience function since train/valid/test steps are similar"""
pred = self.forward(x=batch["embedding"]).to(torch.float32)
act, loss = None, None
if "ctr" in batch.keys():
act = batch["ctr"].to(torch.float32)
loss = self.MSE(pred, act).to(torch.float32)
return {"loss": loss, "pred": pred, "act": act}