File size: 463 Bytes
e8adde1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
from transformers import PretrainedConfig
from typing import List
class MLPConfig(PretrainedConfig):
model_type="mlp"
def __init__(
self,
input_size: int = 784,
output_size: int = 10,
hidden_size: int = 256,
**kwargs,
):
self.input_size = input_size
self.output_size = output_size
self.hidden_size = hidden_size
super().__init__(**kwargs)
MLPConfig.register_for_auto_class()
|