AoiKazama commited on
Commit
16872b9
1 Parent(s): 096b9b8

Update cache_utils.py

Browse files
Files changed (1) hide show
  1. cache_utils.py +550 -48
cache_utils.py CHANGED
@@ -1,12 +1,21 @@
 
 
 
1
  from dataclasses import dataclass
2
- from typing import Any, Dict, List, Optional, Tuple
3
 
4
  import torch
5
 
6
  from transformers.configuration_utils import PretrainedConfig
7
- from transformers.utils import logging
8
 
9
 
 
 
 
 
 
 
10
  logger = logging.get_logger(__name__)
11
 
12
 
@@ -44,6 +53,7 @@ class Cache:
44
 
45
  def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
46
  """Returns the sequence length of the cached states. A layer index can be optionally passed."""
 
47
  raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
48
 
49
  def get_max_length(self) -> Optional[int]:
@@ -61,6 +71,14 @@ class Cache:
61
  return max_length - new_seq_length
62
  return previous_seq_length
63
 
 
 
 
 
 
 
 
 
64
  @property
65
  def seen_tokens(self):
66
  logger.warning_once(
@@ -73,6 +91,201 @@ class Cache:
73
  return None
74
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  class DynamicCache(Cache):
77
  """
78
  A cache that grows dynamically as more tokens are generated. This is the default for generative models.
@@ -150,6 +363,7 @@ class DynamicCache(Cache):
150
 
151
  def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
152
  """Returns the sequence length of the cached states. A layer index can be optionally passed."""
 
153
  if len(self.key_cache) <= layer_idx:
154
  return 0
155
  return self.key_cache[layer_idx].shape[-2]
@@ -158,14 +372,6 @@ class DynamicCache(Cache):
158
  """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
159
  return None
160
 
161
- def reorder_cache(self, beam_idx: torch.LongTensor):
162
- """Reorders the cache for beam search, given the selected beam indices."""
163
- for layer_idx in range(len(self.key_cache)):
164
- device = self.key_cache[layer_idx].device
165
- self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
166
- device = self.value_cache[layer_idx].device
167
- self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
168
-
169
  def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
170
  """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
171
  legacy_cache = ()
@@ -184,6 +390,168 @@ class DynamicCache(Cache):
184
  return cache
185
 
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  class SinkCache(Cache):
188
  """
189
  A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
@@ -205,7 +573,9 @@ class SinkCache(Cache):
205
  self.value_cache: List[torch.Tensor] = []
206
  self.window_length = window_length
207
  self.num_sink_tokens = num_sink_tokens
208
- self.cos_sin_cache = {}
 
 
209
  self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
210
 
211
  @staticmethod
@@ -223,7 +593,7 @@ class SinkCache(Cache):
223
  def _get_rerotation_cos_sin(
224
  self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
225
  ) -> Tuple[torch.Tensor, torch.Tensor]:
226
- if key_states.shape[-2] not in self.cos_sin_cache:
227
  # Upcast to float32 temporarily for better accuracy
228
  cos = cos.to(torch.float32)
229
  sin = sin.to(torch.float32)
@@ -236,14 +606,15 @@ class SinkCache(Cache):
236
  rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
237
  rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
238
 
239
- self.cos_sin_cache[key_states.shape[-2]] = (
240
  rerotation_cos.to(key_states.dtype).unsqueeze(0),
241
  rerotation_sin.to(key_states.dtype).unsqueeze(0),
242
  )
243
- return self.cos_sin_cache[key_states.shape[-2]]
244
 
245
  def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
246
  """Returns the sequence length of the cached states. A layer index can be optionally passed."""
 
247
  # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
248
  if len(self.key_cache) <= layer_idx:
249
  return 0
@@ -289,6 +660,21 @@ class SinkCache(Cache):
289
  if layer_idx == 0:
290
  self._seen_tokens += key_states.shape[-2]
291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  # [bsz, num_heads, seq_len, head_dim]
293
  if len(self.key_cache) <= layer_idx:
294
  # Empty cache
@@ -309,7 +695,7 @@ class SinkCache(Cache):
309
  # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
310
  if using_rope:
311
  rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
312
- key_states, cos[: self.window_length], sin[: self.window_length]
313
  )
