RuntimeError: Index put requires the source and destination dtypes match, got BFloat16 for the destination and Float for the source.

#24
by saireddy - opened

Usecase : I am trying to fine tune gemma2 using SFTTrainer and here is how I am loading the model and my bnb cofigs
model_params = {
"attn_implementation": "eager",
"torch_dtype": torch.bfloat16,
"use_cache": True,
"device_map": "auto",
}
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **model_params)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
BNB_CONFIG = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)

Training arguments

TRAINING_ARGS = {
"num_train_epochs": 1,
"optim": "adamw_torch_fused",
"logging_steps": 20,
"save_strategy": "epoch",
"bf16": True,
"tf32": True,
}

and when i try to use fine tuned model to generate predictions using this
outputs = model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True,
temperature=temperature, pad_token_id=tokenizer.eos_token_id)

i am hitting this error, and the same script works fine with llama3, mistral, qwen ...
RuntimeError: Index put requires the source and destination dtypes match, got BFloat16 for the destination and Float for the source.

stacktrace :

outputs = model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True,
File "/usr/local/lib/python3.10/dist-packages/peft/peft_model.py", line 1491, in generate
outputs = self.base_model.generate(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1914, in generate
result = self._sample(
File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2651, in _sample
outputs = self(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/gemma2/modeling_gemma2.py", line 1068, in forward
outputs = self.model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/gemma2/modeling_gemma2.py", line 908, in forward
layer_outputs = decoder_layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/gemma2/modeling_gemma2.py", line 650, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/gemma2/modeling_gemma2.py", line 252, in forward
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/cache_utils.py", line 1071, in update
return update_fn(
File "/usr/local/lib/python3.10/dist-packages/transformers/cache_utils.py", line 1046, in _static_update
k_out[:, :, cache_position] = key_states
RuntimeError: Index put requires the source and destination dtypes match, got BFloat16 for the destination and Float for the source.

Hardware : NVIDIA H100 80GB
accelerate==0.31.0
bitsandbytes==0.43.1
datasets==2.18.0
deepspeed==0.14.4
evaluate==0.4.1
peft==0.11.1
transformers==4.42.3
trl==0.9.4
pytorch image : nvcr.io/nvidia/pytorch:24.05-py3 -- cuda 12.4.1 and torch 2.4

@Renu11 any advise on this issue?

Do you know how to fix this bug?

@DeHors i was able to fix this issue using

model.to(torch.bfloat16)

before generating predictions

But when i use model.to(torch.bfloat16) before generating predictions, I find this bug:
ValueError: .to is not supported for 4-bit or 8-bit bitsandbytes models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct dtype.

@DeHors it worked for me as i was doing full finetuning and I assuming you are using lora or qlora for peft. I am not sure on how to fix for this one. sorry

+1
getting similar Error:
RuntimeError: Index put requires the source and destination dtypes match, got Half for the destination and Float for the source.

This comment has been hidden
saireddy changed discussion status to closed
saireddy changed discussion status to open

sorry i did type same answer as above , so had to hide my above comment

Try to set use_cache to False; it has helped me.

+1
getting similar Error:
RuntimeError: Index put requires the source and destination dtypes match, got Half for the destination and Float for the source.

Same issue for me here, using gemma2-2b and 4-bit quant.
Setting use_cache to False didn't solve the error, but setting model.to(torch.half) seems to correct it.

This comment has been hidden

Sign up or log in to comment