Unable to quantize layers one at a time
#8
by
abhinavkulkarni
- opened
Hi,
I am trying to apply AWQ quantization to this new architecture one layer at a time and running into a problem.
The way it works is as follows:
- Pass sample input through the model and catch the input to the first layer
- Pass the input through each layer successively while determining optimal quantization parameters
- Output of one layer is input to the next one
I have omitted the quantization logic, but the main scaffold is as follows.
import torch
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, TextStreamer
from accelerate import init_empty_weights, infer_auto_device_map
from datasets import load_dataset
import torch.nn as nn
import gc
model_id = "togethercomputer/StripedHyena-Nous-7B"
# model_id = "meta-llama/Llama-2-7b-hf"
# Config
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
# Load model on CPU
kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
model = AutoModelForCausalLM.from_pretrained(
model_id, config=config, trust_remote_code=True, **kwargs
)
model.eval()
# Tokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True)
except:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
# Load sample dataset
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
texts = [dataset[i]['text'] for i in range(10)]
samples = [tokenizer.encode(text, max_length=512, truncation=True, padding='max_length') for text in texts]
samples = torch.LongTensor(samples) # Shape = (10, 512)
# Catch the input to the first layer
inps = []
layer_kwargs = {}
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, inp, **kwargs):
inps.append(inp)
layer_kwargs.update(kwargs)
raise ValueError # early exit to break later inference
# patch layer 0 to catch input and kwargs
layers = model.backbone.blocks # For StripedHyena
# layers = model.model.layers # For Llama-2
layers[0] = Catcher(layers[0])
try:
model(samples.to(next(model.parameters()).device))
except ValueError: # work with early exit
pass
layers[0] = layers[0].module # restore
inps = inps[0]
layers[0] = layers[0].cpu()
# Now pass the input successively through each layer, collecting the output
# which becomes input for the next layer
for i in range(len(layers)):
print(i)
layer = layers[i]
layer = layer.cuda()
inps = inps.to(next(layer.parameters()).device) # in case multi-gpu
layer_kwargs = {k:(v.to(inps.device) if isinstance(v, torch.Tensor) else v) for k,v in layer_kwargs.items()}
# get output as next layer's input
inps = layer(inps, **layer_kwargs)[0]
# Clear GPU memory
torch.cuda.empty_cache()
layer = layer.cpu()
gc.collect()
torch.cuda.empty_cache()
I get the following error when the input is passed through AttentionBlock
layer:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[1], line 75
73 layer_kwargs = {k:(v.to(inps.device) if isinstance(v, torch.Tensor) else v) for k,v in layer_kwargs.items()}
74 # get output as next layer's input
---> 75 inps = layer(inps, **layer_kwargs)[0]
76 # Clear GPU memory
77 torch.cuda.empty_cache()
File ~/miniconda3/envs/llm-awq/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/llm-awq/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/.cache/huggingface/modules/transformers_modules/togethercomputer/StripedHyena-Nous-7B/42777970d603597dadb768705896533eb9556a07/model.py:71, in AttentionBlock.forward(self, u, inference_params, padding_mask, *args, **kwargs)
64 u = u * padding_mask[..., None]
66 # for attr in ['lengths_per_sample', 'max_seqlen', 'key_value_memory_dict']:
67 # if not hasattr(inference_params, attr):
68 # setattr(inference_params, attr, None)
69 # inference_params.key_value_memory_dict = inference_params.key_value_memory_dict or {}
70 u = (
---> 71 self.inner_mha_cls(
72 self.pre_norm(u),
73 inference_params=inference_params,
74 )
75 + u
76 )
77 if type(padding_mask) == torch.Tensor: # guard against bias
78 u = u * padding_mask[..., None]
File ~/miniconda3/envs/llm-awq/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/llm-awq/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/miniconda3/envs/llm-awq/lib/python3.10/site-packages/flash_attn/modules/mha.py:563, in MHA.forward(self, x, x_kv, key_padding_mask, cu_seqlens, max_seqlen, mixer_subset, inference_params, **kwargs)
551 assert not self.dwconv
553 kwargs = (
554 {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
555 if self.use_flash_attn
556 else {"key_padding_mask": key_padding_mask, **kwargs}
557 )
558 seqlen_offset = (
559 0
560 if inference_params is None
561 else (
562 inference_params.lengths_per_sample
--> 563 if inference_params.lengths_per_sample is not None
564 else inference_params.seqlen_offset
565 )
566 )
567 rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
568 batch, seqlen = x.shape[:2]
AttributeError: 'RecurrentInferenceParams' object has no attribute 'lengths_per_sample'
Please note the same code works for meta-llama/Llama-2-7b-hf
.
All the quantization methods - GPTQ, AWQ, etc. - work layer by layer. Can you please help?
Thanks!
RecurrentInferenceParams
handles cache management for Hyena layers only. Since these layers have a constant cache (no kv-cache), RecurrentInferenceParams
does not have a .lengths_per_sample
attribute.
Can you try setting cache use to False before loading the model:
config.use_cache = False
model = AutoModelForCausalLM.from_pretrained(
model_id, config=config, trust_remote_code=True, **kwargs
)