ctr-ll4 / src /regression /PL /FullModelPL.py
sanjin7's picture
Upload src/ with huggingface_hub
cea4a4b
import emoji
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from loguru import logger
from torch import nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchmetrics import R2Score
from transformers import BertModel, BertTokenizerFast
from src.utils import get_sentiment, vectorise_dict
from src.utils.neural_networks import set_layer
from config import DEVICE
from .DecoderPL import DecoderPL
from .EncoderPL import EncoderPL
torch.set_default_dtype(torch.float32)
class FullModelPL(pl.LightningModule):
def __init__(
self,
model_name: str = "bert-base-uncased",
nontext_features: list[str] = ["aov"],
encoder: EncoderPL | None = None,
decoder: DecoderPL | None = None,
layer_norm: bool = True,
device=DEVICE,
T_max: int = 10,
):
super().__init__()
# layers
self.encoder = (
encoder.to(self.device)
if encoder is not None
else EncoderPL(model_name=model_name, device=device).to(self.device)
)
self.decoder = (
decoder.to(self.device)
if decoder is not None
else DecoderPL(
input_dim=768 + len(nontext_features) + 5,
layer_norm=layer_norm,
device=device,
).to(self.device)
)
# else
self.MSE = nn.MSELoss()
self.R2 = R2Score()
self.optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=3 * 1e-4)
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=T_max)
# self.save_hyperparameters(ignore=["decoder", "encoder"])
def forward(self, input_dict: dict):
input_dict = input_dict.copy()
text = input_dict.pop("text")
print(f"text: {text}")
if "ctr" in input_dict.keys():
input_dict.pop("ctr")
# encode
sentence_embedding = self.encoder.forward(text=text)
# sentiment
sentiment = get_sentiment_for_list_of_texts(text)
input_dict = input_dict | sentiment
input_dict = {k: v.to(self.device) for k, v in input_dict.items()}
# concat nontext features to embedding
nontext_vec = vectorise_dict(input_dict)
nontext_tensor = torch.stack(nontext_vec).T.unsqueeze(1).to(torch.float32)
# logger.debug(f"nontext tensor type: {nontext_tensor.dtype}")
print(f"{sentence_embedding.get_device()}, {nontext_tensor.get_device()}")
x = torch.cat((sentence_embedding, nontext_tensor), 2)
print(self.decoder.device)
print(x.get_device())
# decode
result = self.decoder.forward(x)
return result
def training_step(self, batch):
loss_and_metrics = self._get_loss(batch, get_metrics=True)
pred = loss_and_metrics["pred"]
act = loss_and_metrics["act"]
loss = loss_and_metrics["loss"]
self.log("train_loss", loss, on_epoch=True, on_step=False, prog_bar=True, logger=True)
return {"loss": loss, "pred": pred, "act": act}
def configure_optimizers(self):
for name, param in self.named_parameters():
if "bert" in name:
param.requires_grad = False
optimizer = self.optimizer
scheduler = self.scheduler
return dict(optimizer=optimizer, lr_scheduler=scheduler)
def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
logger.debug(scheduler)
if metric is None:
scheduler.step()
else:
scheduler.step(metric)
def validation_step(self, batch, batch_idx):
"""used for logging metrics"""
loss_and_metrics = self._get_loss(batch, get_metrics=True)
loss = loss_and_metrics["loss"]
# Log loss and metric
self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True)
def training_epoch_end(self, training_step_outputs):
training_step_outputs = list(training_step_outputs)
training_step_outputs.pop()
output_dict = {k: [dic[k] for dic in training_step_outputs] for k in training_step_outputs[0]}
pred = torch.stack(output_dict["pred"])
act = torch.stack(output_dict["act"])
loss = torch.sub(pred, act)
loss_sq = torch.square(loss)
TSS = float(torch.var(act, unbiased=False))
RSS = float(torch.mean(loss_sq))
R2 = 1 - RSS / TSS
self.log("train_R2", R2, prog_bar=True, logger=True)
def _get_loss(self, batch, get_metrics: bool = False):
"""convenience function since train/valid/test steps are similar"""
pred = self.forward(input_dict=batch).to(torch.float32)
act, loss = None, None
if "ctr" in batch.keys():
act = batch["ctr"].to(torch.float32).to(self.device)
loss = self.MSE(pred, act).to(torch.float32)
return {"loss": loss, "pred": pred, "act": act}
def get_sentiment_for_list_of_texts(texts: list[str]) -> dict:
ld = [get_sentiment(text) for text in texts]
v = {k: torch.Tensor([dic[k] for dic in ld]) for k in ld[0]}
return v