kaizen9's picture
Upload model checkpoints directly from S3
958d6f8 verified
raw
history blame contribute delete
No virus
2.06 kB
from typing import Any, Optional, Union
import torch
from .layers_registry import attention_classes, fcs, ffns, ffns_with_megablocks, ffns_with_norm, norms
from .registry_utils import construct_from_registry
def build_norm(name: str, normalized_shape: Union[int, list[int], torch.Size], eps: Optional[float]=1e-05, device: Optional[str]=None):
kwargs = {'normalized_shape': normalized_shape, 'eps': eps, 'device': device}
return construct_from_registry(name=name, registry=norms, pre_validation_function=torch.nn.Module, kwargs=kwargs)
def build_ffn(name: str, d_model: int, expansion_ratio: float, device: Optional[str], bias: bool, ffn_kwargs: dict[str, Any]):
registry_to_use = ffns
if name in ffns_with_norm:
registry_to_use = ffns_with_norm
if name in ffns_with_megablocks:
registry_to_use = ffns_with_megablocks
kwargs = {'d_model': d_model, 'expansion_ratio': expansion_ratio, 'device': device, 'bias': bias, **{k: v for k, v in ffn_kwargs.items() if k != 'ffn_type'}}
def _validation_function(maybe_module: Any):
if not isinstance(maybe_module, torch.nn.Module):
raise ValueError(f'Function {name} must return a torch.nn.Module.')
result = construct_from_registry(name=name, registry=registry_to_use, post_validation_function=_validation_function, partial_function=False, kwargs=kwargs)
if name in ffns_with_norm:
result._has_norm = True
if name in ffns_with_megablocks:
result._uses_megablocks = True
return result
def build_attention_layer(name: str, attn_kwargs: dict[str, Any]):
return construct_from_registry(name=name, registry=attention_classes, pre_validation_function=torch.nn.Module, kwargs=attn_kwargs)
def build_fc(name: str, in_features: int, out_features: int, fc_kwargs: dict[str, Any]):
kwargs = {'in_features': in_features, 'out_features': out_features, **{k: v for k, v in fc_kwargs.items() if k != 'name'}}
return construct_from_registry(name=name, registry=fcs, pre_validation_function=torch.nn.Module, kwargs=kwargs)