|
import copy |
|
import json |
|
import os |
|
from dataclasses import dataclass |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
|
|
from transformers.configuration_utils import PretrainedConfig |
|
from .import_utils import is_hqq_available, is_quanto_available |
|
from transformers.utils import logging |
|
|
|
|
|
if is_quanto_available(): |
|
from quanto import QBitsTensor, qint2, qint4 |
|
|
|
if is_hqq_available(): |
|
from hqq.core.quantize import Quantizer as HQQQuantizer |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
@dataclass |
|
class Cache: |
|
""" |
|
Base, abstract class for all caches. The actual data structure is specific to each subclass. |
|
""" |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
|
|
|
Parameters: |
|
key_states (`torch.Tensor`): |
|
The new key states to cache. |
|
value_states (`torch.Tensor`): |
|
The new value states to cache. |
|
layer_idx (`int`): |
|
The index of the layer to cache the states for. |
|
cache_kwargs (`Dict[str, Any]`, `optional`): |
|
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of |
|
cache to be created. |
|
|
|
Return: |
|
A tuple containing the updated key and value states. |
|
""" |
|
raise NotImplementedError("Make sure to implement `update` in a subclass.") |
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
|
|
|
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
"""Returns the maximum sequence length of the cached states, if there is any.""" |
|
raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.") |
|
|
|
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: |
|
"""Given the sequence length of the new inputs, returns the usable length of the cache.""" |
|
|
|
|
|
|
|
max_length = self.get_max_length() |
|
previous_seq_length = self.get_seq_length(layer_idx) |
|
if max_length is not None and previous_seq_length + new_seq_length > max_length: |
|
return max_length - new_seq_length |
|
return previous_seq_length |
|
|
|
def reorder_cache(self, beam_idx: torch.LongTensor): |
|
"""Reorders the cache for beam search, given the selected beam indices.""" |
|
for layer_idx in range(len(self.key_cache)): |
|
device = self.key_cache[layer_idx].device |
|
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) |
|
device = self.value_cache[layer_idx].device |
|
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
|
@property |
|
def seen_tokens(self): |
|
logger.warning_once( |
|
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " |
|
"model input instead." |
|
) |
|
if hasattr(self, "_seen_tokens"): |
|
return self._seen_tokens |
|
else: |
|
return None |
|
|
|
|
|
@dataclass |
|
class CacheConfig: |
|
""" |
|
Base class for cache configs |
|
""" |
|
|
|
cache_implementation: None |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict, **kwargs): |
|
""" |
|
Constructs a CacheConfig instance from a dictionary of parameters. |
|
Args: |
|
config_dict (Dict[str, Any]): Dictionary containing configuration parameters. |
|
**kwargs: Additional keyword arguments to override dictionary values. |
|
Returns: |
|
CacheConfig: Instance of CacheConfig constructed from the dictionary. |
|
""" |
|
config = cls(**config_dict) |
|
to_remove = [] |
|
for key, value in kwargs.items(): |
|
if hasattr(config, key): |
|
setattr(config, key, value) |
|
to_remove.append(key) |
|
for key in to_remove: |
|
kwargs.pop(key, None) |
|
return config |
|
|
|
|
|
def to_json_file(self, json_file_path: Union[str, os.PathLike]): |
|
""" |
|
Save this instance to a JSON file. |
|
|
|
Args: |
|
json_file_path (`str` or `os.PathLike`): |
|
Path to the JSON file in which this configuration instance's parameters will be saved. |
|
use_diff (`bool`, *optional*, defaults to `True`): |
|
If set to `True`, only the difference between the config instance and the default |
|
`QuantizationConfig()` is serialized to JSON file. |
|
""" |
|
with open(json_file_path, "w", encoding="utf-8") as writer: |
|
config_dict = self.to_dict() |
|
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" |
|
|
|
writer.write(json_string) |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
""" |
|
Serializes this instance to a Python dictionary. Returns: |
|
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. |
|
""" |
|
return copy.deepcopy(self.__dict__) |
|
|
|
|
|
def __iter__(self): |
|
"""allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" |
|
for attr, value in copy.deepcopy(self.__dict__).items(): |
|
yield attr, value |
|
|
|
|
|
def __repr__(self): |
|
return f"{self.__class__.__name__} {self.to_json_string()}" |
|
|
|
def to_json_string(self): |
|
""" |
|
Serializes this instance to a JSON formatted string. |
|
Returns: |
|
str: JSON formatted string representing the configuration instance. |
|
""" |
|
return json.dumps(self.__dict__, indent=2) + "\n" |
|
|
|
|
|
def update(self, **kwargs): |
|
""" |
|
Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes, |
|
returning all the unused kwargs. |
|
|
|
Args: |
|
kwargs (`Dict[str, Any]`): |
|
Dictionary of attributes to tentatively update this class. |
|
|
|
Returns: |
|
`Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. |
|
""" |
|
to_remove = [] |
|
for key, value in kwargs.items(): |
|
if hasattr(self, key): |
|
setattr(self, key, value) |
|
to_remove.append(key) |
|
|
|
|
|
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} |
|
return unused_kwargs |
|
|
|
|
|
@dataclass |
|
class QuantizedCacheConfig(CacheConfig): |
|
""" |
|
Configuration class for quantized cache settings. |
|
|
|
Attributes: |
|
backend (`str`, *optional*, defaults to `"quanto"`): |
|
Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`] |
|
nbits (`Optional[int]`, *optional*, defaults to 4): |
|
Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2. |
|
axis_key (`int`, *optional*, defaults to 0): |
|
Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. |
|
axis_value (`int`, *optional*, defaults to 0): |
|
Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. |
|
q_group_size (`Optional[int]`, *optional*, defaults to 64): |
|
Size of the quantization group, should be a divisor of the model's hidden dimension. |
|
Defaults to 64. |
|
residual_length (`Optional[int]`, *optional*, defaults to 128): |
|
Length of the residual cache which will always be stored in original presicion. |
|
Defaults to 128. |
|
compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): |
|
The defualt dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization. |
|
device (`str`, *optional*, defaults to `"cpu"`): |
|
Device on which to peform computations, should be same as the model's device. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
backend: str = "quanto", |
|
nbits: Optional[int] = 4, |
|
axis_key: Optional[int] = 0, |
|
axis_value: Optional[int] = 0, |
|
q_group_size: Optional[int] = 64, |
|
residual_length: Optional[int] = 128, |
|
compute_dtype: Optional[torch.dtype] = torch.float16, |
|
device: Optional[str] = "cpu", |
|
): |
|
self.backend = backend |
|
self.nbits = nbits |
|
self.axis_key = axis_key |
|
self.axis_value = axis_value |
|
self.q_group_size = q_group_size |
|
self.residual_length = residual_length |
|
self.compute_dtype = compute_dtype |
|
self.device = device |
|
|
|
def validate(self): |
|
"""Validates if the arguments passed are correct""" |
|
|
|
incorrect_arg_msg = ( |
|
"Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " |
|
"but found {found_value}" |
|
) |
|
|
|
|
|
if self.nbits not in [1, 2, 3, 4, 8]: |
|
raise ValueError( |
|
incorrect_arg_msg.format( |
|
key="nbits", |
|
correct_value="2 or 4 or 8", |
|
found_value=self.nbits, |
|
), |
|
) |
|
if self.q_group_size <= 0: |
|
raise ValueError( |
|
incorrect_arg_msg.format( |
|
key="q_group_size", |
|
correct_value="a positive integer", |
|
found_value=self.q_group_size, |
|
), |
|
) |
|
if self.residual_length < 0: |
|
raise ValueError( |
|
incorrect_arg_msg.format( |
|
key="residual_length", |
|
correct_value="a positive integer", |
|
found_value=self.residual_length, |
|
), |
|
) |
|
|
|
if self.axis_key not in [0, 1, -1]: |
|
raise ValueError( |
|
incorrect_arg_msg.format( |
|
key="axis_key", |
|
correct_value="`1` or `0`, `-1`", |
|
found_value=self.axis_key, |
|
), |
|
) |
|
|
|
if self.axis_value not in [0, 1, -1]: |
|
raise ValueError( |
|
incorrect_arg_msg.format( |
|
key="axis_value", |
|
correct_value="`1` or `0` or `-1`", |
|
found_value=self.axis_value, |
|
), |
|
) |
|
|
|
|
|
class DynamicCache(Cache): |
|
""" |
|
A cache that grows dynamically as more tokens are generated. This is the default for generative models. |
|
|
|
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is |
|
`[batch_size, num_heads, seq_len, head_dim]`. |
|
""" |
|
|
|
def __init__(self) -> None: |
|
self.key_cache: List[torch.Tensor] = [] |
|
self.value_cache: List[torch.Tensor] = [] |
|
self._seen_tokens = 0 |
|
|
|
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: |
|
""" |
|
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the |
|
sequence length. |
|
""" |
|
if layer_idx < len(self): |
|
return (self.key_cache[layer_idx], self.value_cache[layer_idx]) |
|
else: |
|
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") |
|
|
|
def __iter__(self): |
|
""" |
|
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over |
|
keys and values |
|
""" |
|
for layer_idx in range(len(self)): |
|
yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) |
|
|
|
def __len__(self): |
|
""" |
|
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds |
|
to the number of layers in the model. |
|
""" |
|
return len(self.key_cache) |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
|
|
|
Parameters: |
|
key_states (`torch.Tensor`): |
|
The new key states to cache. |
|
value_states (`torch.Tensor`): |
|
The new value states to cache. |
|
layer_idx (`int`): |
|
The index of the layer to cache the states for. |
|
cache_kwargs (`Dict[str, Any]`, `optional`): |
|
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. |
|
|
|
Return: |
|
A tuple containing the updated key and value states. |
|
""" |
|
|
|
if layer_idx == 0: |
|
self._seen_tokens += key_states.shape[-2] |
|
|
|
|
|
if len(self.key_cache) <= layer_idx: |
|
self.key_cache.append(key_states) |
|
self.value_cache.append(value_states) |
|
else: |
|
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
|
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
|
|
|
return self.key_cache[layer_idx], self.value_cache[layer_idx] |
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
|
|
|
if len(self.key_cache) <= layer_idx: |
|
return 0 |
|
return self.key_cache[layer_idx].shape[-2] |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" |
|
return None |
|
|
|
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: |
|
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" |
|
legacy_cache = () |
|
for layer_idx in range(len(self)): |
|
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) |
|
return legacy_cache |
|
|
|
@classmethod |
|
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": |
|
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`.""" |
|
cache = cls() |
|
if past_key_values is not None: |
|
for layer_idx in range(len(past_key_values)): |
|
key_states, value_states = past_key_values[layer_idx] |
|
cache.update(key_states, value_states, layer_idx) |
|
return cache |
|
|
|
|
|
class QuantizedCache(DynamicCache): |
|
""" |
|
A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). |
|
It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. |
|
|
|
The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the |
|
original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The |
|
quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. |
|
|
|
It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and |
|
Value in original precision states as a list of tensors, one for each layer. The size of each tensor |
|
is `[batch_size, num_heads, seq_len - residual_length, head_dim]` |
|
""" |
|
|
|
def __init__(self, cache_config: QuantizedCacheConfig) -> None: |
|
self._quantized_key_cache: List[torch.Tensor] = [] |
|
self._quantized_value_cache: List[torch.Tensor] = [] |
|
|
|
self.nbits = cache_config.nbits |
|
self.residual_length = cache_config.residual_length |
|
self.q_group_size = cache_config.q_group_size |
|
self.axis_key = cache_config.axis_key |
|
self.axis_value = cache_config.axis_value |
|
self.compute_dtype = cache_config.compute_dtype |
|
self.device = cache_config.device |
|
|
|
super().__init__() |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
if layer_idx == 0: |
|
self._seen_tokens += key_states.shape[-2] |
|
|
|
if len(self.key_cache) <= layer_idx: |
|
self._quantized_key_cache.append(self._quantize(key_states.contiguous(), axis=self.axis_key)) |
|
self._quantized_value_cache.append(self._quantize(value_states.contiguous(), axis=self.axis_value)) |
|
self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) |
|
self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) |
|
keys_to_return, values_to_return = key_states, value_states |
|
else: |
|
dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) |
|
dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) |
|
keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states] |
|
values_to_return = [dequant_value, self.value_cache[layer_idx], value_states] |
|
|
|
keys_to_return = torch.cat(keys_to_return, dim=-2) |
|
values_to_return = torch.cat(values_to_return, dim=-2) |
|
if ( |
|
self.key_cache[layer_idx].dim() == 4 |
|
and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length |
|
): |
|
self._quantized_key_cache[layer_idx] = self._quantize(keys_to_return.contiguous(), axis=self.axis_key) |
|
self._quantized_value_cache[layer_idx] = self._quantize( |
|
values_to_return.contiguous(), axis=self.axis_value |
|
) |
|
self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) |
|
self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) |
|
else: |
|
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
|
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
|
|
|
return keys_to_return, values_to_return |
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
|
if len(self.key_cache) <= layer_idx: |
|
return 0 |
|
|
|
|
|
|
|
return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 |
|
|
|
def _quantize(self, tensor, axis): |
|
"""Quantizes a key/value using a defined quantization method.""" |
|
raise NotImplementedError("Make sure to implement `_quantize` in a subclass.") |
|
|
|
def _dequantize(self, q_tensor): |
|
"""Dequantizes back the tensor that was quantized by `self._quantize()`""" |
|
raise NotImplementedError("Make sure to implement `_dequantize` in a subclass.") |
|
|
|
|
|
class QuantoQuantizedCache(QuantizedCache): |
|
""" |
|
Quantized Cache class that uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. |
|
|
|
Parameters: |
|
cache_config (`QuantizedCacheConfig`,): |
|
A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. |
|
""" |
|
|
|
def __init__(self, cache_config: CacheConfig) -> None: |
|
super().__init__(cache_config) |
|
if self.nbits not in [2, 4]: |
|
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") |
|
|
|
if self.axis_key not in [0, -1]: |
|
raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}") |
|
|
|
if self.axis_value not in [0, -1]: |
|
raise ValueError( |
|
f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}" |
|
) |
|
|
|
self.qtype = qint4 if self.nbits == 4 else qint2 |
|
|
|
def _quantize(self, tensor, axis): |
|
qtensor = QBitsTensor.quantize(tensor, axis=axis, qtype=self.qtype, group_size=self.q_group_size) |
|
return qtensor |
|
|
|
def _dequantize(self, qtensor): |
|
return qtensor.dequantize() |
|
|
|
|
|
class HQQQuantizedCache(QuantizedCache): |
|
""" |
|
Quantized Cache class that uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. |
|
|
|
Parameters: |
|
cache_config (`QuantizedCacheConfig`,): |
|
A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. |
|
""" |
|
|
|
def __init__(self, cache_config: CacheConfig) -> None: |
|
super().__init__(cache_config) |
|
if self.nbits not in [1, 2, 3, 4, 8]: |
|
raise ValueError( |
|
f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}" |
|
) |
|
|
|
if self.axis_key not in [0, 1]: |
|
raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}") |
|
|
|
if self.axis_value not in [0, 1]: |
|
raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}") |
|
|
|
self.quantizer = HQQQuantizer |
|
|
|
def _quantize(self, tensor, axis): |
|
qtensor, meta = self.quantizer.quantize( |
|
tensor, |
|
axis=axis, |
|
device=self.device, |
|
compute_dtype=self.compute_dtype, |
|
nbits=self.nbits, |
|
group_size=self.q_group_size, |
|
) |
|
meta["compute_dtype"] = self.compute_dtype |
|
self.quantizer.cuda(qtensor, meta=meta, device=self.device) |
|
return qtensor, meta |
|
|
|
def _dequantize(self, qtensor): |
|
quant_tensor, meta = qtensor |
|
tensor = self.quantizer.dequantize(quant_tensor, meta) |
|
return tensor |
|
|
|
|
|
class SinkCache(Cache): |
|
""" |
|
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to |
|
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past |
|
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. |
|
|
|
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is |
|
`[batch_size, num_heads, seq_len, head_dim]`. |
|
|
|
Parameters: |
|
window_length (`int`): |
|
The length of the context window. |
|
num_sink_tokens (`int`): |
|
The number of sink tokens. See the original paper for more information. |
|
""" |
|
|
|
def __init__(self, window_length: int, num_sink_tokens: int) -> None: |
|
self.key_cache: List[torch.Tensor] = [] |
|
self.value_cache: List[torch.Tensor] = [] |
|
self.window_length = window_length |
|
self.num_sink_tokens = num_sink_tokens |
|
self.cos_sin_rerotation_cache = {} |
|
self._cos_cache = None |
|
self._sin_cache = None |
|
self._seen_tokens = 0 |
|
|
|
@staticmethod |
|
def _rotate_half(x): |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
def _apply_key_rotary_pos_emb( |
|
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor |
|
) -> torch.Tensor: |
|
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) |
|
return rotated_key_states |
|
|
|
def _get_rerotation_cos_sin( |
|
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if key_states.shape[-2] not in self.cos_sin_rerotation_cache: |
|
|
|
cos = cos.to(torch.float32) |
|
sin = sin.to(torch.float32) |
|
|
|
|
|
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :] |
|
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]] |
|
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :] |
|
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]] |
|
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin |
|
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin |
|
|
|
self.cos_sin_rerotation_cache[key_states.shape[-2]] = ( |
|
rerotation_cos.to(key_states.dtype).unsqueeze(0), |
|
rerotation_sin.to(key_states.dtype).unsqueeze(0), |
|
) |
|
return self.cos_sin_rerotation_cache[key_states.shape[-2]] |
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
|
|
|
|
|
if len(self.key_cache) <= layer_idx: |
|
return 0 |
|
return self.key_cache[layer_idx].shape[-2] |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
"""Returns the maximum sequence length of the cached states.""" |
|
return self.window_length |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
|
|
|
Parameters: |
|
key_states (`torch.Tensor`): |
|
The new key states to cache. |
|
value_states (`torch.Tensor`): |
|
The new value states to cache. |
|
layer_idx (`int`): |
|
The index of the layer to cache the states for. |
|
cache_kwargs (`Dict[str, Any]`, `optional`): |
|
Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, |
|
`cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the |
|
rotation as the tokens are shifted. |
|
|
|
Return: |
|
A tuple containing the updated key and value states. |
|
""" |
|
|
|
|
|
sin = cache_kwargs.get("sin") |
|
cos = cache_kwargs.get("cos") |
|
partial_rotation_size = cache_kwargs.get("partial_rotation_size") |
|
using_rope = cos is not None and sin is not None |
|
|
|
|
|
if layer_idx == 0: |
|
self._seen_tokens += key_states.shape[-2] |
|
|
|
|
|
if using_rope and layer_idx == 0: |
|
|
|
|
|
if cos.dim() == 2: |
|
self._cos_cache = cos |
|
self._sin_cache = sin |
|
else: |
|
if self._cos_cache is None: |
|
self._cos_cache = cos[0, ...] |
|
self._sin_cache = sin[0, ...] |
|
elif self._cos_cache.shape[0] < self.window_length: |
|
self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0) |
|
self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0) |
|
|
|
|
|
if len(self.key_cache) <= layer_idx: |
|
|
|
self.key_cache.append(key_states) |
|
self.value_cache.append(value_states) |
|
|
|
elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: |
|
|
|
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
|
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
|
|
|
else: |
|
|
|
keys_to_keep = self.key_cache[layer_idx][ |
|
:, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : |
|
] |
|
|
|
|
|
if using_rope: |
|
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( |
|
key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length] |
|
) |
|
if partial_rotation_size is not None: |
|
keys_to_keep, keys_pass = ( |
|
keys_to_keep[..., :partial_rotation_size], |
|
keys_to_keep[..., partial_rotation_size:], |
|
) |
|
keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin) |
|
if partial_rotation_size is not None: |
|
keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) |
|
|
|
|
|
sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] |
|
self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) |
|
|
|
sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] |
|
values_to_keep = self.value_cache[layer_idx][ |
|
:, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : |
|
] |
|
self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2) |
|
|
|
return self.key_cache[layer_idx], self.value_cache[layer_idx] |
|
|
|
|
|
class StaticCache(Cache): |
|
""" |
|
Static Cache class to be used with `torch.compile(model)`. |
|
|
|
Parameters: |
|
config (`PretrainedConfig): |
|
The configuration file defining the shape-related attributes required to initialize the static cache. |
|
max_batch_size (`int`): |
|
The maximum batch size with which the model will be used. |
|
max_cache_len (`int`): |
|
The maximum sequence length with which the model will be used. |
|
device (`torch.device`): |
|
The device on which the cache should be initialized. Should be the same as the layer. |
|
dtype (*optional*, defaults to `torch.float32`): |
|
The default `dtype` to use when initializing the layer. |
|
""" |
|
|
|
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: |
|
super().__init__() |
|
self.max_batch_size = max_batch_size |
|
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len |
|
|
|
self.head_dim = ( |
|
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads |
|
) |
|
|
|
self.dtype = dtype if dtype is not None else torch.float32 |
|
self.num_key_value_heads = ( |
|
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads |
|
) |
|
|
|
self.key_cache: List[torch.Tensor] = [] |
|
self.value_cache: List[torch.Tensor] = [] |
|
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) |
|
for _ in range(config.num_hidden_layers): |
|
|
|
|
|
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) |
|
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) |
|
torch._dynamo.mark_static_address(new_layer_key_cache) |
|
torch._dynamo.mark_static_address(new_layer_value_cache) |
|
self.key_cache.append(new_layer_key_cache) |
|
self.value_cache.append(new_layer_value_cache) |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
|
It is VERY important to index using a tensor, otherwise you introduce a copy to the device. |
|
|
|
Parameters: |
|
key_states (`torch.Tensor`): |
|
The new key states to cache. |
|
value_states (`torch.Tensor`): |
|
The new value states to cache. |
|
layer_idx (`int`): |
|
The index of the layer to cache the states for. |
|
cache_kwargs (`Dict[str, Any]`, `optional`): |
|
Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input |
|
to know how where to write in the cache. |
|
|
|
Return: |
|
A tuple containing the updated key and value states. |
|
""" |
|
cache_position = cache_kwargs.get("cache_position") |
|
k_out = self.key_cache[layer_idx] |
|
v_out = self.value_cache[layer_idx] |
|
|
|
k_out[:, :, cache_position] = key_states |
|
v_out[:, :, cache_position] = value_states |
|
|
|
return k_out, v_out |
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
|
"""Returns the sequence length of the cached states that were seen by the model.""" |
|
|
|
|
|
|
|
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
"""Returns the maximum sequence length of the cached states.""" |
|
return self.max_cache_len |
|
|
|
def reset(self): |
|
"""Resets the cache values while preserving the objects""" |
|
for layer_idx in range(len(self.key_cache)): |
|
|
|
self.key_cache[layer_idx].zero_() |
|
self.value_cache[layer_idx].zero_() |
|
|
|
|
|
class SlidingWindowCache(Cache): |
|
""" |
|
Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. |
|
Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window_size - 1`, |
|
if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), |
|
we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. |
|
|
|
The `to_shift` is only true once we are above sliding_window_size. Thus with `sliding_window_size==64`: |
|
|
|
indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window_size |
|
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, |
|
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, |
|
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, |
|
55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) |
|
|
|
We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window_size`) |
|
|
|
Parameters: |
|
config (`PretrainedConfig): |
|
The configuration file defining the shape-related attributes required to initialize the static cache. |
|
max_batch_size (`int`): |
|
The maximum batch size with which the model will be used. |
|
max_cache_len (`int`): |
|
The maximum sequence length with which the model will be used. |
|
device (`torch.device`): |
|
The device on which the cache should be initialized. Should be the same as the layer. |
|
dtype (*optional*, defaults to `torch.float32`): |
|
The default `dtype` to use when initializing the layer. |
|
""" |
|
|
|
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: |
|
if not hasattr(config, "sliding_window") or config.sliding_window is None: |
|
raise ValueError( |
|
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting " |
|
"sliding window attention, please check if there is a `sliding_window` field in the model " |
|
"config and it's not set to None." |
|
) |
|
|
|
super().__init__() |
|
self.max_batch_size = max_batch_size |
|
|
|
|
|
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len |
|
self.model_sliding_window_size = config.sliding_window |
|
self.sliding_window_size = min(self.max_cache_len, self.model_sliding_window_size) |
|
|
|
self.head_dim = ( |
|
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads |
|
) |
|
|
|
self.dtype = dtype if dtype is not None else torch.float32 |
|
self.num_key_value_heads = ( |
|
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads |
|
) |
|
|
|
cache_shape = ( |
|
config.num_hidden_layers, |
|
max_batch_size, |
|
self.num_key_value_heads, |
|
self.sliding_window_size, |
|
self.head_dim, |
|
) |
|
|
|
self.key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) |
|
self.value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) |
|
|
|
torch._dynamo.mark_static_address(self.key_cache) |
|
torch._dynamo.mark_static_address(self.value_cache) |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor]: |
|
cache_position = cache_kwargs.get("cache_position") |
|
k_out = self.key_cache[layer_idx] |
|
v_out = self.value_cache[layer_idx] |
|
|
|
|
|
if cache_position.shape[0] > self.sliding_window_size: |
|
k_out = key_states[:, :, -self.sliding_window_size :, :] |
|
v_out = value_states[:, :, -self.sliding_window_size :, :] |
|
self.key_cache[layer_idx] = k_out |
|
self.value_cache[layer_idx] = v_out |
|
|
|
|
|
return key_states, value_states |
|
|
|
slicing = torch.ones(self.sliding_window_size, dtype=torch.long, device=value_states.device).cumsum(0) |
|
cache_position = cache_position.clamp(0, self.sliding_window_size - 1) |
|
to_shift = cache_position >= self.sliding_window_size - 1 |
|
indices = (slicing + to_shift[-1].int() - 1) % self.sliding_window_size |
|
|
|
k_out = k_out[:, :, indices] |
|
v_out = v_out[:, :, indices] |
|
|
|
k_out[:, :, cache_position] = key_states |
|
v_out[:, :, cache_position] = value_states |
|
|
|
self.key_cache[layer_idx] = k_out |
|
self.value_cache[layer_idx] = v_out |
|
|
|
return k_out, v_out |
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
|
|
|
|
|
return 0 |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
|
|
|
|
return None |
|
|
|
def reset(self): |
|
self.key_cache.zero_() |
|
self.value_cache.zero_() |
|
|