314
  if partial_rotation_size is not None:
315
  keys_to_keep, keys_pass = (
@@ -332,14 +718,6 @@ class SinkCache(Cache):
332
 
333
  return self.key_cache[layer_idx], self.value_cache[layer_idx]
334
 
335
- def reorder_cache(self, beam_idx: torch.LongTensor):
336
- """Reorders the cache for beam search, given the selected beam indices."""
337
- for layer_idx in range(len(self.key_cache)):
338
- device = self.key_cache[layer_idx].device
339
- self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
340
- device = self.value_cache[layer_idx].device
341
- self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
342
-
343
 
344
  class StaticCache(Cache):
345
  """
@@ -347,8 +725,7 @@ class StaticCache(Cache):
347
 
348
  Parameters:
349
  config (`PretrainedConfig):
350
- The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
351
- required to initialize the static cache.
352
  max_batch_size (`int`):
353
  The maximum batch size with which the model will be used.
354
  max_cache_len (`int`):
@@ -373,9 +750,18 @@ class StaticCache(Cache):
373
  config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
374
  )
375
 
 
 
376
  cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
377
- self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
378
- self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
 
 
 
 
 
 
 
379
 
380
  def update(
381
  self,
@@ -394,42 +780,158 @@ class StaticCache(Cache):
394
  value_states (`torch.Tensor`):
395
  The new value states to cache.
396
  layer_idx (`int`):
397
- The index of the layer to cache the states for. Kept for backward compatibility
398
  cache_kwargs (`Dict[str, Any]`, `optional`):
399
- Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
400
- to know how much of the cache it should overwrite.
401
 
402
  Return:
403
  A tuple containing the updated key and value states.
404
  """
405
- new_cache_positions = cache_kwargs.get("cache_position")
406
- k_out = self.key_cache
407
- v_out = self.value_cache
408
 
409
- k_out[:, :, new_cache_positions] = key_states
410
- v_out[:, :, new_cache_positions] = value_states
411
 
412
  return k_out, v_out
413
 
414
  def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
415
- """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
416
  # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
417
  # limit the check to the first batch member and head dimension.
418
- # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
419
- # https://github.com/pytorch/pytorch/issues/120248 is fixed
420
- return (self.key_cache[0, 0].any(dim=-1)).sum()
421
 
422
  def get_max_length(self) -> Optional[int]:
423
- """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
424
  return self.max_cache_len
425
 
426
- def reorder_cache(self, beam_idx: torch.LongTensor):
427
- """Reorders the cache for beam search, given the selected beam indices."""
428
- device = self.key_cache.device
429
- self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))
430
- device = self.value_cache.device
431
- self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))
 
 
 
 
 
 
 
 
 
 
432
 
