Update modeling_phi.py
Browse files- modeling_phi.py +3 -7
modeling_phi.py
CHANGED
@@ -294,15 +294,11 @@ class MoE(nn.Module):
|
|
294 |
def __init__(
|
295 |
self,
|
296 |
config: PretrainedConfig,
|
297 |
-
num_experts=2,
|
298 |
-
num_experts_per_tok=2,
|
299 |
-
num_shards=1,
|
300 |
-
**kwargs,
|
301 |
):
|
302 |
super().__init__()
|
303 |
-
self.mlp = nn.ModuleList([MLP(config) for i in range(
|
304 |
-
self.gate = nn.Linear(config.n_embd,
|
305 |
-
self.num_experts_per_tok = num_experts_per_tok
|
306 |
|
307 |
def forward(self, x):
|
308 |
orig_shape = x.shape
|
|
|
294 |
def __init__(
|
295 |
self,
|
296 |
config: PretrainedConfig,
|
|
|
|
|
|
|
|
|
297 |
):
|
298 |
super().__init__()
|
299 |
+
self.mlp = nn.ModuleList([MLP(config) for i in range(config.num_local_experts)])
|
300 |
+
self.gate = nn.Linear(config.n_embd, config.num_local_experts, bias=False)
|
301 |
+
self.num_experts_per_tok = config.num_experts_per_tok
|
302 |
|
303 |
def forward(self, x):
|
304 |
orig_shape = x.shape
|