import emoji import numpy as np import pytorch_lightning as pl import torch import torch.nn.functional as F from loguru import logger from torch import nn from torch.optim.lr_scheduler import CosineAnnealingLR from torchmetrics import R2Score from src.utils import get_sentiment from src.utils.neural_networks import set_layer from config import DEVICE torch.set_default_dtype(torch.float32) class DecoderPL(pl.LightningModule): def __init__( self, input_dim: int = 774, layer_norm: bool = True, layer_dict: dict = {}, device=DEVICE, T_max: int = 10, start_lr: float = 5 * 1e-4, ): super().__init__() # layers self.linear1 = set_layer( layer_dict=layer_dict, name="linear1", alternative=nn.Linear(in_features=input_dim, out_features=512), ) self.linear2 = set_layer( layer_dict=layer_dict, name="linear2", alternative=nn.Linear(in_features=512, out_features=264), ) self.linear3 = set_layer( layer_dict=layer_dict, name="linear3", alternative=nn.Linear(in_features=264, out_features=64), ) self.linear4 = set_layer( layer_dict=layer_dict, name="linear4", alternative=nn.Linear(in_features=64, out_features=1), ) self.activation = nn.LeakyReLU(negative_slope=0.01) if not layer_norm: self.layers = [ self.linear1, self.activation, self.linear2, self.activation, self.linear3, self.activation, self.linear4, ] else: self.layernorm1 = nn.LayerNorm(normalized_shape=(1, self.linear1.out_features)) self.layernorm2 = nn.LayerNorm(normalized_shape=(1, self.linear2.out_features)) self.layernorm3 = nn.LayerNorm(normalized_shape=(1, self.linear3.out_features)) self.layers = [ self.linear1, self.layernorm1, self.activation, self.linear2, self.layernorm2, self.activation, self.linear3, self.layernorm3, self.activation, self.linear4, ] # initialize weights [self.initialize_weights(layer) for layer in self.layers] # optimizer and scheduler self.optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=start_lr) self.scheduler = CosineAnnealingLR(self.optimizer, T_max=T_max) # else self.save_hyperparameters(ignore=["model"]) self.MSE = nn.MSELoss() self.R2 = R2Score() def initialize_weights(self, module): if isinstance(module, nn.Linear): logger.debug("linear weights initialized") torch.nn.init.xavier_uniform_(module.weight) module.bias.data.fill_(0.01) def forward(self, x: torch.Tensor): if x.dim() == 2: x = x.unsqueeze(dim=1) for layer in self.layers: x = layer(x) x = x.squeeze() if x.dim() == 0: x = x.unsqueeze(dim=0) return x.to(torch.float32) def training_step(self, batch): loss_and_metrics = self._get_loss(batch, get_metrics=True) pred = loss_and_metrics["pred"] act = loss_and_metrics["act"] loss = loss_and_metrics["loss"] self.log("train_loss", loss, on_epoch=True, on_step=False, prog_bar=True, logger=True) return {"loss": loss, "pred": pred, "act": act} def configure_optimizers(self): optimizer = self.optimizer scheduler = self.scheduler return dict(optimizer=optimizer, lr_scheduler=scheduler) def lr_scheduler_step(self, scheduler, optimizer_idx, metric): logger.debug(scheduler) if metric is None: scheduler.step() else: scheduler.step(metric) def validation_step(self, batch, batch_idx): """used for logging metrics""" loss_and_metrics = self._get_loss(batch, get_metrics=True) loss = loss_and_metrics["loss"] # Log loss and metric self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True) def training_epoch_end(self, training_step_outputs): training_step_outputs = list(training_step_outputs) training_step_outputs.pop() output_dict = {k: [dic[k] for dic in training_step_outputs] for k in training_step_outputs[0]} pred = torch.stack(output_dict["pred"]) act = torch.stack(output_dict["act"]) loss = torch.sub(pred, act) loss_sq = torch.square(loss) TSS = float(torch.var(act, unbiased=False)) RSS = float(torch.mean(loss_sq)) R2 = 1 - RSS / TSS self.log("train_R2", R2, prog_bar=True, logger=True) def _get_loss(self, batch, get_metrics: bool = False): """convenience function since train/valid/test steps are similar""" pred = self.forward(x=batch["embedding"]).to(torch.float32) act, loss = None, None if "ctr" in batch.keys(): act = batch["ctr"].to(torch.float32) loss = self.MSE(pred, act).to(torch.float32) return {"loss": loss, "pred": pred, "act": act}