from typing import Optional, Callable, Union from torch.nn import Module from tha3.module.module_factory import ModuleFactory from tha3.nn.init_function import create_init_function from tha3.nn.nonlinearity_factory import resolve_nonlinearity_factory from tha3.nn.normalization import NormalizationLayerFactory from tha3.nn.spectral_norm import apply_spectral_norm def wrap_conv_or_linear_module(module: Module, initialization_method: Union[str, Callable[[Module], Module]], use_spectral_norm: bool): if isinstance(initialization_method, str): init = create_init_function(initialization_method) else: init = initialization_method return apply_spectral_norm(init(module), use_spectral_norm) class BlockArgs: def __init__(self, initialization_method: Union[str, Callable[[Module], Module]] = 'he', use_spectral_norm: bool = False, normalization_layer_factory: Optional[NormalizationLayerFactory] = None, nonlinearity_factory: Optional[ModuleFactory] = None): self.nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory) self.normalization_layer_factory = normalization_layer_factory self.use_spectral_norm = use_spectral_norm self.initialization_method = initialization_method def wrap_module(self, module: Module) -> Module: return wrap_conv_or_linear_module(module, self.get_init_func(), self.use_spectral_norm) def get_init_func(self) -> Callable[[Module], Module]: if isinstance(self.initialization_method, str): return create_init_function(self.initialization_method) else: return self.initialization_method