kaizen9 commited on
Commit
7a7db37
1 Parent(s): 8c6dd43

Update ffn.py

Browse files
Files changed (1) hide show
  1. ffn.py +68 -14
ffn.py CHANGED
@@ -1,5 +1,8 @@
1
- """GPT Blocks used for the GPT Model."""
2
- from typing import Any, Optional
 
 
 
3
  import torch
4
  import torch.nn as nn
5
  from .fc import FC_CLASS_REGISTRY
@@ -7,33 +10,84 @@ try:
7
  import transformer_engine.pytorch as te
8
  except:
9
  te = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  class MPTMLP(nn.Module):
12
 
13
- def __init__(self, d_model: int, expansion_ratio: int, fc_type: str='torch', device: Optional[str]=None):
14
  super().__init__()
15
- fc_kwargs = {}
 
16
  if fc_type != 'te':
17
- fc_kwargs['device'] = device
18
- self.up_proj = FC_CLASS_REGISTRY[fc_type](d_model, expansion_ratio * d_model, **fc_kwargs)
19
- self.act = nn.GELU(approximate='none')
20
- self.down_proj = FC_CLASS_REGISTRY[fc_type](expansion_ratio * d_model, d_model, **fc_kwargs)
21
  self.down_proj._is_residual = True
22
 
23
  def forward(self, x: torch.Tensor) -> torch.Tensor:
24
  return self.down_proj(self.act(self.up_proj(x)))
25
- FFN_CLASS_REGISTRY = {'mptmlp': MPTMLP}
 
 
 
 
 
 
 
 
 
26
  if te is not None:
27
  te.LayerNormMLP._has_norm = True
28
  FFN_CLASS_REGISTRY['te_ln_mlp'] = te.LayerNormMLP
29
 
30
- def build_ffn(d_model: int, expansion_ratio: int, fc_type: str='torch', device: Optional[str]=None, **kwargs: Any) -> nn.Module:
31
  ffn_type = kwargs.pop('ffn_type')
32
- if ffn_type == 'mptmlp':
33
  if len(kwargs) > 0:
34
- raise ValueError(f'MPTMLP got an unexpected keyword argument: {kwargs}')
35
- return MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, fc_type=fc_type, device=device)
36
  elif ffn_type == 'te_ln_mlp':
37
  assert te is not None
38
- return te.LayerNormMLP(hidden_size=d_model, ffn_hidden_size=d_model * expansion_ratio, **kwargs)
 
 
 
39
  raise ValueError(f'ffn_type={ffn_type!r} not recognized.')
 
1
+ """MPT Blocks used for the MPT Model."""
2
+ import logging
3
+ from copy import deepcopy
4
+ from functools import partial
5
+ from typing import Any, Callable, Optional, Union
6
  import torch
7
  import torch.nn as nn
8
  from .fc import FC_CLASS_REGISTRY
 
10
  import transformer_engine.pytorch as te
11
  except:
12
  te = None
13
+ log = logging.getLogger(__name__)
14
+ _FFN_ACT_FN_DEFAULT = {'name': 'gelu', 'approximate': 'none'}
15
+
16
+ def resolve_ffn_act_fn(config: Optional[dict]=None) -> Callable[[torch.Tensor], torch.Tensor]:
17
+ """Resolve the activation function for the feed-forward network.
18
+ Args:
19
+ config (Optional[dict]): The configuration dictionary for the activation function.
20
+ The dict config must specify the 'name' of a torch.nn.functional activation
21
+ function. All of other key values pairs are bound to the function as a partial.
22
+ Returns:
23
+ Callable[[torch.Tensor], torch.Tensor]: The activation function.
24
+ """
25
+ if config is None:
26
+ config = _FFN_ACT_FN_DEFAULT
27
+ config = deepcopy(config)
28
+ name = config.pop('name')
29
+ if not hasattr(torch.nn.functional, name):
30
+ raise ValueError(f'Unrecognised activation function name ({name}).')
31
+ act = getattr(torch.nn.functional, name)
32
+ return partial(act, **config)
33
+ _DEFAULT_ACT_FN = resolve_ffn_act_fn(_FFN_ACT_FN_DEFAULT)
34
+
35
+ def resolve_ffn_hidden_size(d_model: int, expansion_ratio: Union[int, float], ffn_hidden_size: Optional[int]=None) -> int:
36
+ """Resolve the hidden size of the feed-forward network.
37
+ Args:
38
+ d_model (int): The dimension of the input and output of the feed-forward network.
39
+ expansion_ratio (Union[int, float]): The expansion ratio of the feed-forward network.
40
+ ffn_hidden_size (Optional[int]): The hidden size of the feed-forward network.
41
+ Returns:
42
+ int: The hidden size of the feed-forward network.
43
+ """
44
+ if ffn_hidden_size is not None:
45
+ log.info(f'`expansion_ratio` (={expansion_ratio}) ignored when `ffn_hidden_size` (={ffn_hidden_size}) is specified.')
46
+ else:
47
+ ffn_hidden_size = int(d_model * expansion_ratio)
48
+ if ffn_hidden_size != d_model * expansion_ratio:
49
+ raise ValueError(f'`d_model * expansion_ratio` must be an integer (d_model={d_model!r}; expansion_ratio={expansion_ratio!r}; d_model * expansion_ratio={d_model * expansion_ratio!r}).')
50
+ return ffn_hidden_size
51
 
