Added tokenise method for streamed data, fixed issues with einsums
Browse files
main.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import argparse
|
|
|
2 |
import torch as t
|
3 |
import torch.nn as nn
|
4 |
import torch.functional as F
|
@@ -9,7 +10,8 @@ import wandb
|
|
9 |
from typing import Tuple
|
10 |
from torch.utils.data.dataloader import DataLoader
|
11 |
from datasets import load_dataset
|
12 |
-
from
|
|
|
13 |
from model import OsSoluModel
|
14 |
|
15 |
WANDB_PROJECT_NAME = "os_solu"
|
@@ -32,7 +34,7 @@ def parse_arguments() -> dict:
|
|
32 |
parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
|
33 |
parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
|
34 |
parser.add_argument("--ln_eps", type=float, default=1e-3, help="Layer norm epsilon.")
|
35 |
-
parser.add_argument("--max_positional_embeddings", type=int, default=1024, help="Maximum number of positional embeddings.")
|
36 |
parser.add_argument("--nonlinearity", type=str, default="solu", help=" Nonlinearity to use inside MLP block: must be relu or solu.")
|
37 |
parser.add_argument("--num_blocks", type=int, default=1, help="Number of transformer blocks.")
|
38 |
parser.add_argument("--num_embeddings", type=int, default=1024, help="Number of embeddings.")
|
@@ -40,7 +42,7 @@ def parse_arguments() -> dict:
|
|
40 |
parser.add_argument("--num_heads", type=int, default=4, help="Number of attention heads in each attention layer.")
|
41 |
parser.add_argument("--optimiser_type", type=str, default="adam", help="Optimiser type.")
|
42 |
parser.add_argument("--self_attention_type", type=str, default="unidirectional", help="What type of attention to use: rotary or unidirectional.")
|
43 |
-
parser.add_argument("--vocab_size", type=int, default=
|
44 |
args = vars(parser.parse_args())
|
45 |
|
46 |
# Parse string arguments.
|
@@ -67,7 +69,6 @@ def train(config: OsSoluConfig, model: OsSoluModel, train_dataloader: DataLoader
|
|
67 |
Returns:
|
68 |
OsSoluModel: The trained model.
|
69 |
"""
|
70 |
-
# TODO: training loop
|
71 |
train_loss_fn = t.nn.CrossEntropyLoss()
|
72 |
wandb.watch(model, criterion=train_loss_fn, log="all", log_freq=10, log_graph=True)
|
73 |
|
@@ -77,16 +78,17 @@ def train(config: OsSoluConfig, model: OsSoluModel, train_dataloader: DataLoader
|
|
77 |
|
78 |
# Train loop.
|
79 |
examples_seen = 0
|
|
|
80 |
for epoch in range(config.num_epochs):
|
81 |
-
for i,
|
82 |
-
|
|
|
83 |
data = data.to(DEVICE)
|
84 |
-
target = target.to(DEVICE)
|
85 |
|
86 |
predictions = model(data)
|
87 |
accuracy = (predictions.argmax(dim=-1) == target).sum() / len(data)
|
88 |
optimiser.zero_grad()
|
89 |
-
loss = train_loss_fn(
|
90 |
loss.backward()
|
91 |
optimiser.step()
|
92 |
|
@@ -109,9 +111,10 @@ def eval(model: OsSoluModel, test_dataloader: DataLoader) -> None:
|
|
109 |
total_loss, num_correct = 0, 0
|
110 |
model.eval()
|
111 |
with t.inference_mode():
|
112 |
-
|
|
|
|
|
113 |
data = data.to(DEVICE)
|
114 |
-
target = target.to(DEVICE)
|
115 |
|
116 |
predictions = model(data)
|
117 |
num_correct += (predictions.argmax(dim=-1) == target).sum().item()
|
@@ -134,15 +137,31 @@ def setup() -> Tuple[OsSoluConfig, OsSoluModel]:
|
|
134 |
args = parse_arguments()
|
135 |
wandb.init(project=WANDB_PROJECT_NAME, config=args)
|
136 |
config = OsSoluConfig(args)
|
137 |
-
model = OsSoluModel(config)
|
138 |
|
|
|
139 |
# Load and prep data.
|
140 |
ds = load_dataset("the_pile", streaming=True)
|
141 |
-
train_dataset = ds["train"].with_format("torch")
|
142 |
-
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size)
|
143 |
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size)
|
|
|
|
|
146 |
return config, model, (train_dataloader, test_dataloader)
|
147 |
|
148 |
if __name__=="__main__":
|
|
|
1 |
import argparse
|
2 |
+
import time
|
3 |
import torch as t
|
4 |
import torch.nn as nn
|
5 |
import torch.functional as F
|
|
|
10 |
from typing import Tuple
|
11 |
from torch.utils.data.dataloader import DataLoader
|
12 |
from datasets import load_dataset
|
13 |
+
from transformers import AutoTokenizer
|
14 |
+
from utils import OsSoluConfig, tokenise
|
15 |
from model import OsSoluModel
|
16 |
|
17 |
WANDB_PROJECT_NAME = "os_solu"
|
|
|
34 |
parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
|
35 |
parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
|
36 |
parser.add_argument("--ln_eps", type=float, default=1e-3, help="Layer norm epsilon.")
|
37 |
+
parser.add_argument("--max_positional_embeddings", type=int, default=1024, help="Maximum number of positional embeddings/sequence length.")
|
38 |
parser.add_argument("--nonlinearity", type=str, default="solu", help=" Nonlinearity to use inside MLP block: must be relu or solu.")
|
39 |
parser.add_argument("--num_blocks", type=int, default=1, help="Number of transformer blocks.")
|
40 |
parser.add_argument("--num_embeddings", type=int, default=1024, help="Number of embeddings.")
|
|
|
42 |
parser.add_argument("--num_heads", type=int, default=4, help="Number of attention heads in each attention layer.")
|
43 |
parser.add_argument("--optimiser_type", type=str, default="adam", help="Optimiser type.")
|
44 |
parser.add_argument("--self_attention_type", type=str, default="unidirectional", help="What type of attention to use: rotary or unidirectional.")
|
45 |
+
parser.add_argument("--vocab_size", type=int, default=50_278, help="Vocabulary size of the input sequence.")
|
46 |
args = vars(parser.parse_args())
|
47 |
|
48 |
# Parse string arguments.
|
|
|
69 |
Returns:
|
70 |
OsSoluModel: The trained model.
|
71 |
"""
|
|
|
72 |
train_loss_fn = t.nn.CrossEntropyLoss()
|
73 |
wandb.watch(model, criterion=train_loss_fn, log="all", log_freq=10, log_graph=True)
|
74 |
|
|
|
78 |
|
79 |
# Train loop.
|
80 |
examples_seen = 0
|
81 |
+
train_data_iterator = iter(train_dataloader)
|
82 |
for epoch in range(config.num_epochs):
|
83 |
+
for i, batch in enumerate(tqdm(train_data_iterator
|
84 |
+
)):
|
85 |
+
data = batch["text"]
|
86 |
data = data.to(DEVICE)
|
|
|
87 |
|
88 |
predictions = model(data)
|
89 |
accuracy = (predictions.argmax(dim=-1) == target).sum() / len(data)
|
90 |
optimiser.zero_grad()
|
91 |
+
# loss = train_loss_fn(data, predictions)
|
92 |
loss.backward()
|
93 |
optimiser.step()
|
94 |
|
|
|
111 |
total_loss, num_correct = 0, 0
|
112 |
model.eval()
|
113 |
with t.inference_mode():
|
114 |
+
test_data_iterator = iter(test_dataloader)
|
115 |
+
for i, (data, target) in enumerate(tqdm(test_data_iterator)):
|
116 |
+
data = batch["text"]
|
117 |
data = data.to(DEVICE)
|
|
|
118 |
|
119 |
predictions = model(data)
|
120 |
num_correct += (predictions.argmax(dim=-1) == target).sum().item()
|
|
|
137 |
args = parse_arguments()
|
138 |
wandb.init(project=WANDB_PROJECT_NAME, config=args)
|
139 |
config = OsSoluConfig(args)
|
140 |
+
model = OsSoluModel(config).to(DEVICE)
|
141 |
|
142 |
+
start_data_time = time.time()
|
143 |
# Load and prep data.
|
144 |
ds = load_dataset("the_pile", streaming=True)
|
|
|
|
|
145 |
|
146 |
+
try:
|
147 |
+
ds = ds.remove_columns("meta")
|
148 |
+
except:
|
149 |
+
print("Dataset did not contain 'meta' column.")
|
150 |
+
|
151 |
+
train_dataset = ds["train"]
|
152 |
+
test_dataset = ds["test"]
|
153 |
+
|
154 |
+
# TODO: tokenise the data before sending it to the model.
|
155 |
+
tokeniser = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
156 |
+
tokeniser.add_special_tokens({"pad_token": "<PAD>"})
|
157 |
+
|
158 |
+
train_dataset = train_dataset.map(lambda x: tokenise(x, tokeniser), batched=True).with_format("torch")
|
159 |
+
test_dataset = test_dataset.map(tokenise, batched=True).with_format("torch")
|
160 |
+
|
161 |
+
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size)
|
162 |
test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size)
|
163 |
+
print(f"Data loaded in {time.time() - start_data_time:.1f}s.")
|
164 |
+
|
165 |
return config, model, (train_dataloader, test_dataloader)
|
166 |
|
167 |
if __name__=="__main__":
|
model.py
CHANGED
@@ -3,7 +3,7 @@ import torch.nn as nn
|
|
3 |
import torch.functional as F
|
4 |
import torch.optim as optim
|
5 |
import wandb
|
6 |
-
|
7 |
from einops import rearrange, repeat, reduce
|
8 |
from utils import OsSoluConfig
|
9 |
|
@@ -22,7 +22,7 @@ class OsSoluModel(nn.Module):
|
|
22 |
self.final_ln = nn.LayerNorm(config.d_model, config.ln_eps)
|
23 |
|
24 |
def forward(self, x: t.Tensor) -> t.Tensor:
|
25 |
-
positional_embeddings = self.embed_positions(t.arange(x.size(1)))
|
26 |
token_embeddings = self.embed_tokens(x)
|
27 |
embeddings = positional_embeddings + token_embeddings
|
28 |
out = self.dropout(embeddings)
|
@@ -69,9 +69,9 @@ class UnidirectionalAttention(nn.Module):
|
|
69 |
super().__init__()
|
70 |
self.num_heads = config.num_heads
|
71 |
self.d_model = config.d_model
|
72 |
-
self.project_q = nn.Linear(config.
|
73 |
-
self.project_k = nn.Linear(config.
|
74 |
-
self.project_v = nn.Linear(config.
|
75 |
self.project_out = nn.Linear(config.d_model, config.d_model)
|
76 |
self.LARGE_NEGATIVE_VALUE = -1e5
|
77 |
|
@@ -84,7 +84,11 @@ class UnidirectionalAttention(nn.Module):
|
|
84 |
|
85 |
Q = self.hidden_to_heads(Q)
|
86 |
K = self.hidden_to_heads(K)
|
87 |
-
attention_pattern = einsum(
|
|
|
|
|
|
|
|
|
88 |
|
89 |
return attention_pattern
|
90 |
|
@@ -95,18 +99,23 @@ class UnidirectionalAttention(nn.Module):
|
|
95 |
|
96 |
# Masking attention. Since GPT is unidirectional, it should only attend to previous tokens.
|
97 |
if seqlen > 1:
|
98 |
-
fst_range = t.arange(seqlen, device=
|
99 |
-
snd_range = t.arange(seqlen, device=
|
100 |
bool_array = fst_range < snd_range
|
101 |
-
|
102 |
|
103 |
|
104 |
attention_pattern = attention_pattern / t.sqrt(t.tensor(self.d_model // self.num_heads))
|
105 |
attention_score = attention_pattern.softmax(dim=-1)
|
106 |
|
107 |
V = self.hidden_to_heads(V)
|
108 |
-
out = einsum(
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
110 |
out = self.project_out(out)
|
111 |
|
112 |
|
|
|
3 |
import torch.functional as F
|
4 |
import torch.optim as optim
|
5 |
import wandb
|
6 |
+
from fancy_einsum import einsum
|
7 |
from einops import rearrange, repeat, reduce
|
8 |
from utils import OsSoluConfig
|
9 |
|
|
|
22 |
self.final_ln = nn.LayerNorm(config.d_model, config.ln_eps)
|
23 |
|
24 |
def forward(self, x: t.Tensor) -> t.Tensor:
|
25 |
+
positional_embeddings = self.embed_positions(t.arange(x.size(1), device=x.device))
|
26 |
token_embeddings = self.embed_tokens(x)
|
27 |
embeddings = positional_embeddings + token_embeddings
|
28 |
out = self.dropout(embeddings)
|
|
|
69 |
super().__init__()
|
70 |
self.num_heads = config.num_heads
|
71 |
self.d_model = config.d_model
|
72 |
+
self.project_q = nn.Linear(config.d_model, config.d_model)
|
73 |
+
self.project_k = nn.Linear(config.d_model, config.d_model)
|
74 |
+
self.project_v = nn.Linear(config.d_model, config.d_model)
|
75 |
self.project_out = nn.Linear(config.d_model, config.d_model)
|
76 |
self.LARGE_NEGATIVE_VALUE = -1e5
|
77 |
|
|
|
84 |
|
85 |
Q = self.hidden_to_heads(Q)
|
86 |
K = self.hidden_to_heads(K)
|
87 |
+
attention_pattern = einsum(
|
88 |
+
"batch num_heads seqlen_q head_size, "
|
89 |
+
"batch num_heads seqlen_k head_size ->"
|
90 |
+
"batch num_heads seqlen_q seqlen_k",
|
91 |
+
Q, K)
|
92 |
|
93 |
return attention_pattern
|
94 |
|
|
|
99 |
|
100 |
# Masking attention. Since GPT is unidirectional, it should only attend to previous tokens.
|
101 |
if seqlen > 1:
|
102 |
+
fst_range = t.arange(seqlen, device=x.device).unsqueeze(0).T
|
103 |
+
snd_range = t.arange(seqlen, device=x.device).unsqueeze(0)
|
104 |
bool_array = fst_range < snd_range
|
105 |
+
attention_pattern[..., bool_array] = self.LARGE_NEGATIVE_VALUE
|
106 |
|
107 |
|
108 |
attention_pattern = attention_pattern / t.sqrt(t.tensor(self.d_model // self.num_heads))
|
109 |
attention_score = attention_pattern.softmax(dim=-1)
|
110 |
|
111 |
V = self.hidden_to_heads(V)
|
112 |
+
out = einsum(
|
113 |
+
"batch num_heads seqlen_q seqlen_k,"
|
114 |
+
"batch num_heads seqlen_k head_size ->"
|
115 |
+
"batch num_heads seqlen_q head_size",
|
116 |
+
attention_score, V)
|
117 |
+
|
118 |
+
out = rearrange(out, "b nh s hs -> b s (nh hs)")
|
119 |
out = self.project_out(out)
|
120 |
|
121 |
|
requirements.txt
CHANGED
@@ -9,6 +9,7 @@ notebook
|
|
9 |
numpy-stl
|
10 |
plotly
|
11 |
torch
|
|
|
12 |
tqdm
|
13 |
wandb
|
14 |
zstandard
|
|
|
9 |
numpy-stl
|
10 |
plotly
|
11 |
torch
|
12 |
+
transformers
|
13 |
tqdm
|
14 |
wandb
|
15 |
zstandard
|
utils.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
class OsSoluConfig:
|
2 |
"""A class to hold hyperparameters for the model itself and for the training process."""
|
3 |
|
@@ -32,4 +35,42 @@ class OsSoluConfig:
|
|
32 |
self.num_heads = args["num_heads"]
|
33 |
self.optimiser_type = args["optimiser_type"]
|
34 |
self.self_attention_type = args["self_attention_type"]
|
35 |
-
self.vocab_size = args["vocab_size"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from einops import rearrange
|
3 |
+
|
4 |
class OsSoluConfig:
|
5 |
"""A class to hold hyperparameters for the model itself and for the training process."""
|
6 |
|
|
|
35 |
self.num_heads = args["num_heads"]
|
36 |
self.optimiser_type = args["optimiser_type"]
|
37 |
self.self_attention_type = args["self_attention_type"]
|
38 |
+
self.vocab_size = args["vocab_size"]
|
39 |
+
|
40 |
+
def tokenise(batch, tokeniser, num_gpus: int = 1, context_length: int = 1024):
|
41 |
+
"""Tokenise a batch of text data. This implementation is idiosyncratic to the Pile dataset, but can be easily modified to work with e.g. C4.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
batch (dict): The batch of text, as a dict with a 'text' field.
|
45 |
+
tokeniser (-): A huggingface-API tokeniser, of type returned by AutoTokenizer.from_pretrained (depends on model chosen).
|
46 |
+
num_gpus (int, optional): The number of GPUs available for data parallel training. Defaults to 1.
|
47 |
+
context_length (int, optional): The context length of the model that will be trained on this data. Defaults to 1024.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
dict: A single field dictionary, 'text', whose value is a tensor of shape (batch_size, sequence_length) containing tokenised sequences.
|
51 |
+
"""
|
52 |
+
batch = batch["text"]
|
53 |
+
full_text = tokeniser.eos_token.join(batch)
|
54 |
+
|
55 |
+
# Divide entire batch among all GPUs available.
|
56 |
+
seq_len = len(full_text)//num_gpus
|
57 |
+
sequence_list = [full_text[i*seq_len:(i+1)*seq_len] for i in range(num_gpus)]
|
58 |
+
|
59 |
+
# Tokenise sequences, removing padding tokens.
|
60 |
+
all_tokens = tokeniser(sequence_list, return_tensors="pt", padding=True)["input_ids"].flatten()
|
61 |
+
all_tokens = all_tokens[all_tokens != tokeniser.pad_token_id]
|
62 |
+
|
63 |
+
# Reshape all_tokens to be (batch_size x sequence_length) where each sequence has
|
64 |
+
# a "beginning of sequence" token prepended to it.
|
65 |
+
num_tokens = len(all_tokens)
|
66 |
+
current_batch_size = num_tokens // (context_length-1)
|
67 |
+
all_tokens = all_tokens[:(context_length-1)*current_batch_size]
|
68 |
+
all_tokens = rearrange(all_tokens, "(batch_size seq_len) -> batch_size seq_len", batch_size=current_batch_size, seq_len=context_length-1)
|
69 |
+
prefix = np.full((current_batch_size, 1), tokeniser.bos_token_id, dtype=np.int64)
|
70 |
+
|
71 |
+
tokenised_text = np.concatenate([prefix, all_tokens], axis=1)
|
72 |
+
assert tokenised_text.shape == (current_batch_size, context_length)
|
73 |
+
print(f"{current_batch_size=}, {context_length=}")
|
74 |
+
return {"text": tokenised_text}
|
75 |
+
|
76 |
+
|