Fixed exampleccounting glitch
Browse files
main.py
CHANGED
@@ -33,7 +33,7 @@ def parse_arguments() -> dict:
|
|
33 |
"""
|
34 |
parser = argparse.ArgumentParser(description="Parse command-line arguments for this model.")
|
35 |
parser.add_argument("--batch_size", type=int, default=40, help="Batch size used in training.")
|
36 |
-
parser.add_argument("--checkpoint_every_n_tokens", type=int, default=
|
37 |
parser.add_argument("--d_model", type=int, default=512, help="Hidden size of the model.")
|
38 |
parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
|
39 |
parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
|
@@ -96,7 +96,9 @@ def train(config: OsSoluConfig, model: OsSoluModel, train_dataloader: DataLoader
|
|
96 |
optimiser.step()
|
97 |
|
98 |
wandb.log(dict(train_loss=loss, elapsed=time.time() - start_time), step=examples_seen)
|
99 |
-
|
|
|
|
|
100 |
|
101 |
# Save a checkpoint of the model.
|
102 |
if examples_seen % config.checkpoint_every_n_tokens == 0:
|
@@ -168,11 +170,10 @@ def setup() -> Tuple[OsSoluConfig, OsSoluModel]:
|
|
168 |
train_dataset = ds["train"]
|
169 |
test_dataset = ds["test"]
|
170 |
|
171 |
-
# TODO: tokenise the data before sending it to the model.
|
172 |
tokeniser = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
173 |
tokeniser.add_special_tokens({"pad_token": "<PAD>"})
|
174 |
|
175 |
-
train_dataset = train_dataset.map(lambda x: tokenise(x, tokeniser), batched=True).with_format("torch")
|
176 |
test_dataset = test_dataset.map(tokenise, batched=True).with_format("torch")
|
177 |
|
178 |
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size)
|
|
|
33 |
"""
|
34 |
parser = argparse.ArgumentParser(description="Parse command-line arguments for this model.")
|
35 |
parser.add_argument("--batch_size", type=int, default=40, help="Batch size used in training.")
|
36 |
+
parser.add_argument("--checkpoint_every_n_tokens", type=int, default=500_000_000, help="Save a checkpoint of the model every n tokens processed.")
|
37 |
parser.add_argument("--d_model", type=int, default=512, help="Hidden size of the model.")
|
38 |
parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
|
39 |
parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
|
|
|
96 |
optimiser.step()
|
97 |
|
98 |
wandb.log(dict(train_loss=loss, elapsed=time.time() - start_time), step=examples_seen)
|
99 |
+
|
100 |
+
# Number of tokens processed is batch_size * sequence_length.
|
101 |
+
examples_seen += batch.numel()
|
102 |
|
103 |
# Save a checkpoint of the model.
|
104 |
if examples_seen % config.checkpoint_every_n_tokens == 0:
|
|
|
170 |
train_dataset = ds["train"]
|
171 |
test_dataset = ds["test"]
|
172 |
|
|
|
173 |
tokeniser = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
174 |
tokeniser.add_special_tokens({"pad_token": "<PAD>"})
|
175 |
|
176 |
+
train_dataset = train_dataset.map(lambda x: tokenise(x, tokeniser, 1, config.max_positional_embeddings), batched=True).with_format("torch")
|
177 |
test_dataset = test_dataset.map(tokenise, batched=True).with_format("torch")
|
178 |
|
179 |
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size)
|
utils.py
CHANGED
@@ -42,7 +42,7 @@ class OsSoluConfig:
|
|
42 |
self.self_attention_type = args["self_attention_type"]
|
43 |
self.vocab_size = args["vocab_size"]
|
44 |
|
45 |
-
def tokenise(batch, tokeniser, num_gpus: int
|
46 |
"""Tokenise a batch of text data. This implementation is idiosyncratic to the Pile dataset, but can be easily modified to work with e.g. C4. Code from Neel.
|
47 |
|
48 |
Args:
|
|
|
42 |
self.self_attention_type = args["self_attention_type"]
|
43 |
self.vocab_size = args["vocab_size"]
|
44 |
|
45 |
+
def tokenise(batch, tokeniser, num_gpus: int, context_length: int):
|
46 |
"""Tokenise a batch of text data. This implementation is idiosyncratic to the Pile dataset, but can be easily modified to work with e.g. C4. Code from Neel.
|
47 |
|
48 |
Args:
|