mwirth-epo
commited on
Commit
•
8993bb8
1
Parent(s):
f80aaa3
update positional_embedding.py
Browse filesRelates to bf16 for query_states and key_states issue.
Apply fix in https://huggingface.co/microsoft/Phi-3-small-8k-instruct/commit/f196467b67c13127747a03c142e09aa6841447b8 also for this model
- positional_embedding.py +3 -3
positional_embedding.py
CHANGED
@@ -269,10 +269,10 @@ class RotaryEmbedding(torch.nn.Module):
|
|
269 |
return (
|
270 |
apply_rotary_pos_emb(
|
271 |
q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
|
272 |
-
),
|
273 |
apply_rotary_pos_emb(
|
274 |
k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
|
275 |
-
),
|
276 |
)
|
277 |
|
278 |
@classmethod
|
@@ -285,4 +285,4 @@ class RotaryEmbedding(torch.nn.Module):
|
|
285 |
)
|
286 |
if config.rope_scaling is not None:
|
287 |
kwargs["longrope_config"] = LongRopeConfig.from_dict(config.rope_scaling)
|
288 |
-
return cls(**kwargs)
|
|
|
269 |
return (
|
270 |
apply_rotary_pos_emb(
|
271 |
q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
|
272 |
+
).to(q.dtype),
|
273 |
apply_rotary_pos_emb(
|
274 |
k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
|
275 |
+
).to(k.dtype),
|
276 |
)
|
277 |
|
278 |
@classmethod
|
|
|
285 |
)
|
286 |
if config.rope_scaling is not None:
|
287 |
kwargs["longrope_config"] = LongRopeConfig.from_dict(config.rope_scaling)
|
288 |
+
return cls(**kwargs)
|