fix-glu-mlp
#17
by
michael-guenther
- opened
- mlp.py +8 -2
- modeling_bert.py +2 -1
mlp.py
CHANGED
@@ -33,6 +33,7 @@ class GLUMLP(nn.Module):
|
|
33 |
in_features,
|
34 |
hidden_features,
|
35 |
activation,
|
|
|
36 |
return_residual=False,
|
37 |
hidden_dropout_prob=0.1
|
38 |
):
|
@@ -52,14 +53,19 @@ class GLUMLP(nn.Module):
|
|
52 |
self.wo = nn.Linear(hidden_features, in_features)
|
53 |
self.dropout = nn.Dropout(hidden_dropout_prob)
|
54 |
self.return_residual = return_residual
|
|
|
55 |
#self.layernorm = nn.LayerNorm(in_features, eps=layer_norm_eps)
|
56 |
|
57 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
58 |
residual_connection = hidden_states
|
59 |
# compute the activation
|
60 |
hidden_states = self.gated_layers(hidden_states)
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
63 |
hidden_states = self.act(gated) * non_gated
|
64 |
hidden_states = self.dropout(hidden_states)
|
65 |
# multiply by the second matrix
|
|
|
33 |
in_features,
|
34 |
hidden_features,
|
35 |
activation,
|
36 |
+
use_flash_attn,
|
37 |
return_residual=False,
|
38 |
hidden_dropout_prob=0.1
|
39 |
):
|
|
|
53 |
self.wo = nn.Linear(hidden_features, in_features)
|
54 |
self.dropout = nn.Dropout(hidden_dropout_prob)
|
55 |
self.return_residual = return_residual
|
56 |
+
self.use_flash_attn = use_flash_attn
|
57 |
#self.layernorm = nn.LayerNorm(in_features, eps=layer_norm_eps)
|
58 |
|
59 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
60 |
residual_connection = hidden_states
|
61 |
# compute the activation
|
62 |
hidden_states = self.gated_layers(hidden_states)
|
63 |
+
if self.use_flash_attn:
|
64 |
+
gated = hidden_states[:, : self.hidden_features]
|
65 |
+
non_gated = hidden_states[:, self.hidden_features :]
|
66 |
+
else:
|
67 |
+
gated = hidden_states[:, :, : self.hidden_features]
|
68 |
+
non_gated = hidden_states[:, :, self.hidden_features :]
|
69 |
hidden_states = self.act(gated) * non_gated
|
70 |
hidden_states = self.dropout(hidden_states)
|
71 |
# multiply by the second matrix
|
modeling_bert.py
CHANGED
@@ -114,6 +114,7 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False):
|
|
114 |
GLUMLP,
|
115 |
hidden_features=inner_dim,
|
116 |
activation=config.hidden_act,
|
|
|
117 |
hidden_dropout_prob=config.hidden_dropout_prob,
|
118 |
return_residual=return_residual,
|
119 |
)
|
@@ -802,4 +803,4 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|
802 |
loss=masked_lm_loss,
|
803 |
prediction_logits=prediction_scores,
|
804 |
seq_relationship_logits=seq_relationship_score,
|
805 |
-
)
|
|
|
114 |
GLUMLP,
|
115 |
hidden_features=inner_dim,
|
116 |
activation=config.hidden_act,
|
117 |
+
use_flash_attn=config.use_flash_attn,
|
118 |
hidden_dropout_prob=config.hidden_dropout_prob,
|
119 |
return_residual=return_residual,
|
120 |
)
|
|
|
803 |
loss=masked_lm_loss,
|
804 |
prediction_logits=prediction_scores,
|
805 |
seq_relationship_logits=seq_relationship_score,
|
806 |
+
)
|