stallone commited on
Commit
3872a50
1 Parent(s): 04ad1ab

Merged mosaicml/mpt-7b@79ec93 into main

Browse files
Files changed (5) hide show
  1. attention.py +49 -34
  2. blocks.py +4 -4
  3. configuration_mpt.py +1 -1
  4. custom_embedding.py +11 -0
  5. modeling_mpt.py +37 -9
attention.py CHANGED
@@ -17,25 +17,34 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_cau
17
  return False
18
  return original_is_causal
19
 
20
- def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
21
  q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
22
- k = rearrange(key, 'b s (h d) -> b h d s', h=1 if multiquery else n_heads)
23
- v = rearrange(value, 'b s (h d) -> b h s d', h=1 if multiquery else n_heads)
24
- min_val = torch.finfo(q.dtype).min
 
 
 
 
 
25
  (b, _, s_q, d) = q.shape
26
  s_k = k.size(-1)
27
  if softmax_scale is None:
28
  softmax_scale = 1 / math.sqrt(d)
29
  attn_weight = q.matmul(k) * softmax_scale
30
  if attn_bias is not None:
 
 
 
31
  if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
32
  raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
33
  attn_weight = attn_weight + attn_bias
 
34
  if key_padding_mask is not None:
35
  if attn_bias is not None:
36
  warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
37
  attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
38
- if is_causal:
39
  s = max(s_q, s_k)
40
  causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
41
  causal_mask = causal_mask.tril()
@@ -46,11 +55,11 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_s
46
  attn_weight = torch.softmax(attn_weight, dim=-1)
47
  if dropout_p:
48
  attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
49
- out = attn_weight.matmul(v)
50
  out = rearrange(out, 'b h s d -> b s (h d)')
51
  if needs_weights:
52
- return (out, attn_weight)
53
- return (out, None)
54
 
55
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
56
  for tensor in tensors:
@@ -59,12 +68,21 @@ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
59
  if not tensor.is_cuda:
60
  raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
61
 
62
- def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
63
  try:
64
  from flash_attn import bert_padding, flash_attn_interface
65
  except:
66
  raise RuntimeError('Please install flash-attn==1.0.3.post0')
67
  check_valid_inputs(query, key, value)
 
 
 
 
 
 
 
 
 
68
  if attn_bias is not None:
69
  raise NotImplementedError(f'attn_bias not implemented for flash attn.')
70
  (batch_size, seqlen) = query.shape[:2]
@@ -84,9 +102,9 @@ def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None
84
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
85
  output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
86
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
87
- return (output, None)
88
 
89
- def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
90
  try:
91
  from .flash_attn_triton import flash_attn_func
92
  except:
@@ -100,6 +118,15 @@ def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bi
100
  if not _installed:
101
  raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
102
  check_valid_inputs(query, key, value)
 
 
 
 
 
 
 
 
 
103
  if dropout_p:
104
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
105
  if needs_weights:
@@ -119,7 +146,7 @@ def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bi
119
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
120
  attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
121
  output = attn_output.view(*attn_output.shape[:2], -1)
122
- return (output, None)
123
 
124
  class MultiheadAttention(nn.Module):
125
  """Multi-head self attention.
@@ -128,7 +155,7 @@ class MultiheadAttention(nn.Module):
128
  additive bias.
129
  """
130
 
131
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
132
  super().__init__()
133
  self.attn_impl = attn_impl
134
  self.clip_qkv = clip_qkv
@@ -150,10 +177,11 @@ class MultiheadAttention(nn.Module):
150
  self.attn_fn = flash_attn_fn
151
  elif self.attn_impl == 'triton':
152
  self.attn_fn = triton_flash_attn_fn
153
- warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
 
154
  elif self.attn_impl == 'torch':
155
  self.attn_fn = scaled_multihead_dot_product_attention
156
- if torch.cuda.is_available():
157
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
158
  else:
159
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
@@ -170,14 +198,7 @@ class MultiheadAttention(nn.Module):
170
  dtype = query.dtype
171
  query = self.q_ln(query).to(dtype)
172
  key = self.k_ln(key).to(dtype)
