inwaves commited on
Commit
1bcfe48
1 Parent(s): d97c361

First successful run, added checkpoints

Browse files
Files changed (3) hide show
  1. main.py +39 -22
  2. model.py +11 -1
  3. utils.py +26 -3
main.py CHANGED
@@ -11,12 +11,15 @@ from typing import Tuple
11
  from torch.utils.data.dataloader import DataLoader
12
  from datasets import load_dataset
13
  from transformers import AutoTokenizer
14
- from utils import OsSoluConfig, tokenise
15
  from model import OsSoluModel
16
 
17
  WANDB_PROJECT_NAME = "os_solu"
18
  DEVICE = "cuda" if t.cuda.is_available() else "cpu"
19
 
 
 
 
20
  def parse_arguments() -> dict:
21
  """Parses command-line arguments for this model run. Arguments of type string have allowed values,
22
  which are enforced. Default parameter values are provided such that fields in the config are never None.
@@ -29,7 +32,8 @@ def parse_arguments() -> dict:
29
  dict: a dictionary containing the command-line arguments parsed by this function.
30
  """
31
  parser = argparse.ArgumentParser(description="Parse command-line arguments for this model.")
32
- parser.add_argument("--batch_size", type=int, default=256, help="Batch size used in training.")
 
33
  parser.add_argument("--d_model", type=int, default=512, help="Hidden size of the model.")
34
  parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
35
  parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
@@ -38,7 +42,7 @@ def parse_arguments() -> dict:
38
  parser.add_argument("--nonlinearity", type=str, default="solu", help=" Nonlinearity to use inside MLP block: must be relu or solu.")
39
  parser.add_argument("--num_blocks", type=int, default=1, help="Number of transformer blocks.")
40
  parser.add_argument("--num_embeddings", type=int, default=1024, help="Number of embeddings.")
41
- parser.add_argument("--num_epochs", type=int, default=5, help="Number of epochs to run for.")
42
  parser.add_argument("--num_heads", type=int, default=4, help="Number of attention heads in each attention layer.")
43
  parser.add_argument("--optimiser_type", type=str, default="adam", help="Optimiser type.")
44
  parser.add_argument("--self_attention_type", type=str, default="unidirectional", help="What type of attention to use: rotary or unidirectional.")
