Update modeling_Llamoe.py
Browse files- modeling_Llamoe.py +73 -31
modeling_Llamoe.py
CHANGED
@@ -167,41 +167,58 @@ class LlamoeRMSNorm(nn.Module):
|
|
167 |
return self.weight * hidden_states.to(input_dtype)
|
168 |
|
169 |
|
170 |
-
|
|
|
171 |
class LlamoeRotaryEmbedding(nn.Module):
|
172 |
-
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
173 |
super().__init__()
|
174 |
-
|
175 |
self.dim = dim
|
176 |
self.max_position_embeddings = max_position_embeddings
|
177 |
self.base = base
|
178 |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
179 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
180 |
-
|
181 |
-
|
182 |
-
self._set_cos_sin_cache(
|
183 |
-
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
184 |
-
)
|
185 |
-
|
186 |
-
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
187 |
-
self.max_seq_len_cached = seq_len
|
188 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
189 |
-
|
190 |
freqs = torch.outer(t, self.inv_freq)
|
191 |
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
192 |
emb = torch.cat((freqs, freqs), dim=-1)
|
193 |
-
self.register_buffer("
|
194 |
-
self.register_buffer("
|
195 |
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
|
|
|
|
|
|
200 |
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
|
|
204 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
|
207 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
@@ -212,8 +229,8 @@ def rotate_half(x):
|
|
212 |
return torch.cat((-x2, x1), dim=-1)
|
213 |
|
214 |
|
215 |
-
# Copied from transformers.models.
|
216 |
-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
217 |
"""Applies Rotary Position Embedding to the query and key tensors.
|
218 |
|
219 |
Args:
|
@@ -221,9 +238,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
|
221 |
k (`torch.Tensor`): The key tensor.
|
222 |
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
223 |
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
224 |
-
position_ids (`torch.Tensor
|
225 |
-
|
226 |
-
used to pass offsetted position ids when working with a KV-cache.
|
227 |
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
228 |
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
229 |
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
@@ -234,8 +250,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
|
234 |
Returns:
|
235 |
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
236 |
"""
|
237 |
-
cos = cos
|
238 |
-
sin = sin
|
239 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
240 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
241 |
return q_embed, k_embed
|
@@ -254,7 +270,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
254 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
255 |
|
256 |
|
257 |
-
# Copied from transformers.models.mistral.modeling_mistral.
|
258 |
class LlamoeAttention(nn.Module):
|
259 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
260 |
|
@@ -873,11 +889,11 @@ LLAMOE_START_DOCSTRING = r"""
|
|
873 |
)
|
874 |
# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral
|
875 |
class LlamoePreTrainedModel(PreTrainedModel):
|
876 |
-
config_class =
|
877 |
base_model_prefix = "model"
|
878 |
supports_gradient_checkpointing = True
|
879 |
_no_split_modules = ["LlamoeDecoderLayer"]
|
880 |
-
_skip_keys_device_placement = "past_key_values"
|
881 |
_supports_flash_attn_2 = True
|
882 |
_supports_sdpa = True
|
883 |
_supports_cache_class = True
|
@@ -893,6 +909,32 @@ class LlamoePreTrainedModel(PreTrainedModel):
|
|
893 |
if module.padding_idx is not None:
|
894 |
module.weight.data[module.padding_idx].zero_()
|
895 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
896 |
|
897 |
LLAMOE_INPUTS_DOCSTRING = r"""
|
898 |
Args:
|
|
|
167 |
return self.weight * hidden_states.to(input_dtype)
|
168 |
|
169 |
|
170 |
+
|
171 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mixtral
|
172 |
class LlamoeRotaryEmbedding(nn.Module):
|
173 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
174 |
super().__init__()
|
175 |
+
self.scaling_factor = scaling_factor
|
176 |
self.dim = dim
|
177 |
self.max_position_embeddings = max_position_embeddings
|
178 |
self.base = base
|
179 |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
180 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
181 |
+
# For BC we register cos and sin cached
|
182 |
+
self.max_seq_len_cached = max_position_embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
184 |
+
t = t / self.scaling_factor
|
185 |
freqs = torch.outer(t, self.inv_freq)
|
186 |
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
187 |
emb = torch.cat((freqs, freqs), dim=-1)
|
188 |
+
self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
|
189 |
+
self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
|
190 |
|
191 |
+
@property
|
192 |
+
def sin_cached(self):
|
193 |
+
logger.warning_once(
|
194 |
+
"The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
|
195 |
+
"the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
|
196 |
+
)
|
197 |
+
return self._sin_cached
|
198 |
|
199 |
+
@property
|
200 |
+
def cos_cached(self):
|
201 |
+
logger.warning_once(
|
202 |
+
"The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
|
203 |
+
"the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
|
204 |
)
|
205 |
+
return self._cos_cached
|
206 |
+
|
207 |
+
@torch.no_grad()
|
208 |
+
def forward(self, x, position_ids):
|
209 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
210 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
211 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
212 |
+
# Force float32 since bfloat16 loses precision on long contexts
|
213 |
+
# See https://github.com/huggingface/transformers/pull/29285
|
214 |
+
device_type = x.device.type
|
215 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
216 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
217 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
218 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
219 |
+
cos = emb.cos()
|
220 |
+
sin = emb.sin()
|
221 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
222 |
|
223 |
|
224 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
|
|
229 |
return torch.cat((-x2, x1), dim=-1)
|
230 |
|
231 |
|
232 |
+
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
233 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
234 |
"""Applies Rotary Position Embedding to the query and key tensors.
|
235 |
|
236 |
Args:
|
|
|
238 |
k (`torch.Tensor`): The key tensor.
|
239 |
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
240 |
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
241 |
+
position_ids (`torch.Tensor`, *optional*):
|
242 |
+
Deprecated and unused.
|
|
|
243 |
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
244 |
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
245 |
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
|
|
250 |
Returns:
|
251 |
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
252 |
"""
|
253 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
254 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
255 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
256 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
257 |
return q_embed, k_embed
|
|
|
270 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
271 |
|
272 |
|
273 |
+
# Copied from transformers.models.mistral.modeling_mistral.LlamaAttention with Llama->Mixtral
|
274 |
class LlamoeAttention(nn.Module):
|
275 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
276 |
|
|
|
889 |
)
|
890 |
# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral
|
891 |
class LlamoePreTrainedModel(PreTrainedModel):
|
892 |
+
config_class = LlamaConfig
|
893 |
base_model_prefix = "model"
|
894 |
supports_gradient_checkpointing = True
|
895 |
_no_split_modules = ["LlamoeDecoderLayer"]
|
896 |
+
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
|
897 |
_supports_flash_attn_2 = True
|
898 |
_supports_sdpa = True
|
899 |
_supports_cache_class = True
|
|
|
909 |
if module.padding_idx is not None:
|
910 |
module.weight.data[module.padding_idx].zero_()
|
911 |
|
912 |
+
def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
|
913 |
+
if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
|
914 |
+
raise ValueError(
|
915 |
+
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
916 |
+
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
917 |
+
)
|
918 |
+
|
919 |
+
if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
|
920 |
+
causal_mask = torch.full(
|
921 |
+
(max_cache_len, max_cache_len), fill_value=True, device=self.device, dtype=torch.bool
|
922 |
+
)
|
923 |
+
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
924 |
+
|
925 |
+
for layer in self.model.layers:
|
926 |
+
device = layer.input_layernorm.weight.device
|
927 |
+
if hasattr(self.config, "_pre_quantization_dtype"):
|
928 |
+
dtype = self.config._pre_quantization_dtype
|
929 |
+
else:
|
930 |
+
dtype = layer.self_attn.o_proj.weight.dtype
|
931 |
+
layer.self_attn.past_key_value = cache_cls(
|
932 |
+
self.config, max_batch_size, max_cache_len, device=device, dtype=dtype
|
933 |
+
)
|
934 |
+
|
935 |
+
def _reset_cache(self):
|
936 |
+
for layer in self.model.layers:
|
937 |
+
layer.self_attn.past_key_value = None
|
938 |
|
939 |
LLAMOE_INPUTS_DOCSTRING = r"""
|
940 |
Args:
|