ybelkada commited on
Commit
15e9135
1 Parent(s): 26072b4

Update modeling_chatglm.py

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +2 -2
modeling_chatglm.py CHANGED
@@ -157,7 +157,7 @@ class RotaryEmbedding(nn.Module):
157
  )
158
 
159
 
160
- # @torch.jit.script
161
  def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
162
  # x: [sq, b, np, hn]
163
  sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
@@ -236,7 +236,7 @@ class CoreAttention(torch.nn.Module):
236
  # Raw attention scores
237
 
238
  # [b, np, sq, sk]
239
- output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(0))
240
 
241
  # [sq, b, np, hn] -> [sq, b * np, hn]
242
  query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
 
157
  )
158
 
159
 
160
+ @torch.jit.script
161
  def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
162
  # x: [sq, b, np, hn]
163
  sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
 
236
  # Raw attention scores
237
 
238
  # [b, np, sq, sk]
239
+ output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
240
 
241
  # [sq, b, np, hn] -> [sq, b * np, hn]
242
  query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)