| from transformers import PretrainedConfig | |
| from src.regression.PL import EncoderPL, DecoderPL | |
| from typing import List | |
| class FullModelConfigHF(PretrainedConfig): | |
| model_type = "full_model" | |
| def __init__( | |
| self, | |
| tokenizer_ckpt: str = "", | |
| bert_ckpt: str = "", | |
| decoder_ckpt: str = "", | |
| layer_norm: bool = True, | |
| nontext_features: List[str] = ["aov"], | |
| **kwargs, | |
| ): | |
| self.tokenizer_ckpt = tokenizer_ckpt | |
| self.bert_ckpt = bert_ckpt | |
| self.decoder_ckpt = decoder_ckpt | |
| self.nontext_features = nontext_features | |
| self.layer_norm = layer_norm | |
| super().__init__(**kwargs) | |