|
from typing import Any, Optional, Union |
|
import torch |
|
from .layers_registry import attention_classes, fcs, ffns, ffns_with_megablocks, ffns_with_norm, norms |
|
from .registry_utils import construct_from_registry |
|
|
|
def build_norm(name: str, normalized_shape: Union[int, list[int], torch.Size], eps: Optional[float]=1e-05, device: Optional[str]=None): |
|
kwargs = {'normalized_shape': normalized_shape, 'eps': eps, 'device': device} |
|
return construct_from_registry(name=name, registry=norms, pre_validation_function=torch.nn.Module, kwargs=kwargs) |
|
|
|
def build_ffn(name: str, d_model: int, expansion_ratio: float, device: Optional[str], bias: bool, ffn_kwargs: dict[str, Any]): |
|
registry_to_use = ffns |
|
if name in ffns_with_norm: |
|
registry_to_use = ffns_with_norm |
|
if name in ffns_with_megablocks: |
|
registry_to_use = ffns_with_megablocks |
|
kwargs = {'d_model': d_model, 'expansion_ratio': expansion_ratio, 'device': device, 'bias': bias, **{k: v for k, v in ffn_kwargs.items() if k != 'ffn_type'}} |
|
|
|
def _validation_function(maybe_module: Any): |
|
if not isinstance(maybe_module, torch.nn.Module): |
|
raise ValueError(f'Function {name} must return a torch.nn.Module.') |
|
result = construct_from_registry(name=name, registry=registry_to_use, post_validation_function=_validation_function, partial_function=False, kwargs=kwargs) |
|
if name in ffns_with_norm: |
|
result._has_norm = True |
|
if name in ffns_with_megablocks: |
|
result._uses_megablocks = True |
|
return result |
|
|
|
def build_attention_layer(name: str, attn_kwargs: dict[str, Any]): |
|
return construct_from_registry(name=name, registry=attention_classes, pre_validation_function=torch.nn.Module, kwargs=attn_kwargs) |
|
|
|
def build_fc(name: str, in_features: int, out_features: int, fc_kwargs: dict[str, Any]): |
|
kwargs = {'in_features': in_features, 'out_features': out_features, **{k: v for k, v in fc_kwargs.items() if k != 'name'}} |
|
return construct_from_registry(name=name, registry=fcs, pre_validation_function=torch.nn.Module, kwargs=kwargs) |