Update modeling_chatglm.py
Browse files- modeling_chatglm.py +2 -2
modeling_chatglm.py
CHANGED
@@ -157,7 +157,7 @@ class RotaryEmbedding(nn.Module):
|
|
157 |
)
|
158 |
|
159 |
|
160 |
-
|
161 |
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
162 |
# x: [sq, b, np, hn]
|
163 |
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
@@ -236,7 +236,7 @@ class CoreAttention(torch.nn.Module):
|
|
236 |
# Raw attention scores
|
237 |
|
238 |
# [b, np, sq, sk]
|
239 |
-
output_size = (query_layer.size(
|
240 |
|
241 |
# [sq, b, np, hn] -> [sq, b * np, hn]
|
242 |
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
|
|
|
157 |
)
|
158 |
|
159 |
|
160 |
+
@torch.jit.script
|
161 |
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
162 |
# x: [sq, b, np, hn]
|
163 |
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
|
|
236 |
# Raw attention scores
|
237 |
|
238 |
# [b, np, sq, sk]
|
239 |
+
output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
|
240 |
|
241 |
# [sq, b, np, hn] -> [sq, b * np, hn]
|
242 |
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
|