173
- if past_key_value is not None:
174
- if len(past_key_value) != 0:
175
- key = torch.cat([past_key_value[0], key], dim=1)
176
- value = torch.cat([past_key_value[1], value], dim=1)
177
- past_key_value = (key, value)
178
- if attn_bias is not None:
179
- attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
180
- (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
181
  return (self.out_proj(context), attn_weights, past_key_value)
182
 
183
  class MultiQueryAttention(nn.Module):
@@ -187,7 +208,7 @@ class MultiQueryAttention(nn.Module):
187
  additive bias.
188
  """
189
 
190
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
191
  super().__init__()
192
  self.attn_impl = attn_impl
193
  self.clip_qkv = clip_qkv
@@ -210,10 +231,11 @@ class MultiQueryAttention(nn.Module):
210
  self.attn_fn = flash_attn_fn
211
  elif self.attn_impl == 'triton':
212
  self.attn_fn = triton_flash_attn_fn
213
- warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
 
214
  elif self.attn_impl == 'torch':
215
  self.attn_fn = scaled_multihead_dot_product_attention
216
- if torch.cuda.is_available():
217
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
218
  else:
219
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
@@ -230,14 +252,7 @@ class MultiQueryAttention(nn.Module):
230
  dtype = query.dtype
231
  query = self.q_ln(query).to(dtype)
232
  key = self.k_ln(key).to(dtype)
233
- if past_key_value is not None:
234
- if len(past_key_value) != 0:
235
- key = torch.cat([past_key_value[0], key], dim=1)
236
- value = torch.cat([past_key_value[1], value], dim=1)
237
- past_key_value = (key, value)
238
- if attn_bias is not None:
239
- attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
240
- (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
241
  return (self.out_proj(context), attn_weights, past_key_value)
242
 
243
  def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
 
17
  return False
18
  return original_is_causal
19
 
20
+ def scaled_multihead_dot_product_attention(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
21
  q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
22
+ kv_n_heads = 1 if multiquery else n_heads
23
+ k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
24
+ v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
25
+ if past_key_value is not None:
26
+ if len(past_key_value) != 0:
27
+ k = torch.cat([past_key_value[0], k], dim=3)
28
+ v = torch.cat([past_key_value[1], v], dim=2)
29
+ past_key_value = (k, v)
30
  (b, _, s_q, d) = q.shape
31
  s_k = k.size(-1)
32
  if softmax_scale is None:
33
  softmax_scale = 1 / math.sqrt(d)
34
  attn_weight = q.matmul(k) * softmax_scale
35
  if attn_bias is not None:
36
+ _s_q = max(0, attn_bias.size(2) - s_q)
37
+ _s_k = max(0, attn_bias.size(3) - s_k)
38
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
39
  if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
40
  raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
41
  attn_weight = attn_weight + attn_bias
42
+ min_val = torch.finfo(q.dtype).min
43
  if key_padding_mask is not None:
44
  if attn_bias is not None:
45
  warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
46
  attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
47
+ if is_causal and (not q.size(2) == 1):
48
  s = max(s_q, s_k)
49
  causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
50
  causal_mask = causal_mask.tril()
 
55
  attn_weight = torch.softmax(attn_weight, dim=-1)
56
  if dropout_p:
57
  attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
58
+ out = attn_weight.to(v.dtype).matmul(v)
59
  out = rearrange(out, 'b h s d -> b s (h d)')
60
  if needs_weights:
61
+ return (out, attn_weight, past_key_value)
62
+ return (out, None, past_key_value)
63
 
64
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
65
  for tensor in tensors:
 
68
  if not tensor.is_cuda:
69
  raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
70
 
71
+ def flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
72
  try:
73
  from flash_attn import bert_padding, flash_attn_interface
74
  except:
75
  raise RuntimeError('Please install flash-attn==1.0.3.post0')
76
  check_valid_inputs(query, key, value)
77
+ if past_key_value is not None:
78
+ if len(past_key_value) != 0:
79
+ key = torch.cat([past_key_value[0], key], dim=1)
80
+ value = torch.cat([past_key_value[1], value], dim=1)
81
+ past_key_value = (key, value)
82
+ if attn_bias is not None:
83
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
84
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
85
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
86
  if attn_bias is not None:
87
  raise NotImplementedError(f'attn_bias not implemented for flash attn.')
88
  (batch_size, seqlen) = query.shape[:2]
 
102
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
103
  output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
104
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
105
+ return (output, None, past_key_value)
106
 
107
+ def triton_flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
108
  try:
109
  from .flash_attn_triton import flash_attn_func
110
  except:
 
118
  if not _installed:
119
  raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
120
  check_valid_inputs(query, key, value)
121
+ if past_key_value is not None:
122
+ if len(past_key_value) != 0:
123
+ key = torch.cat([past_key_value[0], key], dim=1)
124
+ value = torch.cat([past_key_value[1], value], dim=1)
125
+ past_key_value = (key, value)
126
+ if attn_bias is not None:
127
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
128
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
129
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
130
  if dropout_p:
131
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
132
  if needs_weights:
 
146
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
147
  attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
148
  output = attn_output.view(*attn_output.shape[:2], -1)
149
+ return (output, None, past_key_value)
150
 
151
  class MultiheadAttention(nn.Module):
152
  """Multi-head self attention.
 
155
  additive bias.
156
  """
157
 
158
+ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None):
159
  super().__init__()
160
  self.attn_impl = attn_impl
161
  self.clip_qkv = clip_qkv
 
177
  self.attn_fn = flash_attn_fn
178
  elif self.attn_impl == 'triton':
179
  self.attn_fn = triton_flash_attn_fn
180
+ if verbose:
181
+ warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
182
  elif self.attn_impl == 'torch':
183
  self.attn_fn = scaled_multihead_dot_product_attention
184
+ if torch.cuda.is_available() and verbose:
185
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
186
  else:
187
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
 
198
  dtype = query.dtype
199
  query = self.q_ln(query).to(dtype)
200
  key = self.k_ln(key).to(dtype)
201
+ (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
 
 
 
 
 
 
 
202
  return (self.out_proj(context), attn_weights, past_key_value)
203
 
204
  class MultiQueryAttention(nn.Module):
 
208
  additive bias.
209
  """
210
 
211
+ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None):
212
  super().__init__()
