File size: 322 Bytes
8b17fbc |
1 2 3 4 5 6 7 8 9 10 |
from transformers import PretrainedConfig
# Define the configuration class
class SimpleNNConfig(PretrainedConfig):
model_type = "simple_nn"
def __init__(self, input_size=784, num_classes=10, **kwargs):
super().__init__(**kwargs)
self.input_size = input_size
self.num_classes = num_classes |