433
- def to_legacy_cache(self):
434
- """Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  return None
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import os
4
  from dataclasses import dataclass
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
 
7
  import torch
8
 
9
  from transformers.configuration_utils import PretrainedConfig
10
+ from transformers.utils import is_hqq_available, is_quanto_available, logging
11
 
12
 
13
+ if is_quanto_available():
14
+ from quanto import QBitsTensor, qint2, qint4
15
+
16
+ if is_hqq_available():
17
+ from hqq.core.quantize import Quantizer as HQQQuantizer
18
+
19
  logger = logging.get_logger(__name__)
20
 
21
 
 
53
 
54
  def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
55
  """Returns the sequence length of the cached states. A layer index can be optionally passed."""
56
+ # TODO: deprecate this function in favor of `cache_position`
57
  raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
58
 
59
  def get_max_length(self) -> Optional[int]:
 
71
  return max_length - new_seq_length
72
  return previous_seq_length
73
 
74
+ def reorder_cache(self, beam_idx: torch.LongTensor):
75
+ """Reorders the cache for beam search, given the selected beam indices."""
76
+ for layer_idx in range(len(self.key_cache)):
77
+ device = self.key_cache[layer_idx].device
78
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
79
+ device = self.value_cache[layer_idx].device
80
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
81
+
82
  @property
83
  def seen_tokens(self):
84
  logger.warning_once(
 
91
  return None
92
 
93
 
94
+ @dataclass
95
+ class CacheConfig:
96
+ """
97
+ Base class for cache configs
98
+ """
99
+
100
+ cache_implementation: None
101
+
102
+ @classmethod
103
+ def from_dict(cls, config_dict, **kwargs):
104
+ """
105
+ Constructs a CacheConfig instance from a dictionary of parameters.
106
+ Args:
107
+ config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
108
+ **kwargs: Additional keyword arguments to override dictionary values.
109
+ Returns:
110
+ CacheConfig: Instance of CacheConfig constructed from the dictionary.
111
+ """
112
+ config = cls(**config_dict)
113
+ to_remove = []
114
+ for key, value in kwargs.items():
115
+ if hasattr(config, key):
116
+ setattr(config, key, value)
117
+ to_remove.append(key)
118
+ for key in to_remove:
119
+ kwargs.pop(key, None)
120
+ return config
121
+
122
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
123
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
124
+ """
125
+ Save this instance to a JSON file.
126
+
127
+ Args:
128
+ json_file_path (`str` or `os.PathLike`):
129
+ Path to the JSON file in which this configuration instance's parameters will be saved.
130
+ use_diff (`bool`, *optional*, defaults to `True`):
131
+ If set to `True`, only the difference between the config instance and the default
132
+ `QuantizationConfig()` is serialized to JSON file.
133
+ """
134
+ with open(json_file_path, "w", encoding="utf-8") as writer:
135
+ config_dict = self.to_dict()
136
+ json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
137
+
138
+ writer.write(json_string)
139
+
140
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict
141
+ def to_dict(self) -> Dict[str, Any]:
142
+ """
143
+ Serializes this instance to a Python dictionary. Returns:
144
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
145
+ """
146
+ return copy.deepcopy(self.__dict__)
147
+
148
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
149
+ def __iter__(self):
150
+ """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
151
+ for attr, value in copy.deepcopy(self.__dict__).items():
152
+ yield attr, value
153
+
154
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
155
+ def __repr__(self):
156
+ return f"{self.__class__.__name__} {self.to_json_string()}"
157
+
158
+ def to_json_string(self):
159
+ """
160
+ Serializes this instance to a JSON formatted string.
161
+ Returns:
162
+ str: JSON formatted string representing the configuration instance.
163
+ """
164
+ return json.dumps(self.__dict__, indent=2) + "\n"
165
+
166
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update
167
+ def update(self, **kwargs):
168
+ """
169
+ Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes,
170
+ returning all the unused kwargs.
171
+
172
+ Args:
173
+ kwargs (`Dict[str, Any]`):
174
+ Dictionary of attributes to tentatively update this class.
175
+
176
+ Returns:
177
+ `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
178
+ """
179
+ to_remove = []
180
+ for key, value in kwargs.items():
181
+ if hasattr(self, key):
182
+ setattr(self, key, value)
183
+ to_remove.append(key)
184
+
185
+ # Remove all the attributes that were updated, without modifying the input dict
186
+ unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
187
+ return unused_kwargs
188
+
189
+
190
+ @dataclass
191
+ class QuantizedCacheConfig(CacheConfig):
192
+ """
193
+ Configuration class for quantized cache settings.
194
+
195
+ Attributes:
196
+ backend (`str`, *optional*, defaults to `"quanto"`):
197
+ Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`]
198
+ nbits (`Optional[int]`, *optional*, defaults to 4):
199
+ Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2.
200
+ axis_key (`int`, *optional*, defaults to 0):
201
+ Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend.
202
+ axis_value (`int`, *optional*, defaults to 0):
203
+ Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend.
204
+ q_group_size (`Optional[int]`, *optional*, defaults to 64):
205
+ Size of the quantization group, should be a divisor of the model's hidden dimension.
206
+ Defaults to 64.
207
+ residual_length (`Optional[int]`, *optional*, defaults to 128):
208
+ Length of the residual cache which will always be stored in original presicion.
209
+ Defaults to 128.
210
+ compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
211
+ The defualt dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization.
212
+ device (`str`, *optional*, defaults to `"cpu"`):
213
+ Device on which to peform computations, should be same as the model's device.
214
+ """
215
+
216
+ def __init__(
217
+ self,
218
+ backend: str = "quanto",
219
+ nbits: Optional[int] = 4,
220
+ axis_key: Optional[int] = 0,
221
+ axis_value: Optional[int] = 0,
222
+ q_group_size: Optional[int] = 64,
223
+ residual_length: Optional[int] = 128,
224
+ compute_dtype: Optional[torch.dtype] = torch.float16,
225
+ device: Optional[str] = "cpu",
226
+ ):
227
+ self.backend = backend
228
+ self.nbits = nbits
229
+ self.axis_key = axis_key
230
+ self.axis_value = axis_value
231
+ self.q_group_size = q_group_size
232
+ self.residual_length = residual_length
233
+ self.compute_dtype = compute_dtype
234
+ self.device = device
235
+
236
+ def validate(self):
237
+ """Validates if the arguments passed are correct"""
238
+
239
+ incorrect_arg_msg = (
240
+ "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` "
241
+ "but found {found_value}"
242
+ )
243
+ # Check that the values are reasonable in general (nbits, axis)
244
+ # Later in QuantizedCache init we check if they are supported for that particular backend
245
+ if self.nbits not in [1, 2, 3, 4, 8]:
246
+ raise ValueError(
247
+ incorrect_arg_msg.format(
248
+ key="nbits",
249
+ correct_value="2 or 4 or 8",
250
+ found_value=self.nbits,
251
+ ),
252
+ )
253
+ if self.q_group_size <= 0:
254
+ raise ValueError(
255
+ incorrect_arg_msg.format(
256
+ key="q_group_size",
257
+ correct_value="a positive integer",
258
+ found_value=self.q_group_size,
259
+ ),
260
+ )
261
+ if self.residual_length < 0:
262
+ raise ValueError(
263
+ incorrect_arg_msg.format(
264
+ key="residual_length",
265
+ correct_value="a positive integer",
266
+ found_value=self.residual_length,
267
+ ),
268
+ )
269
+
270
+ if self.axis_key not in [0, 1, -1]:
271
+ raise ValueError(
272
+ incorrect_arg_msg.format(
273
+ key="axis_key",
274
+ correct_value="`1` or `0`, `-1`",
275
+ found_value=self.axis_key,
276
+ ),
277
+ )
278
+
279
+ if self.axis_value not in [0, 1, -1]:
280
+ raise ValueError(
281
+ incorrect_arg_msg.format(
282
+ key="axis_value",
283
+ correct_value="`1` or `0` or `-1`",
284
+ found_value=self.axis_value,
285
+ ),
286
+ )
287
+
288
+
289
  class DynamicCache(Cache):
