from transformers import PreTrainedModel from transformers import AutoModelForMaskedLM, AutoTokenizer from pytorch_lightning.loggers import WandbLogger from src.regression.PL import FullModelPL, EncoderPL, DecoderPL from src.regression.HF.configs import FullModelConfigHF from config import DEVICE class FullModelHF(PreTrainedModel): config_class = FullModelConfigHF def __init__(self, config): super().__init__(config) self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_ckpt) mlm_bert = AutoModelForMaskedLM.from_pretrained(config.bert_ckpt) self.bert = mlm_bert.distilbert encoder = EncoderPL(tokenizer=self.tokenizer, bert=self.bert).to(DEVICE) wandb_logger = WandbLogger( project="transformers", entity="sanjin_juric_fot", # log_model=True, # reinit=True, ) artifact = wandb_logger.use_artifact(config.decoder_ckpt) artifact_dir = artifact.download() decoder = DecoderPL.load_from_checkpoint(artifact_dir + "/" + "model.ckpt").to(DEVICE) self.model = FullModelPL( encoder=encoder, decoder=decoder, layer_norm=config.layer_norm, nontext_features=config.nontext_features, ).to(DEVICE) def forward(self, input): return self.model._get_loss(input)