First successful run, added checkpoints
Browse files
main.py
CHANGED
@@ -11,12 +11,15 @@ from typing import Tuple
|
|
11 |
from torch.utils.data.dataloader import DataLoader
|
12 |
from datasets import load_dataset
|
13 |
from transformers import AutoTokenizer
|
14 |
-
from utils import OsSoluConfig, tokenise
|
15 |
from model import OsSoluModel
|
16 |
|
17 |
WANDB_PROJECT_NAME = "os_solu"
|
18 |
DEVICE = "cuda" if t.cuda.is_available() else "cpu"
|
19 |
|
|
|
|
|
|
|
20 |
def parse_arguments() -> dict:
|
21 |
"""Parses command-line arguments for this model run. Arguments of type string have allowed values,
|
22 |
which are enforced. Default parameter values are provided such that fields in the config are never None.
|
@@ -29,7 +32,8 @@ def parse_arguments() -> dict:
|
|
29 |
dict: a dictionary containing the command-line arguments parsed by this function.
|
30 |
"""
|
31 |
parser = argparse.ArgumentParser(description="Parse command-line arguments for this model.")
|
32 |
-
parser.add_argument("--batch_size", type=int, default=
|
|
|
33 |
parser.add_argument("--d_model", type=int, default=512, help="Hidden size of the model.")
|
34 |
parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
|
35 |
parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
|
@@ -38,7 +42,7 @@ def parse_arguments() -> dict:
|
|
38 |
parser.add_argument("--nonlinearity", type=str, default="solu", help=" Nonlinearity to use inside MLP block: must be relu or solu.")
|
39 |
parser.add_argument("--num_blocks", type=int, default=1, help="Number of transformer blocks.")
|
40 |
parser.add_argument("--num_embeddings", type=int, default=1024, help="Number of embeddings.")
|
41 |
-
parser.add_argument("--num_epochs", type=int, default=
|
42 |
parser.add_argument("--num_heads", type=int, default=4, help="Number of attention heads in each attention layer.")
|
43 |
parser.add_argument("--optimiser_type", type=str, default="adam", help="Optimiser type.")
|
44 |
parser.add_argument("--self_attention_type", type=str, default="unidirectional", help="What type of attention to use: rotary or unidirectional.")
|
@@ -69,8 +73,7 @@ def train(config: OsSoluConfig, model: OsSoluModel, train_dataloader: DataLoader
|
|
69 |
Returns:
|
70 |
OsSoluModel: The trained model.
|
71 |
"""
|
72 |
-
|
73 |
-
wandb.watch(model, criterion=train_loss_fn, log="all", log_freq=10, log_graph=True)
|
74 |
|
75 |
# Initialise optimiser.
|
76 |
opt = optim.Adam if config.optimiser_type.lower() == "adam" else optim.SGD
|
@@ -82,18 +85,32 @@ def train(config: OsSoluConfig, model: OsSoluModel, train_dataloader: DataLoader
|
|
82 |
for epoch in range(config.num_epochs):
|
83 |
for i, batch in enumerate(tqdm(train_data_iterator
|
84 |
)):
|
85 |
-
|
86 |
-
|
|
|
87 |
|
88 |
-
|
89 |
-
accuracy = (predictions.argmax(dim=-1) == target).sum() / len(data)
|
90 |
optimiser.zero_grad()
|
91 |
-
|
92 |
loss.backward()
|
93 |
optimiser.step()
|
94 |
|
95 |
-
wandb.log(dict(train_loss=loss,
|
96 |
-
examples_seen += len(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
return model
|
99 |
|
@@ -112,15 +129,14 @@ def eval(model: OsSoluModel, test_dataloader: DataLoader) -> None:
|
|
112 |
model.eval()
|
113 |
with t.inference_mode():
|
114 |
test_data_iterator = iter(test_dataloader)
|
115 |
-
for i,
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
wandb.log(dict(test_loss=total_loss, test_accuracy=num_correct / examples_seen, elapsed=time.time() - start_time), step=examples_seen)
|
124 |
|
125 |
# Save the model's state on disk, then upload to wandb.
|
126 |
filename = f"{wandb.run.dir}/model_state_dict.pt"
|
@@ -135,9 +151,10 @@ def setup() -> Tuple[OsSoluConfig, OsSoluModel]:
|
|
135 |
Tuple[OsSoluConfig, OsSoluModel, datasets.iterable_dataset.IterableDataset, datasets.iterable_dataset.IterableDataset]: A tuple containing a config, a model, a training dataset and a test dataset.
|
136 |
"""
|
137 |
args = parse_arguments()
|
138 |
-
wandb.init(project=WANDB_PROJECT_NAME, config=args)
|
139 |
config = OsSoluConfig(args)
|
140 |
model = OsSoluModel(config).to(DEVICE)
|
|
|
|
|
141 |
|
142 |
start_data_time = time.time()
|
143 |
# Load and prep data.
|
|
|
11 |
from torch.utils.data.dataloader import DataLoader
|
12 |
from datasets import load_dataset
|
13 |
from transformers import AutoTokenizer
|
14 |
+
from utils import OsSoluConfig, tokenise, loss_fn, count_parameters
|
15 |
from model import OsSoluModel
|
16 |
|
17 |
WANDB_PROJECT_NAME = "os_solu"
|
18 |
DEVICE = "cuda" if t.cuda.is_available() else "cpu"
|
19 |
|
20 |
+
# TODO: Add support for distributed training.
|
21 |
+
# TODO: Use only book data from dataset.
|
22 |
+
|
23 |
def parse_arguments() -> dict:
|
24 |
"""Parses command-line arguments for this model run. Arguments of type string have allowed values,
|
25 |
which are enforced. Default parameter values are provided such that fields in the config are never None.
|
|
|
32 |
dict: a dictionary containing the command-line arguments parsed by this function.
|
33 |
"""
|
34 |
parser = argparse.ArgumentParser(description="Parse command-line arguments for this model.")
|
35 |
+
parser.add_argument("--batch_size", type=int, default=40, help="Batch size used in training.")
|
36 |
+
parser.add_argument("--checkpoint_every_n_tokens", type=int, default=50_000, help="Save a checkpoint of the model every n tokens processed.")
|
37 |
parser.add_argument("--d_model", type=int, default=512, help="Hidden size of the model.")
|
38 |
parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
|
39 |
parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
|
|
|
42 |
parser.add_argument("--nonlinearity", type=str, default="solu", help=" Nonlinearity to use inside MLP block: must be relu or solu.")
|
43 |
parser.add_argument("--num_blocks", type=int, default=1, help="Number of transformer blocks.")
|
44 |
parser.add_argument("--num_embeddings", type=int, default=1024, help="Number of embeddings.")
|
45 |
+
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs to run for.")
|
46 |
parser.add_argument("--num_heads", type=int, default=4, help="Number of attention heads in each attention layer.")
|
47 |
parser.add_argument("--optimiser_type", type=str, default="adam", help="Optimiser type.")
|
48 |
parser.add_argument("--self_attention_type", type=str, default="unidirectional", help="What type of attention to use: rotary or unidirectional.")
|
|
|
73 |
Returns:
|
74 |
OsSoluModel: The trained model.
|
75 |
"""
|
76 |
+
wandb.watch(model, criterion=loss_fn, log="all", log_freq=10, log_graph=True)
|
|
|
77 |
|
78 |
# Initialise optimiser.
|
79 |
opt = optim.Adam if config.optimiser_type.lower() == "adam" else optim.SGD
|
|
|
85 |
for epoch in range(config.num_epochs):
|
86 |
for i, batch in enumerate(tqdm(train_data_iterator
|
87 |
)):
|
88 |
+
start_time = time.time()
|
89 |
+
batch = batch["text"]
|
90 |
+
batch = batch.to(DEVICE)
|
91 |
|
92 |
+
logits = model(batch)
|
|
|
93 |
optimiser.zero_grad()
|
94 |
+
loss = loss_fn(logits, batch)
|
95 |
loss.backward()
|
96 |
optimiser.step()
|
97 |
|
98 |
+
wandb.log(dict(train_loss=loss, elapsed=time.time() - start_time), step=examples_seen)
|
99 |
+
examples_seen += len(batch)
|
100 |
+
|
101 |
+
# Save a checkpoint of the model.
|
102 |
+
if examples_seen % config.checkpoint_every_n_tokens == 0:
|
103 |
+
# Save the model's state on disk, then upload to wandb.
|
104 |
+
filename = f"{wandb.run.dir}/os_solu_model_ckpt_step_{examples_seen}.pt"
|
105 |
+
t.save({
|
106 |
+
"step": examples_seen,
|
107 |
+
"model_state_dict": model.state_dict(),
|
108 |
+
"optimiser_state_dict": optimiser.state_dict(),
|
109 |
+
"loss": loss.item()
|
110 |
+
}, filename)
|
111 |
+
wandb.save(filename)
|
112 |
+
print(f"Checkpointing model at {examples_seen} tokens seen.")
|
113 |
+
|
114 |
|
115 |
return model
|
116 |
|
|
|
129 |
model.eval()
|
130 |
with t.inference_mode():
|
131 |
test_data_iterator = iter(test_dataloader)
|
132 |
+
for i, batch in enumerate(tqdm(test_data_iterator)):
|
133 |
+
batch = batch["text"]
|
134 |
+
batch = batch.to(DEVICE)
|
135 |
+
|
136 |
+
logits = model(batch)
|
137 |
+
total_loss += loss_fn(logits, batch).item()
|
138 |
+
examples_seen += len(batch)
|
139 |
+
wandb.log(dict(test_loss=total_loss, elapsed=time.time() - start_time), step=examples_seen)
|
|
|
140 |
|
141 |
# Save the model's state on disk, then upload to wandb.
|
142 |
filename = f"{wandb.run.dir}/model_state_dict.pt"
|
|
|
151 |
Tuple[OsSoluConfig, OsSoluModel, datasets.iterable_dataset.IterableDataset, datasets.iterable_dataset.IterableDataset]: A tuple containing a config, a model, a training dataset and a test dataset.
|
152 |
"""
|
153 |
args = parse_arguments()
|
|
|
154 |
config = OsSoluConfig(args)
|
155 |
model = OsSoluModel(config).to(DEVICE)
|
156 |
+
args["num_parameters"] = count_parameters(model)
|
157 |
+
wandb.init(project=WANDB_PROJECT_NAME, config=args)
|
158 |
|
159 |
start_data_time = time.time()
|
160 |
# Load and prep data.
|
model.py
CHANGED
@@ -7,7 +7,8 @@ from fancy_einsum import einsum
|
|
7 |
from einops import rearrange, repeat, reduce
|
8 |
from utils import OsSoluConfig
|
9 |
|
10 |
-
|
|
|
11 |
|
12 |
class OsSoluModel(nn.Module):
|
13 |
"""An open-source implementation of a SoLU-based transformer. This is a GPT-style architecture model
|
@@ -128,4 +129,13 @@ class RotaryAttention(nn.Module):
|
|
128 |
|
129 |
def forward(self, x: t.Tensor) -> t.Tensor:
|
130 |
# TODO: implement rotary self-attention
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
pass
|
|
|
7 |
from einops import rearrange, repeat, reduce
|
8 |
from utils import OsSoluConfig
|
9 |
|
10 |
+
# TODO: Add hooks to the model.
|
11 |
+
# TODO: Add support for mixing dense and sparse attention.
|
12 |
|
13 |
class OsSoluModel(nn.Module):
|
14 |
"""An open-source implementation of a SoLU-based transformer. This is a GPT-style architecture model
|
|
|
129 |
|
130 |
def forward(self, x: t.Tensor) -> t.Tensor:
|
131 |
# TODO: implement rotary self-attention
|
132 |
+
pass
|
133 |
+
|
134 |
+
class LayerNorm(nn.Module):
|
135 |
+
def __init__(self, config: OsSoluConfig) -> None:
|
136 |
+
super().__init__()
|
137 |
+
self.config = config
|
138 |
+
|
139 |
+
def forward(self, x: t.Tensor) -> t.Tensor:
|
140 |
+
# TODO: implement layernorm with hooks on normalisation only.
|
141 |
pass
|
utils.py
CHANGED
@@ -1,10 +1,14 @@
|
|
1 |
import numpy as np
|
|
|
|
|
2 |
from einops import rearrange
|
3 |
|
|
|
4 |
class OsSoluConfig:
|
5 |
"""A class to hold hyperparameters for the model itself and for the training process."""
|
6 |
|
7 |
batch_size: int # Training data batch size.
|
|
|
8 |
d_model: int # Hidden size of the model.
|
9 |
dropout: float # Probability of dropout.
|
10 |
learning_rate: float # Learning rate for the optimiser.
|
@@ -23,6 +27,7 @@ class OsSoluConfig:
|
|
23 |
"""Initialise this config class with values provided by a command-line argument parser.
|
24 |
Values are never None here, as we provide suitable defaults in the parser call."""
|
25 |
self.batch_size = args["batch_size"]
|
|
|
26 |
self.d_model = args["d_model"]
|
27 |
self.dropout = args["dropout"]
|
28 |
self.learning_rate = args["learning_rate"]
|
@@ -38,7 +43,7 @@ class OsSoluConfig:
|
|
38 |
self.vocab_size = args["vocab_size"]
|
39 |
|
40 |
def tokenise(batch, tokeniser, num_gpus: int = 1, context_length: int = 1024):
|
41 |
-
"""Tokenise a batch of text data. This implementation is idiosyncratic to the Pile dataset, but can be easily modified to work with e.g. C4.
|
42 |
|
43 |
Args:
|
44 |
batch (dict): The batch of text, as a dict with a 'text' field.
|
@@ -70,7 +75,25 @@ def tokenise(batch, tokeniser, num_gpus: int = 1, context_length: int = 1024):
|
|
70 |
|
71 |
tokenised_text = np.concatenate([prefix, all_tokens], axis=1)
|
72 |
assert tokenised_text.shape == (current_batch_size, context_length)
|
73 |
-
print(f"{current_batch_size=}, {context_length=}")
|
74 |
return {"text": tokenised_text}
|
75 |
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
+
import torch as t
|
3 |
+
import torch.nn.functional as F
|
4 |
from einops import rearrange
|
5 |
|
6 |
+
# TODO: Add functionality to load this from a config file as an alternative to command-line args.
|
7 |
class OsSoluConfig:
|
8 |
"""A class to hold hyperparameters for the model itself and for the training process."""
|
9 |
|
10 |
batch_size: int # Training data batch size.
|
11 |
+
checkpoint_every_n_tokens: int # Save a checkpoint of the model every n tokens processed.
|
12 |
d_model: int # Hidden size of the model.
|
13 |
dropout: float # Probability of dropout.
|
14 |
learning_rate: float # Learning rate for the optimiser.
|
|
|
27 |
"""Initialise this config class with values provided by a command-line argument parser.
|
28 |
Values are never None here, as we provide suitable defaults in the parser call."""
|
29 |
self.batch_size = args["batch_size"]
|
30 |
+
self.checkpoint_every_n_tokens = args["checkpoint_every_n_tokens"]
|
31 |
self.d_model = args["d_model"]
|
32 |
self.dropout = args["dropout"]
|
33 |
self.learning_rate = args["learning_rate"]
|
|
|
43 |
self.vocab_size = args["vocab_size"]
|
44 |
|
45 |
def tokenise(batch, tokeniser, num_gpus: int = 1, context_length: int = 1024):
|
46 |
+
"""Tokenise a batch of text data. This implementation is idiosyncratic to the Pile dataset, but can be easily modified to work with e.g. C4. Code from Neel.
|
47 |
|
48 |
Args:
|
49 |
batch (dict): The batch of text, as a dict with a 'text' field.
|
|
|
75 |
|
76 |
tokenised_text = np.concatenate([prefix, all_tokens], axis=1)
|
77 |
assert tokenised_text.shape == (current_batch_size, context_length)
|
|
|
78 |
return {"text": tokenised_text}
|
79 |
|
80 |
+
def loss_fn(logits, batch):
|
81 |
+
"""Loss function to train an autoregressive model. It compares the token logits predicted by the model with the actual next token. Code from Neel.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
logits (t.Tensor): A tensor containing logits, has shape (batch_size, sequence_length, vocab_size)
|
85 |
+
batch (t.Tensor): A tensor containing token IDs, has shape (batch_size, sequence_length, vocab_size)
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
loss (t.Tensor): A tensor containing the loss value.
|
89 |
+
"""
|
90 |
+
|
91 |
+
# Log-softmax to get log-probabilities.
|
92 |
+
log_probs = F.log_softmax(logits[:, :-1], dim=-1)
|
93 |
+
|
94 |
+
# Match up the probabilities of the actual words.
|
95 |
+
pred_log_probs = t.gather(log_probs, -1, batch[:, 1:, None])[..., 0]
|
96 |
+
return -pred_log_probs.mean()
|
97 |
+
|
98 |
+
def count_parameters(model):
|
99 |
+
return sum(parameter.numel() for parameter in model.parameters() if parameter.requires_grad)
|