213
  self.attn_impl = attn_impl
214
  self.clip_qkv = clip_qkv
 
231
  self.attn_fn = flash_attn_fn
232
  elif self.attn_impl == 'triton':
233
  self.attn_fn = triton_flash_attn_fn
234
+ if verbose:
235
+ warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
236
  elif self.attn_impl == 'torch':
237
  self.attn_fn = scaled_multihead_dot_product_attention
238
+ if torch.cuda.is_available() and verbose:
239
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
240
  else:
241
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
 
252
  dtype = query.dtype
253
  query = self.q_ln(query).to(dtype)
254
  key = self.k_ln(key).to(dtype)
255
+ (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
 
 
 
 
 
 
 
256
  return (self.out_proj(context), attn_weights, past_key_value)
257
 
258
  def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
blocks.py CHANGED
@@ -19,13 +19,13 @@ class MPTMLP(nn.Module):
19
 
20
  class MPTBlock(nn.Module):
21
 
22
- def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
23
  del kwargs
24
  super().__init__()
25
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
26
  attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
27
  self.norm_1 = norm_class(d_model, device=device)
28
- self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device)
29
  self.norm_2 = norm_class(d_model, device=device)
30
  self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
31
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
@@ -33,9 +33,9 @@ class MPTBlock(nn.Module):
33
 
34
  def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35
  a = self.norm_1(x)
36
- (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37
  x = x + self.resid_attn_dropout(b)
38
  m = self.norm_2(x)
39
  n = self.ffn(m)
40
  x = x + self.resid_ffn_dropout(n)
41
- return (x, past_key_value)
 
19
 
20
  class MPTBlock(nn.Module):
21
 
22
+ def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs):
23
  del kwargs
24
  super().__init__()
25
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
26
  attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
27
  self.norm_1 = norm_class(d_model, device=device)
28
+ self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device)
29
  self.norm_2 = norm_class(d_model, device=device)
30
  self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
31
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
 
33
 
34
  def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35
  a = self.norm_1(x)
36
+ (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37
  x = x + self.resid_attn_dropout(b)
38
  m = self.norm_2(x)
39
  n = self.ffn(m)
40
  x = x + self.resid_ffn_dropout(n)
41
+ return (x, attn_weights, past_key_value)
configuration_mpt.py CHANGED
@@ -2,7 +2,7 @@
2
  from typing import Dict, Optional, Union
3
  from transformers import PretrainedConfig
4
  attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
5
- init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu'}
6
 
7
  class MPTConfig(PretrainedConfig):
8
  model_type = 'mpt'
 
2
  from typing import Dict, Optional, Union
3
  from transformers import PretrainedConfig
4
  attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
5
+ init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0}
6
 
7
  class MPTConfig(PretrainedConfig):
8
  model_type = 'mpt'
custom_embedding.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import Tensor
5
+
6
+ class SharedEmbedding(nn.Embedding):
7
+
8
+ def forward(self, input: Tensor, unembed: bool=False) -> Tensor:
9
+ if unembed:
10
+ return F.linear(input, self.weight)
11
+ return super().forward(input)
modeling_mpt.py CHANGED
@@ -12,17 +12,23 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokeniz
12
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
  from .attention import attn_bias_shape, build_attn_bias
