|
|
from transformers import PreTrainedModel |
|
|
from transformers import AutoModelForMaskedLM, AutoTokenizer |
|
|
from pytorch_lightning.loggers import WandbLogger |
|
|
|
|
|
from src.regression.PL import FullModelPL, EncoderPL, DecoderPL |
|
|
from src.regression.HF.configs import FullModelConfigHF |
|
|
|
|
|
from config import DEVICE |
|
|
|
|
|
|
|
|
class FullModelHF(PreTrainedModel): |
|
|
config_class = FullModelConfigHF |
|
|
|
|
|
def __init__(self, config): |
|
|
|
|
|
super().__init__(config) |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_ckpt) |
|
|
mlm_bert = AutoModelForMaskedLM.from_pretrained(config.bert_ckpt) |
|
|
self.bert = mlm_bert.distilbert |
|
|
|
|
|
encoder = EncoderPL(tokenizer=self.tokenizer, bert=self.bert).to(DEVICE) |
|
|
|
|
|
wandb_logger = WandbLogger( |
|
|
project="transformers", |
|
|
entity="sanjin_juric_fot", |
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
artifact = wandb_logger.use_artifact(config.decoder_ckpt) |
|
|
artifact_dir = artifact.download() |
|
|
decoder = DecoderPL.load_from_checkpoint(artifact_dir + "/" + "model.ckpt").to(DEVICE) |
|
|
|
|
|
self.model = FullModelPL( |
|
|
encoder=encoder, |
|
|
decoder=decoder, |
|
|
layer_norm=config.layer_norm, |
|
|
nontext_features=config.nontext_features, |
|
|
).to(DEVICE) |
|
|
|
|
|
def forward(self, input): |
|
|
return self.model._get_loss(input) |
|
|
|