how to fine tune peft qlora and SFTTrainer?
Load model QLora
import os
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
HfArgumentParser,
TrainingArguments,
pipeline,
logging,
GenerationConfig,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer
bnb_config= BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_use_double_quant=True,
bnb_8bit_quant_type="nf4",
bnb_8bit_compute_dtype=torch.bfloat16,
llm_int8_skip_modules= ['decoder', 'lm_head', 'wo'],
)
import logging
logger = logging.getLogger(__name__)
def get_max_length():
max_length = None
for length_setting in ["n_positions", "max_position_embeddings", "seq_length"]:
max_length = getattr(model.config, length_setting, None)
if max_length:
logger.info(f"Found max lenth: {max_length}")
break
if not max_length:
max_length = 32_000
logger.info(f"Using default max length: {max_length}")
max_length
model_id= "togethercomputer/LLaMA-2-7B-32K"
model = AutoModelForCausalLM.from_pretrained(model_id,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto",
quantization_config=bnb_config,
low_cpu_mem_usage=True,
# load_in_8bit=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_id,
use_fast = False,
max_length=get_max_length(),
)
databricks/databricks-dolly-15k
Dataset({
features: ['instruction', 'context', 'response', 'category'],
num_rows: 15011
})
TRAIN
from transformers import TrainingArguments
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
# LoRA config based on QLoRA paper
peft_config = LoraConfig(
lora_alpha=16,
r=32,
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM",
)
# prepare model for training
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)
args = TrainingArguments(
output_dir="llama-7-int4-dolly",
num_train_epochs=3,
per_device_train_batch_size= 1, #6 if use_flash_attention else 3,
gradient_accumulation_steps=2,
gradient_checkpointing=True,
# optim="paged_adamw_32bit",
# torch_compile=True, # optimizations
# optim="adamw_torch_fused", # improved optimizer
optim="adamw_bnb_8bit", # #['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad']
logging_steps=4,
save_strategy="epoch",
learning_rate=2e-4,
bf16=True,
# fp16=True,
max_grad_norm=0.3,
warmup_ratio=0.03,
lr_scheduler_type="constant",
disable_tqdm= False, #True # disable tqdm since with packing values are in correct
)
from trl import SFTTrainer
max_seq_length = get_max_length() # max sequence length for model and packing of the dataset
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
max_seq_length=max_seq_length,
tokenizer=tokenizer,
packing=True,
formatting_func=format_instruction,
args=args,
)
trainer.train()
ERROR:
```Python
/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
RuntimeError Traceback (most recent call last)
in <cell line: 2>()
1 # train
----> 2 trainer.train() # there will not be a progress bar since tqdm is disabled
28 frames
/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py in _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax)
40 maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
41 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
---> 42 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
43 q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None
44 )
RuntimeError: FlashAttention only support fp16 and bf16 data type
@NickyNicky
Hi, since you are doing QLoRA, you might need to set trust_remote_code=False
to use HF's llama implementation, flash attention only works for float16 or bfloat16.
Thanks for such a quick response,
So I couldn't train with Qlora, and flash-attention? or is it not optimized? or What is missing or what am I doing wrong?
how can i implement float16 or bfloat16.
would it be ok like that?
for name, module in model.named_modules():
if "norm" in name:
module = module.to(torch.float16)
thanks. :)
@NickyNicky -- This might take some time to get Qlora + flash-attention to work (mainly engineering and optimizations). But we are working together with our friends in the open source community on it -- stay tuned! (but it might take some time before the release)
Ce
I tried working without Qlora and got
File ~/anaconda3/lib/python3.10/site-packages/flash_attn/layers/rotary.py:62, in ApplyRotaryEmb.forward(ctx, x, cos, sin, interleaved, inplace)
59 else:
60 o1, o2 = (out_ro.chunk(2, dim=-1) if not interleaved
61 else (out_ro[..., ::2], out_ro[..., 1::2]))
---> 62 rotary_emb.apply_rotary(x1, x2, rearrange(cos[:seqlen], 's d -> s 1 d'),
63 rearrange(sin[:seqlen], 's d -> s 1 d'), o1, o2, False)
64 if not inplace and rotary_dim < headdim:
65 out[..., rotary_dim:].copy_(x[..., rotary_dim:])
RuntimeError: Expected x1.dtype() == cos.dtype() to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
Got the same error as @chanderbalaji . (following)
@NickyNicky Hey, I saw that you released togethercomputer-LLaMA-2-7B-32K-open-Orca-v1, is your problem solved? As long as you don't use flash-attention, you can use QLora, right?
credits to:
- https://www.philschmid.de/instruction-tune-llama-2
- https://github.com/philschmid/deep-learning-pytorch-huggingface/blob/main/training/instruction-tune-llama-2-int4.ipynb
the model togethercomputer-LLaMA-2-7B-32K-open-Orca-v1 and togethercomputer-LLaMA-2-7B-32K-open-Orca-v2 train with QLora, peft and flash-attention for a period of 4 hours V1 and 5 hours v2, 1 GPU A100 (Google colab).
I really wanted to train him longer but it's out of budget.
values to train:
per_device_train_batch_size=14
trust_remote_code=False
After training and joining the weights you can enable flash attention.
Wow, thank you~
And may I ask which orca dataset you used? Is there any re-filtering of token sizes in the dataset?
orca dataset:
- 1M-GPT4-Augmented.parquet
Just curious...
The config from Bitsandbytes you added, is this valid? as I can't find any references for double quant with 8bit only 4bit?
More just wondering is all.
bnb_config= BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_use_double_quant=True,
bnb_8bit_quant_type="nf4",
bnb_8bit_compute_dtype=torch.bfloat16,
llm_int8_skip_modules= ['decoder', 'lm_head', 'wo'],
)
Thanks!
Just curious...
The config from Bitsandbytes you added, is this valid? as I can't find any references for double quant with 8bit only 4bit?
More just wondering is all.bnb_config= BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_use_double_quant=True,
bnb_8bit_quant_type="nf4",
bnb_8bit_compute_dtype=torch.bfloat16,
llm_int8_skip_modules= ['decoder', 'lm_head', 'wo'],
)Thanks!
@Mediocreatmybest I had the same doubt, did you resolve this?