damerajee commited on
Commit
f63f8c7
1 Parent(s): 8f6227e

Update modeling_Llamoe.py

Browse files
Files changed (1) hide show
  1. modeling_Llamoe.py +73 -31
modeling_Llamoe.py CHANGED
@@ -167,41 +167,58 @@ class LlamoeRMSNorm(nn.Module):
167
  return self.weight * hidden_states.to(input_dtype)
168
 
169
 
170
- # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral
 
171
  class LlamoeRotaryEmbedding(nn.Module):
172
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
173
  super().__init__()
174
-
175
  self.dim = dim
176
  self.max_position_embeddings = max_position_embeddings
177
  self.base = base
178
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
179
  self.register_buffer("inv_freq", inv_freq, persistent=False)
180
-
181
- # Build here to make `torch.jit.trace` work.
182
- self._set_cos_sin_cache(
183
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
184
- )
185
-
186
- def _set_cos_sin_cache(self, seq_len, device, dtype):
187
- self.max_seq_len_cached = seq_len
188
  t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
189
-
190
  freqs = torch.outer(t, self.inv_freq)
191
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
192
  emb = torch.cat((freqs, freqs), dim=-1)
193
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
194
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
195
 
196
- def forward(self, x, seq_len=None):
197
- # x: [bs, num_attention_heads, seq_len, head_size]
198
- if seq_len > self.max_seq_len_cached:
199
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
 
 
 
200
 
201
- return (
202
- self.cos_cached[:seq_len].to(dtype=x.dtype),
203
- self.sin_cached[:seq_len].to(dtype=x.dtype),
 
 
204
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
 
207
  # Copied from transformers.models.llama.modeling_llama.rotate_half
@@ -212,8 +229,8 @@ def rotate_half(x):
212
  return torch.cat((-x2, x1), dim=-1)
213
 
214
 
215
- # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
216
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
217
  """Applies Rotary Position Embedding to the query and key tensors.
218
 
219
  Args:
@@ -221,9 +238,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
221
  k (`torch.Tensor`): The key tensor.
222
  cos (`torch.Tensor`): The cosine part of the rotary embedding.
223
  sin (`torch.Tensor`): The sine part of the rotary embedding.
224
- position_ids (`torch.Tensor`):
225
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
226
- used to pass offsetted position ids when working with a KV-cache.
227
  unsqueeze_dim (`int`, *optional*, defaults to 1):
228
  The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
229
  sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -234,8 +250,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
234
  Returns:
235
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
236
  """
237
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
238
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
239
  q_embed = (q * cos) + (rotate_half(q) * sin)
240
  k_embed = (k * cos) + (rotate_half(k) * sin)
241
  return q_embed, k_embed
@@ -254,7 +270,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
254
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
255
 
256
 
257
- # Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
258
  class LlamoeAttention(nn.Module):
259
  """Multi-headed attention from 'Attention Is All You Need' paper"""
260
 
@@ -873,11 +889,11 @@ LLAMOE_START_DOCSTRING = r"""
873
  )
874
  # Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral
875
  class LlamoePreTrainedModel(PreTrainedModel):
876
- config_class = LlamoeConfig
877
  base_model_prefix = "model"
878
  supports_gradient_checkpointing = True
879
  _no_split_modules = ["LlamoeDecoderLayer"]
880
- _skip_keys_device_placement = "past_key_values"
881
  _supports_flash_attn_2 = True
882
  _supports_sdpa = True
883
  _supports_cache_class = True
@@ -893,6 +909,32 @@ class LlamoePreTrainedModel(PreTrainedModel):
893
  if module.padding_idx is not None:
894
  module.weight.data[module.padding_idx].zero_()
895
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
896
 
897
  LLAMOE_INPUTS_DOCSTRING = r"""
898
  Args:
 
167
  return self.weight * hidden_states.to(input_dtype)
168
 
169
 
170
+
171
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mixtral
172
  class LlamoeRotaryEmbedding(nn.Module):
173
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
174
  super().__init__()
175
+ self.scaling_factor = scaling_factor
176
  self.dim = dim
177
  self.max_position_embeddings = max_position_embeddings
178
  self.base = base
179
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
180
  self.register_buffer("inv_freq", inv_freq, persistent=False)
181
+ # For BC we register cos and sin cached
182
+ self.max_seq_len_cached = max_position_embeddings
 
 
 
 
 
 
183
  t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
184
+ t = t / self.scaling_factor
185
  freqs = torch.outer(t, self.inv_freq)
186
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
187
  emb = torch.cat((freqs, freqs), dim=-1)
188
+ self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
189
+ self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
190
 
191
+ @property
192
+ def sin_cached(self):
193
+ logger.warning_once(
194
+ "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
195
+ "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
196
+ )
197
+ return self._sin_cached
198
 
199
+ @property
200
+ def cos_cached(self):
201
+ logger.warning_once(
202
+ "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
203
+ "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
204
  )
205
+ return self._cos_cached
206
+
207
+ @torch.no_grad()
208
+ def forward(self, x, position_ids):
209
+ # x: [bs, num_attention_heads, seq_len, head_size]
210
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
211
+ position_ids_expanded = position_ids[:, None, :].float()
212
+ # Force float32 since bfloat16 loses precision on long contexts
213
+ # See https://github.com/huggingface/transformers/pull/29285
214
+ device_type = x.device.type
215
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
216
+ with torch.autocast(device_type=device_type, enabled=False):
217
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
218
+ emb = torch.cat((freqs, freqs), dim=-1)
219
+ cos = emb.cos()
220
+ sin = emb.sin()
221
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
222
 
223
 
224
  # Copied from transformers.models.llama.modeling_llama.rotate_half
 
229
  return torch.cat((-x2, x1), dim=-1)
230
 
231
 
232
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
233
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
234
  """Applies Rotary Position Embedding to the query and key tensors.
235
 
236
  Args:
 
238
  k (`torch.Tensor`): The key tensor.
239
  cos (`torch.Tensor`): The cosine part of the rotary embedding.
240
  sin (`torch.Tensor`): The sine part of the rotary embedding.
241
+ position_ids (`torch.Tensor`, *optional*):
242
+ Deprecated and unused.
 
243
  unsqueeze_dim (`int`, *optional*, defaults to 1):
244
  The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
245
  sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
 
250
  Returns:
251
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
252
  """
253
+ cos = cos.unsqueeze(unsqueeze_dim)
254
+ sin = sin.unsqueeze(unsqueeze_dim)
255
  q_embed = (q * cos) + (rotate_half(q) * sin)
256
  k_embed = (k * cos) + (rotate_half(k) * sin)
257
  return q_embed, k_embed
 
270
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
271
 
272
 
273
+ # Copied from transformers.models.mistral.modeling_mistral.LlamaAttention with Llama->Mixtral
274
  class LlamoeAttention(nn.Module):
275
  """Multi-headed attention from 'Attention Is All You Need' paper"""
276
 
 
889
  )
