Crystalcareai
commited on
Commit
•
68b3eda
1
Parent(s):
6dc0ddc
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +1 -1
modeling_gemmoe.py
CHANGED
@@ -711,7 +711,7 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
711 |
for i in range(self.num_experts):
|
712 |
expert = self.experts[i]
|
713 |
expert_output = expert(hidden_states[flat_topk_idx == i])
|
714 |
-
y[flat_topk_idx == i] = expert_output
|
715 |
|
716 |
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
717 |
|
|
|
711 |
for i in range(self.num_experts):
|
712 |
expert = self.experts[i]
|
713 |
expert_output = expert(hidden_states[flat_topk_idx == i])
|
714 |
+
y[flat_topk_idx == i] = expert_output.to(y.dtype) # Cast expert_output to the same dtype as y
|
715 |
|
716 |
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
717 |
|