|
import tensorflow as tf
|
|
|
|
from transformers.modeling_tf_utils import TFPreTrainedModel
|
|
|
|
from configuration_my_model import MyModelConfig
|
|
|
|
|
|
class TFMyModelPretrainedModel(TFPreTrainedModel):
|
|
config_class = MyModelConfig
|
|
|
|
|
|
class TFMyModel(TFMyModelPretrainedModel):
|
|
|
|
def __init__(self, config: MyModelConfig):
|
|
super().__init__(config)
|
|
self.config = config
|
|
|
|
self.n_layers = config.n_layers
|
|
self.hidden_dim = config.hidden_dim
|
|
self.linear = tf.keras.layers.Dense(units=config.n_layers)
|
|
|
|
|
|
config = MyModelConfig()
|
|
model = TFMyModel(config)
|
|
print(model)
|
|
model.save_pretrained("my_model") |