890
  # Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral
891
  class LlamoePreTrainedModel(PreTrainedModel):
892
+ config_class = LlamaConfig
893
  base_model_prefix = "model"
894
  supports_gradient_checkpointing = True
895
  _no_split_modules = ["LlamoeDecoderLayer"]
896
+ _skip_keys_device_placement = ["past_key_values", "causal_mask"]
897
  _supports_flash_attn_2 = True
898
  _supports_sdpa = True
899
  _supports_cache_class = True
 
909
  if module.padding_idx is not None:
910
  module.weight.data[module.padding_idx].zero_()
911
 
912
+ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
913
+ if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
914
+ raise ValueError(
915
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
916
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
917
+ )
918
+
919
+ if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
920
+ causal_mask = torch.full(
921
+ (max_cache_len, max_cache_len), fill_value=True, device=self.device, dtype=torch.bool
922
+ )
923
+ self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
924
+
925
+ for layer in self.model.layers:
926
+ device = layer.input_layernorm.weight.device
927
+ if hasattr(self.config, "_pre_quantization_dtype"):
928
+ dtype = self.config._pre_quantization_dtype
929
+ else:
930
+ dtype = layer.self_attn.o_proj.weight.dtype
931
+ layer.self_attn.past_key_value = cache_cls(
932
+ self.config, max_batch_size, max_cache_len, device=device, dtype=dtype
933
+ )
934
+
935
+ def _reset_cache(self):
936
+ for layer in self.model.layers:
937
+ layer.self_attn.past_key_value = None
938
 
939
  LLAMOE_INPUTS_DOCSTRING = r"""
940
  Args: