Update modeling_Llamoe.py
Browse files- modeling_Llamoe.py +2 -1
modeling_Llamoe.py
CHANGED
@@ -53,7 +53,8 @@ if is_torch_fx_available():
|
|
53 |
|
54 |
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
|
55 |
|
56 |
-
|
|
|
57 |
|
58 |
def load_balancing_loss_func(
|
59 |
gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
|
|
|
53 |
|
54 |
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
|
55 |
|
56 |
+
def approx_gelu(x):
|
57 |
+
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))
|
58 |
|
59 |
def load_balancing_loss_func(
|
60 |
gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
|