Fix attention score on mps
Browse files- modeling_chatglm.py +2 -4
modeling_chatglm.py
CHANGED
@@ -280,10 +280,8 @@ def attention_fn(
|
|
280 |
# [sk, b, np, hn] -> [sk, b * np, hn]
|
281 |
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
|
282 |
|
283 |
-
matmul_result = torch.
|
284 |
-
|
285 |
-
output_size[2],
|
286 |
-
output_size[3],
|
287 |
dtype=query_layer.dtype,
|
288 |
device=query_layer.device,
|
289 |
)
|
|
|
280 |
# [sk, b, np, hn] -> [sk, b * np, hn]
|
281 |
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
|
282 |
|
283 |
+
matmul_result = torch.zeros(
|
284 |
+
1, 1, 1,
|
|
|
|
|
285 |
dtype=query_layer.dtype,
|
286 |
device=query_layer.device,
|
287 |
)
|