how to fine tune peft qlora and SFTTrainer?

#2
by NickyNicky - opened

Load model QLora

import os
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
    GenerationConfig,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer


bnb_config= BitsAndBytesConfig(
        load_in_8bit=True,
        bnb_8bit_use_double_quant=True,
        bnb_8bit_quant_type="nf4",
        bnb_8bit_compute_dtype=torch.bfloat16,
        llm_int8_skip_modules= ['decoder', 'lm_head',  'wo'],
    )

import logging
logger = logging.getLogger(__name__)

def get_max_length():
  max_length = None
  for length_setting in ["n_positions", "max_position_embeddings", "seq_length"]:
    max_length = getattr(model.config, length_setting, None)
    if max_length:
      logger.info(f"Found max lenth: {max_length}")
      break
  if not max_length:
    max_length = 32_000
    logger.info(f"Using default max length: {max_length}")
  max_length



model_id= "togethercomputer/LLaMA-2-7B-32K"


model = AutoModelForCausalLM.from_pretrained(model_id, 
                                             trust_remote_code=True, 
                                             torch_dtype=torch.float16,
                                             device_map="auto",
                                             quantization_config=bnb_config,
                                             low_cpu_mem_usage=True,
                                              # load_in_8bit=True,
                                             )

tokenizer = AutoTokenizer.from_pretrained(model_id,
                                          use_fast = False,
                                          max_length=get_max_length(),
                                          )

databricks/databricks-dolly-15k



Dataset({
    features: ['instruction', 'context', 'response', 'category'],
    num_rows: 15011
})

TRAIN

from transformers import TrainingArguments


from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

# LoRA config based on QLoRA paper
peft_config = LoraConfig(
        lora_alpha=16,
        r=32,
        lora_dropout=0.1,
        
        bias="none",
        task_type="CAUSAL_LM", 
)


# prepare model for training
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)

args = TrainingArguments(
    output_dir="llama-7-int4-dolly",
    num_train_epochs=3,
    per_device_train_batch_size= 1, #6 if use_flash_attention else 3,
    gradient_accumulation_steps=2,
    gradient_checkpointing=True,
    # optim="paged_adamw_32bit",
    
    # torch_compile=True, # optimizations
    # optim="adamw_torch_fused", # improved optimizer 
    optim="adamw_bnb_8bit", #     #['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad']

    
    logging_steps=4,
    save_strategy="epoch",
    learning_rate=2e-4,
    bf16=True,
    # fp16=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    disable_tqdm= False, #True # disable tqdm since with packing values are in correct
)

from trl import SFTTrainer
max_seq_length = get_max_length() # max sequence length for model and packing of the dataset
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    max_seq_length=max_seq_length, 
    tokenizer=tokenizer,
    packing=True,
    formatting_func=format_instruction, 
    args=args,
)
trainer.train()

ERROR:
```Python
/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")

RuntimeError Traceback (most recent call last)
in <cell line: 2>()
1 # train
----> 2 trainer.train() # there will not be a progress bar since tqdm is disabled

28 frames
/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py in _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax)
40 maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
41 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
---> 42 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
43 q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None
44 )

RuntimeError: FlashAttention only support fp16 and bf16 data type


How to train this model
image.png

thanks. :)

Together org

@NickyNicky Hi, since you are doing QLoRA, you might need to set trust_remote_code=False to use HF's llama implementation, flash attention only works for float16 or bfloat16.

Thanks for such a quick response,
So I couldn't train with Qlora, and flash-attention? or is it not optimized? or What is missing or what am I doing wrong?

how can i implement float16 or bfloat16.

would it be ok like that?

for name, module in model.named_modules():
    if "norm" in name:
        module = module.to(torch.float16)

thanks. :)

Together org

@NickyNicky -- This might take some time to get Qlora + flash-attention to work (mainly engineering and optimizations). But we are working together with our friends in the open source community on it -- stay tuned! (but it might take some time before the release)

Ce

I tried working without Qlora and got

File ~/anaconda3/lib/python3.10/site-packages/flash_attn/layers/rotary.py:62, in ApplyRotaryEmb.forward(ctx, x, cos, sin, interleaved, inplace)
     59 else:
     60     o1, o2 = (out_ro.chunk(2, dim=-1) if not interleaved
     61               else (out_ro[..., ::2], out_ro[..., 1::2]))
---> 62 rotary_emb.apply_rotary(x1, x2, rearrange(cos[:seqlen], 's d -> s 1 d'),
     63                         rearrange(sin[:seqlen], 's d -> s 1 d'), o1, o2, False)
     64 if not inplace and rotary_dim < headdim:
     65     out[..., rotary_dim:].copy_(x[..., rotary_dim:])

RuntimeError: Expected x1.dtype() == cos.dtype() to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

Got the same error as @chanderbalaji . (following)

@NickyNicky Hey, I saw that you released togethercomputer-LLaMA-2-7B-32K-open-Orca-v1, is your problem solved? As long as you don't use flash-attention, you can use QLora, right?

credits to:

the model togethercomputer-LLaMA-2-7B-32K-open-Orca-v1 and togethercomputer-LLaMA-2-7B-32K-open-Orca-v2 train with QLora, peft and flash-attention for a period of 4 hours V1 and 5 hours v2, 1 GPU A100 (Google colab).
I really wanted to train him longer but it's out of budget.

values to train:
per_device_train_batch_size=14
trust_remote_code=False

After training and joining the weights you can enable flash attention.

Wow, thank you~
And may I ask which orca dataset you used? Is there any re-filtering of token sizes in the dataset?

orca dataset:

  • 1M-GPT4-Augmented.parquet

Just curious...
The config from Bitsandbytes you added, is this valid? as I can't find any references for double quant with 8bit only 4bit?
More just wondering is all.

bnb_config= BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_use_double_quant=True,
bnb_8bit_quant_type="nf4",
bnb_8bit_compute_dtype=torch.bfloat16,
llm_int8_skip_modules= ['decoder', 'lm_head', 'wo'],
)

Thanks!

Just curious...
The config from Bitsandbytes you added, is this valid? as I can't find any references for double quant with 8bit only 4bit?
More just wondering is all.

bnb_config= BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_use_double_quant=True,
bnb_8bit_quant_type="nf4",
bnb_8bit_compute_dtype=torch.bfloat16,
llm_int8_skip_modules= ['decoder', 'lm_head', 'wo'],
)

Thanks!

@Mediocreatmybest I had the same doubt, did you resolve this?

Sign up or log in to comment