|
import argparse |
|
import time |
|
import torch as t |
|
import torch.nn as nn |
|
import torch.functional as F |
|
import torch.optim as optim |
|
from tqdm import tqdm |
|
import wandb |
|
|
|
from typing import Tuple |
|
from torch.utils.data.dataloader import DataLoader |
|
from datasets import load_dataset |
|
from transformers import AutoTokenizer |
|
from utils import OsSoluConfig, tokenise, loss_fn, count_parameters |
|
from model import OsSoluModel |
|
|
|
WANDB_PROJECT_NAME = "os_solu" |
|
DEVICE = "cuda" if t.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
|
def parse_arguments() -> dict: |
|
"""Parses command-line arguments for this model run. Arguments of type string have allowed values, |
|
which are enforced. Default parameter values are provided such that fields in the config are never None. |
|
|
|
Raises: |
|
ValueError: optimiser type must be adam or sgd. |
|
ValueError: attention type must be rotary or unidirectional. |
|
|
|
Returns: |
|
dict: a dictionary containing the command-line arguments parsed by this function. |
|
""" |
|
parser = argparse.ArgumentParser(description="Parse command-line arguments for this model.") |
|
parser.add_argument("--batch_size", type=int, default=40, help="Batch size used in training.") |
|
parser.add_argument("--checkpoint_every_n_tokens", type=int, default=500_000_000, help="Save a checkpoint of the model every n tokens processed.") |
|
parser.add_argument("--d_model", type=int, default=512, help="Hidden size of the model.") |
|
parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.") |
|
parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.") |
|
parser.add_argument("--ln_eps", type=float, default=1e-3, help="Layer norm epsilon.") |
|
parser.add_argument("--max_positional_embeddings", type=int, default=1024, help="Maximum number of positional embeddings/sequence length.") |
|
parser.add_argument("--nonlinearity", type=str, default="solu", help=" Nonlinearity to use inside MLP block: must be relu or solu.") |
|
parser.add_argument("--num_blocks", type=int, default=1, help="Number of transformer blocks.") |
|
parser.add_argument("--num_embeddings", type=int, default=1024, help="Number of embeddings.") |
|
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs to run for.") |
|
parser.add_argument("--num_heads", type=int, default=4, help="Number of attention heads in each attention layer.") |
|
parser.add_argument("--optimiser_type", type=str, default="adam", help="Optimiser type.") |
|
parser.add_argument("--self_attention_type", type=str, default="unidirectional", help="What type of attention to use: rotary or unidirectional.") |
|
parser.add_argument("--vocab_size", type=int, default=50_278, help="Vocabulary size of the input sequence.") |
|
args = vars(parser.parse_args()) |
|
|
|
|
|
allowed_values = { |
|
"optimiser_type": ["adam", "sgd"], |
|
"self_attention_type": ["unidirectional", "rotary"], |
|
"nonlinearity": ["relu", "solu"], |
|
} |
|
|
|
for key, values in allowed_values.items(): |
|
if args[key] not in values: |
|
raise ValueError(f"{key} should be one of {values}.") |
|
|
|
return args |
|
|
|
def train(config: OsSoluConfig, model: OsSoluModel, train_dataloader: DataLoader) -> OsSoluModel: |
|
"""Trains a model using the config and training dataset provided. |
|
|
|
Args: |
|
config (OsSoluConfig): The config object. |
|
model (OsSoluModel): The model to train. |
|
train_dataloader (t.utils.data.DataLoader): The training dataset provided as a torch DataLoader object. |
|
|
|
Returns: |
|
OsSoluModel: The trained model. |
|
""" |
|
wandb.watch(model, criterion=loss_fn, log="all", log_freq=10, log_graph=True) |
|
|
|
|
|
opt = optim.Adam if config.optimiser_type.lower() == "adam" else optim.SGD |
|
optimiser = opt(model.parameters(), lr=config.learning_rate) |
|
|
|
|
|
examples_seen = 0 |
|
train_data_iterator = iter(train_dataloader) |
|
for epoch in range(config.num_epochs): |
|
for i, batch in enumerate(tqdm(train_data_iterator |
|
)): |
|
start_time = time.time() |
|
batch = batch["text"] |
|
batch = batch.to(DEVICE) |
|
|
|
logits = model(batch) |
|
optimiser.zero_grad() |
|
loss = loss_fn(logits, batch) |
|
loss.backward() |
|
optimiser.step() |
|
|
|
wandb.log(dict(train_loss=loss, elapsed=time.time() - start_time), step=examples_seen) |
|
|
|
|
|
examples_seen += batch.numel() |
|
|
|
|
|
if examples_seen % config.checkpoint_every_n_tokens == 0: |
|
|
|
filename = f"{wandb.run.dir}/os_solu_model_ckpt_step_{examples_seen}.pt" |
|
t.save({ |
|
"step": examples_seen, |
|
"model_state_dict": model.state_dict(), |
|
"optimiser_state_dict": optimiser.state_dict(), |
|
"loss": loss.item() |
|
}, filename) |
|
wandb.save(filename) |
|
print(f"Checkpointing model at {examples_seen} tokens seen.") |
|
|
|
|
|
return model |
|
|
|
def eval(model: OsSoluModel, test_dataloader: DataLoader) -> None: |
|
"""Evaluates a trained model on the test dataset provided. |
|
|
|
Args: |
|
model (OsSoluModel): The trained model. |
|
test_dataset (t.utils.data.Dataset): The dataset on which to evaluate the model. |
|
""" |
|
test_loss_fn = t.nn.CrossEntropyLoss() |
|
|
|
|
|
examples_seen = 0 |
|
total_loss, num_correct = 0, 0 |
|
model.eval() |
|
with t.inference_mode(): |
|
test_data_iterator = iter(test_dataloader) |
|
for i, batch in enumerate(tqdm(test_data_iterator)): |
|
batch = batch["text"] |
|
batch = batch.to(DEVICE) |
|
|
|
logits = model(batch) |
|
total_loss += loss_fn(logits, batch).item() |
|
examples_seen += len(batch) |
|
wandb.log(dict(test_loss=total_loss, elapsed=time.time() - start_time), step=examples_seen) |
|
|
|
|
|
filename = f"{wandb.run.dir}/model_state_dict.pt" |
|
t.save(model.state_dict(), filename) |
|
wandb.save(filename) |
|
|
|
|
|
def setup() -> Tuple[OsSoluConfig, OsSoluModel]: |
|
"""This function delegates the setup to various helper functions. |
|
|
|
Returns: |
|
Tuple[OsSoluConfig, OsSoluModel, datasets.iterable_dataset.IterableDataset, datasets.iterable_dataset.IterableDataset]: A tuple containing a config, a model, a training dataset and a test dataset. |
|
""" |
|
args = parse_arguments() |
|
config = OsSoluConfig(args) |
|
model = OsSoluModel(config).to(DEVICE) |
|
args["num_parameters"] = count_parameters(model) |
|
wandb.init(project=WANDB_PROJECT_NAME, config=args) |
|
|
|
start_data_time = time.time() |
|
|
|
ds = load_dataset("the_pile", streaming=True) |
|
|
|
try: |
|
ds = ds.remove_columns("meta") |
|
except: |
|
print("Dataset did not contain 'meta' column.") |
|
|
|
train_dataset = ds["train"] |
|
test_dataset = ds["test"] |
|
|
|
tokeniser = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") |
|
tokeniser.add_special_tokens({"pad_token": "<PAD>"}) |
|
|
|
train_dataset = train_dataset.map(lambda x: tokenise(x, tokeniser, 1, config.max_positional_embeddings), batched=True).with_format("torch") |
|
test_dataset = test_dataset.map(tokenise, batched=True).with_format("torch") |
|
|
|
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size) |
|
test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size) |
|
print(f"Data loaded in {time.time() - start_data_time:.1f}s.") |
|
|
|
return config, model, (train_dataloader, test_dataloader) |
|
|
|
if __name__=="__main__": |
|
config, model, (train_dataloader, test_dataloader) = setup() |
|
trained_model = train(config, model, train_dataloader) |
|
eval(trained_model, test_dataloader) |