|
|
import emoji |
|
|
import numpy as np |
|
|
import pytorch_lightning as pl |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from loguru import logger |
|
|
from torch import nn |
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
|
from torchmetrics import R2Score |
|
|
from transformers import BertModel, BertTokenizer, DistilBertModel, AutoModel, AutoTokenizer |
|
|
from pytorch_lightning import LightningModule |
|
|
|
|
|
|
|
|
from src.utils.neural_networks import set_layer |
|
|
from src.utils import add_emoji_tokens, add_new_line_token, vectorise_dict |
|
|
from config import DEVICE |
|
|
|
|
|
torch.set_default_dtype(torch.float32) |
|
|
|
|
|
|
|
|
class EncoderPL(pl.LightningModule): |
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "bert-base-uncased", |
|
|
tokenizer: AutoTokenizer | None = None, |
|
|
bert: AutoModel | None = None, |
|
|
cls: bool = False, |
|
|
device=DEVICE, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self._device = device |
|
|
self.cls = cls |
|
|
self.model_name = model_name |
|
|
|
|
|
|
|
|
|
|
|
self.tokenizer = tokenizer if tokenizer is not None else BertTokenizer.from_pretrained(model_name) |
|
|
|
|
|
self.bert = bert if bert is not None else BertModel.from_pretrained(model_name) |
|
|
|
|
|
if tokenizer is None: |
|
|
self.tokenizer = add_emoji_tokens(self.tokenizer) |
|
|
self.tokenizer = add_new_line_token(self.tokenizer) |
|
|
self.bert.resize_token_embeddings(len(self.tokenizer)) |
|
|
|
|
|
|
|
|
self.optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=1e-3) |
|
|
|
|
|
|
|
|
self.bert.config.torch_dtype = "float32" |
|
|
|
|
|
def forward(self, text: str): |
|
|
|
|
|
|
|
|
encoded = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True).to(self._device) |
|
|
|
|
|
if type(self.bert) == DistilBertModel: |
|
|
encoded.pop("token_type_ids") |
|
|
|
|
|
bert_output = self.bert(**encoded) |
|
|
|
|
|
if self.cls: |
|
|
if hasattr(bert_output, "pooler_output") and bert_output.pooler_output is not None: |
|
|
embedding = bert_output.pooler_output.unsqueeze(dim=1) |
|
|
else: |
|
|
embedding = bert_output.last_hidden_state[0, 0, :].unsqueeze(dim=0).unsqueeze(dim=0) |
|
|
else: |
|
|
last_hidden_state = bert_output.last_hidden_state |
|
|
|
|
|
if last_hidden_state.dim() == 2: |
|
|
last_hidden_state = last_hidden_state.unsqueeze(dim=0) |
|
|
|
|
|
embedding = torch.matmul( |
|
|
encoded["attention_mask"].type(torch.float32).view(-1, 1, 512), |
|
|
last_hidden_state, |
|
|
) |
|
|
|
|
|
return embedding |
|
|
|
|
|
def configure_optimizers(self): |
|
|
return self.optimizer |
|
|
|
|
|
|
|
|
def get_bert_embedding( |
|
|
text: str, as_list: bool = True, cls: bool = False, device=DEVICE, layer_dict: dict = {} |
|
|
) -> list: |
|
|
encoder = EncoderPL(cls=cls, layer_dict=layer_dict).to(device) |
|
|
embedding = encoder.forward(text) |
|
|
|
|
|
if as_list: |
|
|
embedding = embedding.tolist()[0][0] |
|
|
|
|
|
return embedding |
|
|
|
|
|
|
|
|
def get_concat_embedding( |
|
|
text: str = None, |
|
|
bert_embedding: list = [], |
|
|
other_features: dict = {}, |
|
|
cls: bool = False, |
|
|
device=DEVICE, |
|
|
layer_dict: dict = {}, |
|
|
) -> list: |
|
|
|
|
|
if not len(bert_embedding): |
|
|
|
|
|
if text is None: |
|
|
raise ValueError("both text and embedding are empty!") |
|
|
bert_embedding = get_bert_embedding(text=text, cls=cls, device=device, layer_dict=layer_dict) |
|
|
|
|
|
other_features = vectorise_dict(other_features, as_list=True) |
|
|
|
|
|
concat_vec = bert_embedding + other_features |
|
|
|
|
|
return concat_vec |
|
|
|