qfournier commited on
Commit
3ca1265
1 Parent(s): b38fe51

Add support for CPU

Browse files
Files changed (1) hide show
  1. amplify.py +68 -27
amplify.py CHANGED
@@ -4,7 +4,7 @@
4
 
5
  import torch
6
  from torch import nn
7
-
8
  from xformers.ops import SwiGLU, memory_efficient_attention
9
 
10
  from .rmsnorm import RMSNorm
@@ -13,6 +13,7 @@ from .rotary import precompute_freqs_cis, apply_rotary_emb
13
  from transformers import PreTrainedModel, PretrainedConfig
14
  from transformers.modeling_outputs import MaskedLMOutput
15
 
 
16
  class DotDict(dict):
17
  """Dictionary that supports the dot notation to access attributes (similarly to HuggingFace)."""
18
 
@@ -20,8 +21,10 @@ class DotDict(dict):
20
  __setattr__ = dict.__setitem__
21
  __delattr__ = dict.__delitem__
22
 
 
23
  class AMPLIFYConfig(PretrainedConfig):
24
  model_type = "AMPLIFY"
 
25
  # All config parameters must have a default value.
26
  def __init__(
27
  self,
@@ -45,7 +48,7 @@ class AMPLIFYConfig(PretrainedConfig):
45
  **kwargs,
46
  ):
47
  super().__init__(**kwargs)
48
-
49
  self.hidden_size = hidden_size
50
  self.num_hidden_layers = num_hidden_layers
51
  self.num_attention_heads = num_attention_heads
@@ -63,7 +66,7 @@ class AMPLIFYConfig(PretrainedConfig):
63
  self.att_bias = att_bias
64
  self.pad_token_id = pad_token_id
65
  self.max_length = max_length
66
-
67
 
68
  class EncoderBlock(nn.Module):
69
  """Transformer encoder block."""
@@ -119,8 +122,16 @@ class EncoderBlock(nn.Module):
119
  nn.Linear(config.intermediate_size, config.hidden_size, bias=config.ffn_bias),
120
  )
121
 
122
- self.attention_norm = RMSNorm(config.hidden_size, config.norm_eps) if config.rms_norm else nn.LayerNorm(config.hidden_size, config.norm_eps)
123
- self.ffn_norm = RMSNorm(config.hidden_size, config.norm_eps) if config.rms_norm else nn.LayerNorm(config.hidden_size, config.norm_eps)
 
 
 
 
 
 
 
 
124
 
125
  self.ffn_dropout = nn.Dropout(config.dropout_prob)
126
 
@@ -130,7 +141,9 @@ class EncoderBlock(nn.Module):
130
  x = x + self._ff_block(self.ffn_norm(x))
131
  return x, contact
132
 
133
- def _att_block(self, x: torch.Tensor, attention_mask: torch.Tensor, freqs_cis: torch.Tensor, output_attentions: bool):
 
 
134
  batch_size, seq_len, _ = x.shape
135
  xq, xk, xv = self.q(x), self.k(x), self.v(x)
136
 
@@ -140,22 +153,37 @@ class EncoderBlock(nn.Module):
140
  xv = xv.view(batch_size, seq_len, self.config.num_attention_heads, self.d_head)
141
  xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
142
 
143
- attn = memory_efficient_attention(
144
- query=xq,
145
- key=xk,
146
- value=xv,
147
- attn_bias=attention_mask,
148
- p=self.config.dropout_prob if self.training else 0,
149
- )
150
-
151
- _attn = None
152
  if output_attentions:
153
- _attn = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
154
  if attention_mask is not None:
155
- _attn = _attn + attention_mask
156
- _attn = _attn.softmax(-1)
157
-
158
- return self.resid_dropout(self.wo(attn.view(batch_size, seq_len, self.config.num_attention_heads * self.d_head))), _attn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  def _ff_block(self, x: torch.Tensor):
161
  return self.ffn_dropout(self.ffn(x))
@@ -176,9 +204,10 @@ class AMPLIFYPreTrainedModel(PreTrainedModel):
176
  class AMPLIFY(AMPLIFYPreTrainedModel):
