File size: 3,608 Bytes
cea4a4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import emoji
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from loguru import logger
from torch import nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchmetrics import R2Score
from transformers import BertModel, BertTokenizer, DistilBertModel, AutoModel, AutoTokenizer
from pytorch_lightning import LightningModule


from src.utils.neural_networks import set_layer
from src.utils import add_emoji_tokens, add_new_line_token, vectorise_dict
from config import DEVICE

torch.set_default_dtype(torch.float32)


class EncoderPL(pl.LightningModule):
    def __init__(
        self,
        model_name: str = "bert-base-uncased",
        tokenizer: AutoTokenizer | None = None,
        bert: AutoModel | None = None,
        cls: bool = False,
        device=DEVICE,
    ):
        super().__init__()

        self._device = device
        self.cls = cls
        self.model_name = model_name

        # layers

        self.tokenizer = tokenizer if tokenizer is not None else BertTokenizer.from_pretrained(model_name)

        self.bert = bert if bert is not None else BertModel.from_pretrained(model_name)

        if tokenizer is None:
            self.tokenizer = add_emoji_tokens(self.tokenizer)
            self.tokenizer = add_new_line_token(self.tokenizer)
            self.bert.resize_token_embeddings(len(self.tokenizer))

        # optimizer and scheduler
        self.optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=1e-3)

        # config tweaking
        self.bert.config.torch_dtype = "float32"

    def forward(self, text: str):

        # run text through bert and squash the output to get embeddings
        encoded = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True).to(self._device)

        if type(self.bert) == DistilBertModel:
            encoded.pop("token_type_ids")

        bert_output = self.bert(**encoded)

        if self.cls:
            if hasattr(bert_output, "pooler_output") and bert_output.pooler_output is not None:
                embedding = bert_output.pooler_output.unsqueeze(dim=1)
            else:
                embedding = bert_output.last_hidden_state[0, 0, :].unsqueeze(dim=0).unsqueeze(dim=0)
        else:
            last_hidden_state = bert_output.last_hidden_state

            if last_hidden_state.dim() == 2:
                last_hidden_state = last_hidden_state.unsqueeze(dim=0)

            embedding = torch.matmul(
                encoded["attention_mask"].type(torch.float32).view(-1, 1, 512),
                last_hidden_state,
            )

        return embedding

    def configure_optimizers(self):
        return self.optimizer


def get_bert_embedding(
    text: str, as_list: bool = True, cls: bool = False, device=DEVICE, layer_dict: dict = {}
) -> list:
    encoder = EncoderPL(cls=cls, layer_dict=layer_dict).to(device)
    embedding = encoder.forward(text)

    if as_list:
        embedding = embedding.tolist()[0][0]

    return embedding


def get_concat_embedding(
    text: str = None,
    bert_embedding: list = [],
    other_features: dict = {},
    cls: bool = False,
    device=DEVICE,
    layer_dict: dict = {},
) -> list:

    if not len(bert_embedding):

        if text is None:
            raise ValueError("both text and embedding are empty!")
        bert_embedding = get_bert_embedding(text=text, cls=cls, device=device, layer_dict=layer_dict)

    other_features = vectorise_dict(other_features, as_list=True)

    concat_vec = bert_embedding + other_features

    return concat_vec