Files changed (1) hide show
  1. modeling_phi3_v.py +8 -1
modeling_phi3_v.py CHANGED
@@ -441,7 +441,7 @@ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
441
 
442
  @torch.no_grad()
443
  def forward(self, x, position_ids, seq_len=None):
444
- seq_len = torch.max(position_ids) + 1
445
  if seq_len > self.original_max_position_embeddings:
446
  ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
447
  else:
@@ -1647,6 +1647,13 @@ class Phi3VForCausalLM(Phi3VPreTrainedModel):
1647
  def prepare_inputs_for_generation(
1648
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, pixel_values=None, image_sizes=None, **kwargs
1649
  ):
 
 
 
 
 
 
 
1650
  if past_key_values is not None:
1651
  if isinstance(past_key_values, Cache):
1652
  cache_length = past_key_values.get_seq_length()
 
441
 
442
  @torch.no_grad()
443
  def forward(self, x, position_ids, seq_len=None):
444
+ seq_len = seq_len or torch.max(position_ids) + 1
445
  if seq_len > self.original_max_position_embeddings:
446
  ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
447
  else:
 
1647
  def prepare_inputs_for_generation(
1648
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, pixel_values=None, image_sizes=None, **kwargs
1649
  ):
1650
+ # When the first time input length reached long and short factor switching point, enforce re-compute cache
1651
+ # It will cause downside of slower at this single token position, however, better than current failure.
1652
+ if past_key_values and self.config.rope_scaling and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1:
1653
+ past_length = past_key_values.seen_tokens if isinstance(past_key_values, Cache) else past_key_values[0][0].shape[2]
1654
+ if past_length <= self.config.original_max_position_embeddings:
1655
+ past_key_values = None
1656
+
1657
  if past_key_values is not None:
1658
  if isinstance(past_key_values, Cache):
1659
  cache_length = past_key_values.get_seq_length()