Spaces:
Runtime error
Runtime error
from typing import Optional | |
from torch.nn import Module, ReLU, LeakyReLU, ELU, ReLU6, Hardswish, SiLU, Tanh, Sigmoid | |
from tha3.module.module_factory import ModuleFactory | |
class ReLUFactory(ModuleFactory): | |
def __init__(self, inplace: bool = False): | |
self.inplace = inplace | |
def create(self) -> Module: | |
return ReLU(self.inplace) | |
class LeakyReLUFactory(ModuleFactory): | |
def __init__(self, inplace: bool = False, negative_slope: float = 1e-2): | |
self.negative_slope = negative_slope | |
self.inplace = inplace | |
def create(self) -> Module: | |
return LeakyReLU(inplace=self.inplace, negative_slope=self.negative_slope) | |
class ELUFactory(ModuleFactory): | |
def __init__(self, inplace: bool = False, alpha: float = 1.0): | |
self.alpha = alpha | |
self.inplace = inplace | |
def create(self) -> Module: | |
return ELU(inplace=self.inplace, alpha=self.alpha) | |
class ReLU6Factory(ModuleFactory): | |
def __init__(self, inplace: bool = False): | |
self.inplace = inplace | |
def create(self) -> Module: | |
return ReLU6(inplace=self.inplace) | |
class SiLUFactory(ModuleFactory): | |
def __init__(self, inplace: bool = False): | |
self.inplace = inplace | |
def create(self) -> Module: | |
return SiLU(inplace=self.inplace) | |
class HardswishFactory(ModuleFactory): | |
def __init__(self, inplace: bool = False): | |
self.inplace = inplace | |
def create(self) -> Module: | |
return Hardswish(inplace=self.inplace) | |
class TanhFactory(ModuleFactory): | |
def create(self) -> Module: | |
return Tanh() | |
class SigmoidFactory(ModuleFactory): | |
def create(self) -> Module: | |
return Sigmoid() | |
def resolve_nonlinearity_factory(nonlinearity_fatory: Optional[ModuleFactory]) -> ModuleFactory: | |
if nonlinearity_fatory is None: | |
return ReLUFactory(inplace=False) | |
else: | |
return nonlinearity_fatory | |