LeroyDyer commited on
Commit
2ecbaff
1 Parent(s): 5178716

Upload modeling_mistral.py

Browse files
Files changed (1) hide show
  1. modeling_mistral.py +356 -386
modeling_mistral.py CHANGED
@@ -662,267 +662,6 @@ class MistralPreTrainedModel(PreTrainedModel):
662
  module.weight.data[module.padding_idx].zero_()
663
 
664
 
665
- @add_start_docstrings(
666
- "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
667
- MISTRAL_START_DOCSTRING,
668
- )
669
- class MistralModel(MistralPreTrainedModel):
670
- """
671
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`]
672
-
673
- Args:
674
- config: MistralConfig
675
- """
676
-
677
- def __init__(self, config: MistralConfig):
678
- super().__init__(config)
679
- self.padding_idx = config.pad_token_id
680
- self.vocab_size = config.vocab_size
681
-
682
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
683
- self.layers = nn.ModuleList(
684
- [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
685
- )
686
- self._attn_implementation = config._attn_implementation
687
- self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
688
-
689
- self.gradient_checkpointing = False
690
- # Initialize weights and apply final processing
691
- self.post_init()
692
-
693
- def get_input_embeddings(self):
694
- return self.embed_tokens
695
-
696
- def set_input_embeddings(self, value):
697
- self.embed_tokens = value
698
-
699
- @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
700
- def forward(
701
- self,
702
- input_ids: torch.LongTensor = None,
703
- attention_mask: Optional[torch.Tensor] = None,
704
- position_ids: Optional[torch.LongTensor] = None,
705
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
706
- inputs_embeds: Optional[torch.FloatTensor] = None,
707
- use_cache: Optional[bool] = None,
708
- output_attentions: Optional[bool] = None,
709
- output_hidden_states: Optional[bool] = None,
710
- return_dict: Optional[bool] = None,
711
- cache_position: Optional[torch.LongTensor] = None,
712
- ) -> Union[Tuple, BaseModelOutputWithPast]:
713
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
714
- output_hidden_states = (
715
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
716
- )
717
- use_cache = use_cache if use_cache is not None else self.config.use_cache
718
-
719
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
720
-
721
- # retrieve input_ids and inputs_embeds
722
- if (input_ids is None) ^ (inputs_embeds is not None):
723
- raise ValueError(
724
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
725
- )
726
-
727
- if self.gradient_checkpointing and self.training and use_cache:
728
- logger.warning_once(
729
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
730
- )
731
- use_cache = False
732
-
733
- if inputs_embeds is None:
734
- inputs_embeds = self.embed_tokens(input_ids)
735
-
736
- return_legacy_cache = False
737
- if use_cache and not isinstance(past_key_values, Cache):
738
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
739
- return_legacy_cache = True
740
- logger.warning_once(
741
- "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
742
- "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
743
- )
744
-
745
- if cache_position is None:
746
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
747
- cache_position = torch.arange(
748
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
749
- )
750
-
751
- if position_ids is None:
752
- position_ids = cache_position.unsqueeze(0)
753
-
754
- causal_mask = self._update_causal_mask(
755
- attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions
756
- )
757
-
758
- hidden_states = inputs_embeds
759
-
760
- # decoder layers
761
- all_hidden_states = () if output_hidden_states else None
762
- all_self_attns = () if output_attentions else None
763
- next_decoder_cache = None
764
-
765
- for decoder_layer in self.layers:
766
- if output_hidden_states:
767
- all_hidden_states += (hidden_states,)
768
-
769
- if self.gradient_checkpointing and self.training:
770
- layer_outputs = self._gradient_checkpointing_func(
771
- decoder_layer.__call__,
772
- hidden_states,
773
- causal_mask,
774
- position_ids,
775
- past_key_values,
776
- output_attentions,
777
- use_cache,
778
- cache_position,
779
- )
780
- else:
781
- layer_outputs = decoder_layer(
782
- hidden_states,
783
- attention_mask=causal_mask,
784
- position_ids=position_ids,
785
- past_key_value=past_key_values,
786
- output_attentions=output_attentions,
787
- use_cache=use_cache,
788
- cache_position=cache_position,
789
- )
790
-
791
- hidden_states = layer_outputs[0]
792
-
793
- if use_cache:
794
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
795
-
796
- if output_attentions:
797
- all_self_attns += (layer_outputs[1],)
798
-
799
- hidden_states = self.norm(hidden_states)
800
-
801
- # add hidden states from the last decoder layer
802
- if output_hidden_states:
803
- all_hidden_states += (hidden_states,)
804
-
805
- next_cache = next_decoder_cache if use_cache else None
806
- if return_legacy_cache:
807
- next_cache = next_cache.to_legacy_cache()
808
-
809
- if not return_dict:
810
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
811
- return BaseModelOutputWithPast(
812
- last_hidden_state=hidden_states,
813
- past_key_values=next_cache,
814
- hidden_states=all_hidden_states,
815
- attentions=all_self_attns,
816
- )
817
-
818
- def _update_causal_mask(
819
- self,
820
- attention_mask: torch.Tensor,
821
- input_tensor: torch.Tensor,
822
- cache_position: torch.Tensor,
823
- past_key_values: Cache,
824
- use_cache: bool,
825
- output_attentions: bool,
826
- ):
827
- # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
828
- # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
829
- # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
830
- # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
831
-
832
- if self._attn_implementation == "flash_attention_2":
833
- if attention_mask is not None and use_cache:
834
- is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
835
- if is_padding_right:
836
- raise ValueError(
837
- "You are attempting to perform batched generation with padding_side='right'"
838
- " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
839
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
840
- )
841
- if attention_mask is not None and 0.0 in attention_mask:
842
- return attention_mask
843
- return None
844
-
845
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
846
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
847
- # to infer the attention mask.
848
-
849
- # cache_position must be valid here no matter which cache we use
850
- past_seen_tokens = cache_position[0] if past_key_values is not None else 0
851
- using_static_cache = isinstance(past_key_values, StaticCache)
852
- using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
853
-
854
- if (
855
- self.config._attn_implementation == "sdpa"
856
- and not (using_static_cache or using_sliding_window_cache)
857
- and not output_attentions
858
- ):
859
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
860
- attention_mask,
861
- inputs_embeds=input_tensor,
862
- past_key_values_length=past_seen_tokens,
863
- sliding_window=self.config.sliding_window,
864
- is_training=self.training,
865
- ):
866
- return None
867
-
868
- dtype, device = input_tensor.dtype, input_tensor.device
869
- min_dtype = torch.finfo(dtype).min
870
- sequence_length = input_tensor.shape[1]
871
- # SlidingWindowCache
872
- if using_sliding_window_cache:
873
- target_length = max(sequence_length, self.config.sliding_window)
874
- # StaticCache
875
- elif using_static_cache:
876
- target_length = past_key_values.get_max_length()
877
- # DynamicCache or no cache
878
- else:
879
- target_length = (
880
- attention_mask.shape[-1]
881
- if isinstance(attention_mask, torch.Tensor)
882
- else past_seen_tokens + sequence_length + 1
883
- )
884
-
885
- if attention_mask is not None and attention_mask.dim() == 4:
886
- # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
887
- if attention_mask.max() != 0:
888
- raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
889
- causal_mask = attention_mask
890
- else:
891
- causal_mask = torch.full(
892
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
893
- )
894
- exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
895
- if self.config.sliding_window is not None:
896
- if not using_sliding_window_cache or sequence_length > self.config.sliding_window:
897
- exclude_mask.bitwise_or_(
898
- torch.arange(target_length, device=device)
899
- <= (cache_position.reshape(-1, 1) - self.config.sliding_window)
900
- )
901
- causal_mask *= exclude_mask
902
- causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
903
- if attention_mask is not None:
904
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
905
- if attention_mask.dim() == 2:
906
- mask_length = attention_mask.shape[-1]
907
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
908
- padding_mask = padding_mask == 0
909
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
910
- padding_mask, min_dtype
911
- )
912
-
913
- if (
914
- self.config._attn_implementation == "sdpa"
915
- and attention_mask is not None
916
- and attention_mask.device.type == "cuda"
917
- and not output_attentions
918
- ):
919
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
920
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
921
- # Details: https://github.com/pytorch/pytorch/issues/110213
922
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
923
-
924
- return causal_mask
925
-
926
 
