File size: 252 Bytes
60094bd
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
from torch.nn import Module
from torch.nn.utils import spectral_norm


def apply_spectral_norm(module: Module, use_spectrial_norm: bool = False) -> Module:
    if use_spectrial_norm:
        return spectral_norm(module)
    else:
        return module