290
  """
291
  A cache that grows dynamically as more tokens are generated. This is the default for generative models.
 
363
 
364
  def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
365
  """Returns the sequence length of the cached states. A layer index can be optionally passed."""
366
+ # TODO: deprecate this function in favor of `cache_position`
367
  if len(self.key_cache) <= layer_idx:
368
  return 0
369
  return self.key_cache[layer_idx].shape[-2]
 
372
  """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
373
  return None
374
 
 
 
 
 
 
 
 
 
375
  def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
376
  """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
377
  legacy_cache = ()
 
390
  return cache
391
 
392
 
393
+ class QuantizedCache(DynamicCache):
394
+ """
395
+ A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).
396
+ It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.
397
+
398
+ The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the
399
+ original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
400
+ quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper.
401
+
402
+ It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and
403
+ Value in original precision states as a list of tensors, one for each layer. The size of each tensor
404
+ is `[batch_size, num_heads, seq_len - residual_length, head_dim]`
405
+ """
406
+
407
+ def __init__(self, cache_config: QuantizedCacheConfig) -> None:
408
+ self._quantized_key_cache: List[torch.Tensor] = []
409
+ self._quantized_value_cache: List[torch.Tensor] = []
410
+
411
+ self.nbits = cache_config.nbits
412
+ self.residual_length = cache_config.residual_length
413
+ self.q_group_size = cache_config.q_group_size
414
+ self.axis_key = cache_config.axis_key
415
+ self.axis_value = cache_config.axis_value
416
+ self.compute_dtype = cache_config.compute_dtype
417
+ self.device = cache_config.device
418
+
419
+ super().__init__()
420
+
421
+ def update(
422
+ self,
423
+ key_states: torch.Tensor,
424
+ value_states: torch.Tensor,
425
+ layer_idx: int,
426
+ cache_kwargs: Optional[Dict[str, Any]] = None,
427
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
428
+ # Update the number of seen tokens
429
+ if layer_idx == 0:
430
+ self._seen_tokens += key_states.shape[-2]
431
+
432
+ if len(self.key_cache) <= layer_idx:
433
+ self._quantized_key_cache.append(self._quantize(key_states.contiguous(), axis=self.axis_key))
434
+ self._quantized_value_cache.append(self._quantize(value_states.contiguous(), axis=self.axis_value))
435
+ self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
436
+ self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
437
+ keys_to_return, values_to_return = key_states, value_states
438
+ else:
439
+ dequant_key = self._dequantize(self._quantized_key_cache[layer_idx])
440
+ dequant_value = self._dequantize(self._quantized_value_cache[layer_idx])
441
+ keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states]
442
+ values_to_return = [dequant_value, self.value_cache[layer_idx], value_states]
443
+
444
+ keys_to_return = torch.cat(keys_to_return, dim=-2)
445
+ values_to_return = torch.cat(values_to_return, dim=-2)
446
+ if (
447
+ self.key_cache[layer_idx].dim() == 4
448
+ and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length
449
+ ):
450
+ self._quantized_key_cache[layer_idx] = self._quantize(keys_to_return.contiguous(), axis=self.axis_key)
451
+ self._quantized_value_cache[layer_idx] = self._quantize(
452
+ values_to_return.contiguous(), axis=self.axis_value
453
+ )
454
+ self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
455
+ self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
456
+ else:
457
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
458
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
459
+
460
+ return keys_to_return, values_to_return
461
+
462
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
463
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
464
+ if len(self.key_cache) <= layer_idx:
465
+ return 0
466
+ # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is
467
+ # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx
468
+ # this part of code otherwise fails when used to verify attn_weight shape in some models
469
+ return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1
470
+
471
+ def _quantize(self, tensor, axis):
472
+ """Quantizes a key/value using a defined quantization method."""
473
+ raise NotImplementedError("Make sure to implement `_quantize` in a subclass.")
474
+
475
+ def _dequantize(self, q_tensor):
476
+ """Dequantizes back the tensor that was quantized by `self._quantize()`"""
477
+ raise NotImplementedError("Make sure to implement `_dequantize` in a subclass.")
478
+
479
+
480
+ class QuantoQuantizedCache(QuantizedCache):
481
+ """
482
+ Quantized Cache class that uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only.
483
+
484
+ Parameters:
485
+ cache_config (`QuantizedCacheConfig`,):
486
+ A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
487
+ """
488
+
489
+ def __init__(self, cache_config: CacheConfig) -> None:
490
+ super().__init__(cache_config)
491
+ if self.nbits not in [2, 4]:
492
+ raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
493
+
494
+ if self.axis_key not in [0, -1]:
495
+ raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}")
496
+
497
+ if self.axis_value not in [0, -1]:
498
+ raise ValueError(
499
+ f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}"
500
+ )
501
+
502
+ self.qtype = qint4 if self.nbits == 4 else qint2
503
+
504
+ def _quantize(self, tensor, axis):
505
+ qtensor = QBitsTensor.quantize(tensor, axis=axis, qtype=self.qtype, group_size=self.q_group_size)
506
+ return qtensor
507
+
508
+ def _dequantize(self, qtensor):
509
+ return qtensor.dequantize()
510
+
511
+
512
+ class HQQQuantizedCache(QuantizedCache):
513
+ """
514
+ Quantized Cache class that uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes.
515
+
516
+ Parameters:
517
+ cache_config (`QuantizedCacheConfig`,):
518
+ A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
519
+ """
520
+
521
+ def __init__(self, cache_config: CacheConfig) -> None:
522
+ super().__init__(cache_config)
523
+ if self.nbits not in [1, 2, 3, 4, 8]:
524
+ raise ValueError(
525
+ f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}"
526
+ )
527
+
528
+ if self.axis_key not in [0, 1]:
529
+ raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}")
530
+
531
+ if self.axis_value not in [0, 1]:
532
+ raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}")
533
+
534
+ self.quantizer = HQQQuantizer
535
+
536
+ def _quantize(self, tensor, axis):
537
+ qtensor, meta = self.quantizer.quantize(
538
+ tensor,
539
+ axis=axis,
540
+ device=self.device,
541
+ compute_dtype=self.compute_dtype,
542
+ nbits=self.nbits,
543
+ group_size=self.q_group_size,
544
+ )
545
+ meta["compute_dtype"] = self.compute_dtype
546
+ self.quantizer.cuda(qtensor, meta=meta, device=self.device) # Move to device and cast to dtype
547
+ return qtensor, meta
548
+
549
+ def _dequantize(self, qtensor):
550
+ quant_tensor, meta = qtensor
551
+ tensor = self.quantizer.dequantize(quant_tensor, meta)
552
+ return tensor
553
+
554
+
555
  class SinkCache(Cache):
556
  """