52
  class MPTMLP(nn.Module):
53
 
54
+ def __init__(self, d_model: int, expansion_ratio: Union[int, float], fc_type: str='torch', ffn_hidden_size: Optional[int]=None, act_fn: Callable[[torch.Tensor], torch.Tensor]=_DEFAULT_ACT_FN, device: Optional[str]=None, bias: bool=True):
55
  super().__init__()
56
+ ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, ffn_hidden_size)
57
+ self.fc_kwargs: dict[str, Any] = {'bias': bias}
58
  if fc_type != 'te':
59
+ self.fc_kwargs['device'] = device
60
+ self.up_proj = FC_CLASS_REGISTRY[fc_type](d_model, ffn_hidden_size, **self.fc_kwargs)
61
+ self.act = act_fn
62
+ self.down_proj = FC_CLASS_REGISTRY[fc_type](ffn_hidden_size, d_model, **self.fc_kwargs)
63
  self.down_proj._is_residual = True
64
 
65
  def forward(self, x: torch.Tensor) -> torch.Tensor:
66
  return self.down_proj(self.act(self.up_proj(x)))
67
+
68
+ class MPTGLU(MPTMLP):
69
+
70
+ def __init__(self, d_model: int, expansion_ratio: Union[int, float], fc_type: str='torch', ffn_hidden_size: Optional[int]=None, act_fn: Callable[[torch.Tensor], torch.Tensor]=_DEFAULT_ACT_FN, device: Optional[str]=None, bias: bool=True):
71
+ super().__init__(d_model=d_model, expansion_ratio=expansion_ratio, fc_type=fc_type, ffn_hidden_size=ffn_hidden_size, act_fn=act_fn, device=device, bias=bias)
72
+ self.gate_proj = FC_CLASS_REGISTRY[fc_type](d_model, self.up_proj.out_features, **self.fc_kwargs)
73
+
74
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
75
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
76
+ FFN_CLASS_REGISTRY = {'mptmlp': MPTMLP, 'mptglu': MPTGLU}
77
  if te is not None:
78
  te.LayerNormMLP._has_norm = True
79
  FFN_CLASS_REGISTRY['te_ln_mlp'] = te.LayerNormMLP
80
 
81
+ def build_ffn(d_model: int, expansion_ratio: Union[int, float], fc_type: str='torch', ffn_hidden_size: Optional[int]=None, ffn_act_fn: Optional[dict]=None, device: Optional[str]=None, bias: bool=True, **kwargs: Any) -> nn.Module:
82
  ffn_type = kwargs.pop('ffn_type')
83
+ if ffn_type in ['mptmlp', 'mptglu']:
84
  if len(kwargs) > 0:
85
+ raise ValueError(f'MPTMLP (or MPTGLU) got an unexpected keyword argument: {kwargs}')
86
+ return FFN_CLASS_REGISTRY[ffn_type](d_model=d_model, expansion_ratio=expansion_ratio, fc_type=fc_type, act_fn=resolve_ffn_act_fn(ffn_act_fn), ffn_hidden_size=ffn_hidden_size, device=device, bias=bias)
87
  elif ffn_type == 'te_ln_mlp':
88
  assert te is not None
89
+ ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, ffn_hidden_size)
90
+ if ffn_act_fn is not None:
91
+ raise ValueError(f'Transformer Engine block does not support custom activation functions.')
92
+ return te.LayerNormMLP(hidden_size=d_model, ffn_hidden_size=ffn_hidden_size, bias=bias, **kwargs)
93
  raise ValueError(f'ffn_type={ffn_type!r} not recognized.')