fix longrope scaling
Browse files- modeling_phi3_v.py +8 -1
modeling_phi3_v.py
CHANGED
@@ -441,7 +441,7 @@ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
|
441 |
|
442 |
@torch.no_grad()
|
443 |
def forward(self, x, position_ids, seq_len=None):
|
444 |
-
seq_len = torch.max(position_ids) + 1
|
445 |
if seq_len > self.original_max_position_embeddings:
|
446 |
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
|
447 |
else:
|
@@ -1647,6 +1647,13 @@ class Phi3VForCausalLM(Phi3VPreTrainedModel):
|
|
1647 |
def prepare_inputs_for_generation(
|
1648 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, pixel_values=None, image_sizes=None, **kwargs
|
1649 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1650 |
if past_key_values is not None:
|
1651 |
if isinstance(past_key_values, Cache):
|
1652 |
cache_length = past_key_values.get_seq_length()
|
|
|
441 |
|
442 |
@torch.no_grad()
|
443 |
def forward(self, x, position_ids, seq_len=None):
|
444 |
+
seq_len = seq_len or torch.max(position_ids) + 1
|
445 |
if seq_len > self.original_max_position_embeddings:
|
446 |
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
|
447 |
else:
|
|
|
1647 |
def prepare_inputs_for_generation(
|
1648 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, pixel_values=None, image_sizes=None, **kwargs
|
1649 |
):
|
1650 |
+
# When the first time input length reached long and short factor switching point, enforce re-compute cache
|
1651 |
+
# It will cause downside of slower at this single token position, however, better than current failure.
|
1652 |
+
if past_key_values and self.config.rope_scaling and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1:
|
1653 |
+
past_length = past_key_values.seen_tokens if isinstance(past_key_values, Cache) else past_key_values[0][0].shape[2]
|
1654 |
+
if past_length <= self.config.original_max_position_embeddings:
|
1655 |
+
past_key_values = None
|
1656 |
+
|
1657 |
if past_key_values is not None:
|
1658 |
if isinstance(past_key_values, Cache):
|
1659 |
cache_length = past_key_values.get_seq_length()
|