damerajee commited on
Commit
e4ae404
1 Parent(s): 7fc5f8a

Update modeling_Llamoe.py

Browse files
Files changed (1) hide show
  1. modeling_Llamoe.py +28 -71
modeling_Llamoe.py CHANGED
@@ -162,60 +162,34 @@ ALL_LAYERNORM_LAYERS.append(LlamoeRMSNorm)
162
 
163
 
164
  class LlamoeRotaryEmbedding(nn.Module):
165
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
166
  super().__init__()
167
- self.scaling_factor = scaling_factor
168
  self.dim = dim
169
  self.max_position_embeddings = max_position_embeddings
170
  self.base = base
171
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
172
- self.register_buffer("inv_freq", inv_freq, persistent=False)
173
- # For BC we register cos and sin cached
174
- self.max_seq_len_cached = max_position_embeddings
175
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
176
- t = t / self.scaling_factor
177
- freqs = torch.outer(t, self.inv_freq)
178
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
179
- emb = torch.cat((freqs, freqs), dim=-1)
180
- self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
181
- self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
182
-
183
- @property
184
- def sin_cached(self):
185
- logger.warning_once(
186
- "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
187
- "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
188
- )
189
- return self._sin_cached
190
-
191
- @property
192
- def cos_cached(self):
193
- logger.warning_once(
194
- "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
195
- "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
196
- )
197
- return self._cos_cached
198
-
199
- @torch.no_grad()
200
- def forward(self, x, position_ids):
201
- # x: [bs, num_attention_heads, seq_len, head_size]
202
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
203
- position_ids_expanded = position_ids[:, None, :].float()
204
- # Force float32 since bfloat16 loses precision on long contexts
205
- # See https://github.com/huggingface/transformers/pull/29285
206
- device_type = x.device.type
207
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
208
- with torch.autocast(device_type=device_type, enabled=False):
209
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
210
- emb = torch.cat((freqs, freqs), dim=-1)
211
- cos = emb.cos()
212
- sin = emb.sin()
213
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
214
-
215
-
216
-
217
-
218
-
219
  def rotate_half(x):
220
  """Rotates half the hidden dims of the input."""
221
  x1 = x[..., : x.shape[-1] // 2]
@@ -224,32 +198,15 @@ def rotate_half(x):
224
 
225
 
226
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
227
- """Applies Rotary Position Embedding to the query and key tensors.
228
-
229
- Args:
230
- q (`torch.Tensor`): The query tensor.
231
- k (`torch.Tensor`): The key tensor.
232
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
233
- sin (`torch.Tensor`): The sine part of the rotary embedding.
234
- position_ids (`torch.Tensor`, *optional*):
235
- Deprecated and unused.
236
- unsqueeze_dim (`int`, *optional*, defaults to 1):
237
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
238
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
239
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
240
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
241
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
242
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
243
- Returns:
244
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
245
- """
246
- cos = cos.unsqueeze(unsqueeze_dim)
247
- sin = sin.unsqueeze(unsqueeze_dim)
248
  q_embed = (q * cos) + (rotate_half(q) * sin)
249
  k_embed = (k * cos) + (rotate_half(k) * sin)
250
  return q_embed, k_embed
251
 
252
 
 
253
  class LlamoeBlockSparseTop2MLP(nn.Module):
254
  def __init__(self, config: LlamoeConfig):
255
  super().__init__()
 
162
 
163
 
164
  class LlamoeRotaryEmbedding(nn.Module):
165
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
166
  super().__init__()
 
167
  self.dim = dim
168
  self.max_position_embeddings = max_position_embeddings
169
  self.base = base
170
+ self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype())
171
+
172
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
173
+ self.max_seq_len_cached = seq_len
174
+ freq_exponents = (2.0 / self.dim) * (torch.arange(self.dim // 2, dtype=torch.float32, device="cpu").float())
175
+ timescale = self.base ** freq_exponents
176
+ positions = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.float32).float()
177
+ radians_new = positions[..., None] / timescale[None, None, :]
178
+ radians_new = radians_new.squeeze(0)
179
+ emb = torch.cat((radians_new, radians_new), dim=-1)
180
+ cos = emb.cos().to(device=device, dtype=dtype, non_blocking=True)
181
+ sin = emb.sin().to(device=device, dtype=dtype, non_blocking=True)
182
+ self.register_buffer("cos_cached", cos, persistent=False)
183
+ self.register_buffer("sin_cached", sin, persistent=False)
184
+
185
+ def forward(self, x, position_ids=None, seq_len=None):
186
+ if seq_len is None:
187
+ seq_len = x.size(2)
188
+ if seq_len > self.max_seq_len_cached:
189
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
190
+ return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
191
+
192
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  def rotate_half(x):
194
  """Rotates half the hidden dims of the input."""
195
  x1 = x[..., : x.shape[-1] // 2]
 
198
 
199
 
200
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
201
+ seq_len, dim = q.shape[-2], q.shape[-1]
202
+ cos = cos[:seq_len].view(1, 1, seq_len, dim)
203
+ sin = sin[:seq_len].view(1, 1, seq_len, dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  q_embed = (q * cos) + (rotate_half(q) * sin)
205
  k_embed = (k * cos) + (rotate_half(k) * sin)
206
  return q_embed, k_embed
207
 
208
 
209
+
210
  class LlamoeBlockSparseTop2MLP(nn.Module):
211
  def __init__(self, config: LlamoeConfig):
212
  super().__init__()