14
  from .blocks import MPTBlock
 
15
  from .norm import NORM_CLASS_REGISTRY
16
  from .configuration_mpt import MPTConfig
17
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
18
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
19
  from .meta_init_context import init_empty_weights
20
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
 
 
 
 
21
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
22
 
23
  class MPTPreTrainedModel(PreTrainedModel):
24
  config_class = MPTConfig
25
  base_model_prefix = 'model'
 
26
 
27
  class MPTModel(MPTPreTrainedModel):
28
 
@@ -34,18 +40,24 @@ class MPTModel(MPTPreTrainedModel):
34
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
35
  self.alibi = config.attn_config['alibi']
36
  self.alibi_bias_max = config.attn_config['alibi_bias_max']
 
 
 
 
 
37
  if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
38
  norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
39
  raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
40
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
41
  self.embedding_fraction = config.embedding_fraction
42
- self.wte = nn.Embedding(config.vocab_size, config.d_model, device=config.init_device)
43
  if not self.alibi:
44
- self.wpe = nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
45
  self.emb_drop = nn.Dropout(config.emb_pdrop)
46
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
47
  self.norm_f = norm_class(config.d_model, device=config.init_device)
48
  if config.init_device != 'meta':
 
49
  self.apply(self.param_init_fn)
50
  self.is_causal = not self.prefix_lm
51
  self._attn_bias_initialized = False
@@ -95,7 +107,8 @@ class MPTModel(MPTPreTrainedModel):
95
  if attn_bias is None:
96
  attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
97
  else:
98
- attn_bias = attn_bias[:, :, :, -s_k:]
 
99
  if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
100
  raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
101
  min_val = torch.finfo(attn_bias.dtype).min
@@ -137,7 +150,8 @@ class MPTModel(MPTPreTrainedModel):
137
  if not return_dict:
138
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
139
  if output_attentions:
140
- raise NotImplementedError('output_attentions is not implemented yet for MPT')
 
141
  if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
142
  raise NotImplementedError('MPT does not support training with left padding.')
143
  if self.prefix_lm and prefix_mask is None:
@@ -158,6 +172,8 @@ class MPTModel(MPTPreTrainedModel):
158
  if len(past_key_values) != self.config.n_layers:
159
  raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
160
  past_position = past_key_values[0][0].size(1)
 
 
161
  if S + past_position > self.config.max_seq_len:
162
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
163
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
@@ -171,20 +187,27 @@ class MPTModel(MPTPreTrainedModel):
171
  x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
172
  assert isinstance(self.emb_drop, nn.Module)
173
  x = self.emb_drop(x_shrunk)