927
  ############################## LM Heads #################################
928
 
@@ -2288,118 +2027,379 @@ class MixtralSparseMoeBlock(nn.Module):
2288
  # we cast back to the input dtype
2289
  routing_weights = routing_weights.to(hidden_states.dtype)
2290
 
2291
- final_hidden_states = torch.zeros(
2292
- (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2293
  )
2294
 
2295
- # One hot encode the selected experts to create an expert mask
2296
- # this will be used to easily index which expert is going to be sollicitated
2297
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
2298
 
2299
- # Loop over all available experts in the model and perform the computation on each expert
2300
- for expert_idx in range(self.num_experts):
2301
- expert_layer = self.experts[expert_idx]
2302
- idx, top_x = torch.where(expert_mask[expert_idx])
2303
 
2304
- # Index the correct hidden states and compute the expert hidden state for
2305
- # the current expert. We need to make sure to multiply the output hidden
2306
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
2307
- current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
2308
- current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
2309
 
2310
- # However `index_add_` only support torch tensors for indexing so we'll use
2311
- # the `top_x` tensor here.
2312
- final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
2313
- final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
2314
- return final_hidden_states, router_logits
2315
- class MixtralDecoderLayer(nn.Module):
2316
- def __init__(self, config: MixtralConfig, layer_idx: int):
2317
- super().__init__()
2318
- self.hidden_size = config.hidden_size
 
 
 
 
 
 
 
 
 
 
 
 
2319
 
2320
- self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
2321
- self.mlp = MistralMLP(config)
2322
- self.block_sparse_moe = MixtralSparseMoeBlock(config)
2323
- self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
2324
- self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
2325
 
2326
- def forward(
2327
- self,
2328
- hidden_states: torch.Tensor,
2329
- attention_mask: Optional[torch.Tensor] = None,
2330
- position_ids: Optional[torch.LongTensor] = None,
2331
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
2332
- output_attentions: Optional[bool] = False,
2333
- output_router_logits: Optional[bool] = False,
2334
- use_cache: Optional[bool] = False,
2335
- cache_position: Optional[torch.LongTensor] = None,
2336
- **kwargs,
2337
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
2338
- """
2339
- Args:
2340
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
2341
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
2342
- `(batch, sequence_length)` where padding elements are indicated by 0.
2343
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
2344
- output_attentions (`bool`, *optional*):
2345
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
2346
- returned tensors for more detail.
2347
- output_router_logits (`bool`, *optional*):
2348
- Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
2349
- should not be returned during inference.
2350
- use_cache (`bool`, *optional*):
2351
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
2352
- (see `past_key_values`).
2353
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
2354
- Indices depicting the position of the input sequence tokens in the sequence.
2355
- kwargs (`dict`, *optional*):
2356
- Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
2357
- into the model
2358
- """
2359
 
2360
- residual = hidden_states
 
2361
 
2362
- hidden_states = self.input_layernorm(hidden_states)
2363
 
2364
- # Self Attention
2365
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
2366
- hidden_states=hidden_states,
2367
- attention_mask=attention_mask,
2368
- position_ids=position_ids,
2369
- past_key_value=past_key_value,
2370
- output_attentions=output_attentions,
2371
- use_cache=use_cache,
2372
- cache_position=cache_position,
 
 
 
 
 
 
2373
  )
2374
- hidden_states = residual + hidden_states
2375
 
2376
- # Fully Connected
2377
- residual = hidden_states
2378
- hidden_states = self.post_attention_layernorm(hidden_states)
2379
- hidden_states, router_logits = self.block_sparse_moe(hidden_states)
2380
- hidden_states = residual + hidden_states
 
 
 
 
2381
 
2382
- # Fully Connected
2383
- residual = hidden_states
2384
- hidden_states = self.post_attention_layernorm(hidden_states)
2385
- hidden_states = self.mlp(hidden_states)
2386
- hidden_states = residual + hidden_states
2387
 
2388
- outputs = (hidden_states,)
 
 
 
 
 
 
 
 
 
 
 
2389
 
2390
- if output_attentions:
2391
- outputs += (self_attn_weights,)
 
2392
 
2393
- if use_cache:
2394
- outputs += (present_key_value,)
 
 
2395
 
2396
- if output_router_logits:
2397
- outputs += (router_logits,)
 
 
 
 
 
 
 
 
 
 
 
2398
 
2399
- return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2400
 
2401
- ################################ closed COMPONENTS ################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2402
 
 
2403
 
2404
  ############# Causal LM #################
2405
  class MistralForCausalLM(MistralPreTrainedModel):
@@ -3421,40 +3421,10 @@ class MistralForCausalLM(MistralPreTrainedModel):
3421
  else:
3422
  cur_talk_loss = talk_loss_list[talk_idx]
3423
  log_dict[f"rel_loss_{i}"] += (nonzero_mean(loss_list[i]) - cur_talk_loss) / self.n_tokens_print
3424
- if self.training:
3425
- self.training_steps += 1
3426
- try:
3427
- # if self.training_steps % (self.gradient_accumulation_steps * 256) == 0:
3428
- if self.wandb_enabled:
3429
- if self.training_steps % (self.n_tokens_print) == 0 or not self.training:# and "0" in str(loss.device):
3430
- if not self.training:
3431
- new_log_dict = {}
3432
- for key in list(log_dict.keys()):
3433
- new_log_dict["eval_" + key] = log_dict[key]
3434
- log_dict = new_log_dict
3435
- log_dict["training_steps"] = self.training_steps
3436
- log_dict["batch_size"] = batch_size
3437
- log_dict["example_steps"] = self.training_steps * batch_size * self.gradient_accumulation_steps
3438
- if self.n_ahead > 1:
3439
- log_dict["compute_steps"] = self.training_steps * batch_size * (self.n_ahead + self.n_ahead_talk - 1) * self.gradient_accumulation_steps
3440
- else: # There's no overhead for talk tokens if there's no thinking
3441
- log_dict["compute_steps"] = self.training_steps * batch_size * self.gradient_accumulation_steps
3442
- # remove all nans
3443
- for key in list(log_dict.keys()):
3444
- if log_dict[key] != log_dict[key]:
3445
- del log_dict[key]
3446
- if self.training:
3447
- wandb.log(log_dict)
3448
- if self.training:
3449
- self.log_dict = defaultdict(int)
3450
- else:
3451
- self.eval_log_dict = defaultdict(int)
3452
- except Exception as e:
3453
- pass
3454
 
3455
- if not self.training:
3456
- self.n_ahead_talk = n_ahead_talk_to_restore
3457
- self.n_passes = n_passes_to_restore
3458
  return CausalLMOutputWithPast(
3459
  loss=loss if loss is not None else None,
3460
  logits=(rm_logits if self.n_ahead > 1 else logits) if not self.output_logits_at_the_end else logits,
 
662
  module.weight.data[module.padding_idx].zero_()
663
 
664
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
665
 
666
  ############################## LM Heads #################################
667
 
 
2027
  # we cast back to the input dtype
2028
  routing_weights = routing_weights.to(hidden_states.dtype)
2029
 
2030
+ final_hidden_states = torch.zeros(
2031
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
2032
+ )
2033
+
2034
+ # One hot encode the selected experts to create an expert mask
2035
+ # this will be used to easily index which expert is going to be sollicitated
2036
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
2037
+
2038
+ # Loop over all available experts in the model and perform the computation on each expert
2039
+ for expert_idx in range(self.num_experts):
2040
+ expert_layer = self.experts[expert_idx]
2041
+ idx, top_x = torch.where(expert_mask[expert_idx])
2042
+
2043
+ # Index the correct hidden states and compute the expert hidden state for
2044
+ # the current expert. We need to make sure to multiply the output hidden
2045
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
2046
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
2047
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
2048
+
2049
+ # However `index_add_` only support torch tensors for indexing so we'll use
2050
+ # the `top_x` tensor here.
2051
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
2052
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
2053
+ return final_hidden_states, router_logits
2054
+ class MixtralDecoderLayer(nn.Module):
2055
+ def __init__(self, config: MixtralConfig, layer_idx: int):
2056
+ super().__init__()
2057
+ self.hidden_size = config.hidden_size
2058
+
2059
+ self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
2060
+ self.mlp = MistralMLP(config)
2061
+ self.block_sparse_moe = MixtralSparseMoeBlock(config)
2062
+ self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
2063
+ self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
2064
+
2065
+ def forward(
2066
+ self,
2067
+ hidden_states: torch.Tensor,
2068
+ attention_mask: Optional[torch.Tensor] = None,
2069
+ position_ids: Optional[torch.LongTensor] = None,
2070
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
2071
+ output_attentions: Optional[bool] = False,
2072
+ output_router_logits: Optional[bool] = False,
2073
+ use_cache: Optional[bool] = False,
2074
+ cache_position: Optional[torch.LongTensor] = None,
2075
+ **kwargs,
2076
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
2077
+ """
2078
+ Args:
2079
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
2080
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
2081
+ `(batch, sequence_length)` where padding elements are indicated by 0.
2082
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
2083
+ output_attentions (`bool`, *optional*):
2084
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
2085
+ returned tensors for more detail.
2086
+ output_router_logits (`bool`, *optional*):
2087
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
2088
+ should not be returned during inference.
2089
+ use_cache (`bool`, *optional*):
2090
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
2091
+ (see `past_key_values`).
2092
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
2093
+ Indices depicting the position of the input sequence tokens in the sequence.
2094
+ kwargs (`dict`, *optional*):
2095
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
2096
+ into the model
2097
+ """
2098
+
2099
+ residual = hidden_states
2100
+
2101
+ hidden_states = self.input_layernorm(hidden_states)
2102
+
2103
+ # Self Attention
2104
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
2105
+ hidden_states=hidden_states,
2106
+ attention_mask=attention_mask,
2107
+ position_ids=position_ids,
2108
+ past_key_value=past_key_value,
2109
+ output_attentions=output_attentions,
2110
+ use_cache=use_cache,
2111
+ cache_position=cache_position,
2112
+ )
2113
+ hidden_states = residual + hidden_states
2114
+
2115
+ # Fully Connected
2116
+ residual = hidden_states
2117
+ hidden_states = self.post_attention_layernorm(hidden_states)
2118
+ hidden_states, router_logits = self.block_sparse_moe(hidden_states)
2119
+ hidden_states = residual + hidden_states
2120
+
2121
+ # Fully Connected
2122
+ residual = hidden_states
2123
+ hidden_states = self.post_attention_layernorm(hidden_states)
2124
+ hidden_states = self.mlp(hidden_states)
2125
+ hidden_states = residual + hidden_states
2126
+
2127
+ outputs = (hidden_states,)
2128
+
2129
+ if output_attentions:
2130
+ outputs += (self_attn_weights,)
2131
+
2132
+ if use_cache:
2133
+ outputs += (present_key_value,)
2134
+
2135
+ if output_router_logits:
2136
+ outputs += (router_logits,)
2137
+
2138
+ return outputs
2139
+
2140
+ ################################ closed COMPONENTS ################################
2141
+
2142
+
2143
+ @add_start_docstrings(
2144
+ "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
2145
+ MISTRAL_START_DOCSTRING,
2146
+ )
2147
+ class MistralModel(MistralPreTrainedModel):
2148
+ """
2149
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`]
2150
+
2151
+ Args:
2152
+ config: MistralConfig
2153
+ """
2154
+
2155
+ def __init__(self, config: MistralConfig):
2156
+ super().__init__(config)
2157
+ self.padding_idx = config.pad_token_id
2158
+ self.vocab_size = config.vocab_size
2159
+
2160
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
2161
+ self.layers = nn.ModuleList(
2162
+ [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
2163
+ )
2164
+ self._attn_implementation = config._attn_implementation
2165
+ self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
2166
+
2167
+ self.gradient_checkpointing = False
2168
+ # Initialize weights and apply final processing
2169
+ self.post_init()
2170
+
2171
+ def get_input_embeddings(self):
2172
+ return self.embed_tokens
2173
+
2174
+ def set_input_embeddings(self, value):
2175
+ self.embed_tokens = value
2176
+
2177
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
2178
+ def forward(
2179
+ self,
2180
+ input_ids: torch.LongTensor = None,
2181
+ attention_mask: Optional[torch.Tensor] = None,
2182
+ position_ids: Optional[torch.LongTensor] = None,
2183
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
2184
+ inputs_embeds: Optional[torch.FloatTensor] = None,
2185
+ use_cache: Optional[bool] = None,
2186
+ output_attentions: Optional[bool] = None,
2187
+ output_hidden_states: Optional[bool] = None,
2188
+ return_dict: Optional[bool] = None,
2189
+ cache_position: Optional[torch.LongTensor] = None,
2190
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
2191
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
2192
+ output_hidden_states = (
2193
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
2194
+ )
2195
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
2196
+
2197
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2198
+
2199
+ # retrieve input_ids and inputs_embeds
2200
+ if (input_ids is None) ^ (inputs_embeds is not None):
2201
+ raise ValueError(
2202
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
2203
+ )
2204
+
2205
+ if self.gradient_checkpointing and self.training and use_cache:
2206
+ logger.warning_once(
2207
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
2208
+ )
2209
+ use_cache = False
2210
+
2211
+ if inputs_embeds is None:
2212
+ inputs_embeds = self.embed_tokens(input_ids)
2213
+
2214
+ return_legacy_cache = False
2215
+ if use_cache and not isinstance(past_key_values, Cache):
2216
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
2217
+ return_legacy_cache = True
2218
+ logger.warning_once(
2219
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
2220
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
2221
+ )
2222
+
2223
+ if cache_position is None:
2224
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
2225
+ cache_position = torch.arange(
2226
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
2227
+ )
2228
+
2229
+ if position_ids is None:
2230
+ position_ids = cache_position.unsqueeze(0)
2231
+
2232
+ causal_mask = self._update_causal_mask(
2233
+ attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions
2234
  )
2235
 
2236
+ hidden_states = inputs_embeds
 
 
2237
 
2238
+ # decoder layers
2239
+ all_hidden_states = () if output_hidden_states else None
2240
+ all_self_attns = () if output_attentions else None
2241
+ next_decoder_cache = None
2242
 
2243
+ for decoder_layer in self.layers:
2244
+ if output_hidden_states:
2245
+ all_hidden_states += (hidden_states,)
 
 
2246
 
2247
+ if self.gradient_checkpointing and self.training:
2248
+ layer_outputs = self._gradient_checkpointing_func(
2249
+ decoder_layer.__call__,
2250
+ hidden_states,
2251
+ causal_mask,
2252
+ position_ids,
2253
+ past_key_values,
2254
+ output_attentions,
2255
+ use_cache,
2256
+ cache_position,
2257
+ )
2258
+ else:
2259
+ layer_outputs = decoder_layer(
2260
+ hidden_states,
2261
+ attention_mask=causal_mask,
2262
+ position_ids=position_ids,
2263
+ past_key_value=past_key_values,
2264
+ output_attentions=output_attentions,
2265
+ use_cache=use_cache,
2266
+ cache_position=cache_position,
2267
+ )
2268
 
2269
+ hidden_states = layer_outputs[0]
 
 
 
 
2270
 
2271
+ if use_cache:
2272
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2273
 
2274
+ if output_attentions:
2275
+ all_self_attns += (layer_outputs[1],)
2276
 
2277
+ hidden_states = self.norm(hidden_states)
2278
 
2279
+ # add hidden states from the last decoder layer
2280
+ if output_hidden_states:
2281
+ all_hidden_states += (hidden_states,)
2282
+
2283
+ next_cache = next_decoder_cache if use_cache else None
2284
+ if return_legacy_cache:
2285
+ next_cache = next_cache.to_legacy_cache()
2286
+
2287
+ if not return_dict:
2288
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
2289
+ return BaseModelOutputWithPast(
2290
+ last_hidden_state=hidden_states,
2291
+ past_key_values=next_cache,
2292
+ hidden_states=all_hidden_states,
2293
+ attentions=all_self_attns,
2294
  )
 
2295
 
2296
+ def _update_causal_mask(
2297
+ self,
2298
+ attention_mask: torch.Tensor,
2299
+ input_tensor: torch.Tensor,
2300
+ cache_position: torch.Tensor,
2301
+ past_key_values: Cache,
2302
+ use_cache: bool,
2303
+ output_attentions: bool,
2304
+ ):
2305
 
2306
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
2307
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
2308
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
 
 
2309
 
2310
+ if self._attn_implementation == "flash_attention_2":
2311
+ if attention_mask is not None and use_cache:
2312
+ is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
2313
+ if is_padding_right:
2314
+ raise ValueError(
2315
+ "You are attempting to perform batched generation with padding_side='right'"
2316
+ " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
2317
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
2318
+ )
2319
+ if attention_mask is not None and 0.0 in attention_mask:
2320
+ return attention_mask
2321
+ return None
2322
 
2323
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
2324
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
2325
+ # to infer the attention mask.
2326
 
2327
+ # cache_position must be valid here no matter which cache we use
2328
+ past_seen_tokens = cache_position[0] if past_key_values is not None else 0
2329
+ using_static_cache = isinstance(past_key_values, StaticCache)
2330
+ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
2331
 
2332
+ if (
2333
+ self.config._attn_implementation == "sdpa"
2334
+ and not (using_static_cache or using_sliding_window_cache)
2335
+ and not output_attentions
2336
+ ):
2337
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
2338
+ attention_mask,
2339
+ inputs_embeds=input_tensor,
2340
+ past_key_values_length=past_seen_tokens,
2341
+ sliding_window=self.config.sliding_window,
2342
+ is_training=self.training,
2343
+ ):
2344
+ return None
2345
 
2346
+ dtype, device = input_tensor.dtype, input_tensor.device
2347
+ min_dtype = torch.finfo(dtype).min
2348
+ sequence_length = input_tensor.shape[1]
2349
+ # SlidingWindowCache
2350
+ if using_sliding_window_cache:
2351
+ target_length = max(sequence_length, self.config.sliding_window)
2352
+ # StaticCache
2353
+ elif using_static_cache:
2354
+ target_length = past_key_values.get_max_length()
2355
+ # DynamicCache or no cache
2356
+ else:
2357
+ target_length = (
2358
+ attention_mask.shape[-1]
2359
+ if isinstance(attention_mask, torch.Tensor)
2360
+ else past_seen_tokens + sequence_length + 1
2361
+ )
2362
 
2363
+ if attention_mask is not None and attention_mask.dim() == 4:
2364
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
2365
+ if attention_mask.max() != 0:
2366
+ raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
2367
+ causal_mask = attention_mask
2368
+ else:
2369
+ causal_mask = torch.full(
2370
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
2371
+ )
2372
+ exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
2373
+ if self.config.sliding_window is not None:
2374
+ if not using_sliding_window_cache or sequence_length > self.config.sliding_window:
2375
+ exclude_mask.bitwise_or_(
2376
+ torch.arange(target_length, device=device)
2377
+ <= (cache_position.reshape(-1, 1) - self.config.sliding_window)
2378
+ )
2379
+ causal_mask *= exclude_mask
2380
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
2381
+ if attention_mask is not None:
2382
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
2383
+ if attention_mask.dim() == 2:
2384
+ mask_length = attention_mask.shape[-1]
2385
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
2386
+ padding_mask = padding_mask == 0
2387
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
2388
+ padding_mask, min_dtype
2389
+ )
2390
+
2391
+ if (
2392
+ self.config._attn_implementation == "sdpa"
2393
+ and attention_mask is not None
2394
+ and attention_mask.device.type == "cuda"
2395
+ and not output_attentions
2396
+ ):
2397
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
2398
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
2399
+ # Details: https://github.com/pytorch/pytorch/issues/110213
2400
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
2401
 
2402
+ return causal_mask
2403
 
2404
  ############# Causal LM #################
2405
  class MistralForCausalLM(MistralPreTrainedModel):
 
3421
  else:
3422
  cur_talk_loss = talk_loss_list[talk_idx]
3423
  log_dict[f"rel_loss_{i}"] += (nonzero_mean(loss_list[i]) - cur_talk_loss) / self.n_tokens_print
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3424
 
3425
+
3426
+ self.n_ahead_talk = n_ahead_talk_to_restore
3427
+ self.n_passes = n_passes_to_restore
3428
  return CausalLMOutputWithPast(
3429
  loss=loss if loss is not None else None,
3430
  logits=(rm_logits if self.n_ahead > 1 else logits) if not self.output_logits_at_the_end else logits,