EXPOX commited on
Commit
3198d2c
1 Parent(s): 72bb841

Upload mdxnet.py

Browse files
Files changed (1) hide show
  1. lib_v5/mdxnet.py +140 -0
lib_v5/mdxnet.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from pytorch_lightning import LightningModule
6
+ from .modules import TFC_TDF
7
+
8
+ dim_s = 4
9
+
10
+ class AbstractMDXNet(LightningModule):
11
+ __metaclass__ = ABCMeta
12
+
13
+ def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap):
14
+ super().__init__()
15
+ self.target_name = target_name
16
+ self.lr = lr
17
+ self.optimizer = optimizer
18
+ self.dim_c = dim_c
19
+ self.dim_f = dim_f
20
+ self.dim_t = dim_t
21
+ self.n_fft = n_fft
22
+ self.n_bins = n_fft // 2 + 1
23
+ self.hop_length = hop_length
24
+ self.window = nn.Parameter(torch.hann_window(window_length=self.n_fft, periodic=True), requires_grad=False)
25
+ self.freq_pad = nn.Parameter(torch.zeros([1, dim_c, self.n_bins - self.dim_f, self.dim_t]), requires_grad=False)
26
+
27
+ def configure_optimizers(self):
28
+ if self.optimizer == 'rmsprop':
29
+ return torch.optim.RMSprop(self.parameters(), self.lr)
30
+
31
+ if self.optimizer == 'adamw':
32
+ return torch.optim.AdamW(self.parameters(), self.lr)
33
+
34
+ class ConvTDFNet(AbstractMDXNet):
35
+ def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length,
36
+ num_blocks, l, g, k, bn, bias, overlap):
37
+
38
+ super(ConvTDFNet, self).__init__(
39
+ target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap)
40
+ self.save_hyperparameters()
41
+
42
+ self.num_blocks = num_blocks
43
+ self.l = l
44
+ self.g = g
45
+ self.k = k
46
+ self.bn = bn
47
+ self.bias = bias
48
+
49
+ if optimizer == 'rmsprop':
50
+ norm = nn.BatchNorm2d
51
+
52
+ if optimizer == 'adamw':
53
+ norm = lambda input:nn.GroupNorm(2, input)
54
+
55
+ self.n = num_blocks // 2
56
+ scale = (2, 2)
57
+
58
+ self.first_conv = nn.Sequential(
59
+ nn.Conv2d(in_channels=self.dim_c, out_channels=g, kernel_size=(1, 1)),
60
+ norm(g),
61
+ nn.ReLU(),
62
+ )
63
+
64
+ f = self.dim_f
65
+ c = g
66
+ self.encoding_blocks = nn.ModuleList()
67
+ self.ds = nn.ModuleList()
68
+ for i in range(self.n):
69
+ self.encoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm))
70
+ self.ds.append(
71
+ nn.Sequential(
72
+ nn.Conv2d(in_channels=c, out_channels=c + g, kernel_size=scale, stride=scale),
73
+ norm(c + g),
74
+ nn.ReLU()
75
+ )
76
+ )
77
+ f = f // 2
78
+ c += g
79
+
80
+ self.bottleneck_block = TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm)
81
+
82
+ self.decoding_blocks = nn.ModuleList()
83
+ self.us = nn.ModuleList()
84
+ for i in range(self.n):
85
+ self.us.append(
86
+ nn.Sequential(
87
+ nn.ConvTranspose2d(in_channels=c, out_channels=c - g, kernel_size=scale, stride=scale),
88
+ norm(c - g),
89
+ nn.ReLU()
90
+ )
91
+ )
92
+ f = f * 2
93
+ c -= g
94
+
95
+ self.decoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm))
96
+
97
+ self.final_conv = nn.Sequential(
98
+ nn.Conv2d(in_channels=c, out_channels=self.dim_c, kernel_size=(1, 1)),
99
+ )
100
+
101
+ def forward(self, x):
102
+
103
+ x = self.first_conv(x)
104
+
105
+ x = x.transpose(-1, -2)
106
+
107
+ ds_outputs = []
108
+ for i in range(self.n):
109
+ x = self.encoding_blocks[i](x)
110
+ ds_outputs.append(x)
111
+ x = self.ds[i](x)
112
+
113
+ x = self.bottleneck_block(x)
114
+
115
+ for i in range(self.n):
116
+ x = self.us[i](x)
117
+ x *= ds_outputs[-i - 1]
118
+ x = self.decoding_blocks[i](x)
119
+
120
+ x = x.transpose(-1, -2)
121
+
122
+ x = self.final_conv(x)
123
+
124
+ return x
125
+
126
+ class Mixer(nn.Module):
127
+ def __init__(self, device, mixer_path):
128
+
129
+ super(Mixer, self).__init__()
130
+
131
+ self.linear = nn.Linear((dim_s+1)*2, dim_s*2, bias=False)
132
+
133
+ self.load_state_dict(
134
+ torch.load(mixer_path, map_location=device)
135
+ )
136
+
137
+ def forward(self, x):
138
+ x = x.reshape(1,(dim_s+1)*2,-1).transpose(-1,-2)
139
+ x = self.linear(x)
140
+ return x.transpose(-1,-2).reshape(dim_s,2,-1)