557
  A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
 
573
  self.value_cache: List[torch.Tensor] = []
574
  self.window_length = window_length
575
  self.num_sink_tokens = num_sink_tokens
576
+ self.cos_sin_rerotation_cache = {}
577
+ self._cos_cache = None
578
+ self._sin_cache = None
579
  self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
580
 
581
  @staticmethod
 
593
  def _get_rerotation_cos_sin(
594
  self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
595
  ) -> Tuple[torch.Tensor, torch.Tensor]:
596
+ if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
597
  # Upcast to float32 temporarily for better accuracy
598
  cos = cos.to(torch.float32)
599
  sin = sin.to(torch.float32)
 
606
  rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
607
  rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
608
 
609
+ self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
610
  rerotation_cos.to(key_states.dtype).unsqueeze(0),
611
  rerotation_sin.to(key_states.dtype).unsqueeze(0),
612
  )
613
+ return self.cos_sin_rerotation_cache[key_states.shape[-2]]
614
 
615
  def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
616
  """Returns the sequence length of the cached states. A layer index can be optionally passed."""
617
+ # TODO: deprecate this function in favor of `cache_position`
618
  # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
619
  if len(self.key_cache) <= layer_idx:
620
  return 0
 
660
  if layer_idx == 0:
661
  self._seen_tokens += key_states.shape[-2]
