inwaves commited on
Commit
4e1467d
1 Parent(s): dc40649

Skeleton for classes

Browse files
Files changed (5) hide show
  1. README.MD +3 -1
  2. main.py +23 -0
  3. model.py +46 -0
  4. requirements.txt +13 -0
  5. utils.py +6 -0
README.MD CHANGED
@@ -1 +1,3 @@
1
- # Open-source Softmax Linear Unit
 
 
 
1
+ # Open-source Softmax Linear Unit
2
+
3
+ Replicating the results in the paper [Softmax Linear Units](https://transformer-circuits.pub/2022/solu/index.html) published recently by Anthropic.
main.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as t
2
+ import torch.nn as nn
3
+ import torch.functional as F
4
+ import torch.optim as optim
5
+
6
+
7
+ def parse_args():
8
+ # TODO: command-line args for hparams
9
+ pass
10
+
11
+ def train():
12
+ # TODO: training loop
13
+ pass
14
+
15
+ def eval():
16
+ pass
17
+
18
+ def setup():
19
+ # TODO: wandb logging, load configs, all that stuff
20
+ pass
21
+
22
+ if __name__=="__main__":
23
+ parse_args()
model.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as t
2
+ import torch.nn as nn
3
+ import torch.functional as F
4
+ import torch.optim as optim
5
+ import wandb
6
+ import fancy_einsum
7
+ from einops import rearrange, repeat, reduce
8
+
9
+
10
+ class OsSoluModel(nn.Module):
11
+ def __init__(self, config) -> None:
12
+ super().__init__()
13
+ self.config = config
14
+ self.transformer_block = TransformerBlock(config)
15
+
16
+ def forward(self, x: t.Tensor) -> t.Tensor:
17
+ pass
18
+
19
+
20
+ class TransformerBlock(nn.Module):
21
+ def __init__(self, config) -> None:
22
+ super().__init__()
23
+ self.config = config
24
+
25
+ # Embed,
26
+ self.embed = nn.Embedding(num_embeddings, config.d_model)
27
+ # One MLP, one attention
28
+ # one layernorm, one dropout (?)
29
+ # Unembed
30
+
31
+ def forward(self, x: t.Tensor) -> t.Tensor:
32
+ pass
33
+
34
+
35
+ class RotaryAttention(nn.Module):
36
+ def __init__(self, config) -> None:
37
+ super().__init__()
38
+
39
+ def forward(self, x: t.Tensor, attention_mask: t.Tensor) -> t.Tensor:
40
+ # Compute pre-softmax attention scores
41
+ # Apply attention mask
42
+ # Compute softmax
43
+ # Apply final einsum
44
+ # Return attention output
45
+
46
+ pass
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ wandb
3
+ einops
4
+ fancy_einsum
5
+ tqdm
6
+ ipykernel
7
+ notebook
8
+ ipywidgets
9
+ jupyter
10
+ matplotlib
11
+ numpy-stl
12
+ wandb
13
+ plotly
utils.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ @dataclass
2
+ class OsSoluConfig:
3
+ d_model: int = 512
4
+ vocab_size: int = 65536 # Unsure about this.
5
+ learning_rate: float = 1e-3
6
+ num_embeddings: int = 1024 # Unsure about this.