177
  """The main model class.
178
 
179
- Args:
180
- config (amplify.model.amplify.AMPLIFYConfig): model configuration, usually defined from the Hydra configuration.
181
  """
 
182
  def __init__(self, config: AMPLIFYConfig, **kwargs):
183
  super().__init__(config)
184
 
@@ -187,19 +216,27 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
187
  self.encoder = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
188
 
189
  if config.layer_norm_after_embedding:
190
- self.layer_norm_1 = RMSNorm(config.hidden_size, config.norm_eps) if config.rms_norm else nn.LayerNorm(config.hidden_size, config.norm_eps)
 
 
 
 
191
 
192
  self.transformer_encoder = nn.ModuleList()
193
  for _ in range(config.num_hidden_layers):
194
  self.transformer_encoder.append(EncoderBlock(config))
195
 
196
  if config.layer_norm_before_last_layer:
197
- self.layer_norm_2 = RMSNorm(config.hidden_size, config.norm_eps) if config.rms_norm else nn.LayerNorm(config.hidden_size, config.norm_eps)
 
 
 
 
198
 
199
  self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
200
 
201
  self.freqs_cis = precompute_freqs_cis(config.hidden_size // config.num_attention_heads, config.max_length)
202
-
203
  # Initialize weights and apply final processing
204
  self.post_init()
205
 
@@ -209,7 +246,11 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
209
 
210
  # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
211
  if attention_mask is not None and not torch.all(attention_mask == 0):
212
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
 
 
 
 
213
  else:
214
  attention_mask = None
215
 
@@ -234,4 +275,4 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
234
  logits = self.decoder(self.layer_norm_2(x) if self.config.layer_norm_before_last_layer else x)
235
 
236
  # Return logits or the output of the last hidden layer
237
- return MaskedLMOutput(logits=logits, hidden_states=hidden_states, attentions=attentions)
 
4
 
5
  import torch
6
  from torch import nn
7
+ from torch.nn.functional import scaled_dot_product_attention
8
  from xformers.ops import SwiGLU, memory_efficient_attention
9
 
10
  from .rmsnorm import RMSNorm
 
13
  from transformers import PreTrainedModel, PretrainedConfig
14
  from transformers.modeling_outputs import MaskedLMOutput
15
 
16
+
17
  class DotDict(dict):
18
  """Dictionary that supports the dot notation to access attributes (similarly to HuggingFace)."""
19
 
 
21
  __setattr__ = dict.__setitem__
22
  __delattr__ = dict.__delitem__
23
 
24
+
25
  class AMPLIFYConfig(PretrainedConfig):
26
  model_type = "AMPLIFY"
27
+
28
  # All config parameters must have a default value.
29
  def __init__(
30
  self,
 
48
  **kwargs,
49
  ):
50
  super().__init__(**kwargs)
51
+
52
  self.hidden_size = hidden_size
53
  self.num_hidden_layers = num_hidden_layers
54
  self.num_attention_heads = num_attention_heads
 
66
  self.att_bias = att_bias
67
  self.pad_token_id = pad_token_id
68
  self.max_length = max_length
69
+
70
 
71
  class EncoderBlock(nn.Module):
72
  """Transformer encoder block."""
 
122
  nn.Linear(config.intermediate_size, config.hidden_size, bias=config.ffn_bias),
123
  )
124
 
125
+ self.attention_norm = (
126
+ RMSNorm(config.hidden_size, config.norm_eps)
127
+ if config.rms_norm
128
+ else nn.LayerNorm(config.hidden_size, config.norm_eps)
129
+ )
130
+ self.ffn_norm = (
131
+ RMSNorm(config.hidden_size, config.norm_eps)
132
+ if config.rms_norm
133
+ else nn.LayerNorm(config.hidden_size, config.norm_eps)
134
+ )
135
 
136
  self.ffn_dropout = nn.Dropout(config.dropout_prob)
137
 
 
141
  x = x + self._ff_block(self.ffn_norm(x))
142
  return x, contact
143
 
