kaizen9 commited on
Commit
44f8bb2
1 Parent(s): b58f295

Update blocks.py

Browse files
Files changed (1) hide show
  1. blocks.py +21 -7
blocks.py CHANGED
@@ -5,12 +5,17 @@ import torch.nn as nn
5
  from .attention import ATTN_CLASS_REGISTRY
6
  from .ffn import FFN_CLASS_REGISTRY, build_ffn
7
  from .norm import NORM_CLASS_REGISTRY
 
 
 
 
 
8
 
9
  class MPTBlock(nn.Module):
10
 
11
- def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Optional[Dict]=None, ffn_config: Optional[Dict]=None, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, **kwargs: Any):
12
  if attn_config is None:
13
- attn_config = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
14
  if ffn_config is None:
15
  ffn_config = {'ffn_type': 'mptmlp'}
16
  del kwargs
@@ -18,24 +23,33 @@ class MPTBlock(nn.Module):
18
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
19
  assert isinstance(attn_config['attn_type'], str)
20
  attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
21
- args_to_exclude_in_attn_class = {'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max'}
22
  attn_config_subset_for_attn_class = {k: v for (k, v) in attn_config.items() if k not in args_to_exclude_in_attn_class}
23
  self.norm_1 = norm_class(d_model, device=device)
24
- self.attn = attn_class(d_model=d_model, n_heads=n_heads, fc_type=fc_type, device=device, **attn_config_subset_for_attn_class)
25
  self.norm_2 = None
26
  if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm', False):
27
  self.norm_2 = norm_class(d_model, device=device)
28
- self.ffn = build_ffn(d_model=d_model, expansion_ratio=expansion_ratio, device=device, **ffn_config)
29
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
30
  self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
 
31
 
32
- def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True, output_attentions: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
33
  a = self.norm_1(x)
34
- (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions)
35
  x = x + self.resid_attn_dropout(b)
36
  m = x
37
  if self.norm_2 is not None:
38
  m = self.norm_2(x)
 
 
 
 
 
39
  n = self.ffn(m)
 
 
 
40
  x = x + self.resid_ffn_dropout(n)
41
  return (x, attn_weights, past_key_value)
 
5
  from .attention import ATTN_CLASS_REGISTRY
6
  from .ffn import FFN_CLASS_REGISTRY, build_ffn
7
  from .norm import NORM_CLASS_REGISTRY
8
+ try:
9
+ from flash_attn.bert_padding import unpad_input, pad_input
10
+ except:
11
+ (unpad_input, pad_input) = (None, None)
12
+ attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'qk_gn': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'sliding_window_size': -1, 'alibi': False, 'alibi_bias_max': 8, 'rope': False, 'rope_theta': 10000, 'rope_impl': 'dail', 'rope_dail_config': {'type': 'original', 'pos_idx_in_fp32': True, 'xpos_scale_base': 512}, 'rope_hf_config': {'type': 'no_scaling', 'factor': 1.0}}
13
 
14
  class MPTBlock(nn.Module):
15
 
16
+ def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Optional[Dict]=None, ffn_config: Optional[Dict]=None, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, no_bias: bool=False, use_pad_tok_in_ffn: bool=True, **kwargs: Any):
17
  if attn_config is None:
18
+ attn_config = attn_config_defaults
19
  if ffn_config is None:
20
  ffn_config = {'ffn_type': 'mptmlp'}
21
  del kwargs
 
23
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
24
  assert isinstance(attn_config['attn_type'], str)
25
  attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
26
+ args_to_exclude_in_attn_class = {'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max', 'rope', 'rope_theta', 'rope_impl', 'rope_dail_config', 'rope_hf_config'}
27
  attn_config_subset_for_attn_class = {k: v for (k, v) in attn_config.items() if k not in args_to_exclude_in_attn_class}
28
  self.norm_1 = norm_class(d_model, device=device)
29
+ self.attn = attn_class(d_model=d_model, n_heads=n_heads, fc_type=fc_type, device=device, **attn_config_subset_for_attn_class, bias=not no_bias)
30
  self.norm_2 = None
31
  if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm', False):
32
  self.norm_2 = norm_class(d_model, device=device)
33
+ self.ffn = build_ffn(d_model=d_model, expansion_ratio=expansion_ratio, device=device, bias=not no_bias, **ffn_config)
34
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
35
  self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
36
+ self.use_pad_tok_in_ffn = use_pad_tok_in_ffn
37
 
38
+ def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, rotary_emb_w_meta_info: Optional[Dict]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True, output_attentions: bool=False, alibi_slopes: Optional[torch.Tensor]=None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]]=None) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
39
  a = self.norm_1(x)
40
+ (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info)
41
  x = x + self.resid_attn_dropout(b)
42
  m = x
43
  if self.norm_2 is not None:
44
  m = self.norm_2(x)
45
+ (batch_size, seq_len) = m.size()[:2]
46
+ indices = None
47
+ if not self.use_pad_tok_in_ffn:
48
+ assert unpad_input is not None
49
+ (m, indices, _, _) = unpad_input(m, attention_mask)
50
  n = self.ffn(m)
51
+ if not self.use_pad_tok_in_ffn:
52
+ assert pad_input is not None
53
+ n = pad_input(n, indices, batch_size, seq_len)
54
  x = x + self.resid_ffn_dropout(n)
55
  return (x, attn_weights, past_key_value)