Spaces:
Build error
Build error
import torch | |
class Conv1d(torch.nn.Conv1d): | |
def __init__(self, w_init_gain= 'linear', *args, **kwargs): | |
self.w_init_gain = w_init_gain | |
super().__init__(*args, **kwargs) | |
def reset_parameters(self): | |
if self.w_init_gain in ['zero']: | |
torch.nn.init.zeros_(self.weight) | |
elif self.w_init_gain is None: | |
pass | |
elif self.w_init_gain in ['relu', 'leaky_relu']: | |
torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain) | |
elif self.w_init_gain == 'glu': | |
assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' | |
torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear') | |
torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) | |
elif self.w_init_gain == 'gate': | |
assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' | |
torch.nn.init.xavier_uniform_(self.weight[:self.out_channels // 2], gain= torch.nn.init.calculate_gain('tanh')) | |
torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) | |
else: | |
torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain)) | |
if not self.bias is None: | |
torch.nn.init.zeros_(self.bias) | |
class ConvTranspose1d(torch.nn.ConvTranspose1d): | |
def __init__(self, w_init_gain= 'linear', *args, **kwargs): | |
self.w_init_gain = w_init_gain | |
super().__init__(*args, **kwargs) | |
def reset_parameters(self): | |
if self.w_init_gain in ['zero']: | |
torch.nn.init.zeros_(self.weight) | |
elif self.w_init_gain in ['relu', 'leaky_relu']: | |
torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain) | |
elif self.w_init_gain == 'glu': | |
assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' | |
torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear') | |
torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) | |
elif self.w_init_gain == 'gate': | |
assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' | |
torch.nn.init.xavier_uniform_(self.weight[:self.out_channels // 2], gain= torch.nn.init.calculate_gain('tanh')) | |
torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) | |
else: | |
torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain)) | |
if not self.bias is None: | |
torch.nn.init.zeros_(self.bias) | |
class Conv2d(torch.nn.Conv2d): | |
def __init__(self, w_init_gain= 'linear', *args, **kwargs): | |
self.w_init_gain = w_init_gain | |
super().__init__(*args, **kwargs) | |
def reset_parameters(self): | |
if self.w_init_gain in ['zero']: | |
torch.nn.init.zeros_(self.weight) | |
elif self.w_init_gain in ['relu', 'leaky_relu']: | |
torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain) | |
elif self.w_init_gain == 'glu': | |
assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' | |
torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear') | |
torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) | |
elif self.w_init_gain == 'gate': | |
assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' | |
torch.nn.init.xavier_uniform_(self.weight[:self.out_channels // 2], gain= torch.nn.init.calculate_gain('tanh')) | |
torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) | |
else: | |
torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain)) | |
if not self.bias is None: | |
torch.nn.init.zeros_(self.bias) | |
class ConvTranspose2d(torch.nn.ConvTranspose2d): | |
def __init__(self, w_init_gain= 'linear', *args, **kwargs): | |
self.w_init_gain = w_init_gain | |
super().__init__(*args, **kwargs) | |
def reset_parameters(self): | |
if self.w_init_gain in ['zero']: | |
torch.nn.init.zeros_(self.weight) | |
elif self.w_init_gain in ['relu', 'leaky_relu']: | |
torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain) | |
elif self.w_init_gain == 'glu': | |
assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' | |
torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear') | |
torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) | |
elif self.w_init_gain == 'gate': | |
assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' | |
torch.nn.init.xavier_uniform_(self.weight[:self.out_channels // 2], gain= torch.nn.init.calculate_gain('tanh')) | |
torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) | |
else: | |
torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain)) | |
if not self.bias is None: | |
torch.nn.init.zeros_(self.bias) | |
class Linear(torch.nn.Linear): | |
def __init__(self, w_init_gain= 'linear', *args, **kwargs): | |
self.w_init_gain = w_init_gain | |
super().__init__(*args, **kwargs) | |
def reset_parameters(self): | |
if self.w_init_gain in ['zero']: | |
torch.nn.init.zeros_(self.weight) | |
elif self.w_init_gain in ['relu', 'leaky_relu']: | |
torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain) | |
elif self.w_init_gain == 'glu': | |
assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' | |
torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear') | |
torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) | |
else: | |
torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain)) | |
if not self.bias is None: | |
torch.nn.init.zeros_(self.bias) | |
class Lambda(torch.nn.Module): | |
def __init__(self, lambd): | |
super().__init__() | |
self.lambd = lambd | |
def forward(self, x): | |
return self.lambd(x) | |
class Residual(torch.nn.Module): | |
def __init__(self, module): | |
super().__init__() | |
self.module = module | |
def forward(self, *args, **kwargs): | |
return self.module(*args, **kwargs) | |
class LayerNorm(torch.nn.Module): | |
def __init__(self, num_features: int, eps: float= 1e-5): | |
super().__init__() | |
self.eps = eps | |
self.gamma = torch.nn.Parameter(torch.ones(num_features)) | |
self.beta = torch.nn.Parameter(torch.zeros(num_features)) | |
def forward(self, inputs: torch.Tensor): | |
means = inputs.mean(dim= 1, keepdim= True) | |
variances = (inputs - means).pow(2.0).mean(dim= 1, keepdim= True) | |
x = (inputs - means) * (variances + self.eps).rsqrt() | |
shape = [1, -1] + [1] * (x.ndim - 2) | |
return x * self.gamma.view(*shape) + self.beta.view(*shape) | |
class LightweightConv1d(torch.nn.Module): | |
''' | |
Args: | |
input_size: # of channels of the input and output | |
kernel_size: convolution channels | |
padding: padding | |
num_heads: number of heads used. The weight is of shape | |
`(num_heads, 1, kernel_size)` | |
weight_softmax: normalize the weight with softmax before the convolution | |
Shape: | |
Input: BxCxT, i.e. (batch_size, input_size, timesteps) | |
Output: BxCxT, i.e. (batch_size, input_size, timesteps) | |
Attributes: | |
weight: the learnable weights of the module of shape | |
`(num_heads, 1, kernel_size)` | |
bias: the learnable bias of the module of shape `(input_size)` | |
''' | |
def __init__( | |
self, | |
input_size, | |
kernel_size=1, | |
padding=0, | |
num_heads=1, | |
weight_softmax=False, | |
bias=False, | |
weight_dropout=0.0, | |
w_init_gain= 'linear' | |
): | |
super().__init__() | |
self.input_size = input_size | |
self.kernel_size = kernel_size | |
self.num_heads = num_heads | |
self.padding = padding | |
self.weight_softmax = weight_softmax | |
self.weight = torch.nn.Parameter(torch.Tensor(num_heads, 1, kernel_size)) | |
self.w_init_gain = w_init_gain | |
if bias: | |
self.bias = torch.nn.Parameter(torch.Tensor(input_size)) | |
else: | |
self.bias = None | |
self.weight_dropout_module = FairseqDropout( | |
weight_dropout, module_name=self.__class__.__name__ | |
) | |
self.reset_parameters() | |
def reset_parameters(self): | |
if self.w_init_gain in ['relu', 'leaky_relu']: | |
torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain) | |
elif self.w_init_gain == 'glu': | |
assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' | |
torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear') | |
torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) | |
else: | |
torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain)) | |
if not self.bias is None: | |
torch.nn.init.zeros_(self.bias) | |
def forward(self, input): | |
""" | |
input size: B x C x T | |
output size: B x C x T | |
""" | |
B, C, T = input.size() | |
H = self.num_heads | |
weight = self.weight | |
if self.weight_softmax: | |
weight = weight.softmax(dim=-1) | |
weight = self.weight_dropout_module(weight) | |
# Merge every C/H entries into the batch dimension (C = self.input_size) | |
# B x C x T -> (B * C/H) x H x T | |
# One can also expand the weight to C x 1 x K by a factor of C/H | |
# and do not reshape the input instead, which is slow though | |
input = input.view(-1, H, T) | |
output = torch.nn.functional.conv1d(input, weight, padding=self.padding, groups=self.num_heads) | |
output = output.view(B, C, T) | |
if self.bias is not None: | |
output = output + self.bias.view(1, -1, 1) | |
return output | |
class FairseqDropout(torch.nn.Module): | |
def __init__(self, p, module_name=None): | |
super().__init__() | |
self.p = p | |
self.module_name = module_name | |
self.apply_during_inference = False | |
def forward(self, x, inplace: bool = False): | |
if self.training or self.apply_during_inference: | |
return torch.nn.functional.dropout(x, p=self.p, training=True, inplace=inplace) | |
else: | |
return x | |
class LinearAttention(torch.nn.Module): | |
def __init__( | |
self, | |
channels: int, | |
calc_channels: int, | |
num_heads: int, | |
dropout_rate: float= 0.1, | |
use_scale: bool= True, | |
use_residual: bool= True, | |
use_norm: bool= True | |
): | |
super().__init__() | |
assert calc_channels % num_heads == 0 | |
self.calc_channels = calc_channels | |
self.num_heads = num_heads | |
self.use_scale = use_scale | |
self.use_residual = use_residual | |
self.use_norm = use_norm | |
self.prenet = Conv1d( | |
in_channels= channels, | |
out_channels= calc_channels * 3, | |
kernel_size= 1, | |
bias=False, | |
w_init_gain= 'linear' | |
) | |
self.projection = Conv1d( | |
in_channels= calc_channels, | |
out_channels= channels, | |
kernel_size= 1, | |
w_init_gain= 'linear' | |
) | |
self.dropout = torch.nn.Dropout(p= dropout_rate) | |
if use_scale: | |
self.scale = torch.nn.Parameter(torch.zeros(1)) | |
if use_norm: | |
self.norm = LayerNorm(num_features= channels) | |
def forward(self, x: torch.Tensor, *args, **kwargs): | |
''' | |
x: [Batch, Enc_d, Enc_t] | |
''' | |
residuals = x | |
x = self.prenet(x) # [Batch, Calc_d * 3, Enc_t] | |
x = x.view(x.size(0), self.num_heads, x.size(1) // self.num_heads, x.size(2)) # [Batch, Head, Calc_d // Head * 3, Enc_t] | |
queries, keys, values = x.chunk(chunks= 3, dim= 2) # [Batch, Head, Calc_d // Head, Enc_t] * 3 | |
keys = (keys + 1e-5).softmax(dim= 3) | |
contexts = keys @ values.permute(0, 1, 3, 2) # [Batch, Head, Calc_d // Head, Calc_d // Head] | |
contexts = contexts.permute(0, 1, 3, 2) @ queries # [Batch, Head, Calc_d // Head, Enc_t] | |
contexts = contexts.view(contexts.size(0), contexts.size(1) * contexts.size(2), contexts.size(3)) # [Batch, Calc_d, Enc_t] | |
contexts = self.projection(contexts) # [Batch, Enc_d, Enc_t] | |
if self.use_scale: | |
contexts = self.scale * contexts | |
contexts = self.dropout(contexts) | |
if self.use_residual: | |
contexts = contexts + residuals | |
if self.use_norm: | |
contexts = self.norm(contexts) | |
return contexts | |