144
+ def _att_block(
145
+ self, x: torch.Tensor, attention_mask: torch.Tensor, freqs_cis: torch.Tensor, output_attentions: bool
146
+ ):
147
  batch_size, seq_len, _ = x.shape
148
  xq, xk, xv = self.q(x), self.k(x), self.v(x)
149
 
 
153
  xv = xv.view(batch_size, seq_len, self.config.num_attention_heads, self.d_head)
154
  xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
155
 
156
+ # Compute the attention weight
157
+ attn_weights = None
 
 
 
 
 
 
 
158
  if output_attentions:
159
+ attn_weights = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
160
  if attention_mask is not None:
161
+ attn_weights = attn_weights + attention_mask
162
+ attn_weights = attn_weights.softmax(-1)
163
+
164
+ # Compute the attention using xformers if the tensors are on GPU
165
+ if x.is_cuda:
166
+ # Input and output are of dimension (B, M, H, K) where B is the batch size, M the sequence length,
167
+ # H the number of heads, and K the embeding size per head
168
+ attn = memory_efficient_attention(
169
+ query=xq,
170
+ key=xk,
171
+ value=xv,
172
+ attn_bias=attention_mask,
173
+ p=self.config.dropout_prob if self.training else 0,
174
+ )
175
+ else:
176
+ # Input and output are of dimension (B, H, M, K)
177
+ attn = scaled_dot_product_attention(
178
+ query=xq.transpose(1, 2),
179
+ key=xk.transpose(1, 2),
180
+ value=xv.transpose(1, 2),
181
+ attn_mask=attention_mask,
182
+ dropout_p=self.config.dropout_prob if self.training else 0,
183
+ ).transpose(1, 2)
184
+
185
+ attn_scores = self.wo(attn.view(batch_size, seq_len, self.config.num_attention_heads * self.d_head))
186
+ return (self.resid_dropout(attn_scores), attn_weights)
187
 
188
  def _ff_block(self, x: torch.Tensor):
189
  return self.ffn_dropout(self.ffn(x))
 
204
  class AMPLIFY(AMPLIFYPreTrainedModel):
205
  """The main model class.
206
 
207
+ Args:
208
+ config (amplify.model.amplify.AMPLIFYConfig): model configuration, usually defined from the Hydra configuration.
209
  """
210
+
211
  def __init__(self, config: AMPLIFYConfig, **kwargs):
212
  super().__init__(config)
213
 
 
216
  self.encoder = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
217
 
218
  if config.layer_norm_after_embedding:
219
+ self.layer_norm_1 = (
220
+ RMSNorm(config.hidden_size, config.norm_eps)
221
+ if config.rms_norm
222
+ else nn.LayerNorm(config.hidden_size, config.norm_eps)
223
+ )
224
 
225
  self.transformer_encoder = nn.ModuleList()
226
  for _ in range(config.num_hidden_layers):
227
  self.transformer_encoder.append(EncoderBlock(config))
228
 
229
  if config.layer_norm_before_last_layer:
230
+ self.layer_norm_2 = (
231
+ RMSNorm(config.hidden_size, config.norm_eps)
232
+ if config.rms_norm
233
+ else nn.LayerNorm(config.hidden_size, config.norm_eps)
234
+ )
235
 
236
  self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
237
 
238
  self.freqs_cis = precompute_freqs_cis(config.hidden_size // config.num_attention_heads, config.max_length)
239
+
240
  # Initialize weights and apply final processing
241
  self.post_init()
242
 
 
246
 
247
  # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
248
  if attention_mask is not None and not torch.all(attention_mask == 0):
249
+ attention_mask = (
250
+ attention_mask.unsqueeze(1)
251
+ .unsqueeze(1)
252
+ .repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
253
+ )
254
  else:
255
  attention_mask = None
256
 
 
275
  logits = self.decoder(self.layer_norm_2(x) if self.config.layer_norm_before_last_layer else x)
276
 
277
  # Return logits or the output of the last hidden layer
278
+ return MaskedLMOutput(logits=logits, hidden_states=hidden_states, attentions=attentions)