Refactor config class, add argparser
Browse files
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*__pycache__/
|
main.py
CHANGED
@@ -2,22 +2,43 @@ import torch as t
|
|
2 |
import torch.nn as nn
|
3 |
import torch.functional as F
|
4 |
import torch.optim as optim
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
def parse_args():
|
8 |
# TODO: command-line args for hparams
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
def train():
|
12 |
# TODO: training loop
|
13 |
-
|
|
|
14 |
|
15 |
def eval():
|
16 |
pass
|
17 |
|
18 |
-
def setup():
|
19 |
-
# TODO: wandb logging
|
20 |
-
|
|
|
|
|
|
|
21 |
|
22 |
if __name__=="__main__":
|
23 |
-
|
|
|
|
|
|
2 |
import torch.nn as nn
|
3 |
import torch.functional as F
|
4 |
import torch.optim as optim
|
5 |
+
import argparse
|
6 |
+
from utils import OsSoluConfig
|
7 |
+
from model import OsSoluModel
|
8 |
+
from typing import Tuple
|
9 |
|
10 |
+
def parse_arguments() -> argparse.Namespace:
|
|
|
11 |
# TODO: command-line args for hparams
|
12 |
+
parser = argparse.ArgumentParser(description="Parse command-line arguments for this model.")
|
13 |
+
parser.add_argument("--d_model", type=int, default=512, help="Hidden size of the model.")
|
14 |
+
parser.add_argument("--vocab_size", type=int, default=65536, help="Vocabulary size of the input sequence.")
|
15 |
+
parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
|
16 |
+
parser.add_argument("--num_embeddings", type=int, default=1024, help="Number of embeddings.")
|
17 |
+
parser.add_argument("--num_blocks", type=int, default=1, help="Number of transformer blocks.")
|
18 |
+
parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
|
19 |
+
parser.add_argument("--ln_eps", type=float, default=1e-3, help="Layer norm epsilon.")
|
20 |
+
parser.add_argument("--num_heads", type=int, default=4, help="Number of attention heads in each attention layer.")
|
21 |
+
parser.add_argument("--self_attention_type", type=str, default="unidirectional", help="What type of attention to use: rotary or unidirectional. ")
|
22 |
+
parser.add_argument("--max_positional_embeddings", type=int, default=1024, help="Maximum number of positional embeddings.")
|
23 |
+
args = parser.parse_args()
|
24 |
+
return args
|
25 |
|
26 |
+
def train(config: OsSoluConfig, model: OsSoluModel) -> OsSoluModel:
|
27 |
# TODO: training loop
|
28 |
+
|
29 |
+
return model
|
30 |
|
31 |
def eval():
|
32 |
pass
|
33 |
|
34 |
+
def setup() -> Tuple[OsSoluConfig, OsSoluModel]:
|
35 |
+
# TODO: wandb logging
|
36 |
+
args = parse_arguments()
|
37 |
+
config = OsSoluConfig(args)
|
38 |
+
model = OsSoluModel(config)
|
39 |
+
return config, model
|
40 |
|
41 |
if __name__=="__main__":
|
42 |
+
config, model = setup()
|
43 |
+
trained_model = train(config, model)
|
44 |
+
eval()
|
model.py
CHANGED
@@ -15,7 +15,8 @@ class OsSoluModel(nn.Module):
|
|
15 |
self.config = config
|
16 |
self.embed_positions = nn.Embedding(config.max_positional_embeddings, config.d_model)
|
17 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
|
18 |
-
self.
|
|
|
19 |
self.final_ln = nn.LayerNorm(normalized_shape, config.ln_eps)
|
20 |
self.unembed = nn
|
21 |
|
@@ -23,23 +24,36 @@ class OsSoluModel(nn.Module):
|
|
23 |
positional_embeddings = self.embed_positions(t.arange(x.size(1)))
|
24 |
token_embeddings = self.embed_tokens(x)
|
25 |
embeddings = positional_embeddings + token_embeddings
|
|
|
|
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
class
|
29 |
def __init__(self, config: OsSoluConfig) -> None:
|
30 |
super().__init__()
|
31 |
self.config = config
|
32 |
|
|
|
33 |
self.attention = UnidirectionalAttention(config) if config.self_attention_type == "unidirectional" else RotaryAttention(config)
|
34 |
-
self.
|
35 |
-
nn.
|
|
|
36 |
SoLU(),
|
|
|
|
|
37 |
)
|
38 |
-
self.layer_norm = nn.LayerNorm(normalized_shape, config.ln_eps)
|
39 |
-
self.unembed = nn.Embedding(config.num_embeddings, config.d_model)
|
40 |
|
41 |
def forward(self, x: t.Tensor) -> t.Tensor:
|
42 |
-
|
|
|
|
|
|
|
43 |
|
44 |
|
45 |
class UnidirectionalAttention(nn.Module):
|
@@ -96,4 +110,5 @@ class RotaryAttention(nn.Module):
|
|
96 |
self.config = config
|
97 |
|
98 |
def forward(self, x: t.Tensor) -> t.Tensor:
|
|
|
99 |
pass
|
|
|
15 |
self.config = config
|
16 |
self.embed_positions = nn.Embedding(config.max_positional_embeddings, config.d_model)
|
17 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
|
18 |
+
self.dropout = nn.Dropout(config.dropout)
|
19 |
+
self.transformer_blocks = nn.ModuleList([GPT2Block(config) for _ in range(config.num_blocks)])
|
20 |
self.final_ln = nn.LayerNorm(normalized_shape, config.ln_eps)
|
21 |
self.unembed = nn
|
22 |
|
|
|
24 |
positional_embeddings = self.embed_positions(t.arange(x.size(1)))
|
25 |
token_embeddings = self.embed_tokens(x)
|
26 |
embeddings = positional_embeddings + token_embeddings
|
27 |
+
out = self.dropout(embeddings)
|
28 |
+
out = self.transformer_blocks(out)
|
29 |
|
30 |
+
class SoLU(nn.Module):
|
31 |
+
def __init__(self):
|
32 |
+
pass
|
33 |
+
|
34 |
+
def forward(self, x: t.Tensor) -> t.Tensor:
|
35 |
+
return x * x.softmax(dim=-1)
|
36 |
|
37 |
+
class GPT2Block(nn.Module):
|
38 |
def __init__(self, config: OsSoluConfig) -> None:
|
39 |
super().__init__()
|
40 |
self.config = config
|
41 |
|
42 |
+
self.layer_norm1 = nn.LayerNorm(normalized_shape, config.ln_eps)
|
43 |
self.attention = UnidirectionalAttention(config) if config.self_attention_type == "unidirectional" else RotaryAttention(config)
|
44 |
+
self.MLP = nn.Sequential(
|
45 |
+
nn.LayerNorm(normalized_shape, config.ln_eps),
|
46 |
+
nn.Linear(config.d_model, 4*config.d_model),
|
47 |
SoLU(),
|
48 |
+
nn.Linear(4*config.d_model, config.d_model),
|
49 |
+
nn.Dropout(config.dropout)
|
50 |
)
|
|
|
|
|
51 |
|
52 |
def forward(self, x: t.Tensor) -> t.Tensor:
|
53 |
+
x = x + self.attention(self.layer_norm1(x))
|
54 |
+
x = x + self.MLP(x)
|
55 |
+
return x
|
56 |
+
|
57 |
|
58 |
|
59 |
class UnidirectionalAttention(nn.Module):
|
|
|
110 |
self.config = config
|
111 |
|
112 |
def forward(self, x: t.Tensor) -> t.Tensor:
|
113 |
+
# TODO: implement rotary self-attention
|
114 |
pass
|
utils.py
CHANGED
@@ -1,12 +1,27 @@
|
|
1 |
-
|
|
|
2 |
class OsSoluConfig:
|
3 |
-
d_model: int
|
4 |
-
vocab_size: int
|
5 |
-
learning_rate: float
|
6 |
-
num_embeddings: int
|
7 |
-
num_blocks: int
|
8 |
-
dropout: float
|
9 |
-
ln_eps: float
|
10 |
-
num_heads: int
|
11 |
-
self_attention_type: str
|
12 |
-
max_positional_embeddings: int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
class OsSoluConfig:
|
4 |
+
d_model: int # Hidden size of the model.
|
5 |
+
vocab_size: int # Vocabulary size of the input sequence. Unsure about this.
|
6 |
+
learning_rate: float # Learning rate for the optimiser.
|
7 |
+
num_embeddings: int # Number of embeddings. Unsure about this.
|
8 |
+
num_blocks: int # Number of transformer blocks.
|
9 |
+
dropout: float # Probability of dropout.
|
10 |
+
ln_eps: float # Layer norm epsilon.
|
11 |
+
num_heads: int # Number of attention heads in each attention layer.
|
12 |
+
self_attention_type: str # What type of attention to use: rotary or unidirectional.
|
13 |
+
max_positional_embeddings: int # Maximum number of positional embeddings.
|
14 |
+
|
15 |
+
def __init__(self, args: argparse.Namespace) -> None:
|
16 |
+
"""Initialise this config class with values provided by a command-line argument parser.
|
17 |
+
Values are never None here, as we provide suitable defaults in the parser call."""
|
18 |
+
self.d_model = args.d_model
|
19 |
+
self.vocab_size = args.vocab_size
|
20 |
+
self.learning_rate = args.learning_rate
|
21 |
+
self.num_embeddings = args.num_embeddings
|
22 |
+
self.num_blocks = args.num_blocks
|
23 |
+
self.dropout = args.dropout
|
24 |
+
self.ln_eps = args.ln_eps
|
25 |
+
self.num_heads = args.num_heads
|
26 |
+
self.self_attention_type = args.self_attention_type
|
27 |
+
self.max_positional_embeddings = args.max_positional_embeddings
|