662
 
663
+ # Update the sin/cos cache, which holds sin/cos values for all possible positions
664
+ if using_rope and layer_idx == 0:
665
+ # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
666
+ # after all RoPE models have a llama-like cache utilization.
667
+ if cos.dim() == 2:
668
+ self._cos_cache = cos
669
+ self._sin_cache = sin
670
+ else:
671
+ if self._cos_cache is None:
672
+ self._cos_cache = cos[0, ...]
673
+ self._sin_cache = sin[0, ...]
674
+ elif self._cos_cache.shape[0] < self.window_length:
675
+ self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0)
676
+ self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0)
677
+
678
  # [bsz, num_heads, seq_len, head_dim]
679
  if len(self.key_cache) <= layer_idx:
680
  # Empty cache
 
695
  # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
696
  if using_rope:
697
  rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
698
+ key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length]
699
  )
700
  if partial_rotation_size is not None:
701
  keys_to_keep, keys_pass = (
 
718
 
719
  return self.key_cache[layer_idx], self.value_cache[layer_idx]
720
 
 
 
 
 
 
 
 
 
721
 
722
  class StaticCache(Cache):
723
  """
 
725
 
726
  Parameters:
727
  config (`PretrainedConfig):
728
+ The configuration file defining the shape-related attributes required to initialize the static cache.
 
729
  max_batch_size (`int`):
730
  The maximum batch size with which the model will be used.
731
  max_cache_len (`int`):
 
750
  config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
751
  )
752
 
753
+ self.key_cache: List[torch.Tensor] = []
754
+ self.value_cache: List[torch.Tensor] = []
755
  cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
756
+ for _ in range(config.num_hidden_layers):
757
+ # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
758
+ # breaks when updating the cache.
759
+ new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
760
+ new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
761
+ torch._dynamo.mark_static_address(new_layer_key_cache)
762
+ torch._dynamo.mark_static_address(new_layer_value_cache)
763
+ self.key_cache.append(new_layer_key_cache)
764
+ self.value_cache.append(new_layer_value_cache)
765
 
766
  def update(
767
  self,
 
780
  value_states (`torch.Tensor`):
781
  The new value states to cache.
782
  layer_idx (`int`):
783
+ The index of the layer to cache the states for.
784
  cache_kwargs (`Dict[str, Any]`, `optional`):
785
+ Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
786
+ to know how where to write in the cache.
787
 
788
  Return:
789
  A tuple containing the updated key and value states.
790
  """
791
+ cache_position = cache_kwargs.get("cache_position")
792
+ k_out = self.key_cache[layer_idx]
793
+ v_out = self.value_cache[layer_idx]
794
 
795
+ k_out[:, :, cache_position] = key_states
796
+ v_out[:, :, cache_position] = value_states
797
 
798
  return k_out, v_out
799
 
800
  def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
801
+ """Returns the sequence length of the cached states that were seen by the model."""
802
  # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
803
  # limit the check to the first batch member and head dimension.
804
+ # TODO: deprecate this function in favor of `cache_position`
805
+ return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
 
806
 
807
  def get_max_length(self) -> Optional[int]:
808
+ """Returns the maximum sequence length of the cached states."""
809
  return self.max_cache_len
810
 
811
+ def reset(self):
812
+ """Resets the cache values while preserving the objects"""
813
+ for layer_idx in range(len(self.key_cache)):
814
+ # In-place ops prevent breaking the static address
815
+ self.key_cache[layer_idx].zero_()
816
+ self.value_cache[layer_idx].zero_()
817
+
818
+
819
+ class SlidingWindowCache(Cache):
820
+ """
821
+ Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
822
+ Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window_size - 1`,
823
+ if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint),
824
+ we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in.
825
+
826
+ The `to_shift` is only true once we are above sliding_window_size. Thus with `sliding_window_size==64`:
827
 
828
+ indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window_size
829
+ tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
830
+ 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
831
+ 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
832
+ 55, 56, 57, 58, 59, 60, 61, 62, 63, 0])
833
+
834
+ We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window_size`)
835
+
836
+ Parameters:
837
+ config (`PretrainedConfig):
838
+ The configuration file defining the shape-related attributes required to initialize the static cache.
839
+ max_batch_size (`int`):
840
+ The maximum batch size with which the model will be used.
841
+ max_cache_len (`int`):
842
+ The maximum sequence length with which the model will be used.
843
+ device (`torch.device`):
844
+ The device on which the cache should be initialized. Should be the same as the layer.
845
+ dtype (*optional*, defaults to `torch.float32`):
846
+ The default `dtype` to use when initializing the layer.
847
+ """
848
+
849
+ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
850
+ if not hasattr(config, "sliding_window") or config.sliding_window is None:
851
+ raise ValueError(
852
+ "Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
853
+ "sliding window attention, please check if there is a `sliding_window` field in the model "
854
+ "config and it's not set to None."
855
+ )
856
+
857
+ super().__init__()
858
+ self.max_batch_size = max_batch_size
859
+ # take the minimum of max_cache_len and config.sliding_window so that we allocate less memory
860
+ # when we do short-sentence generation
861
+ self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
862
+ self.model_sliding_window_size = config.sliding_window
863
+ self.sliding_window_size = min(self.max_cache_len, self.model_sliding_window_size)
864
+ # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
865
+ self.head_dim = (
866
+ config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
867
+ )
868
+
869
+ self.dtype = dtype if dtype is not None else torch.float32
870
+ self.num_key_value_heads = (
871
+ config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
872
+ )
873
+
874
+ cache_shape = (
875
+ config.num_hidden_layers,
876
+ max_batch_size,
877
+ self.num_key_value_heads,
878
+ self.sliding_window_size,
879
+ self.head_dim,
880
+ )
881
+
882
+ self.key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
883
+ self.value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
884
+
885
+ torch._dynamo.mark_static_address(self.key_cache)
886
+ torch._dynamo.mark_static_address(self.value_cache)
887
+
888
+ def update(
889
+ self,
890
+ key_states: torch.Tensor,
891
+ value_states: torch.Tensor,
892
+ layer_idx: int,
893
+ cache_kwargs: Optional[Dict[str, Any]] = None,
894
+ ) -> Tuple[torch.Tensor]:
895
+ cache_position = cache_kwargs.get("cache_position")
896
+ k_out = self.key_cache[layer_idx]
897
+ v_out = self.value_cache[layer_idx]
898
+
899
+ # assume this only happens in prefill phase when prompt length > sliding_window_size
900
+ if cache_position.shape[0] > self.sliding_window_size:
901
+ k_out = key_states[:, :, -self.sliding_window_size :, :]
902
+ v_out = value_states[:, :, -self.sliding_window_size :, :]
903
+ self.key_cache[layer_idx] = k_out
904
+ self.value_cache[layer_idx] = v_out
905
+ # we should return the whole states instead of k_out, v_out to take the whole prompt
906
+ # into consideration when building kv cache instead of just throwing away tokens outside of the window
907
+ return key_states, value_states
908
+
909
+ slicing = torch.ones(self.sliding_window_size, dtype=torch.long, device=value_states.device).cumsum(0)
910
+ cache_position = cache_position.clamp(0, self.sliding_window_size - 1)
911
+ to_shift = cache_position >= self.sliding_window_size - 1
912
+ indices = (slicing + to_shift[-1].int() - 1) % self.sliding_window_size
913
+
914
+ k_out = k_out[:, :, indices]
915
+ v_out = v_out[:, :, indices]
916
+
917
+ k_out[:, :, cache_position] = key_states
918
+ v_out[:, :, cache_position] = value_states
919
+
920
+ self.key_cache[layer_idx] = k_out
921
+ self.value_cache[layer_idx] = v_out
922
+
923
+ return k_out, v_out
924
+
925
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
926
+ # assume this will be called only in the first generation step
927
+ # `cache_postion` will be used in other cases
928
+ return 0
929
+
930
+ def get_max_length(self) -> Optional[int]:
931
+ # in theory there is no limit because the sliding window size is fixed
932
+ # no matter how long the sentence is
933
  return None
934
+
935
+ def reset(self):
936
+ self.key_cache.zero_()
937
+ self.value_cache.zero_()