import emoji import numpy as np import pytorch_lightning as pl import torch import torch.nn.functional as F from loguru import logger from torch import nn from torch.optim.lr_scheduler import CosineAnnealingLR from torchmetrics import R2Score from transformers import BertModel, BertTokenizer, DistilBertModel, AutoModel, AutoTokenizer from pytorch_lightning import LightningModule from src.utils.neural_networks import set_layer from src.utils import add_emoji_tokens, add_new_line_token, vectorise_dict from config import DEVICE torch.set_default_dtype(torch.float32) class EncoderPL(pl.LightningModule): def __init__( self, model_name: str = "bert-base-uncased", tokenizer: AutoTokenizer | None = None, bert: AutoModel | None = None, cls: bool = False, device=DEVICE, ): super().__init__() self._device = device self.cls = cls self.model_name = model_name # layers self.tokenizer = tokenizer if tokenizer is not None else BertTokenizer.from_pretrained(model_name) self.bert = bert if bert is not None else BertModel.from_pretrained(model_name) if tokenizer is None: self.tokenizer = add_emoji_tokens(self.tokenizer) self.tokenizer = add_new_line_token(self.tokenizer) self.bert.resize_token_embeddings(len(self.tokenizer)) # optimizer and scheduler self.optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=1e-3) # config tweaking self.bert.config.torch_dtype = "float32" def forward(self, text: str): # run text through bert and squash the output to get embeddings encoded = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True).to(self._device) if type(self.bert) == DistilBertModel: encoded.pop("token_type_ids") bert_output = self.bert(**encoded) if self.cls: if hasattr(bert_output, "pooler_output") and bert_output.pooler_output is not None: embedding = bert_output.pooler_output.unsqueeze(dim=1) else: embedding = bert_output.last_hidden_state[0, 0, :].unsqueeze(dim=0).unsqueeze(dim=0) else: last_hidden_state = bert_output.last_hidden_state if last_hidden_state.dim() == 2: last_hidden_state = last_hidden_state.unsqueeze(dim=0) embedding = torch.matmul( encoded["attention_mask"].type(torch.float32).view(-1, 1, 512), last_hidden_state, ) return embedding def configure_optimizers(self): return self.optimizer def get_bert_embedding( text: str, as_list: bool = True, cls: bool = False, device=DEVICE, layer_dict: dict = {} ) -> list: encoder = EncoderPL(cls=cls, layer_dict=layer_dict).to(device) embedding = encoder.forward(text) if as_list: embedding = embedding.tolist()[0][0] return embedding def get_concat_embedding( text: str = None, bert_embedding: list = [], other_features: dict = {}, cls: bool = False, device=DEVICE, layer_dict: dict = {}, ) -> list: if not len(bert_embedding): if text is None: raise ValueError("both text and embedding are empty!") bert_embedding = get_bert_embedding(text=text, cls=cls, device=device, layer_dict=layer_dict) other_features = vectorise_dict(other_features, as_list=True) concat_vec = bert_embedding + other_features return concat_vec