train-llama / train.py
nroggendorff's picture
Update train.py
c62bc4a verified
raw
history blame
12 kB
import os
from sys import exit
import torch
import trl
from transformers import (
AutoTokenizer, LlamaConfig, AutoModelForCausalLM, LlamaForCausalLM,
TrainingArguments, PreTrainedTokenizerFast, AdamW, get_cosine_schedule_with_warmup
)
from datasets import load_dataset, Dataset
from tokenizers import ByteLevelBPETokenizer
from huggingface_hub import HfApi
from trl import SFTConfig, SFTTrainer
from torch.utils.data import DataLoader
from itertools import islice
class Config:
def __init__(self):
# Model and training hyperparameters
self.BATCH_SIZE = 16
self.EPOCHS = 3
self.LEARNING_RATE = 2e-4
self.MAX_SEQ_LENGTH = 512
self.VOCAB_SIZE = 32000
self.FP16 = True
self.WEIGHT_DECAY = 1e-3
self.GRADIENT_ACCUMULATION_STEPS = self.BATCH_SIZE // 4
# Dataset configurations
self.INPUT_DATASET = "HuggingFaceTB/smollm-corpus"
self.INSTRUCT_DATASET = "nroggendorff/elephant"
self.SHARD_SIZE = int(2e+5)
# Output and repo settings
self.OUTPUT_REPO = "nroggendorff/smallama"
self.PUSH_TO_HUB = True
self.INSTRUCT_FINETUNE_BOOL = False
# Training steps and warmup
self.FACTOR = 12 ** 3 // 2
self.TOTAL_STEPS = (self.SHARD_SIZE * self.EPOCHS) // (self.BATCH_SIZE * self.GRADIENT_ACCUMULATION_STEPS)
self.WARMUP_STEPS = int(self.TOTAL_STEPS * 0.1)
# Initial state for shard offset
self.INIT = 0
# ignore
self.getConfig = lambda: self._args()
# @staticmethod
def _args(self):
return SFTConfig(
output_dir="model",
num_train_epochs=self.EPOCHS,
per_device_train_batch_size=self.BATCH_SIZE,
learning_rate=self.LEARNING_RATE,
warmup_steps=self.WARMUP_STEPS,
weight_decay=self.WEIGHT_DECAY,
gradient_accumulation_steps=self.GRADIENT_ACCUMULATION_STEPS,
fp16=self.FP16,
save_steps=int(self.WARMUP_STEPS * 5),
logging_steps=int(self.WARMUP_STEPS),
save_total_limit=2,
report_to="none",
)
config = Config()
class Space:
def __init__(self):
self.api = HfApi()
self.pause = lambda: self.api.pause_space("nroggendorff/train-llama")
class FineError(Exception):
def __init__(self, message="Script execution has completed."):
self.message = message
super().__init__(self.message)
def load_data():
if not config.INSTRUCT_FINETUNE_BOOL:
dataset = load_dataset(config.INPUT_DATASET, "cosmopedia-v2", split="train", streaming=True)
else:
dataset = load_dataset(config.INSTRUCT_DATASET, split="train", streaming=True)
start = config.INIT * config.SHARD_SIZE
data_list = list(islice(dataset, start, start + config.SHARD_SIZE))
dataset = Dataset.from_dict({'text': [example['text'] for example in data_list]})
return dataset
def encode_decode(texts, tok):
if tok.pad_token is None:
tok.pad_token = tok.eos_token
tokenized_texts = tok(
texts,
padding="max_length",
truncation=True,
max_length=config.MAX_SEQ_LENGTH,
return_tensors="pt"
).input_ids
if tokenized_texts.dim() >= 1:
decoded_texts = tok.batch_decode(tokenized_texts)
else:
print('Found invalid entry in examples. Returning dummy..')
decoded_texts = [tok.pad_token * config.MAX_SEQ_LENGTH]
islist = not len(decoded_texts) == 1
return decoded_texts if islist else decoded_texts[0]
def create_tokenizer(training_corpus):
tokenizer = ByteLevelBPETokenizer()
special_tokens = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
tokenizer.train_from_iterator(
training_corpus,
vocab_size=config.VOCAB_SIZE,
min_frequency=2,
special_tokens=special_tokens
)
fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)
return fast_tokenizer
def load_tokenizer():
return AutoTokenizer.from_pretrained(config.OUTPUT_REPO + '-it' if config.INSTRUCT_FINETUNE_BOOL else config.OUTPUT_REPO)
def get_training_corpus(dataset):
for i in range(0, len(dataset['text']), 1000):
yield dataset['text'][i : i + 1000]
def format_prompts(examples, tokenizer, isinst):
texts = []
for text in examples['text']:
if text and len(text.strip()) > 0:
if isinst:
conversation = []
parts = text.split('<|end|>')
for i in range(0, len(parts) - 1, 2):
prompt = parts[i].replace("<|user|>", "").strip()
response = parts[i + 1].replace("<|bot|>", "").strip()
conversation.append({"role": "user", "content": prompt})
conversation.append({"role": "assistant", "content": response})
formatted_conversation = tokenizer.apply_chat_template(conversation, tokenize=False)
coded_text = tokenizer.code(formatted_conversation)
texts.append(coded_text)
else:
texts.append(tokenizer.bos_token + tokenizer.code(text) + tokenizer.eos_token)
else:
print('Found empty entry in examples. Moving on..')
continue
if len(texts) == 0:
raise ValueError("No valid texts found in examples for formatting.")
coded_texts = tokenizer.code(texts)
return {'text': coded_texts}
def create_model(tokenizer):
model_config = LlamaConfig(
vocab_size=tokenizer.vocab_size,
hidden_size=config.FACTOR,
intermediate_size=config.FACTOR * 4,
num_hidden_layers=config.FACTOR // 2 ** 4,
num_attention_heads=config.FACTOR // 2 ** 5,
max_position_embeddings=config.MAX_SEQ_LENGTH,
rms_norm_eps=1e-5,
initializer_range=2e-2,
use_cache=True,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
tie_word_embeddings=False,
)
return LlamaForCausalLM(model_config)
def load_model():
return AutoModelForCausalLM.from_pretrained(config.OUTPUT_REPO + '-it' if config.INSTRUCT_FINETUNE_BOOL else config.OUTPUT_REPO)
def configure_tokenizer(tokenizer):
special_tokens = {
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
"pad_token": "<pad>",
"mask_token": "<mask>",
"additional_special_tokens": []
}
if config.INSTRUCT_FINETUNE_BOOL:
special_tokens["additional_special_tokens"] = ["<|user|>", "<|bot|>", "<|end|>"]
tokenizer.add_special_tokens(special_tokens)
if config.INSTRUCT_FINETUNE_BOOL:
tokenizer.user_token_id = tokenizer.convert_tokens_to_ids("<|user|>")
tokenizer.assistant_token_id = tokenizer.convert_tokens_to_ids("<|bot|>")
chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '<|end|>\n' }}{% elif message['role'] == 'assistant' %}{{ '<|bot|>\n' + message['content'] + '<|end|>\n' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
tokenizer.chat_template = chat_template
tokenizer.code = lambda example: encode_decode(example, tokenizer)
def update_tokenizer(tokenizer, dataset, batch_size=1000):
existing_vocab = tokenizer.get_vocab()
oov_tokens = set()
for i in range(0, len(dataset['text']), batch_size):
batch = dataset['text'][i:i + batch_size]
for text in batch:
token_ids = tokenizer.encode(text, add_special_tokens=False)
for token_id in token_ids:
token = tokenizer.decode([token_id])
if token.strip() and token not in existing_vocab:
oov_tokens.add(token)
if oov_tokens:
num_added = tokenizer.add_tokens(list(oov_tokens))
return num_added
return 0
def train_model(model, tokenizer, dataset, push, isinst):
args = config.getConfig()
optimizer = AdamW(model.parameters(), lr=args.learning_rate, weight_decay=config.WEIGHT_DECAY)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=total_steps
)
dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer, isinst), batched=True, remove_columns=dataset.column_names)
if 'text' not in dataset.column_names:
raise ValueError("Dataset transformation failed: 'text' column missing after mapping.")
print("Mapped dataset sample length:", len(dataset[0]['text']))
try:
test_input = tokenizer(
["This is a test input."],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=MAX_SEQ_LENGTH
)
test_output = model(**test_input)
print("Model test output shape:", test_output.logits.shape)
except RuntimeError as e:
print(f"Error processing test batch: {e}")
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=args,
train_dataset=dataset,
# dataset_text_field='text',
max_seq_length=config.MAX_SEQ_LENGTH,
optimizers=(optimizer, scheduler)
)
train = trainer.train()
trained_model = trainer.model
trained_tokenizer = trainer.tokenizer
if push:
repo_id = config.OUTPUT_REPO + "-it" if config.INSTRUCT_FINETUNE_BOOL else config.OUTPUT_REPO
msg = f"Training loss: {train.training_loss:.4f}"
trained_model.push_to_hub(repo_id, commit_message=msg, force=True)
trained_tokenizer.push_to_hub(repo_id, commit_message=msg, force=True)
else:
trained_model.save_pretrained("model")
trained_tokenizer.save_pretrained("tokenizer")
def main(push_to_hub=True, is_inst_finetune=config.INSTRUCT_FINETUNE_BOOL):
print("Loading Data..")
dataset = load_data()
print("Loaded data.")
if is_inst_finetune and config.INIT > 0:
print("Loading Tokenizer..")
tokenizer = load_tokenizer()
print("Loaded Tokenizer.")
else:
print("Making Corpus..")
training_corpus = get_training_corpus(dataset)
print("Made Corpus.")
print("Making Tokenizer..")
tokenizer = create_tokenizer(training_corpus)
print(f"Made Tokenizer with size {len(tokenizer)}.")
# print("Adding Tokens..")
# num_new_tokens = update_tokenizer(tokenizer, dataset)
# print(f"Added {num_new_tokens} new tokens to the vocabulary")
if config.INIT == 0:
print("Adding Special Tokens..")
configure_tokenizer(tokenizer)
print("Added Tokens.")
if is_inst_finetune or config.INIT > 0:
print("Loading Model..")
model = load_model()
print("Loaded Model.")
else:
print("Creating Model..")
model = create_model(tokenizer)
print("Created Model.")
print(f"Tokenizer vocabulary size: {len(tokenizer)}")
print(f"Special tokens: {tokenizer.special_tokens_map}")
print("Resizing Token Embeddings..")
try:
model.resize_token_embeddings(len(tokenizer))
except RuntimeError as e:
raise RuntimeError(f"Error resizing token embeddings: {e}")
print("Resized Embeddings.")
print("Training Model..")
train_model(model, tokenizer, dataset, push_to_hub, is_inst_finetune)
raise FineError("Trained Model.")
if __name__ == "__main__":
try:
main()
except Exception as e:
print(f'{type(e).__name__}: {e}')
Space().pause()