Update ffn.py
Browse files
ffn.py
CHANGED
@@ -1,5 +1,8 @@
|
|
1 |
-
"""
|
2 |
-
|
|
|
|
|
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
from .fc import FC_CLASS_REGISTRY
|
@@ -7,33 +10,84 @@ try:
|
|
7 |
import transformer_engine.pytorch as te
|
8 |
except:
|
9 |
te = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
class MPTMLP(nn.Module):
|
12 |
|
13 |
-
def __init__(self, d_model: int, expansion_ratio: int, fc_type: str='torch', device: Optional[str]=None):
|
14 |
super().__init__()
|
15 |
-
|
|
|
16 |
if fc_type != 'te':
|
17 |
-
fc_kwargs['device'] = device
|
18 |
-
self.up_proj = FC_CLASS_REGISTRY[fc_type](d_model,
|
19 |
-
self.act =
|
20 |
-
self.down_proj = FC_CLASS_REGISTRY[fc_type](
|
21 |
self.down_proj._is_residual = True
|
22 |
|
23 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
24 |
return self.down_proj(self.act(self.up_proj(x)))
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
if te is not None:
|
27 |
te.LayerNormMLP._has_norm = True
|
28 |
FFN_CLASS_REGISTRY['te_ln_mlp'] = te.LayerNormMLP
|
29 |
|
30 |
-
def build_ffn(d_model: int, expansion_ratio: int, fc_type: str='torch', device: Optional[str]=None, **kwargs: Any) -> nn.Module:
|
31 |
ffn_type = kwargs.pop('ffn_type')
|
32 |
-
if ffn_type
|
33 |
if len(kwargs) > 0:
|
34 |
-
raise ValueError(f'MPTMLP got an unexpected keyword argument: {kwargs}')
|
35 |
-
return
|
36 |
elif ffn_type == 'te_ln_mlp':
|
37 |
assert te is not None
|
38 |
-
|
|
|
|
|
|
|
39 |
raise ValueError(f'ffn_type={ffn_type!r} not recognized.')
|
|
|
1 |
+
"""MPT Blocks used for the MPT Model."""
|
2 |
+
import logging
|
3 |
+
from copy import deepcopy
|
4 |
+
from functools import partial
|
5 |
+
from typing import Any, Callable, Optional, Union
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
from .fc import FC_CLASS_REGISTRY
|
|
|
10 |
import transformer_engine.pytorch as te
|
11 |
except:
|
12 |
te = None
|
13 |
+
log = logging.getLogger(__name__)
|
14 |
+
_FFN_ACT_FN_DEFAULT = {'name': 'gelu', 'approximate': 'none'}
|
15 |
+
|
16 |
+
def resolve_ffn_act_fn(config: Optional[dict]=None) -> Callable[[torch.Tensor], torch.Tensor]:
|
17 |
+
"""Resolve the activation function for the feed-forward network.
|
18 |
+
Args:
|
19 |
+
config (Optional[dict]): The configuration dictionary for the activation function.
|
20 |
+
The dict config must specify the 'name' of a torch.nn.functional activation
|
21 |
+
function. All of other key values pairs are bound to the function as a partial.
|
22 |
+
Returns:
|
23 |
+
Callable[[torch.Tensor], torch.Tensor]: The activation function.
|
24 |
+
"""
|
25 |
+
if config is None:
|
26 |
+
config = _FFN_ACT_FN_DEFAULT
|
27 |
+
config = deepcopy(config)
|
28 |
+
name = config.pop('name')
|
29 |
+
if not hasattr(torch.nn.functional, name):
|
30 |
+
raise ValueError(f'Unrecognised activation function name ({name}).')
|
31 |
+
act = getattr(torch.nn.functional, name)
|
32 |
+
return partial(act, **config)
|
33 |
+
_DEFAULT_ACT_FN = resolve_ffn_act_fn(_FFN_ACT_FN_DEFAULT)
|
34 |
+
|
35 |
+
def resolve_ffn_hidden_size(d_model: int, expansion_ratio: Union[int, float], ffn_hidden_size: Optional[int]=None) -> int:
|
36 |
+
"""Resolve the hidden size of the feed-forward network.
|
37 |
+
Args:
|
38 |
+
d_model (int): The dimension of the input and output of the feed-forward network.
|
39 |
+
expansion_ratio (Union[int, float]): The expansion ratio of the feed-forward network.
|
40 |
+
ffn_hidden_size (Optional[int]): The hidden size of the feed-forward network.
|
41 |
+
Returns:
|
42 |
+
int: The hidden size of the feed-forward network.
|
43 |
+
"""
|
44 |
+
if ffn_hidden_size is not None:
|
45 |
+
log.info(f'`expansion_ratio` (={expansion_ratio}) ignored when `ffn_hidden_size` (={ffn_hidden_size}) is specified.')
|
46 |
+
else:
|
47 |
+
ffn_hidden_size = int(d_model * expansion_ratio)
|
48 |
+
if ffn_hidden_size != d_model * expansion_ratio:
|
49 |
+
raise ValueError(f'`d_model * expansion_ratio` must be an integer (d_model={d_model!r}; expansion_ratio={expansion_ratio!r}; d_model * expansion_ratio={d_model * expansion_ratio!r}).')
|
50 |
+
return ffn_hidden_size
|
51 |
|
52 |
class MPTMLP(nn.Module):
|
53 |
|
54 |
+
def __init__(self, d_model: int, expansion_ratio: Union[int, float], fc_type: str='torch', ffn_hidden_size: Optional[int]=None, act_fn: Callable[[torch.Tensor], torch.Tensor]=_DEFAULT_ACT_FN, device: Optional[str]=None, bias: bool=True):
|
55 |
super().__init__()
|
56 |
+
ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, ffn_hidden_size)
|
57 |
+
self.fc_kwargs: dict[str, Any] = {'bias': bias}
|
58 |
if fc_type != 'te':
|
59 |
+
self.fc_kwargs['device'] = device
|
60 |
+
self.up_proj = FC_CLASS_REGISTRY[fc_type](d_model, ffn_hidden_size, **self.fc_kwargs)
|
61 |
+
self.act = act_fn
|
62 |
+
self.down_proj = FC_CLASS_REGISTRY[fc_type](ffn_hidden_size, d_model, **self.fc_kwargs)
|
63 |
self.down_proj._is_residual = True
|
64 |
|
65 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
66 |
return self.down_proj(self.act(self.up_proj(x)))
|
67 |
+
|
68 |
+
class MPTGLU(MPTMLP):
|
69 |
+
|
70 |
+
def __init__(self, d_model: int, expansion_ratio: Union[int, float], fc_type: str='torch', ffn_hidden_size: Optional[int]=None, act_fn: Callable[[torch.Tensor], torch.Tensor]=_DEFAULT_ACT_FN, device: Optional[str]=None, bias: bool=True):
|
71 |
+
super().__init__(d_model=d_model, expansion_ratio=expansion_ratio, fc_type=fc_type, ffn_hidden_size=ffn_hidden_size, act_fn=act_fn, device=device, bias=bias)
|
72 |
+
self.gate_proj = FC_CLASS_REGISTRY[fc_type](d_model, self.up_proj.out_features, **self.fc_kwargs)
|
73 |
+
|
74 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
75 |
+
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
|
76 |
+
FFN_CLASS_REGISTRY = {'mptmlp': MPTMLP, 'mptglu': MPTGLU}
|
77 |
if te is not None:
|
78 |
te.LayerNormMLP._has_norm = True
|
79 |
FFN_CLASS_REGISTRY['te_ln_mlp'] = te.LayerNormMLP
|
80 |
|
81 |
+
def build_ffn(d_model: int, expansion_ratio: Union[int, float], fc_type: str='torch', ffn_hidden_size: Optional[int]=None, ffn_act_fn: Optional[dict]=None, device: Optional[str]=None, bias: bool=True, **kwargs: Any) -> nn.Module:
|
82 |
ffn_type = kwargs.pop('ffn_type')
|
83 |
+
if ffn_type in ['mptmlp', 'mptglu']:
|
84 |
if len(kwargs) > 0:
|
85 |
+
raise ValueError(f'MPTMLP (or MPTGLU) got an unexpected keyword argument: {kwargs}')
|
86 |
+
return FFN_CLASS_REGISTRY[ffn_type](d_model=d_model, expansion_ratio=expansion_ratio, fc_type=fc_type, act_fn=resolve_ffn_act_fn(ffn_act_fn), ffn_hidden_size=ffn_hidden_size, device=device, bias=bias)
|
87 |
elif ffn_type == 'te_ln_mlp':
|
88 |
assert te is not None
|
89 |
+
ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, ffn_hidden_size)
|
90 |
+
if ffn_act_fn is not None:
|
91 |
+
raise ValueError(f'Transformer Engine block does not support custom activation functions.')
|
92 |
+
return te.LayerNormMLP(hidden_size=d_model, ffn_hidden_size=ffn_hidden_size, bias=bias, **kwargs)
|
93 |
raise ValueError(f'ffn_type={ffn_type!r} not recognized.')
|