@@ -69,8 +73,7 @@ def train(config: OsSoluConfig, model: OsSoluModel, train_dataloader: DataLoader
69
  Returns:
70
  OsSoluModel: The trained model.
71
  """
72
- train_loss_fn = t.nn.CrossEntropyLoss()
73
- wandb.watch(model, criterion=train_loss_fn, log="all", log_freq=10, log_graph=True)
74
 
75
  # Initialise optimiser.
76
  opt = optim.Adam if config.optimiser_type.lower() == "adam" else optim.SGD
@@ -82,18 +85,32 @@ def train(config: OsSoluConfig, model: OsSoluModel, train_dataloader: DataLoader
82
  for epoch in range(config.num_epochs):
83
  for i, batch in enumerate(tqdm(train_data_iterator
84
  )):
85
- data = batch["text"]
86
- data = data.to(DEVICE)
 
87
 
88
- predictions = model(data)
89
- accuracy = (predictions.argmax(dim=-1) == target).sum() / len(data)
90
  optimiser.zero_grad()
91
- # loss = train_loss_fn(data, predictions)
92
  loss.backward()
93
  optimiser.step()
94
 
95
- wandb.log(dict(train_loss=loss, train_accuracy=accuracy, elapsed=time.time() - start_time), step=examples_seen)
96
- examples_seen += len(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  return model
99
 
@@ -112,15 +129,14 @@ def eval(model: OsSoluModel, test_dataloader: DataLoader) -> None:
112
  model.eval()
113
  with t.inference_mode():
114
  test_data_iterator = iter(test_dataloader)
115
- for i, (data, target) in enumerate(tqdm(test_data_iterator)):
116
- data = batch["text"]
117
- data = data.to(DEVICE)
118
-
119
- predictions = model(data)
120
- num_correct += (predictions.argmax(dim=-1) == target).sum().item()
121
- total_loss += test_loss_fn(target, predictions).item()
122
- examples_seen += len(data)
123
- wandb.log(dict(test_loss=total_loss, test_accuracy=num_correct / examples_seen, elapsed=time.time() - start_time), step=examples_seen)
124
 
125
  # Save the model's state on disk, then upload to wandb.
126
  filename = f"{wandb.run.dir}/model_state_dict.pt"
@@ -135,9 +151,10 @@ def setup() -> Tuple[OsSoluConfig, OsSoluModel]:
135
  Tuple[OsSoluConfig, OsSoluModel, datasets.iterable_dataset.IterableDataset, datasets.iterable_dataset.IterableDataset]: A tuple containing a config, a model, a training dataset and a test dataset.
136
  """
137
  args = parse_arguments()
138
- wandb.init(project=WANDB_PROJECT_NAME, config=args)
139
  config = OsSoluConfig(args)
140
  model = OsSoluModel(config).to(DEVICE)
 
 
141
 
142
  start_data_time = time.time()
143
  # Load and prep data.
 
11
  from torch.utils.data.dataloader import DataLoader
12
  from datasets import load_dataset
13
  from transformers import AutoTokenizer
14
+ from utils import OsSoluConfig, tokenise, loss_fn, count_parameters
15
  from model import OsSoluModel
16
 
17
  WANDB_PROJECT_NAME = "os_solu"
18
  DEVICE = "cuda" if t.cuda.is_available() else "cpu"
19
 
20
+ # TODO: Add support for distributed training.
21
+ # TODO: Use only book data from dataset.
22
+
23
  def parse_arguments() -> dict:
24
  """Parses command-line arguments for this model run. Arguments of type string have allowed values,
25
  which are enforced. Default parameter values are provided such that fields in the config are never None.
 
32
  dict: a dictionary containing the command-line arguments parsed by this function.
33
  """
34
  parser = argparse.ArgumentParser(description="Parse command-line arguments for this model.")
35
+ parser.add_argument("--batch_size", type=int, default=40, help="Batch size used in training.")
36
+ parser.add_argument("--checkpoint_every_n_tokens", type=int, default=50_000, help="Save a checkpoint of the model every n tokens processed.")
37
  parser.add_argument("--d_model", type=int, default=512, help="Hidden size of the model.")
38
  parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
39
  parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
 
42
  parser.add_argument("--nonlinearity", type=str, default="solu", help=" Nonlinearity to use inside MLP block: must be relu or solu.")
43
  parser.add_argument("--num_blocks", type=int, default=1, help="Number of transformer blocks.")
44
  parser.add_argument("--num_embeddings", type=int, default=1024, help="Number of embeddings.")
45
+ parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs to run for.")
46
  parser.add_argument("--num_heads", type=int, default=4, help="Number of attention heads in each attention layer.")
47
  parser.add_argument("--optimiser_type", type=str, default="adam", help="Optimiser type.")
48
  parser.add_argument("--self_attention_type", type=str, default="unidirectional", help="What type of attention to use: rotary or unidirectional.")
 
73
  Returns:
74
  OsSoluModel: The trained model.
75
  """
76
+ wandb.watch(model, criterion=loss_fn, log="all", log_freq=10, log_graph=True)
 
77
 
78
  # Initialise optimiser.
79
  opt = optim.Adam if config.optimiser_type.lower() == "adam" else optim.SGD
 
85
  for epoch in range(config.num_epochs):
86
  for i, batch in enumerate(tqdm(train_data_iterator
87
  )):
88
+ start_time = time.time()
89
+ batch = batch["text"]
90
+ batch = batch.to(DEVICE)
91
 
92
+ logits = model(batch)
 
93
  optimiser.zero_grad()
94
+ loss = loss_fn(logits, batch)
95
  loss.backward()
96
  optimiser.step()
97
 
98
+ wandb.log(dict(train_loss=loss, elapsed=time.time() - start_time), step=examples_seen)
99
+ examples_seen += len(batch)
100
+
101
+ # Save a checkpoint of the model.
102
+ if examples_seen % config.checkpoint_every_n_tokens == 0:
103
+ # Save the model's state on disk, then upload to wandb.
104
+ filename = f"{wandb.run.dir}/os_solu_model_ckpt_step_{examples_seen}.pt"
105
+ t.save({
106
+ "step": examples_seen,
107
+ "model_state_dict": model.state_dict(),
108
+ "optimiser_state_dict": optimiser.state_dict(),
109
+ "loss": loss.item()
110
+ }, filename)
111
+ wandb.save(filename)
112
+ print(f"Checkpointing model at {examples_seen} tokens seen.")
113
+
114
 
115
  return model
116
 
 
129
  model.eval()
130
  with t.inference_mode():
131
  test_data_iterator = iter(test_dataloader)
132
+ for i, batch in enumerate(tqdm(test_data_iterator)):
133
+ batch = batch["text"]
134
+ batch = batch.to(DEVICE)
135
+
136
+ logits = model(batch)
137
+ total_loss += loss_fn(logits, batch).item()
138
+ examples_seen += len(batch)
139
+ wandb.log(dict(test_loss=total_loss, elapsed=time.time() - start_time), step=examples_seen)
 
140
 
141
  # Save the model's state on disk, then upload to wandb.
142
  filename = f"{wandb.run.dir}/model_state_dict.pt"
 
151
  Tuple[OsSoluConfig, OsSoluModel, datasets.iterable_dataset.IterableDataset, datasets.iterable_dataset.IterableDataset]: A tuple containing a config, a model, a training dataset and a test dataset.
152
  """
153
  args = parse_arguments()
 
154
  config = OsSoluConfig(args)
155
  model = OsSoluModel(config).to(DEVICE)
156
+ args["num_parameters"] = count_parameters(model)
157
+ wandb.init(project=WANDB_PROJECT_NAME, config=args)
158
 
159
  start_data_time = time.time()
160
  # Load and prep data.
model.py CHANGED
@@ -7,7 +7,8 @@ from fancy_einsum import einsum
7
  from einops import rearrange, repeat, reduce
8
  from utils import OsSoluConfig
9
 
10
-
 
11
 
12
  class OsSoluModel(nn.Module):
13
  """An open-source implementation of a SoLU-based transformer. This is a GPT-style architecture model
@@ -128,4 +129,13 @@ class RotaryAttention(nn.Module):
128
 
129
  def forward(self, x: t.Tensor) -> t.Tensor:
130
  # TODO: implement rotary self-attention
 
 
 
 
 
 
 
 
 
131
  pass
 
7
  from einops import rearrange, repeat, reduce
8
  from utils import OsSoluConfig
9
 
10
+ # TODO: Add hooks to the model.
11
+ # TODO: Add support for mixing dense and sparse attention.
12
 
13
  class OsSoluModel(nn.Module):
14
  """An open-source implementation of a SoLU-based transformer. This is a GPT-style architecture model
 
129
 
130
  def forward(self, x: t.Tensor) -> t.Tensor:
131
  # TODO: implement rotary self-attention
132
+ pass
133
+
134
+ class LayerNorm(nn.Module):
135
+ def __init__(self, config: OsSoluConfig) -> None:
136
+ super().__init__()
137
+ self.config = config
138
+
139
+ def forward(self, x: t.Tensor) -> t.Tensor:
140
+ # TODO: implement layernorm with hooks on normalisation only.
141
  pass
utils.py CHANGED
@@ -1,10 +1,14 @@
1
  import numpy as np
 
 
2
  from einops import rearrange
3
 
 
4
  class OsSoluConfig:
5
  """A class to hold hyperparameters for the model itself and for the training process."""
6
 
7
  batch_size: int # Training data batch size.
 
8
  d_model: int # Hidden size of the model.
9
  dropout: float # Probability of dropout.
10
  learning_rate: float # Learning rate for the optimiser.
@@ -23,6 +27,7 @@ class OsSoluConfig:
23
  """Initialise this config class with values provided by a command-line argument parser.
24
  Values are never None here, as we provide suitable defaults in the parser call."""
25
  self.batch_size = args["batch_size"]
 
26
  self.d_model = args["d_model"]
27
  self.dropout = args["dropout"]
28
  self.learning_rate = args["learning_rate"]
@@ -38,7 +43,7 @@ class OsSoluConfig:
38
  self.vocab_size = args["vocab_size"]
39
 
40
  def tokenise(batch, tokeniser, num_gpus: int = 1, context_length: int = 1024):
41
- """Tokenise a batch of text data. This implementation is idiosyncratic to the Pile dataset, but can be easily modified to work with e.g. C4.
42
 
43
  Args:
44
  batch (dict): The batch of text, as a dict with a 'text' field.
@@ -70,7 +75,25 @@ def tokenise(batch, tokeniser, num_gpus: int = 1, context_length: int = 1024):
70
 
71
  tokenised_text = np.concatenate([prefix, all_tokens], axis=1)
72
  assert tokenised_text.shape == (current_batch_size, context_length)
73
- print(f"{current_batch_size=}, {context_length=}")
74
  return {"text": tokenised_text}
75
 
76
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
+ import torch as t
3
+ import torch.nn.functional as F
4
  from einops import rearrange
5
 
6
+ # TODO: Add functionality to load this from a config file as an alternative to command-line args.
7
  class OsSoluConfig:
8
  """A class to hold hyperparameters for the model itself and for the training process."""
9
 
10
  batch_size: int # Training data batch size.
11
+ checkpoint_every_n_tokens: int # Save a checkpoint of the model every n tokens processed.
12
  d_model: int # Hidden size of the model.
13
  dropout: float # Probability of dropout.
14
  learning_rate: float # Learning rate for the optimiser.
 
27
  """Initialise this config class with values provided by a command-line argument parser.
28
  Values are never None here, as we provide suitable defaults in the parser call."""
29
  self.batch_size = args["batch_size"]
30
+ self.checkpoint_every_n_tokens = args["checkpoint_every_n_tokens"]
31
  self.d_model = args["d_model"]
32
  self.dropout = args["dropout"]
33
  self.learning_rate = args["learning_rate"]
 
43
  self.vocab_size = args["vocab_size"]
44
 
45
  def tokenise(batch, tokeniser, num_gpus: int = 1, context_length: int = 1024):
46
+ """Tokenise a batch of text data. This implementation is idiosyncratic to the Pile dataset, but can be easily modified to work with e.g. C4. Code from Neel.
47
 
48
  Args:
49
  batch (dict): The batch of text, as a dict with a 'text' field.
 
75
 
76
  tokenised_text = np.concatenate([prefix, all_tokens], axis=1)
77
  assert tokenised_text.shape == (current_batch_size, context_length)
 
78
  return {"text": tokenised_text}
79
 
80
+ def loss_fn(logits, batch):
81
+ """Loss function to train an autoregressive model. It compares the token logits predicted by the model with the actual next token. Code from Neel.
82
+
83
+ Args:
84
+ logits (t.Tensor): A tensor containing logits, has shape (batch_size, sequence_length, vocab_size)
85
+ batch (t.Tensor): A tensor containing token IDs, has shape (batch_size, sequence_length, vocab_size)
86
+
87
+ Returns:
88
+ loss (t.Tensor): A tensor containing the loss value.
89
+ """
90
+
91
+ # Log-softmax to get log-probabilities.
92
+ log_probs = F.log_softmax(logits[:, :-1], dim=-1)
93
+
94
+ # Match up the probabilities of the actual words.
95
+ pred_log_probs = t.gather(log_probs, -1, batch[:, 1:, None])[..., 0]
96
+ return -pred_log_probs.mean()
97
+
98
+ def count_parameters(model):
99
+ return sum(parameter.numel() for parameter in model.parameters() if parameter.requires_grad)