import argparse |
import bitsandbytes as bnb |
from datasets import load_dataset |
from functools import partial |
import os |
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, AutoPeftModelForCausalLM |
import torch |
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed, Trainer, TrainingArguments, BitsAndBytesConfig, \ |
DataCollatorForLanguageModeling, Trainer, TrainingArguments |
from datasets import load_dataset |
def load_model(model_name, bnb_config): |
n_gpus = torch.cuda.device_count() |
max_memory = f'{40960}MB' |
model = AutoModelForCausalLM.from_pretrained( |
model_name, |
quantization_config=bnb_config, |
device_map="auto", |
max_memory = {i: max_memory for i in range(n_gpus)}, |
) |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True) |
tokenizer.pad_token = tokenizer.eos_token |
return model, tokenizer |
from datasets import load_dataset |
dataset = load_dataset("databricks/databricks-dolly-15k", split="train") |
print(f'Number of prompts: {len(dataset)}') |
print(f'Column names are: {dataset.column_names}') |
def create_prompt_formats(sample): |
""" |
Format various fields of the sample ('instruction', 'context', 'response') |
Then concatenate them using two newline characters |
:param sample: Sample dictionnary |
""" |
INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request." |
INSTRUCTION_KEY = "### Instruction:" |
INPUT_KEY = "Input:" |
RESPONSE_KEY = "### Response:" |
END_KEY = "### End" |
blurb = f"{INTRO_BLURB}" |
instruction = f"{INSTRUCTION_KEY}\n{sample['instruction']}" |
input_context = f"{INPUT_KEY}\n{sample['context']}" if sample["context"] else None |
response = f"{RESPONSE_KEY}\n{sample['response']}" |
end = f"{END_KEY}" |
parts = [part for part in [blurb, instruction, input_context, response, end] if part] |
formatted_prompt = "\n\n".join(parts) |
sample["text"] = formatted_prompt |
return sample |
def get_max_length(model): |
conf = model.config |
max_length = None |
for length_setting in ["n_positions", "max_position_embeddings", "seq_length"]: |
max_length = getattr(model.config, length_setting, None) |
if max_length: |
print(f"Found max lenth: {max_length}") |
break |
if not max_length: |
max_length = 1024 |
print(f"Using default max length: {max_length}") |
return max_length |
def preprocess_batch(batch, tokenizer, max_length): |
""" |
Tokenizing a batch |
""" |
return tokenizer( |
batch["text"], |
max_length=max_length, |
truncation=True, |
) |
def preprocess_dataset(tokenizer: AutoTokenizer, max_length: int, seed, dataset: str): |
"""Format & tokenize it so it is ready for training |
:param tokenizer (AutoTokenizer): Model Tokenizer |
:param max_length (int): Maximum number of tokens to emit from tokenizer |
""" |
print("Preprocessing dataset...") |
dataset = dataset.map(create_prompt_formats) |
_preprocessing_function = partial(preprocess_batch, max_length=max_length, tokenizer=tokenizer) |
dataset = dataset.map( |
_preprocessing_function, |
batched=True, |
remove_columns=["instruction", "context", "response", "text", "category"], |
) |
dataset = dataset.filter(lambda sample: len(sample["input_ids"]) < max_length) |
dataset = dataset.shuffle(seed=seed) |
return dataset |
def create_bnb_config(): |
bnb_config = BitsAndBytesConfig( |
load_in_4bit=True, |
bnb_4bit_use_double_quant=True, |
bnb_4bit_quant_type="nf4", |
bnb_4bit_compute_dtype=torch.bfloat16, |
) |
return bnb_config |
def create_peft_config(modules): |
""" |
Create Parameter-Efficient Fine-Tuning config for your model |
:param modules: Names of the modules to apply Lora to |
""" |
config = LoraConfig( |
r=16, |
lora_alpha=64, |
target_modules=modules, |
lora_dropout=0.1, |
bias="none", |
task_type="CAUSAL_LM", |
) |
return config |
def find_all_linear_names(model): |
cls = bnb.nn.Linear4bit |
lora_module_names = set() |
for name, module in model.named_modules(): |
if isinstance(module, cls): |
names = name.split('.') |
lora_module_names.add(names[0] if len(names) == 1 else names[-1]) |
if 'lm_head' in lora_module_names: |
lora_module_names.remove('lm_head') |
return list(lora_module_names) |
def print_trainable_parameters(model, use_4bit=False): |
""" |
Prints the number of trainable parameters in the model. |
""" |
trainable_params = 0 |
all_param = 0 |
for _, param in model.named_parameters(): |
num_params = param.numel() |
if num_params == 0 and hasattr(param, "ds_numel"): |
num_params = param.ds_numel |
all_param += num_params |
if param.requires_grad: |
trainable_params += num_params |
if use_4bit: |
trainable_params /= 2 |
print( |
f"all params: {all_param:,d} || trainable params: {trainable_params:,d} || trainable%: {100 * trainable_params / all_param}" |
) |
model_name = "meta-llama/Llama-2-7b-hf" |
bnb_config = create_bnb_config() |
model, tokenizer = load_model(model_name, bnb_config) |
print(model) |
max_length = get_max_length(model) |
print(max_length) |
seed = 98345 |
dataset = preprocess_dataset(tokenizer, max_length, seed, dataset) |
def train(model, tokenizer, dataset, output_dir): |
model.gradient_checkpointing_enable() |
model = prepare_model_for_kbit_training(model) |
modules = find_all_linear_names(model) |
peft_config = create_peft_config(modules) |
model = get_peft_model(model, peft_config) |
print_trainable_parameters(model) |
trainer = Trainer( |
model=model, |
train_dataset=dataset, |
args=TrainingArguments( |
per_device_train_batch_size=1, |
gradient_accumulation_steps=4, |
warmup_steps=2, |
max_steps=20, |
learning_rate=2e-4, |
fp16=True, |
logging_steps=1, |
output_dir="outputs", |
optim="paged_adamw_8bit", |
), |
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False) |
) |
model.config.use_cache = False |
dtypes = {} |
for _, p in model.named_parameters(): |
dtype = p.dtype |
if dtype not in dtypes: dtypes[dtype] = 0 |
dtypes[dtype] += p.numel() |
total = 0 |
for k, v in dtypes.items(): total+= v |
for k, v in dtypes.items(): |
print(k, v, v/total) |
do_train = True |
print("Training...") |
if do_train: |
train_result = trainer.train() |
metrics = train_result.metrics |
trainer.log_metrics("train", metrics) |
trainer.save_metrics("train", metrics) |
trainer.save_state() |
print(metrics) |
print("Saving last checkpoint of the model...") |
os.makedirs(output_dir, exist_ok=True) |
trainer.model.save_pretrained(output_dir) |
del model |
del trainer |
torch.cuda.empty_cache() |
output_dir = "results/llama2/final_checkpoint" |
print("Run train ...") |
train(model, tokenizer, dataset, output_dir) |
model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", torch_dtype=torch.bfloat16) |
model = model.merge_and_unload() |
output_merged_dir = "results/llama2/final_merged_checkpoint" |
os.makedirs(output_merged_dir, exist_ok=True) |
model.save_pretrained(output_merged_dir, safe_serialization=True) |
tokenizer = AutoTokenizer.from_pretrained(model_name) |
tokenizer.save_pretrained(output_merged_dir) |