import emoji import numpy as np import pytorch_lightning as pl import torch import torch.nn.functional as F from loguru import logger from torch import nn from torch.optim.lr_scheduler import CosineAnnealingLR from torchmetrics import R2Score from transformers import BertModel, BertTokenizerFast from src.utils import get_sentiment, vectorise_dict from src.utils.neural_networks import set_layer from config import DEVICE from .DecoderPL import DecoderPL from .EncoderPL import EncoderPL torch.set_default_dtype(torch.float32) class FullModelPL(pl.LightningModule): def __init__( self, model_name: str = "bert-base-uncased", nontext_features: list[str] = ["aov"], encoder: EncoderPL | None = None, decoder: DecoderPL | None = None, layer_norm: bool = True, device=DEVICE, T_max: int = 10, ): super().__init__() # layers self.encoder = ( encoder.to(self.device) if encoder is not None else EncoderPL(model_name=model_name, device=device).to(self.device) ) self.decoder = ( decoder.to(self.device) if decoder is not None else DecoderPL( input_dim=768 + len(nontext_features) + 5, layer_norm=layer_norm, device=device, ).to(self.device) ) # else self.MSE = nn.MSELoss() self.R2 = R2Score() self.optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=3 * 1e-4) self.scheduler = CosineAnnealingLR(self.optimizer, T_max=T_max) # self.save_hyperparameters(ignore=["decoder", "encoder"]) def forward(self, input_dict: dict): input_dict = input_dict.copy() text = input_dict.pop("text") print(f"text: {text}") if "ctr" in input_dict.keys(): input_dict.pop("ctr") # encode sentence_embedding = self.encoder.forward(text=text) # sentiment sentiment = get_sentiment_for_list_of_texts(text) input_dict = input_dict | sentiment input_dict = {k: v.to(self.device) for k, v in input_dict.items()} # concat nontext features to embedding nontext_vec = vectorise_dict(input_dict) nontext_tensor = torch.stack(nontext_vec).T.unsqueeze(1).to(torch.float32) # logger.debug(f"nontext tensor type: {nontext_tensor.dtype}") print(f"{sentence_embedding.get_device()}, {nontext_tensor.get_device()}") x = torch.cat((sentence_embedding, nontext_tensor), 2) print(self.decoder.device) print(x.get_device()) # decode result = self.decoder.forward(x) return result def training_step(self, batch): loss_and_metrics = self._get_loss(batch, get_metrics=True) pred = loss_and_metrics["pred"] act = loss_and_metrics["act"] loss = loss_and_metrics["loss"] self.log("train_loss", loss, on_epoch=True, on_step=False, prog_bar=True, logger=True) return {"loss": loss, "pred": pred, "act": act} def configure_optimizers(self): for name, param in self.named_parameters(): if "bert" in name: param.requires_grad = False optimizer = self.optimizer scheduler = self.scheduler return dict(optimizer=optimizer, lr_scheduler=scheduler) def lr_scheduler_step(self, scheduler, optimizer_idx, metric): logger.debug(scheduler) if metric is None: scheduler.step() else: scheduler.step(metric) def validation_step(self, batch, batch_idx): """used for logging metrics""" loss_and_metrics = self._get_loss(batch, get_metrics=True) loss = loss_and_metrics["loss"] # Log loss and metric self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True) def training_epoch_end(self, training_step_outputs): training_step_outputs = list(training_step_outputs) training_step_outputs.pop() output_dict = {k: [dic[k] for dic in training_step_outputs] for k in training_step_outputs[0]} pred = torch.stack(output_dict["pred"]) act = torch.stack(output_dict["act"]) loss = torch.sub(pred, act) loss_sq = torch.square(loss) TSS = float(torch.var(act, unbiased=False)) RSS = float(torch.mean(loss_sq)) R2 = 1 - RSS / TSS self.log("train_R2", R2, prog_bar=True, logger=True) def _get_loss(self, batch, get_metrics: bool = False): """convenience function since train/valid/test steps are similar""" pred = self.forward(input_dict=batch).to(torch.float32) act, loss = None, None if "ctr" in batch.keys(): act = batch["ctr"].to(torch.float32).to(self.device) loss = self.MSE(pred, act).to(torch.float32) return {"loss": loss, "pred": pred, "act": act} def get_sentiment_for_list_of_texts(texts: list[str]) -> dict: ld = [get_sentiment(text) for text in texts] v = {k: torch.Tensor([dic[k] for dic in ld]) for k in ld[0]} return v