juliagsy commited on
Commit
e2328d8
1 Parent(s): 94ef816

Create __main__.py

Browse files
Files changed (1) hide show
  1. linear/__main__.py +30 -0
linear/__main__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel
2
+ import torch
3
+
4
+
5
+ class LinearConfig(PretrainedConfig):
6
+ model_type = "linear"
7
+
8
+ def __init__(
9
+ self,
10
+ in_dims=768,
11
+ out_dims=1024,
12
+ **kwargs,
13
+ ):
14
+ self.in_dims = in_dims
15
+ self.out_dims = out_dims
16
+ super().__init__(**kwargs)
17
+
18
+
19
+ class LinearModel(PreTrainedModel):
20
+ config_class = LinearConfig
21
+
22
+ def __init__(self, config):
23
+ super().__init__(config)
24
+ self.model = torch.nn.Linear(
25
+ config.in_dims,
26
+ config.out_dims,
27
+ )
28
+
29
+ def forward(self, x):
30
+ return self.model.forward(x)