ctr-ll4 / src /regression /PL /EncoderPL.py
sanjin7's picture
Upload src/ with huggingface_hub
cea4a4b
import emoji
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from loguru import logger
from torch import nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchmetrics import R2Score
from transformers import BertModel, BertTokenizer, DistilBertModel, AutoModel, AutoTokenizer
from pytorch_lightning import LightningModule
from src.utils.neural_networks import set_layer
from src.utils import add_emoji_tokens, add_new_line_token, vectorise_dict
from config import DEVICE
torch.set_default_dtype(torch.float32)
class EncoderPL(pl.LightningModule):
def __init__(
self,
model_name: str = "bert-base-uncased",
tokenizer: AutoTokenizer | None = None,
bert: AutoModel | None = None,
cls: bool = False,
device=DEVICE,
):
super().__init__()
self._device = device
self.cls = cls
self.model_name = model_name
# layers
self.tokenizer = tokenizer if tokenizer is not None else BertTokenizer.from_pretrained(model_name)
self.bert = bert if bert is not None else BertModel.from_pretrained(model_name)
if tokenizer is None:
self.tokenizer = add_emoji_tokens(self.tokenizer)
self.tokenizer = add_new_line_token(self.tokenizer)
self.bert.resize_token_embeddings(len(self.tokenizer))
# optimizer and scheduler
self.optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=1e-3)
# config tweaking
self.bert.config.torch_dtype = "float32"
def forward(self, text: str):
# run text through bert and squash the output to get embeddings
encoded = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True).to(self._device)
if type(self.bert) == DistilBertModel:
encoded.pop("token_type_ids")
bert_output = self.bert(**encoded)
if self.cls:
if hasattr(bert_output, "pooler_output") and bert_output.pooler_output is not None:
embedding = bert_output.pooler_output.unsqueeze(dim=1)
else:
embedding = bert_output.last_hidden_state[0, 0, :].unsqueeze(dim=0).unsqueeze(dim=0)
else:
last_hidden_state = bert_output.last_hidden_state
if last_hidden_state.dim() == 2:
last_hidden_state = last_hidden_state.unsqueeze(dim=0)
embedding = torch.matmul(
encoded["attention_mask"].type(torch.float32).view(-1, 1, 512),
last_hidden_state,
)
return embedding
def configure_optimizers(self):
return self.optimizer
def get_bert_embedding(
text: str, as_list: bool = True, cls: bool = False, device=DEVICE, layer_dict: dict = {}
) -> list:
encoder = EncoderPL(cls=cls, layer_dict=layer_dict).to(device)
embedding = encoder.forward(text)
if as_list:
embedding = embedding.tolist()[0][0]
return embedding
def get_concat_embedding(
text: str = None,
bert_embedding: list = [],
other_features: dict = {},
cls: bool = False,
device=DEVICE,
layer_dict: dict = {},
) -> list:
if not len(bert_embedding):
if text is None:
raise ValueError("both text and embedding are empty!")
bert_embedding = get_bert_embedding(text=text, cls=cls, device=device, layer_dict=layer_dict)
other_features = vectorise_dict(other_features, as_list=True)
concat_vec = bert_embedding + other_features
return concat_vec