|
import os |
|
import sys |
|
|
|
import torch |
|
import torch.nn as nn |
|
import bitsandbytes as bnb |
|
from datasets import load_dataset |
|
import transformers |
|
import argparse |
|
import warnings |
|
from huggingface_hub import snapshot_download |
|
|
|
assert ( |
|
"LlamaTokenizer" in transformers._import_structure["models.llama"] |
|
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git" |
|
from transformers import LlamaForCausalLM, LlamaTokenizer |
|
from peft import ( |
|
prepare_model_for_int8_training, |
|
LoraConfig, |
|
get_peft_model, |
|
get_peft_model_state_dict, |
|
set_peft_model_state_dict, |
|
) |
|
|
|
def get_peft_state_maybe_zero_3(state_dict, bias): |
|
if hasattr(param, "ds_id"): |
|
assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE |
|
with zero.GatheredParameters([param]): |
|
param = param.data.cpu().clone().detach() |
|
to_return = {k: maybe_zero_3(v) for k, v in to_return.items()} |
|
return to_return |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--wandb", action="store_true", default=False) |
|
parser.add_argument("--data_path", type=str, default="merge.json") |
|
parser.add_argument("--output_path", type=str, default="lora-Vicuna") |
|
parser.add_argument("--model_path", type=str, default="decapoda-research/llama-7b-hf") |
|
parser.add_argument("--eval_steps", type=int, default=200) |
|
parser.add_argument("--save_steps", type=int, default=200) |
|
parser.add_argument("--test_size", type=int, default=200) |
|
parser.add_argument("--resume_from_checkpoint", type=str, default=None) |
|
parser.add_argument("--ignore_data_skip", type=str, default="False") |
|
parser.add_argument("--lora_remote_checkpoint", type=str, default=None) |
|
|
|
parser.add_argument("--local_rank", type=int, default=-1) |
|
parser.add_argument("--deepspeed", action="store_true", default=False) |
|
|
|
args = parser.parse_args() |
|
|
|
if not args.wandb: |
|
os.environ["WANDB_MODE"] = "disable" |
|
|
|
MICRO_BATCH_SIZE = 2 |
|
BATCH_SIZE = 128 |
|
MAX_STEPS = None |
|
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE |
|
EPOCHS = 3 |
|
LEARNING_RATE = 3e-4 |
|
CUTOFF_LEN = 256 |
|
LORA_R = 8 |
|
LORA_ALPHA = 16 |
|
LORA_DROPOUT = 0.05 |
|
VAL_SET_SIZE = args.test_size |
|
TARGET_MODULES = [ |
|
"q_proj", |
|
"v_proj", |
|
] |
|
DATA_PATH = args.data_path |
|
OUTPUT_DIR = args.output_path |
|
|
|
device_map = {"": 0} |
|
world_size = int(os.environ.get("WORLD_SIZE", 1)) |
|
ddp = world_size != 1 |
|
if ddp: |
|
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} |
|
GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size |
|
print(args.model_path) |
|
model = LlamaForCausalLM.from_pretrained( |
|
args.model_path, |
|
load_in_8bit=False, |
|
torch_dtype=torch.float16, |
|
device_map=device_map, |
|
).half() |
|
tokenizer = LlamaTokenizer.from_pretrained( |
|
args.model_path, add_eos_token=True |
|
) |
|
|
|
|
|
|
|
config = LoraConfig( |
|
r=LORA_R, |
|
lora_alpha=LORA_ALPHA, |
|
target_modules=TARGET_MODULES, |
|
lora_dropout=LORA_DROPOUT, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
) |
|
model = get_peft_model(model, config) |
|
|
|
tokenizer.pad_token_id = 0 |
|
|
|
|
|
data = load_dataset("json", data_files=DATA_PATH) |
|
|
|
now_max_steps = max((len(data["train"]) - VAL_SET_SIZE) // BATCH_SIZE * EPOCHS, EPOCHS) |
|
if args.resume_from_checkpoint: |
|
if args.lora_remote_checkpoint is not None: |
|
snapshot_download(repo_id=args.lora_remote_checkpoint, allow_patterns=["*.pt", "*.bin", "*.json"], local_dir=args.resume_from_checkpoint) |
|
|
|
checkpoint_name = os.path.join( |
|
args.resume_from_checkpoint, "pytorch_model.bin" |
|
) |
|
if not os.path.exists(checkpoint_name): |
|
pytorch_bin_path = checkpoint_name |
|
checkpoint_name = os.path.join( |
|
args.resume_from_checkpoint, "adapter_model.bin" |
|
) |
|
if os.path.exists(checkpoint_name): |
|
os.rename(checkpoint_name, pytorch_bin_path) |
|
warnings.warn("The file name of the lora checkpoint'adapter_model.bin' is replaced with 'pytorch_model.bin'") |
|
else: |
|
args.resume_from_checkpoint = ( |
|
None |
|
) |
|
|
|
if os.path.exists(checkpoint_name): |
|
print(f"Restarting from {checkpoint_name}") |
|
adapters_weights = torch.load(checkpoint_name) |
|
model = set_peft_model_state_dict(model, adapters_weights) |
|
else: |
|
print(f"Checkpoint {checkpoint_name} not found") |
|
|
|
train_args_path = os.path.join(args.resume_from_checkpoint, "trainer_state.json") |
|
|
|
if os.path.exists(train_args_path): |
|
import json |
|
base_train_args = json.load(open(train_args_path, 'r')) |
|
base_max_steps = base_train_args["max_steps"] |
|
resume_scale = base_max_steps / now_max_steps |
|
if base_max_steps > now_max_steps: |
|
warnings.warn("epoch {} replace to the base_max_steps {}".format(EPOCHS, base_max_steps)) |
|
EPOCHS = None |
|
MAX_STEPS = base_max_steps |
|
else: |
|
MAX_STEPS = now_max_steps |
|
else: |
|
MAX_STEPS = now_max_steps |
|
|
|
|
|
model.print_trainable_parameters() |
|
|
|
def generate_prompt(data_point): |
|
|
|
if data_point["input"]: |
|
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. |
|
|
|
### Instruction: |
|
{data_point["instruction"]} |
|
|
|
### Input: |
|
{data_point["input"]} |
|
|
|
### Response: |
|
{data_point["output"]}""" |
|
else: |
|
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. |
|
|
|
### Instruction: |
|
{data_point["instruction"]} |
|
|
|
### Response: |
|
{data_point["output"]}""" |
|
|
|
|
|
def tokenize(prompt): |
|
|
|
|
|
result = tokenizer( |
|
prompt, |
|
truncation=True, |
|
max_length=CUTOFF_LEN + 1, |
|
padding="max_length", |
|
) |
|
return { |
|
"input_ids": result["input_ids"][:-1], |
|
"attention_mask": result["attention_mask"][:-1], |
|
} |
|
|
|
|
|
def generate_and_tokenize_prompt(data_point): |
|
|
|
|
|
user_prompt = ( |
|
( |
|
f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. |
|
|
|
### Instruction: |
|
{data_point["instruction"]} |
|
|
|
### Input: |
|
{data_point["input"]} |
|
|
|
### Response: |
|
""" |
|
) |
|
if data_point["input"] |
|
else ( |
|
f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. |
|
|
|
### Instruction: |
|
{data_point["instruction"]} |
|
|
|
### Response: |
|
""" |
|
) |
|
) |
|
len_user_prompt_tokens = ( |
|
len( |
|
tokenizer( |
|
user_prompt, |
|
truncation=True, |
|
max_length=CUTOFF_LEN + 1, |
|
)["input_ids"] |
|
) |
|
- 1 |
|
) |
|
full_tokens = tokenizer( |
|
user_prompt + data_point["output"], |
|
truncation=True, |
|
max_length=CUTOFF_LEN + 1, |
|
padding="max_length", |
|
)["input_ids"][:-1] |
|
return { |
|
"input_ids": full_tokens, |
|
"labels": [-100] * len_user_prompt_tokens |
|
+ full_tokens[len_user_prompt_tokens:], |
|
"attention_mask": [1] * (len(full_tokens)), |
|
} |
|
|
|
|
|
if VAL_SET_SIZE > 0: |
|
train_val = data["train"].train_test_split( |
|
test_size=VAL_SET_SIZE, shuffle=True, seed=42 |
|
) |
|
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt) |
|
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt) |
|
else: |
|
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt) |
|
val_data = None |
|
trainer = transformers.Trainer( |
|
model=model, |
|
train_dataset=train_data, |
|
eval_dataset=val_data, |
|
args=transformers.TrainingArguments( |
|
per_device_train_batch_size=MICRO_BATCH_SIZE, |
|
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, |
|
warmup_steps=100, |
|
num_train_epochs=EPOCHS, |
|
max_steps=MAX_STEPS, |
|
learning_rate=LEARNING_RATE, |
|
fp16=True, |
|
logging_steps=20, |
|
evaluation_strategy="steps" if VAL_SET_SIZE > 0 else "no", |
|
save_strategy="steps", |
|
eval_steps=args.eval_steps if VAL_SET_SIZE > 0 else None, |
|
save_steps=args.save_steps, |
|
output_dir=OUTPUT_DIR, |
|
save_total_limit=30, |
|
load_best_model_at_end=True if VAL_SET_SIZE > 0 else False, |
|
ddp_find_unused_parameters=False if ddp else None, |
|
report_to="wandb" if args.wandb else [], |
|
ignore_data_skip=args.ignore_data_skip, |
|
deepspeed="sample/zero_config.json" if args.deepspeed else None, |
|
), |
|
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False) |
|
) |
|
model.config.use_cache = False |
|
|
|
old_state_dict = model.state_dict |
|
model.state_dict = ( |
|
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict()) |
|
).__get__(model, type(model)) |
|
|
|
if torch.__version__ >= "2" and sys.platform != "win32": |
|
model = torch.compile(model) |
|
|
|
print("\n If there's a warning about missing keys above, please disregard :)") |
|
|
|
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) |
|
|
|
model.save_pretrained(OUTPUT_DIR) |
|
|
|
|