File size: 1,203 Bytes
0a3525d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
from torch import nn
from torch.nn.utils.parametrizations import weight_norm


class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        blocks = []
        convs = [
            (1, 64, (3, 9), 1, (1, 4)),
            (64, 128, (3, 9), (1, 2), (1, 4)),
            (128, 256, (3, 9), (1, 2), (1, 4)),
            (256, 512, (3, 9), (1, 2), (1, 4)),
            (512, 1024, (3, 3), 1, (1, 1)),
            (1024, 1, (3, 3), 1, (1, 1)),
        ]

        for idx, (in_channels, out_channels, kernel_size, stride, padding) in enumerate(
            convs
        ):
            blocks.append(
                weight_norm(
                    nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
                )
            )

            if idx != len(convs) - 1:
                blocks.append(nn.SiLU(inplace=True))

        self.blocks = nn.Sequential(*blocks)

    def forward(self, x):
        return self.blocks(x[:, None])[:, 0]


if __name__ == "__main__":
    model = Discriminator()
    print(sum(p.numel() for p in model.parameters()) / 1_000_000)
    x = torch.randn(1, 128, 1024)
    y = model(x)
    print(y.shape)
    print(y)