mlabonne commited on
Commit
ba0a899
1 Parent(s): db615a2

Update modeling_phi.py

Browse files
Files changed (1) hide show
  1. modeling_phi.py +3 -7
modeling_phi.py CHANGED
@@ -294,15 +294,11 @@ class MoE(nn.Module):
294
  def __init__(
295
  self,
296
  config: PretrainedConfig,
297
- num_experts=2,
298
- num_experts_per_tok=2,
299
- num_shards=1,
300
- **kwargs,
301
  ):
302
  super().__init__()
303
- self.mlp = nn.ModuleList([MLP(config) for i in range(num_experts)])
304
- self.gate = nn.Linear(config.n_embd, num_experts, bias=False)
305
- self.num_experts_per_tok = num_experts_per_tok
306
 
307
  def forward(self, x):
308
  orig_shape = x.shape
 
294
  def __init__(
295
  self,
296
  config: PretrainedConfig,
 
 
 
 
297
  ):
298
  super().__init__()
299
+ self.mlp = nn.ModuleList([MLP(config) for i in range(config.num_local_experts)])
300
+ self.gate = nn.Linear(config.n_embd, config.num_local_experts, bias=False)
301
+ self.num_experts_per_tok = config.num_experts_per_tok
302
 
303
  def forward(self, x):
304
  orig_shape = x.shape