ctr-ll4 / FullModelConfigHF.py
sanjin7's picture
Upload model
82c0c38
from transformers import PretrainedConfig
from src.regression.PL import EncoderPL, DecoderPL
from typing import List
class FullModelConfigHF(PretrainedConfig):
model_type = "full_model"
def __init__(
self,
tokenizer_ckpt: str = "",
bert_ckpt: str = "",
decoder_ckpt: str = "",
layer_norm: bool = True,
nontext_features: List[str] = ["aov"],
**kwargs,
):
self.tokenizer_ckpt = tokenizer_ckpt
self.bert_ckpt = bert_ckpt
self.decoder_ckpt = decoder_ckpt
self.nontext_features = nontext_features
self.layer_norm = layer_norm
super().__init__(**kwargs)