File size: 463 Bytes
e8adde1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from transformers import PretrainedConfig
from typing import List


class MLPConfig(PretrainedConfig):
    model_type="mlp"

    def __init__(
        self,
        input_size: int = 784,
        output_size: int = 10,
        hidden_size: int = 256,
        **kwargs,
    ):
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        super().__init__(**kwargs)

MLPConfig.register_for_auto_class()