TheComputerMan commited on
Commit
f6fe944
1 Parent(s): d09d67d

Upload ResidualBlock.py

Browse files
Files changed (1) hide show
  1. ResidualBlock.py +98 -0
ResidualBlock.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ References:
5
+ - https://github.com/jik876/hifi-gan
6
+ - https://github.com/kan-bayashi/ParallelWaveGAN
7
+ """
8
+
9
+ import torch
10
+
11
+
12
+ class Conv1d(torch.nn.Conv1d):
13
+ """
14
+ Conv1d module with customized initialization.
15
+ """
16
+
17
+ def __init__(self, *args, **kwargs):
18
+ super(Conv1d, self).__init__(*args, **kwargs)
19
+
20
+ def reset_parameters(self):
21
+ torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
22
+ if self.bias is not None:
23
+ torch.nn.init.constant_(self.bias, 0.0)
24
+
25
+
26
+ class Conv1d1x1(Conv1d):
27
+ """
28
+ 1x1 Conv1d with customized initialization.
29
+ """
30
+
31
+ def __init__(self, in_channels, out_channels, bias):
32
+ super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias)
33
+
34
+
35
+ class HiFiGANResidualBlock(torch.nn.Module):
36
+ """Residual block module in HiFiGAN."""
37
+
38
+ def __init__(self,
39
+ kernel_size=3,
40
+ channels=512,
41
+ dilations=(1, 3, 5),
42
+ bias=True,
43
+ use_additional_convs=True,
44
+ nonlinear_activation="LeakyReLU",
45
+ nonlinear_activation_params={"negative_slope": 0.1}, ):
46
+ """
47
+ Initialize HiFiGANResidualBlock module.
48
+
49
+ Args:
50
+ kernel_size (int): Kernel size of dilation convolution layer.
51
+ channels (int): Number of channels for convolution layer.
52
+ dilations (List[int]): List of dilation factors.
53
+ use_additional_convs (bool): Whether to use additional convolution layers.
54
+ bias (bool): Whether to add bias parameter in convolution layers.
55
+ nonlinear_activation (str): Activation function module name.
56
+ nonlinear_activation_params (dict): Hyperparameters for activation function.
57
+ """
58
+ super().__init__()
59
+ self.use_additional_convs = use_additional_convs
60
+ self.convs1 = torch.nn.ModuleList()
61
+ if use_additional_convs:
62
+ self.convs2 = torch.nn.ModuleList()
63
+ assert kernel_size % 2 == 1, "Kernel size must be odd number."
64
+ for dilation in dilations:
65
+ self.convs1 += [torch.nn.Sequential(getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
66
+ torch.nn.Conv1d(channels,
67
+ channels,
68
+ kernel_size,
69
+ 1,
70
+ dilation=dilation,
71
+ bias=bias,
72
+ padding=(kernel_size - 1) // 2 * dilation, ), )]
73
+ if use_additional_convs:
74
+ self.convs2 += [torch.nn.Sequential(getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
75
+ torch.nn.Conv1d(channels,
76
+ channels,
77
+ kernel_size,
78
+ 1,
79
+ dilation=1,
80
+ bias=bias,
81
+ padding=(kernel_size - 1) // 2, ), )]
82
+
83
+ def forward(self, x):
84
+ """
85
+ Calculate forward propagation.
86
+
87
+ Args:
88
+ x (Tensor): Input tensor (B, channels, T).
89
+
90
+ Returns:
91
+ Tensor: Output tensor (B, channels, T).
92
+ """
93
+ for idx in range(len(self.convs1)):
94
+ xt = self.convs1[idx](x)
95
+ if self.use_additional_convs:
96
+ xt = self.convs2[idx](xt)
97
+ x = xt + x
98
+ return x