mwirth-epo commited on
Commit
8993bb8
1 Parent(s): f80aaa3

update positional_embedding.py

Browse files

Relates to bf16 for query_states and key_states issue.
Apply fix in https://huggingface.co/microsoft/Phi-3-small-8k-instruct/commit/f196467b67c13127747a03c142e09aa6841447b8 also for this model

Files changed (1) hide show
  1. positional_embedding.py +3 -3
positional_embedding.py CHANGED
@@ -269,10 +269,10 @@ class RotaryEmbedding(torch.nn.Module):
269
  return (
270
  apply_rotary_pos_emb(
271
  q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
272
- ),
273
  apply_rotary_pos_emb(
274
  k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
275
- ),
276
  )
277
 
278
  @classmethod
@@ -285,4 +285,4 @@ class RotaryEmbedding(torch.nn.Module):
285
  )
286
  if config.rope_scaling is not None:
287
  kwargs["longrope_config"] = LongRopeConfig.from_dict(config.rope_scaling)
288
- return cls(**kwargs)
 
269
  return (
270
  apply_rotary_pos_emb(
271
  q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
272
+ ).to(q.dtype),
273
  apply_rotary_pos_emb(
274
  k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
275
+ ).to(k.dtype),
276
  )
277
 
278
  @classmethod
 
285
  )
286
  if config.rope_scaling is not None:
287
  kwargs["longrope_config"] = LongRopeConfig.from_dict(config.rope_scaling)
288
+ return cls(**kwargs)