174
- (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
175
  if use_cache and past_key_values is None:
176
  past_key_values = [() for _ in range(self.config.n_layers)]
177
  all_hidden_states = () if output_hidden_states else None
 
178
  for (b_idx, block) in enumerate(self.blocks):
179
  if output_hidden_states:
180
  assert all_hidden_states is not None
181
  all_hidden_states = all_hidden_states + (x,)
182
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
183
- (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
184
  if past_key_values is not None:
185
  past_key_values[b_idx] = past_key_value
 
 
 
186
  x = self.norm_f(x)
187
- return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)
 
 
 
188
 
189
  def param_init_fn(self, module):
190
  init_fn_name = self.config.init_config['name']
@@ -203,6 +226,11 @@ class MPTForCausalLM(MPTPreTrainedModel):
203
  if not config.tie_word_embeddings:
204
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
205
  self.transformer = MPTModel(config)
 
 
 
 
 
206
  self.logit_scale = None
207
  if config.logit_scale is not None:
208
  logit_scale = config.logit_scale
@@ -235,7 +263,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
235
  return_dict = return_dict if return_dict is not None else self.config.return_dict
236
  use_cache = use_cache if use_cache is not None else self.config.use_cache
237
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
238
- logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
239
  if self.logit_scale is not None:
240
  if self.logit_scale == 0:
241
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
@@ -245,7 +273,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
245
  labels = torch.roll(labels, shifts=-1)
246
  labels[:, -1] = -100
247
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
248
- return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
249
 
250
  def param_init_fn(self, module):
251
  init_fn_name = self.config.init_config['name']
 
12
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
  from .attention import attn_bias_shape, build_attn_bias
14
  from .blocks import MPTBlock
15
+ from .custom_embedding import SharedEmbedding
16
  from .norm import NORM_CLASS_REGISTRY
17
  from .configuration_mpt import MPTConfig
18
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
19
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
20
  from .meta_init_context import init_empty_weights
21
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
22
+ try:
23
+ from .flash_attn_triton import flash_attn_func
24
+ except:
25
+ pass
26
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
27
 
28
  class MPTPreTrainedModel(PreTrainedModel):
29
  config_class = MPTConfig
30
  base_model_prefix = 'model'
31
+ _no_split_modules = ['MPTBlock']
32
 
33
  class MPTModel(MPTPreTrainedModel):
34
 
 
40
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
41
  self.alibi = config.attn_config['alibi']
42
  self.alibi_bias_max = config.attn_config['alibi_bias_max']
43
+ if config.init_device == 'mixed':
44
+ if dist.get_local_rank() == 0:
45
+ config.init_device = 'cpu'
46
+ else:
47
+ config.init_device = 'meta'
48
  if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
49
  norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
50
  raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
51
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
52
  self.embedding_fraction = config.embedding_fraction
53
+ self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
54
  if not self.alibi:
55
+ self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
56
  self.emb_drop = nn.Dropout(config.emb_pdrop)
57
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
58
  self.norm_f = norm_class(config.d_model, device=config.init_device)
59
  if config.init_device != 'meta':
60
+ print(f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.')
61
  self.apply(self.param_init_fn)
62
  self.is_causal = not self.prefix_lm
63
  self._attn_bias_initialized = False
 
107
  if attn_bias is None:
108
  attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
109
  else:
110
+ _s_k = max(0, attn_bias.size(-1) - s_k)
111
+ attn_bias = attn_bias[:, :, :, _s_k:]
112
  if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
113
  raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
114
  min_val = torch.finfo(attn_bias.dtype).min
 
150
  if not return_dict:
151
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
152
  if output_attentions:
153
+ if self.attn_impl != 'torch':
154
+ raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
155
  if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
156
  raise NotImplementedError('MPT does not support training with left padding.')
157
  if self.prefix_lm and prefix_mask is None:
 
172
  if len(past_key_values) != self.config.n_layers:
173
  raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
174
  past_position = past_key_values[0][0].size(1)
175
+ if self.attn_impl == 'torch':
176
+ past_position = past_key_values[0][0].size(3)
177
  if S + past_position > self.config.max_seq_len:
178
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
179
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
 
187
  x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
188
  assert isinstance(self.emb_drop, nn.Module)
189
  x = self.emb_drop(x_shrunk)
190
+ (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
191
  if use_cache and past_key_values is None:
192
  past_key_values = [() for _ in range(self.config.n_layers)]
193
  all_hidden_states = () if output_hidden_states else None
194
+ all_self_attns = () if output_attentions else None
195
  for (b_idx, block) in enumerate(self.blocks):
196
  if output_hidden_states:
197
  assert all_hidden_states is not None
198
  all_hidden_states = all_hidden_states + (x,)
199
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
200
+ (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
201
  if past_key_values is not None:
202
  past_key_values[b_idx] = past_key_value
203
+ if output_attentions:
204
+ assert all_self_attns is not None
205
+ all_self_attns = all_self_attns + (attn_weights,)
206
  x = self.norm_f(x)
207
+ if output_hidden_states:
208
+ assert all_hidden_states is not None
209
+ all_hidden_states = all_hidden_states + (x,)
210
+ return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns)
211
 
212
  def param_init_fn(self, module):
213
  init_fn_name = self.config.init_config['name']
 
226
  if not config.tie_word_embeddings:
227
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
228
  self.transformer = MPTModel(config)
229
+ for child in self.transformer.children():
230
+ if isinstance(child, torch.nn.ModuleList):
231
+ continue
232
+ if isinstance(child, torch.nn.Module):
233
+ child._fsdp_wrap = True
234
  self.logit_scale = None
235
  if config.logit_scale is not None:
236
  logit_scale = config.logit_scale
 
263
  return_dict = return_dict if return_dict is not None else self.config.return_dict
264
  use_cache = use_cache if use_cache is not None else self.config.use_cache
265
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
266
+ logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
267
  if self.logit_scale is not None:
268
  if self.logit_scale == 0:
269
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
 
273
  labels = torch.roll(labels, shifts=-1)
274
  labels[:, -1] = -100
275
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
276
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
277
 
278
  def param_init_fn(self, module):
279
  init_fn_name = self.config.init_config['name']