cchaun commited on
Commit
0d6426a
1 Parent(s): f25ccdb

add project files

Browse files
.gitattributes CHANGED
@@ -29,3 +29,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ *.wav filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ venv
2
+ __pycache__
3
+ flagged
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ import gradio as gr
3
+ import torch, torchaudio
4
+ from timeit import default_timer as timer
5
+ from torchaudio.transforms import Resample
6
+ from models.model import HarmonicCNN
7
+
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ SAMPLE_RATE = 16000
11
+ AUDIO_LEN = 2.90
12
+
13
+ model = HarmonicCNN()
14
+ S = torch.load('models/best_model.pth')
15
+ model.load_state_dict(S)
16
+
17
+ LABELS = [
18
+ "alternative",
19
+ "ambient",
20
+ "atmospheric",
21
+ "chillout",
22
+ "classical",
23
+ "dance",
24
+ "downtempo",
25
+ "easylistening",
26
+ "electronic",
27
+ "experimental",
28
+ "folk",
29
+ "funk",
30
+ "hiphop",
31
+ "house",
32
+ "indie",
33
+ "instrumentalpop",
34
+ "jazz",
35
+ "lounge",
36
+ "metal",
37
+ "newage",
38
+ "orchestral",
39
+ "pop",
40
+ "popfolk",
41
+ "poprock",
42
+ "reggae",
43
+ "rock",
44
+ "soundtrack",
45
+ "techno",
46
+ "trance",
47
+ "triphop",
48
+ "world",
49
+ "acousticguitar",
50
+ "bass",
51
+ "computer",
52
+ "drummachine",
53
+ "drums",
54
+ "electricguitar",
55
+ "electricpiano",
56
+ "guitar",
57
+ "keyboard",
58
+ "piano",
59
+ "strings",
60
+ "synthesizer",
61
+ "violin",
62
+ "voice",
63
+ "emotional",
64
+ "energetic",
65
+ "film",
66
+ "happy",
67
+ "relaxing"
68
+ ]
69
+
70
+ example_list = [
71
+ "samples/guitar_acoustic.wav",
72
+ "samples/guitar_electric.wav",
73
+ "samples/piano.wav",
74
+ "samples/violin.wav",
75
+ "samples/flute.wav"
76
+ ]
77
+
78
+ def predict(audio_path):
79
+ start_time = timer()
80
+ wav, sample_rate = torchaudio.load(audio_path)
81
+ if sample_rate > SAMPLE_RATE:
82
+ resampler = Resample(sample_rate, SAMPLE_RATE)
83
+ wav = resampler(wav)
84
+ if wav.shape[0] >= 2:
85
+ wav = torch.mean(wav, dim=0)
86
+ wav = wav.unsqueeze(0)
87
+ model.eval()
88
+ with torch.inference_mode():
89
+ pred_probs = model(wav)
90
+ pred_labels_and_probs = {LABELS[i]: float(pred_probs[0][i]) for i in range(len(LABELS))}
91
+ pred_time = round(timer() - start_time, 5)
92
+ return pred_labels_and_probs, pred_time
93
+
94
+
95
+ title = "Music Tagging"
96
+
97
+ demo = gr.Interface(fn=predict,
98
+ inputs=gr.Audio(type="filepath"),
99
+ outputs=[gr.Label(num_top_classes=10, label="Predictions"),
100
+ gr.Number(label="Prediction time (s)")],
101
+ examples=example_list,
102
+ title=title)
103
+
104
+ demo.launch(debug=False)
models/attention_modules.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Code adopted from https://github.com/huggingface/pytorch-pretrained-BERT
3
+
4
+ import math
5
+ import copy
6
+ import torch
7
+ import torch.nn as nn
8
+ import numpy as np
9
+
10
+ # Gelu
11
+ def gelu(x):
12
+ """Implementation of the gelu activation function.
13
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
14
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
15
+ Also see https://arxiv.org/abs/1606.08415
16
+ """
17
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
18
+
19
+ # LayerNorm
20
+ try:
21
+ from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
22
+ except ImportError:
23
+ #print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
24
+ class BertLayerNorm(nn.Module):
25
+ def __init__(self, hidden_size, eps=1e-12):
26
+ """Construct a layernorm module in the TF style (epsilon inside the square root).
27
+ """
28
+ super(BertLayerNorm, self).__init__()
29
+ self.weight = nn.Parameter(torch.ones(hidden_size))
30
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
31
+ self.variance_epsilon = eps
32
+
33
+ def forward(self, x):
34
+ u = x.mean(-1, keepdim=True)
35
+ s = (x - u).pow(2).mean(-1, keepdim=True)
36
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
37
+ return self.weight * x + self.bias
38
+
39
+
40
+ class BertConfig(object):
41
+ def __init__(self,
42
+ vocab_size,
43
+ hidden_size=768,
44
+ num_hidden_layers=12,
45
+ num_attention_heads=12,
46
+ intermediate_size=3072,
47
+ hidden_act="gelu",
48
+ hidden_dropout_prob=0.1,
49
+ max_position_embeddings=512,
50
+ attention_probs_dropout_prob=0.1,
51
+ type_vocab_size=2):
52
+ self.vocab_size = vocab_size
53
+ self.hidden_size = hidden_size
54
+ self.num_hidden_layers = num_hidden_layers
55
+ self.num_attention_heads = num_attention_heads
56
+ self.hidden_act = hidden_act
57
+ self.intermediate_size = intermediate_size
58
+ self.hidden_dropout_prob = hidden_dropout_prob
59
+ self.max_position_embeddings = max_position_embeddings
60
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
61
+ self.type_vocab_size = type_vocab_size
62
+
63
+
64
+ class BertSelfAttention(nn.Module):
65
+ def __init__(self, config):
66
+ super(BertSelfAttention, self).__init__()
67
+ if config.hidden_size % config.num_attention_heads != 0:
68
+ raise ValueError(
69
+ "The hidden size (%d) is not a multiple of the number of attention "
70
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
71
+ self.num_attention_heads = config.num_attention_heads
72
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
73
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
74
+
75
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
76
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
77
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
78
+
79
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
80
+
81
+ def transpose_for_scores(self, x):
82
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
83
+ x = x.view(*new_x_shape)
84
+ return x.permute(0, 2, 1, 3)
85
+
86
+ def forward(self, hidden_states, attention_mask):
87
+ mixed_query_layer = self.query(hidden_states)
88
+ mixed_key_layer = self.key(hidden_states)
89
+ mixed_value_layer = self.value(hidden_states)
90
+
91
+ query_layer = self.transpose_for_scores(mixed_query_layer)
92
+ key_layer = self.transpose_for_scores(mixed_key_layer)
93
+ value_layer = self.transpose_for_scores(mixed_value_layer)
94
+
95
+ # Take the dot product between "query" and "key" to get the raw attention scores.
96
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
97
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
98
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
99
+ if attention_mask is not None:
100
+ attention_scores = attention_scores + attention_mask
101
+
102
+ # Normalize the attention scores to probabilities.
103
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
104
+
105
+ # This is actually dropping out entire tokens to attend to, which might
106
+ # seem a bit unusual, but is taken from the original Transformer paper.
107
+ attention_probs = self.dropout(attention_probs)
108
+
109
+ context_layer = torch.matmul(attention_probs, value_layer)
110
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
111
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
112
+ context_layer = context_layer.view(*new_context_layer_shape)
113
+ return context_layer
114
+
115
+
116
+ class BertSelfOutput(nn.Module):
117
+ def __init__(self, config):
118
+ super(BertSelfOutput, self).__init__()
119
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
120
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
121
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
122
+
123
+ def forward(self, hidden_states, input_tensor):
124
+ hidden_states = self.dense(hidden_states)
125
+ hidden_states = self.dropout(hidden_states)
126
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
127
+ return hidden_states
128
+
129
+
130
+ class BertAttention(nn.Module):
131
+ def __init__(self, config):
132
+ super(BertAttention, self).__init__()
133
+ self.self = BertSelfAttention(config)
134
+ self.output = BertSelfOutput(config)
135
+
136
+ def forward(self, input_tensor, attention_mask):
137
+ self_output = self.self(input_tensor, attention_mask)
138
+ attention_output = self.output(self_output, input_tensor)
139
+ return attention_output
140
+
141
+
142
+ class BertIntermediate(nn.Module):
143
+ def __init__(self, config):
144
+ super(BertIntermediate, self).__init__()
145
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
146
+ self.intermediate_act_fn = gelu
147
+
148
+ def forward(self, hidden_states):
149
+ hidden_states = self.dense(hidden_states)
150
+ hidden_states = self.intermediate_act_fn(hidden_states)
151
+ return hidden_states
152
+
153
+
154
+ class BertOutput(nn.Module):
155
+ def __init__(self, config):
156
+ super(BertOutput, self).__init__()
157
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
158
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
159
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
160
+
161
+ def forward(self, hidden_states, input_tensor):
162
+ hidden_states = self.dense(hidden_states)
163
+ hidden_states = self.dropout(hidden_states)
164
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
165
+ return hidden_states
166
+
167
+
168
+ class BertLayer(nn.Module):
169
+ def __init__(self, config):
170
+ super(BertLayer, self).__init__()
171
+ self.attention = BertAttention(config)
172
+ self.intermediate = BertIntermediate(config)
173
+ self.output = BertOutput(config)
174
+
175
+ def forward(self, hidden_states, attention_mask):
176
+ attention_output = self.attention(hidden_states, attention_mask)
177
+ intermediate_output = self.intermediate(attention_output)
178
+ layer_output = self.output(intermediate_output, attention_output)
179
+ return layer_output
180
+
181
+
182
+ class BertEncoder(nn.Module):
183
+ def __init__(self, config):
184
+ super(BertEncoder, self).__init__()
185
+ layer = BertLayer(config)
186
+ self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
187
+
188
+ def forward(self, hidden_states, attention_mask=None, output_all_encoded_layers=True):
189
+ all_encoder_layers = []
190
+ for layer_module in self.layer:
191
+ hidden_states = layer_module(hidden_states, attention_mask)
192
+ if output_all_encoded_layers:
193
+ all_encoder_layers.append(hidden_states)
194
+ if not output_all_encoded_layers:
195
+ all_encoder_layers.append(hidden_states)
196
+ return all_encoder_layers
197
+
198
+
199
+ class BertEmbeddings(nn.Module):
200
+ """Construct the embeddings from word, position and token_type embeddings.
201
+ """
202
+ def __init__(self, config):
203
+ super(BertEmbeddings, self).__init__()
204
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
205
+
206
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
207
+ # any TensorFlow checkpoint file
208
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
209
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
210
+
211
+ def forward(self, input_ids, token_type_ids=None):
212
+ seq_length = input_ids.size(1)
213
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
214
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids[:, :, 0])
215
+
216
+ position_embeddings = self.position_embeddings(position_ids)
217
+
218
+ embeddings = input_ids + position_embeddings
219
+ #embeddings = input_ids
220
+ embeddings = self.LayerNorm(embeddings)
221
+ embeddings = self.dropout(embeddings)
222
+ return embeddings
223
+
224
+
225
+ class PositionalEncoding(nn.Module):
226
+ def __init__(self, config):
227
+ super(PositionalEncoding, self).__init__()
228
+ emb_dim = config.hidden_size
229
+ max_len = config.max_position_embeddings
230
+ self.position_enc = self.position_encoding_init(max_len, emb_dim)
231
+
232
+ @staticmethod
233
+ def position_encoding_init(n_position, emb_dim):
234
+ ''' Init the sinusoid position encoding table '''
235
+
236
+ # keep dim 0 for padding token position encoding zero vector
237
+ position_enc = np.array([
238
+ [pos / np.power(10000, 2 * (j // 2) / emb_dim) for j in range(emb_dim)]
239
+ if pos != 0 else np.zeros(emb_dim) for pos in range(n_position)])
240
+
241
+ position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # apply sin on 0th,2nd,4th...emb_dim
242
+ position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # apply cos on 1st,3rd,5th...emb_dim
243
+ return torch.from_numpy(position_enc).type(torch.FloatTensor)
244
+
245
+ def forward(self, word_seq):
246
+ position_encoding = self.position_enc.unsqueeze(0).expand_as(word_seq)
247
+ position_encoding = position_encoding.to(word_seq.device)
248
+ word_pos_encoded = word_seq + position_encoding
249
+ return word_pos_encoded
250
+
251
+ class BertPooler(nn.Module):
252
+ def __init__(self, config):
253
+ super(BertPooler, self).__init__()
254
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
255
+ self.activation = nn.Tanh()
256
+
257
+ def forward(self, hidden_states):
258
+ # We "pool" the model by simply taking the hidden state corresponding
259
+ # to the first token.
260
+ first_token_tensor = hidden_states[:, 0]
261
+ pooled_output = self.dense(first_token_tensor)
262
+ pooled_output = self.activation(pooled_output)
263
+ return pooled_output
models/best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0920da2535e92791f5123a59216a3daa0b7c7e9a21873827551a597ba11648a7
3
+ size 14563900
models/model.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.autograd import Variable
7
+ import torchaudio
8
+
9
+ from models.modules import Conv_1d, ResSE_1d, Conv_2d, Res_2d, Conv_V, Conv_H, HarmonicSTFT, Res_2d_mp
10
+ from models.attention_modules import BertConfig, BertEncoder, BertEmbeddings, BertPooler, PositionalEncoding
11
+
12
+
13
+ class FCN(nn.Module):
14
+ '''
15
+ Choi et al. 2016
16
+ Automatic tagging using deep convolutional neural networks.
17
+ Fully convolutional network.
18
+ '''
19
+ def __init__(self,
20
+ sample_rate=16000,
21
+ n_fft=512,
22
+ f_min=0.0,
23
+ f_max=8000.0,
24
+ n_mels=96,
25
+ n_class=50):
26
+ super(FCN, self).__init__()
27
+
28
+ # Spectrogram
29
+ self.spec = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate,
30
+ n_fft=n_fft,
31
+ f_min=f_min,
32
+ f_max=f_max,
33
+ n_mels=n_mels)
34
+ self.to_db = torchaudio.transforms.AmplitudeToDB()
35
+ self.spec_bn = nn.BatchNorm2d(1)
36
+
37
+ # FCN
38
+ self.layer1 = Conv_2d(1, 64, pooling=(2,4))
39
+ self.layer2 = Conv_2d(64, 128, pooling=(2,4))
40
+ self.layer3 = Conv_2d(128, 128, pooling=(2,4))
41
+ self.layer4 = Conv_2d(128, 128, pooling=(3,5))
42
+ self.layer5 = Conv_2d(128, 64, pooling=(4,4))
43
+
44
+ # Dense
45
+ self.dense = nn.Linear(64, n_class)
46
+ self.dropout = nn.Dropout(0.5)
47
+
48
+ def forward(self, x):
49
+ # Spectrogram
50
+ x = self.spec(x)
51
+ x = self.to_db(x)
52
+ x = x.unsqueeze(1)
53
+ x = self.spec_bn(x)
54
+
55
+ # FCN
56
+ x = self.layer1(x)
57
+ x = self.layer2(x)
58
+ x = self.layer3(x)
59
+ x = self.layer4(x)
60
+ x = self.layer5(x)
61
+
62
+ # Dense
63
+ x = x.view(x.size(0), -1)
64
+ x = self.dropout(x)
65
+ x = self.dense(x)
66
+ x = nn.Sigmoid()(x)
67
+
68
+ return x
69
+
70
+
71
+ class Musicnn(nn.Module):
72
+ '''
73
+ Pons et al. 2017
74
+ End-to-end learning for music audio tagging at scale.
75
+ This is the updated implementation of the original paper. Referred to the Musicnn code.
76
+ https://github.com/jordipons/musicnn
77
+ '''
78
+ def __init__(self,
79
+ sample_rate=16000,
80
+ n_fft=512,
81
+ f_min=0.0,
82
+ f_max=8000.0,
83
+ n_mels=96,
84
+ n_class=50,
85
+ dataset='mtat'):
86
+ super(Musicnn, self).__init__()
87
+
88
+ # Spectrogram
89
+ self.spec = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate,
90
+ n_fft=n_fft,
91
+ f_min=f_min,
92
+ f_max=f_max,
93
+ n_mels=n_mels)
94
+ self.to_db = torchaudio.transforms.AmplitudeToDB()
95
+ self.spec_bn = nn.BatchNorm2d(1)
96
+
97
+ # Pons front-end
98
+ m1 = Conv_V(1, 204, (int(0.7*96), 7))
99
+ m2 = Conv_V(1, 204, (int(0.4*96), 7))
100
+ m3 = Conv_H(1, 51, 129)
101
+ m4 = Conv_H(1, 51, 65)
102
+ m5 = Conv_H(1, 51, 33)
103
+ self.layers = nn.ModuleList([m1, m2, m3, m4, m5])
104
+
105
+ # Pons back-end
106
+ backend_channel= 512 if dataset=='msd' else 64
107
+ self.layer1 = Conv_1d(561, backend_channel, 7, 1, 1)
108
+ self.layer2 = Conv_1d(backend_channel, backend_channel, 7, 1, 1)
109
+ self.layer3 = Conv_1d(backend_channel, backend_channel, 7, 1, 1)
110
+
111
+ # Dense
112
+ dense_channel = 500 if dataset=='msd' else 200
113
+ self.dense1 = nn.Linear((561+(backend_channel*3))*2, dense_channel)
114
+ self.bn = nn.BatchNorm1d(dense_channel)
115
+ self.relu = nn.ReLU()
116
+ self.dropout = nn.Dropout(0.5)
117
+ self.dense2 = nn.Linear(dense_channel, n_class)
118
+
119
+ def forward(self, x):
120
+ # Spectrogram
121
+ x = self.spec(x)
122
+ x = self.to_db(x)
123
+ x = x.unsqueeze(1)
124
+ x = self.spec_bn(x)
125
+
126
+ # Pons front-end
127
+ out = []
128
+ for layer in self.layers:
129
+ out.append(layer(x))
130
+ out = torch.cat(out, dim=1)
131
+
132
+ # Pons back-end
133
+ length = out.size(2)
134
+ res1 = self.layer1(out)
135
+ res2 = self.layer2(res1) + res1
136
+ res3 = self.layer3(res2) + res2
137
+ out = torch.cat([out, res1, res2, res3], 1)
138
+
139
+ mp = nn.MaxPool1d(length)(out)
140
+ avgp = nn.AvgPool1d(length)(out)
141
+
142
+ out = torch.cat([mp, avgp], dim=1)
143
+ out = out.squeeze(2)
144
+
145
+ out = self.relu(self.bn(self.dense1(out)))
146
+ out = self.dropout(out)
147
+ out = self.dense2(out)
148
+ out = nn.Sigmoid()(out)
149
+
150
+ return out
151
+
152
+
153
+ class CRNN(nn.Module):
154
+ '''
155
+ Choi et al. 2017
156
+ Convolution recurrent neural networks for music classification.
157
+ Feature extraction with CNN + temporal summary with RNN
158
+ '''
159
+ def __init__(self,
160
+ sample_rate=16000,
161
+ n_fft=512,
162
+ f_min=0.0,
163
+ f_max=8000.0,
164
+ n_mels=96,
165
+ n_class=50):
166
+ super(CRNN, self).__init__()
167
+
168
+ # Spectrogram
169
+ self.spec = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate,
170
+ n_fft=n_fft,
171
+ f_min=f_min,
172
+ f_max=f_max,
173
+ n_mels=n_mels)
174
+ self.to_db = torchaudio.transforms.AmplitudeToDB()
175
+ self.spec_bn = nn.BatchNorm2d(1)
176
+
177
+ # CNN
178
+ self.layer1 = Conv_2d(1, 64, pooling=(2,2))
179
+ self.layer2 = Conv_2d(64, 128, pooling=(3,3))
180
+ self.layer3 = Conv_2d(128, 128, pooling=(4,4))
181
+ self.layer4 = Conv_2d(128, 128, pooling=(4,4))
182
+
183
+ # RNN
184
+ self.layer5 = nn.GRU(128, 32, 2, batch_first=True)
185
+
186
+ # Dense
187
+ self.dropout = nn.Dropout(0.5)
188
+ self.dense = nn.Linear(32, 50)
189
+
190
+ def forward(self, x):
191
+ # Spectrogram
192
+ x = self.spec(x)
193
+ x = self.to_db(x)
194
+ x = x.unsqueeze(1)
195
+ x = self.spec_bn(x)
196
+
197
+ # CCN
198
+ x = self.layer1(x)
199
+ x = self.layer2(x)
200
+ x = self.layer3(x)
201
+ x = self.layer4(x)
202
+
203
+ # RNN
204
+ x = x.squeeze(2)
205
+ x = x.permute(0, 2, 1)
206
+ x, _ = self.layer5(x)
207
+ x = x[:, -1, :]
208
+
209
+ # Dense
210
+ x = self.dropout(x)
211
+ x = self.dense(x)
212
+ x = nn.Sigmoid()(x)
213
+
214
+ return x
215
+
216
+
217
+ class SampleCNN(nn.Module):
218
+ '''
219
+ Lee et al. 2017
220
+ Sample-level deep convolutional neural networks for music auto-tagging using raw waveforms.
221
+ Sample-level CNN.
222
+ '''
223
+ def __init__(self,
224
+ n_class=50):
225
+ super(SampleCNN, self).__init__()
226
+ self.layer1 = Conv_1d(1, 128, shape=3, stride=3, pooling=1)
227
+ self.layer2 = Conv_1d(128, 128, shape=3, stride=1, pooling=3)
228
+ self.layer3 = Conv_1d(128, 128, shape=3, stride=1, pooling=3)
229
+ self.layer4 = Conv_1d(128, 256, shape=3, stride=1, pooling=3)
230
+ self.layer5 = Conv_1d(256, 256, shape=3, stride=1, pooling=3)
231
+ self.layer6 = Conv_1d(256, 256, shape=3, stride=1, pooling=3)
232
+ self.layer7 = Conv_1d(256, 256, shape=3, stride=1, pooling=3)
233
+ self.layer8 = Conv_1d(256, 256, shape=3, stride=1, pooling=3)
234
+ self.layer9 = Conv_1d(256, 256, shape=3, stride=1, pooling=3)
235
+ self.layer10 = Conv_1d(256, 512, shape=3, stride=1, pooling=3)
236
+ self.layer11 = Conv_1d(512, 512, shape=1, stride=1, pooling=1)
237
+ self.dropout = nn.Dropout(0.5)
238
+ self.dense = nn.Linear(512, n_class)
239
+
240
+ def forward(self, x):
241
+ x = x.unsqueeze(1)
242
+ x = self.layer1(x)
243
+ x = self.layer2(x)
244
+ x = self.layer3(x)
245
+ x = self.layer4(x)
246
+ x = self.layer5(x)
247
+ x = self.layer6(x)
248
+ x = self.layer7(x)
249
+ x = self.layer8(x)
250
+ x = self.layer9(x)
251
+ x = self.layer10(x)
252
+ x = self.layer11(x)
253
+ x = x.squeeze(-1)
254
+ x = self.dropout(x)
255
+ x = self.dense(x)
256
+ x = nn.Sigmoid()(x)
257
+ return x
258
+
259
+
260
+ class SampleCNNSE(nn.Module):
261
+ '''
262
+ Kim et al. 2018
263
+ Sample-level CNN architectures for music auto-tagging using raw waveforms.
264
+ Sample-level CNN + residual connections + squeeze & excitation.
265
+ '''
266
+ def __init__(self,
267
+ n_class=50):
268
+ super(SampleCNNSE, self).__init__()
269
+ self.layer1 = ResSE_1d(1, 128, shape=3, stride=3, pooling=1)
270
+ self.layer2 = ResSE_1d(128, 128, shape=3, stride=1, pooling=3)
271
+ self.layer3 = ResSE_1d(128, 128, shape=3, stride=1, pooling=3)
272
+ self.layer4 = ResSE_1d(128, 256, shape=3, stride=1, pooling=3)
273
+ self.layer5 = ResSE_1d(256, 256, shape=3, stride=1, pooling=3)
274
+ self.layer6 = ResSE_1d(256, 256, shape=3, stride=1, pooling=3)
275
+ self.layer7 = ResSE_1d(256, 256, shape=3, stride=1, pooling=3)
276
+ self.layer8 = ResSE_1d(256, 256, shape=3, stride=1, pooling=3)
277
+ self.layer9 = ResSE_1d(256, 256, shape=3, stride=1, pooling=3)
278
+ self.layer10 = ResSE_1d(256, 512, shape=3, stride=1, pooling=3)
279
+ self.layer11 = ResSE_1d(512, 512, shape=1, stride=1, pooling=1)
280
+ self.dropout = nn.Dropout(0.5)
281
+ self.dense1 = nn.Linear(512, 512)
282
+ self.bn = nn.BatchNorm1d(512)
283
+ self.dense2 = nn.Linear(512, n_class)
284
+
285
+ def forward(self, x):
286
+ x = x.unsqueeze(1)
287
+ x = self.layer1(x)
288
+ x = self.layer2(x)
289
+ x = self.layer3(x)
290
+ x = self.layer4(x)
291
+ x = self.layer5(x)
292
+ x = self.layer6(x)
293
+ x = self.layer7(x)
294
+ x = self.layer8(x)
295
+ x = self.layer9(x)
296
+ x = self.layer10(x)
297
+ x = self.layer11(x)
298
+ x = x.squeeze(-1)
299
+ x = nn.ReLU()(self.bn(self.dense1(x)))
300
+ x = self.dropout(x)
301
+ x = self.dense2(x)
302
+ x = nn.Sigmoid()(x)
303
+ return x
304
+
305
+
306
+ class ShortChunkCNN(nn.Module):
307
+ '''
308
+ Short-chunk CNN architecture.
309
+ So-called vgg-ish model with a small receptive field.
310
+ Deeper layers, smaller pooling (2x2).
311
+ '''
312
+ def __init__(self,
313
+ n_channels=128,
314
+ sample_rate=16000,
315
+ n_fft=512,
316
+ f_min=0.0,
317
+ f_max=8000.0,
318
+ n_mels=128,
319
+ n_class=50):
320
+ super(ShortChunkCNN, self).__init__()
321
+
322
+ # Spectrogram
323
+ self.spec = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate,
324
+ n_fft=n_fft,
325
+ f_min=f_min,
326
+ f_max=f_max,
327
+ n_mels=n_mels)
328
+ self.to_db = torchaudio.transforms.AmplitudeToDB()
329
+ self.spec_bn = nn.BatchNorm2d(1)
330
+
331
+ # CNN
332
+ self.layer1 = Conv_2d(1, n_channels, pooling=2)
333
+ self.layer2 = Conv_2d(n_channels, n_channels, pooling=2)
334
+ self.layer3 = Conv_2d(n_channels, n_channels*2, pooling=2)
335
+ self.layer4 = Conv_2d(n_channels*2, n_channels*2, pooling=2)
336
+ self.layer5 = Conv_2d(n_channels*2, n_channels*2, pooling=2)
337
+ self.layer6 = Conv_2d(n_channels*2, n_channels*2, pooling=2)
338
+ self.layer7 = Conv_2d(n_channels*2, n_channels*4, pooling=2)
339
+
340
+ # Dense
341
+ self.dense1 = nn.Linear(n_channels*4, n_channels*4)
342
+ self.bn = nn.BatchNorm1d(n_channels*4)
343
+ self.dense2 = nn.Linear(n_channels*4, n_class)
344
+ self.dropout = nn.Dropout(0.5)
345
+ self.relu = nn.ReLU()
346
+
347
+ def forward(self, x):
348
+ # Spectrogram
349
+ x = self.spec(x)
350
+ x = self.to_db(x)
351
+ x = x.unsqueeze(1)
352
+ x = self.spec_bn(x)
353
+
354
+ # CNN
355
+ x = self.layer1(x)
356
+ x = self.layer2(x)
357
+ x = self.layer3(x)
358
+ x = self.layer4(x)
359
+ x = self.layer5(x)
360
+ x = self.layer6(x)
361
+ x = self.layer7(x)
362
+ x = x.squeeze(2)
363
+
364
+ # Global Max Pooling
365
+ if x.size(-1) != 1:
366
+ x = nn.MaxPool1d(x.size(-1))(x)
367
+ x = x.squeeze(2)
368
+
369
+ # Dense
370
+ x = self.dense1(x)
371
+ x = self.bn(x)
372
+ x = self.relu(x)
373
+ x = self.dropout(x)
374
+ x = self.dense2(x)
375
+ x = nn.Sigmoid()(x)
376
+
377
+ return x
378
+
379
+
380
+ class ShortChunkCNN_Res(nn.Module):
381
+ '''
382
+ Short-chunk CNN architecture with residual connections.
383
+ '''
384
+ def __init__(self,
385
+ n_channels=128,
386
+ sample_rate=16000,
387
+ n_fft=512,
388
+ f_min=0.0,
389
+ f_max=8000.0,
390
+ n_mels=128,
391
+ n_class=50):
392
+ super(ShortChunkCNN_Res, self).__init__()
393
+
394
+ # Spectrogram
395
+ self.spec = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate,
396
+ n_fft=n_fft,
397
+ f_min=f_min,
398
+ f_max=f_max,
399
+ n_mels=n_mels)
400
+ self.to_db = torchaudio.transforms.AmplitudeToDB()
401
+ self.spec_bn = nn.BatchNorm2d(1)
402
+
403
+ # CNN
404
+ self.layer1 = Res_2d(1, n_channels, stride=2)
405
+ self.layer2 = Res_2d(n_channels, n_channels, stride=2)
406
+ self.layer3 = Res_2d(n_channels, n_channels*2, stride=2)
407
+ self.layer4 = Res_2d(n_channels*2, n_channels*2, stride=2)
408
+ self.layer5 = Res_2d(n_channels*2, n_channels*2, stride=2)
409
+ self.layer6 = Res_2d(n_channels*2, n_channels*2, stride=2)
410
+ self.layer7 = Res_2d(n_channels*2, n_channels*4, stride=2)
411
+
412
+ # Dense
413
+ self.dense1 = nn.Linear(n_channels*4, n_channels*4)
414
+ self.bn = nn.BatchNorm1d(n_channels*4)
415
+ self.dense2 = nn.Linear(n_channels*4, n_class)
416
+ self.dropout = nn.Dropout(0.5)
417
+ self.relu = nn.ReLU()
418
+
419
+ def forward(self, x):
420
+ # Spectrogram
421
+ x = self.spec(x)
422
+ x = self.to_db(x)
423
+ x = x.unsqueeze(1)
424
+ x = self.spec_bn(x)
425
+
426
+ # CNN
427
+ x = self.layer1(x)
428
+ x = self.layer2(x)
429
+ x = self.layer3(x)
430
+ x = self.layer4(x)
431
+ x = self.layer5(x)
432
+ x = self.layer6(x)
433
+ x = self.layer7(x)
434
+ x = x.squeeze(2)
435
+
436
+ # Global Max Pooling
437
+ if x.size(-1) != 1:
438
+ x = nn.MaxPool1d(x.size(-1))(x)
439
+ x = x.squeeze(2)
440
+
441
+ # Dense
442
+ x = self.dense1(x)
443
+ x = self.bn(x)
444
+ x = self.relu(x)
445
+ x = self.dropout(x)
446
+ x = self.dense2(x)
447
+ x = nn.Sigmoid()(x)
448
+
449
+ return x
450
+
451
+
452
+ class CNNSA(nn.Module):
453
+ '''
454
+ Won et al. 2019
455
+ Toward interpretable music tagging with self-attention.
456
+ Feature extraction with CNN + temporal summary with Transformer encoder.
457
+ '''
458
+ def __init__(self,
459
+ n_channels=128,
460
+ sample_rate=16000,
461
+ n_fft=512,
462
+ f_min=0.0,
463
+ f_max=8000.0,
464
+ n_mels=128,
465
+ n_class=50):
466
+ super(CNNSA, self).__init__()
467
+
468
+ # Spectrogram
469
+ self.spec = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate,
470
+ n_fft=n_fft,
471
+ f_min=f_min,
472
+ f_max=f_max,
473
+ n_mels=n_mels)
474
+ self.to_db = torchaudio.transforms.AmplitudeToDB()
475
+ self.spec_bn = nn.BatchNorm2d(1)
476
+
477
+ # CNN
478
+ self.layer1 = Res_2d(1, n_channels, stride=2)
479
+ self.layer2 = Res_2d(n_channels, n_channels, stride=2)
480
+ self.layer3 = Res_2d(n_channels, n_channels*2, stride=2)
481
+ self.layer4 = Res_2d(n_channels*2, n_channels*2, stride=(2, 1))
482
+ self.layer5 = Res_2d(n_channels*2, n_channels*2, stride=(2, 1))
483
+ self.layer6 = Res_2d(n_channels*2, n_channels*2, stride=(2, 1))
484
+ self.layer7 = Res_2d(n_channels*2, n_channels*2, stride=(2, 1))
485
+
486
+ # Transformer encoder
487
+ bert_config = BertConfig(vocab_size=256,
488
+ hidden_size=256,
489
+ num_hidden_layers=2,
490
+ num_attention_heads=8,
491
+ intermediate_size=1024,
492
+ hidden_act="gelu",
493
+ hidden_dropout_prob=0.4,
494
+ max_position_embeddings=700,
495
+ attention_probs_dropout_prob=0.5)
496
+ self.encoder = BertEncoder(bert_config)
497
+ self.pooler = BertPooler(bert_config)
498
+ self.vec_cls = self.get_cls(256)
499
+
500
+ # Dense
501
+ self.dropout = nn.Dropout(0.5)
502
+ self.dense = nn.Linear(256, n_class)
503
+
504
+ def get_cls(self, channel):
505
+ np.random.seed(0)
506
+ single_cls = torch.Tensor(np.random.random((1, channel)))
507
+ vec_cls = torch.cat([single_cls for _ in range(64)], dim=0)
508
+ vec_cls = vec_cls.unsqueeze(1)
509
+ return vec_cls
510
+
511
+ def append_cls(self, x):
512
+ batch, _, _ = x.size()
513
+ part_vec_cls = self.vec_cls[:batch].clone()
514
+ part_vec_cls = part_vec_cls.to(x.device)
515
+ return torch.cat([part_vec_cls, x], dim=1)
516
+
517
+ def forward(self, x):
518
+ # Spectrogram
519
+ x = self.spec(x)
520
+ x = self.to_db(x)
521
+ x = x.unsqueeze(1)
522
+ x = self.spec_bn(x)
523
+
524
+ # CNN
525
+ x = self.layer1(x)
526
+ x = self.layer2(x)
527
+ x = self.layer3(x)
528
+ x = self.layer4(x)
529
+ x = self.layer5(x)
530
+ x = self.layer6(x)
531
+ x = self.layer7(x)
532
+ x = x.squeeze(2)
533
+
534
+ # Get [CLS] token
535
+ x = x.permute(0, 2, 1)
536
+ x = self.append_cls(x)
537
+
538
+ # Transformer encoder
539
+ x = self.encoder(x)
540
+ x = x[-1]
541
+ x = self.pooler(x)
542
+
543
+ # Dense
544
+ x = self.dropout(x)
545
+ x = self.dense(x)
546
+ x = nn.Sigmoid()(x)
547
+
548
+ return x
549
+
550
+
551
+ class HarmonicCNN(nn.Module):
552
+ '''
553
+ Won et al. 2020
554
+ Data-driven harmonic filters for audio representation learning.
555
+ Trainable harmonic band-pass filters, short-chunk CNN.
556
+ '''
557
+ def __init__(self,
558
+ n_channels=128,
559
+ sample_rate=16000,
560
+ n_fft=512,
561
+ f_min=0.0,
562
+ f_max=8000.0,
563
+ n_mels=128,
564
+ n_class=50,
565
+ n_harmonic=6,
566
+ semitone_scale=2,
567
+ learn_bw='only_Q'):
568
+ super(HarmonicCNN, self).__init__()
569
+
570
+ # Harmonic STFT
571
+ self.hstft = HarmonicSTFT(sample_rate=sample_rate,
572
+ n_fft=n_fft,
573
+ n_harmonic=n_harmonic,
574
+ semitone_scale=semitone_scale,
575
+ learn_bw=learn_bw)
576
+ self.hstft_bn = nn.BatchNorm2d(n_harmonic)
577
+
578
+ # CNN
579
+ self.layer1 = Conv_2d(n_harmonic, n_channels, pooling=2)
580
+ self.layer2 = Res_2d_mp(n_channels, n_channels, pooling=2)
581
+ self.layer3 = Res_2d_mp(n_channels, n_channels, pooling=2)
582
+ self.layer4 = Res_2d_mp(n_channels, n_channels, pooling=2)
583
+ self.layer5 = Conv_2d(n_channels, n_channels*2, pooling=2)
584
+ self.layer6 = Res_2d_mp(n_channels*2, n_channels*2, pooling=(2,3))
585
+ self.layer7 = Res_2d_mp(n_channels*2, n_channels*2, pooling=(2,3))
586
+
587
+ # Dense
588
+ self.dense1 = nn.Linear(n_channels*2, n_channels*2)
589
+ self.bn = nn.BatchNorm1d(n_channels*2)
590
+ self.dense2 = nn.Linear(n_channels*2, n_class)
591
+ self.dropout = nn.Dropout(0.5)
592
+ self.relu = nn.ReLU()
593
+
594
+ def forward(self, x):
595
+ # Spectrogram
596
+ x = self.hstft_bn(self.hstft(x))
597
+
598
+ # CNN
599
+ x = self.layer1(x)
600
+ x = self.layer2(x)
601
+ x = self.layer3(x)
602
+ x = self.layer4(x)
603
+ x = self.layer5(x)
604
+ x = self.layer6(x)
605
+ x = self.layer7(x)
606
+ x = x.squeeze(2)
607
+
608
+ # Global Max Pooling
609
+ if x.size(-1) != 1:
610
+ x = nn.MaxPool1d(x.size(-1))(x)
611
+ x = x.squeeze(2)
612
+
613
+ # Dense
614
+ x = self.dense1(x)
615
+ x = self.bn(x)
616
+ x = self.relu(x)
617
+ x = self.dropout(x)
618
+ x = self.dense2(x)
619
+ x = nn.Sigmoid()(x)
620
+
621
+ return x
622
+
models/modules.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torch.nn as nn
5
+ import torchaudio
6
+ import sys
7
+ from torch.autograd import Variable
8
+ import math
9
+ import librosa
10
+
11
+
12
+ class Conv_1d(nn.Module):
13
+ def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=2):
14
+ super(Conv_1d, self).__init__()
15
+ self.conv = nn.Conv1d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
16
+ self.bn = nn.BatchNorm1d(output_channels)
17
+ self.relu = nn.ReLU()
18
+ self.mp = nn.MaxPool1d(pooling)
19
+ def forward(self, x):
20
+ out = self.mp(self.relu(self.bn(self.conv(x))))
21
+ return out
22
+
23
+
24
+ class Conv_2d(nn.Module):
25
+ def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=2):
26
+ super(Conv_2d, self).__init__()
27
+ self.conv = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
28
+ self.bn = nn.BatchNorm2d(output_channels)
29
+ self.relu = nn.ReLU()
30
+ self.mp = nn.MaxPool2d(pooling)
31
+ def forward(self, x):
32
+ out = self.mp(self.relu(self.bn(self.conv(x))))
33
+ return out
34
+
35
+
36
+ class Res_2d(nn.Module):
37
+ def __init__(self, input_channels, output_channels, shape=3, stride=2):
38
+ super(Res_2d, self).__init__()
39
+ # convolution
40
+ self.conv_1 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
41
+ self.bn_1 = nn.BatchNorm2d(output_channels)
42
+ self.conv_2 = nn.Conv2d(output_channels, output_channels, shape, padding=shape//2)
43
+ self.bn_2 = nn.BatchNorm2d(output_channels)
44
+
45
+ # residual
46
+ self.diff = False
47
+ if (stride != 1) or (input_channels != output_channels):
48
+ self.conv_3 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
49
+ self.bn_3 = nn.BatchNorm2d(output_channels)
50
+ self.diff = True
51
+ self.relu = nn.ReLU()
52
+
53
+ def forward(self, x):
54
+ # convolution
55
+ out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x)))))
56
+
57
+ # residual
58
+ if self.diff:
59
+ x = self.bn_3(self.conv_3(x))
60
+ out = x + out
61
+ out = self.relu(out)
62
+ return out
63
+
64
+
65
+ class Res_2d_mp(nn.Module):
66
+ def __init__(self, input_channels, output_channels, pooling=2):
67
+ super(Res_2d_mp, self).__init__()
68
+ self.conv_1 = nn.Conv2d(input_channels, output_channels, 3, padding=1)
69
+ self.bn_1 = nn.BatchNorm2d(output_channels)
70
+ self.conv_2 = nn.Conv2d(output_channels, output_channels, 3, padding=1)
71
+ self.bn_2 = nn.BatchNorm2d(output_channels)
72
+ self.relu = nn.ReLU()
73
+ self.mp = nn.MaxPool2d(pooling)
74
+ def forward(self, x):
75
+ out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x)))))
76
+ out = x + out
77
+ out = self.mp(self.relu(out))
78
+ return out
79
+
80
+
81
+ class ResSE_1d(nn.Module):
82
+ def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=3):
83
+ super(ResSE_1d, self).__init__()
84
+ # convolution
85
+ self.conv_1 = nn.Conv1d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
86
+ self.bn_1 = nn.BatchNorm1d(output_channels)
87
+ self.conv_2 = nn.Conv1d(output_channels, output_channels, shape, padding=shape//2)
88
+ self.bn_2 = nn.BatchNorm1d(output_channels)
89
+
90
+ # squeeze & excitation
91
+ self.dense1 = nn.Linear(output_channels, output_channels)
92
+ self.dense2 = nn.Linear(output_channels, output_channels)
93
+
94
+ # residual
95
+ self.diff = False
96
+ if (stride != 1) or (input_channels != output_channels):
97
+ self.conv_3 = nn.Conv1d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
98
+ self.bn_3 = nn.BatchNorm1d(output_channels)
99
+ self.diff = True
100
+ self.relu = nn.ReLU()
101
+ self.sigmoid = nn.Sigmoid()
102
+ self.mp = nn.MaxPool1d(pooling)
103
+
104
+ def forward(self, x):
105
+ # convolution
106
+ out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x)))))
107
+
108
+ # squeeze & excitation
109
+ se_out = nn.AvgPool1d(out.size(-1))(out)
110
+ se_out = se_out.squeeze(-1)
111
+ se_out = self.relu(self.dense1(se_out))
112
+ se_out = self.sigmoid(self.dense2(se_out))
113
+ se_out = se_out.unsqueeze(-1)
114
+ out = torch.mul(out, se_out)
115
+
116
+ # residual
117
+ if self.diff:
118
+ x = self.bn_3(self.conv_3(x))
119
+ out = x + out
120
+ out = self.mp(self.relu(out))
121
+ return out
122
+
123
+
124
+ class Conv_V(nn.Module):
125
+ # vertical convolution
126
+ def __init__(self, input_channels, output_channels, filter_shape):
127
+ super(Conv_V, self).__init__()
128
+ self.conv = nn.Conv2d(input_channels, output_channels, filter_shape,
129
+ padding=(0, filter_shape[1]//2))
130
+ self.bn = nn.BatchNorm2d(output_channels)
131
+ self.relu = nn.ReLU()
132
+
133
+ def forward(self, x):
134
+ x = self.relu(self.bn(self.conv(x)))
135
+ freq = x.size(2)
136
+ out = nn.MaxPool2d((freq, 1), stride=(freq, 1))(x)
137
+ out = out.squeeze(2)
138
+ return out
139
+
140
+
141
+ class Conv_H(nn.Module):
142
+ # horizontal convolution
143
+ def __init__(self, input_channels, output_channels, filter_length):
144
+ super(Conv_H, self).__init__()
145
+ self.conv = nn.Conv1d(input_channels, output_channels, filter_length,
146
+ padding=filter_length//2)
147
+ self.bn = nn.BatchNorm1d(output_channels)
148
+ self.relu = nn.ReLU()
149
+
150
+ def forward(self, x):
151
+ freq = x.size(2)
152
+ out = nn.AvgPool2d((freq, 1), stride=(freq, 1))(x)
153
+ out = out.squeeze(2)
154
+ out = self.relu(self.bn(self.conv(out)))
155
+ return out
156
+
157
+
158
+ # Modules for harmonic filters
159
+ def hz_to_midi(hz):
160
+ return 12 * (torch.log2(hz) - np.log2(440.0)) + 69
161
+
162
+ def midi_to_hz(midi):
163
+ return 440.0 * (2.0 ** ((midi - 69.0)/12.0))
164
+
165
+ def note_to_midi(note):
166
+ return librosa.core.note_to_midi(note)
167
+
168
+ def hz_to_note(hz):
169
+ return librosa.core.hz_to_note(hz)
170
+
171
+ def initialize_filterbank(sample_rate, n_harmonic, semitone_scale):
172
+ # MIDI
173
+ # lowest note
174
+ low_midi = note_to_midi('C1')
175
+
176
+ # highest note
177
+ high_note = hz_to_note(sample_rate / (2 * n_harmonic))
178
+ high_midi = note_to_midi(high_note)
179
+
180
+ # number of scales
181
+ level = (high_midi - low_midi) * semitone_scale
182
+ midi = np.linspace(low_midi, high_midi, level + 1)
183
+ hz = midi_to_hz(midi[:-1])
184
+
185
+ # stack harmonics
186
+ harmonic_hz = []
187
+ for i in range(n_harmonic):
188
+ harmonic_hz = np.concatenate((harmonic_hz, hz * (i+1)))
189
+
190
+ return harmonic_hz, level
191
+
192
+
193
+ class HarmonicSTFT(nn.Module):
194
+ def __init__(self,
195
+ sample_rate=16000,
196
+ n_fft=513,
197
+ win_length=None,
198
+ hop_length=None,
199
+ pad=0,
200
+ power=2,
201
+ normalized=False,
202
+ n_harmonic=6,
203
+ semitone_scale=2,
204
+ bw_Q=1.0,
205
+ learn_bw=None):
206
+ super(HarmonicSTFT, self).__init__()
207
+
208
+ # Parameters
209
+ self.sample_rate = sample_rate
210
+ self.n_harmonic = n_harmonic
211
+ self.bw_alpha = 0.1079
212
+ self.bw_beta = 24.7
213
+
214
+ # Spectrogram
215
+ self.spec = torchaudio.transforms.Spectrogram(n_fft=n_fft, win_length=win_length,
216
+ hop_length=None, pad=0,
217
+ window_fn=torch.hann_window,
218
+ power=power, normalized=normalized, wkwargs=None)
219
+ self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
220
+
221
+ # Initialize the filterbank. Equally spaced in MIDI scale.
222
+ harmonic_hz, self.level = initialize_filterbank(sample_rate, n_harmonic, semitone_scale)
223
+
224
+ # Center frequncies to tensor
225
+ self.f0 = torch.tensor(harmonic_hz.astype('float32'))
226
+
227
+ # Bandwidth parameters
228
+ if learn_bw == 'only_Q':
229
+ self.bw_Q = nn.Parameter(torch.tensor(np.array([bw_Q]).astype('float32')))
230
+ elif learn_bw == 'fix':
231
+ self.bw_Q = torch.tensor(np.array([bw_Q]).astype('float32'))
232
+
233
+ def get_harmonic_fb(self):
234
+ # bandwidth
235
+ bw = (self.bw_alpha * self.f0 + self.bw_beta) / self.bw_Q
236
+ bw = bw.unsqueeze(0) # (1, n_band)
237
+ f0 = self.f0.unsqueeze(0) # (1, n_band)
238
+ fft_bins = self.fft_bins.unsqueeze(1) # (n_bins, 1)
239
+
240
+ up_slope = torch.matmul(fft_bins, (2/bw)) + 1 - (2 * f0 / bw)
241
+ down_slope = torch.matmul(fft_bins, (-2/bw)) + 1 + (2 * f0 / bw)
242
+ fb = torch.max(self.zero, torch.min(down_slope, up_slope))
243
+ return fb
244
+
245
+ def to_device(self, device, n_bins):
246
+ self.f0 = self.f0.to(device)
247
+ self.bw_Q = self.bw_Q.to(device)
248
+ # fft bins
249
+ self.fft_bins = torch.linspace(0, self.sample_rate//2, n_bins)
250
+ self.fft_bins = self.fft_bins.to(device)
251
+ self.zero = torch.zeros(1)
252
+ self.zero = self.zero.to(device)
253
+
254
+ def forward(self, waveform):
255
+ # stft
256
+ spectrogram = self.spec(waveform)
257
+
258
+ # to device
259
+ self.to_device(waveform.device, spectrogram.size(1))
260
+
261
+ # triangle filter
262
+ harmonic_fb = self.get_harmonic_fb()
263
+ harmonic_spec = torch.matmul(spectrogram.transpose(1, 2), harmonic_fb).transpose(1, 2)
264
+
265
+ # (batch, channel, length) -> (batch, harmonic, f0, length)
266
+ b, c, l = harmonic_spec.size()
267
+ harmonic_spec = harmonic_spec.view(b, self.n_harmonic, self.level, l)
268
+
269
+ # amplitude to db
270
+ harmonic_spec = self.amplitude_to_db(harmonic_spec)
271
+ return harmonic_spec
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==1.12.0
2
+ torchvision==0.13.0
3
+ torchaudio==0.12.0
4
+ gradio==3.1.4
5
+ librosa==0.9.2
samples/flute.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2aaa6c5640106826a4db1d7932f9edc3b0fbb0c68cbd4e7d7d544d2fdc28af17
3
+ size 3528044
samples/guitar_acoustic.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:450adb05b9b91dcc03b1262407b20c801769ccdca841e0f7860e5e3fe1a0a652
3
+ size 4301040
samples/guitar_electric.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60f854cc407877512a3e68a286cfd26e95dc2f0a4e76ba313fbb3e21ddf2d2f9
3
+ size 3492764
samples/piano.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01ba9d83ec1404ccad78a6310baba7d51583e42c20a07b7304e215a7edfe2d5e
3
+ size 4300764
samples/violin.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:690365b52ee8ca9f7b0147247270e375d70be31512c3ae591e52bf55605d3ece
3
+ size 19105034