from typing import Any, Optional, Union import torch from .layers_registry import attention_classes, fcs, ffns, ffns_with_megablocks, ffns_with_norm, norms from .registry_utils import construct_from_registry def build_norm(name: str, normalized_shape: Union[int, list[int], torch.Size], eps: Optional[float]=1e-05, device: Optional[str]=None): kwargs = {'normalized_shape': normalized_shape, 'eps': eps, 'device': device} return construct_from_registry(name=name, registry=norms, pre_validation_function=torch.nn.Module, kwargs=kwargs) def build_ffn(name: str, d_model: int, expansion_ratio: float, device: Optional[str], bias: bool, ffn_kwargs: dict[str, Any]): registry_to_use = ffns if name in ffns_with_norm: registry_to_use = ffns_with_norm if name in ffns_with_megablocks: registry_to_use = ffns_with_megablocks kwargs = {'d_model': d_model, 'expansion_ratio': expansion_ratio, 'device': device, 'bias': bias, **{k: v for k, v in ffn_kwargs.items() if k != 'ffn_type'}} def _validation_function(maybe_module: Any): if not isinstance(maybe_module, torch.nn.Module): raise ValueError(f'Function {name} must return a torch.nn.Module.') result = construct_from_registry(name=name, registry=registry_to_use, post_validation_function=_validation_function, partial_function=False, kwargs=kwargs) if name in ffns_with_norm: result._has_norm = True if name in ffns_with_megablocks: result._uses_megablocks = True return result def build_attention_layer(name: str, attn_kwargs: dict[str, Any]): return construct_from_registry(name=name, registry=attention_classes, pre_validation_function=torch.nn.Module, kwargs=attn_kwargs) def build_fc(name: str, in_features: int, out_features: int, fc_kwargs: dict[str, Any]): kwargs = {'in_features': in_features, 'out_features': out_features, **{k: v for k, v in fc_kwargs.items() if k != 'name'}} return construct_from_registry(name=name, registry=fcs, pre_validation_function=torch.nn.Module, kwargs=kwargs)