| import tensorflow as tf | |
| from transformers.modeling_tf_utils import TFPreTrainedModel | |
| from configuration_my_model import MyModelConfig | |
| class TFMyModelPretrainedModel(TFPreTrainedModel): | |
| config_class = MyModelConfig | |
| class TFMyModel(TFMyModelPretrainedModel): | |
| def __init__(self, config: MyModelConfig): | |
| super().__init__(config) | |
| self.config = config | |
| self.n_layers = config.n_layers | |
| self.hidden_dim = config.hidden_dim | |
| self.linear = tf.keras.layers.Dense(units=config.n_layers) | |
| config = MyModelConfig() | |
| model = TFMyModel(config) | |
| print(model) | |
| model.save_pretrained("my_model") |