os-solu / main.py
inwaves's picture
Fixed exampleccounting glitch
2c547b1
import argparse
import time
import torch as t
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
from tqdm import tqdm
import wandb
from typing import Tuple
from torch.utils.data.dataloader import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
from utils import OsSoluConfig, tokenise, loss_fn, count_parameters
from model import OsSoluModel
WANDB_PROJECT_NAME = "os_solu"
DEVICE = "cuda" if t.cuda.is_available() else "cpu"
# TODO: Add support for distributed training.
# TODO: Use only book data from dataset.
def parse_arguments() -> dict:
"""Parses command-line arguments for this model run. Arguments of type string have allowed values,
which are enforced. Default parameter values are provided such that fields in the config are never None.
Raises:
ValueError: optimiser type must be adam or sgd.
ValueError: attention type must be rotary or unidirectional.
Returns:
dict: a dictionary containing the command-line arguments parsed by this function.
"""
parser = argparse.ArgumentParser(description="Parse command-line arguments for this model.")
parser.add_argument("--batch_size", type=int, default=40, help="Batch size used in training.")
parser.add_argument("--checkpoint_every_n_tokens", type=int, default=500_000_000, help="Save a checkpoint of the model every n tokens processed.")
parser.add_argument("--d_model", type=int, default=512, help="Hidden size of the model.")
parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
parser.add_argument("--ln_eps", type=float, default=1e-3, help="Layer norm epsilon.")
parser.add_argument("--max_positional_embeddings", type=int, default=1024, help="Maximum number of positional embeddings/sequence length.")
parser.add_argument("--nonlinearity", type=str, default="solu", help=" Nonlinearity to use inside MLP block: must be relu or solu.")
parser.add_argument("--num_blocks", type=int, default=1, help="Number of transformer blocks.")
parser.add_argument("--num_embeddings", type=int, default=1024, help="Number of embeddings.")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs to run for.")
parser.add_argument("--num_heads", type=int, default=4, help="Number of attention heads in each attention layer.")
parser.add_argument("--optimiser_type", type=str, default="adam", help="Optimiser type.")
parser.add_argument("--self_attention_type", type=str, default="unidirectional", help="What type of attention to use: rotary or unidirectional.")
parser.add_argument("--vocab_size", type=int, default=50_278, help="Vocabulary size of the input sequence.")
args = vars(parser.parse_args())
# Parse string arguments.
allowed_values = {
"optimiser_type": ["adam", "sgd"],
"self_attention_type": ["unidirectional", "rotary"],
"nonlinearity": ["relu", "solu"],
}
for key, values in allowed_values.items():
if args[key] not in values:
raise ValueError(f"{key} should be one of {values}.")
return args
def train(config: OsSoluConfig, model: OsSoluModel, train_dataloader: DataLoader) -> OsSoluModel:
"""Trains a model using the config and training dataset provided.
Args:
config (OsSoluConfig): The config object.
model (OsSoluModel): The model to train.
train_dataloader (t.utils.data.DataLoader): The training dataset provided as a torch DataLoader object.
Returns:
OsSoluModel: The trained model.
"""
wandb.watch(model, criterion=loss_fn, log="all", log_freq=10, log_graph=True)
# Initialise optimiser.
opt = optim.Adam if config.optimiser_type.lower() == "adam" else optim.SGD
optimiser = opt(model.parameters(), lr=config.learning_rate)
# Train loop.
examples_seen = 0
train_data_iterator = iter(train_dataloader)
for epoch in range(config.num_epochs):
for i, batch in enumerate(tqdm(train_data_iterator
)):
start_time = time.time()
batch = batch["text"]
batch = batch.to(DEVICE)
logits = model(batch)
optimiser.zero_grad()
loss = loss_fn(logits, batch)
loss.backward()
optimiser.step()
wandb.log(dict(train_loss=loss, elapsed=time.time() - start_time), step=examples_seen)
# Number of tokens processed is batch_size * sequence_length.
examples_seen += batch.numel()
# Save a checkpoint of the model.
if examples_seen % config.checkpoint_every_n_tokens == 0:
# Save the model's state on disk, then upload to wandb.
filename = f"{wandb.run.dir}/os_solu_model_ckpt_step_{examples_seen}.pt"
t.save({
"step": examples_seen,
"model_state_dict": model.state_dict(),
"optimiser_state_dict": optimiser.state_dict(),
"loss": loss.item()
}, filename)
wandb.save(filename)
print(f"Checkpointing model at {examples_seen} tokens seen.")
return model
def eval(model: OsSoluModel, test_dataloader: DataLoader) -> None:
"""Evaluates a trained model on the test dataset provided.
Args:
model (OsSoluModel): The trained model.
test_dataset (t.utils.data.Dataset): The dataset on which to evaluate the model.
"""
test_loss_fn = t.nn.CrossEntropyLoss()
# Eval loop.
examples_seen = 0
total_loss, num_correct = 0, 0
model.eval()
with t.inference_mode():
test_data_iterator = iter(test_dataloader)
for i, batch in enumerate(tqdm(test_data_iterator)):
batch = batch["text"]
batch = batch.to(DEVICE)
logits = model(batch)
total_loss += loss_fn(logits, batch).item()
examples_seen += len(batch)
wandb.log(dict(test_loss=total_loss, elapsed=time.time() - start_time), step=examples_seen)
# Save the model's state on disk, then upload to wandb.
filename = f"{wandb.run.dir}/model_state_dict.pt"
t.save(model.state_dict(), filename)
wandb.save(filename)
def setup() -> Tuple[OsSoluConfig, OsSoluModel]:
"""This function delegates the setup to various helper functions.
Returns:
Tuple[OsSoluConfig, OsSoluModel, datasets.iterable_dataset.IterableDataset, datasets.iterable_dataset.IterableDataset]: A tuple containing a config, a model, a training dataset and a test dataset.
"""
args = parse_arguments()
config = OsSoluConfig(args)
model = OsSoluModel(config).to(DEVICE)
args["num_parameters"] = count_parameters(model)
wandb.init(project=WANDB_PROJECT_NAME, config=args)
start_data_time = time.time()
# Load and prep data.
ds = load_dataset("the_pile", streaming=True)
try:
ds = ds.remove_columns("meta")
except:
print("Dataset did not contain 'meta' column.")
train_dataset = ds["train"]
test_dataset = ds["test"]
tokeniser = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokeniser.add_special_tokens({"pad_token": "<PAD>"})
train_dataset = train_dataset.map(lambda x: tokenise(x, tokeniser, 1, config.max_positional_embeddings), batched=True).with_format("torch")
test_dataset = test_dataset.map(tokenise, batched=True).with_format("torch")
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size)
print(f"Data loaded in {time.time() - start_data_time:.1f}s.")
return config, model, (train_dataloader, test_dataloader)
if __name__=="__main__":
config, model, (train_dataloader, test_dataloader) = setup()
trained_model = train(config, model, train_dataloader)
eval(trained_model, test_dataloader)