Dionyssos commited on
Commit
f7fd0c3
1 Parent(s): 4e4c64c
Modules/diffusion/diffusion.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+ from random import randint
3
+ from typing import Any, Optional, Sequence, Tuple, Union
4
+
5
+ import torch
6
+ from einops import rearrange
7
+ from torch import Tensor, nn
8
+ from tqdm import tqdm
9
+
10
+ from .utils import *
11
+ from .sampler import *
12
+
13
+ """
14
+ Diffusion Classes (generic for 1d data)
15
+ """
16
+
17
+
18
+ class Model1d(nn.Module):
19
+ def __init__(self, unet_type: str = "base", **kwargs):
20
+ super().__init__()
21
+ diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
22
+ self.unet = None
23
+ self.diffusion = None
24
+
25
+ def forward(self, x: Tensor, **kwargs) -> Tensor:
26
+ return self.diffusion(x, **kwargs)
27
+
28
+ def sample(self, *args, **kwargs) -> Tensor:
29
+ return self.diffusion.sample(*args, **kwargs)
30
+
31
+
32
+ """
33
+ Audio Diffusion Classes (specific for 1d audio data)
34
+ """
35
+
36
+
37
+ def get_default_model_kwargs():
38
+ return dict(
39
+ channels=128,
40
+ patch_size=16,
41
+ multipliers=[1, 2, 4, 4, 4, 4, 4],
42
+ factors=[4, 4, 4, 2, 2, 2],
43
+ num_blocks=[2, 2, 2, 2, 2, 2],
44
+ attentions=[0, 0, 0, 1, 1, 1, 1],
45
+ attention_heads=8,
46
+ attention_features=64,
47
+ attention_multiplier=2,
48
+ attention_use_rel_pos=False,
49
+ diffusion_type="v",
50
+ diffusion_sigma_distribution=UniformDistribution(),
51
+ )
52
+
53
+
54
+ def get_default_sampling_kwargs():
55
+ return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True)
56
+
57
+
58
+ class AudioDiffusionModel(Model1d):
59
+ def __init__(self, **kwargs):
60
+ super().__init__(**{**get_default_model_kwargs(), **kwargs})
61
+
62
+ def sample(self, *args, **kwargs):
63
+ return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs})
64
+
65
+
66
+ class AudioDiffusionConditional(Model1d):
67
+ def __init__(
68
+ self,
69
+ embedding_features: int,
70
+ embedding_max_length: int,
71
+ embedding_mask_proba: float = 0.1,
72
+ **kwargs,
73
+ ):
74
+ self.embedding_mask_proba = embedding_mask_proba
75
+ default_kwargs = dict(
76
+ **get_default_model_kwargs(),
77
+ unet_type="cfg",
78
+ context_embedding_features=embedding_features,
79
+ context_embedding_max_length=embedding_max_length,
80
+ )
81
+ super().__init__(**{**default_kwargs, **kwargs})
82
+
83
+ def forward(self, *args, **kwargs):
84
+ default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
85
+ return super().forward(*args, **{**default_kwargs, **kwargs})
86
+
87
+ def sample(self, *args, **kwargs):
88
+ default_kwargs = dict(
89
+ **get_default_sampling_kwargs(),
90
+ embedding_scale=5.0,
91
+ )
92
+ return super().sample(*args, **{**default_kwargs, **kwargs})
93
+
94
+
Modules/diffusion/modules.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import floor, log, pi
2
+ from typing import Any, List, Optional, Sequence, Tuple, Union
3
+
4
+ from .utils import *
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange, reduce, repeat
9
+ from einops.layers.torch import Rearrange
10
+ from einops_exts import rearrange_many
11
+ from torch import Tensor, einsum
12
+
13
+
14
+ """
15
+ Utils
16
+ """
17
+
18
+ class AdaLayerNorm(nn.Module):
19
+ def __init__(self, style_dim, channels, eps=1e-5):
20
+ super().__init__()
21
+ self.channels = channels
22
+ self.eps = eps
23
+
24
+ self.fc = nn.Linear(style_dim, channels*2)
25
+
26
+ def forward(self, x, s):
27
+ x = x.transpose(-1, -2)
28
+ x = x.transpose(1, -1)
29
+
30
+ h = self.fc(s)
31
+ h = h.view(h.size(0), h.size(1), 1)
32
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
33
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
34
+
35
+
36
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
37
+ x = (1 + gamma) * x + beta
38
+ return x.transpose(1, -1).transpose(-1, -2)
39
+
40
+ class StyleTransformer1d(nn.Module):
41
+ def __init__(
42
+ self,
43
+ num_layers: int,
44
+ channels: int,
45
+ num_heads: int,
46
+ head_features: int,
47
+ multiplier: int,
48
+ use_context_time: bool = True,
49
+ use_rel_pos: bool = False,
50
+ context_features_multiplier: int = 1,
51
+ rel_pos_num_buckets: Optional[int] = None,
52
+ rel_pos_max_distance: Optional[int] = None,
53
+ context_features: Optional[int] = None,
54
+ context_embedding_features: Optional[int] = None,
55
+ embedding_max_length: int = 512,
56
+ ):
57
+ super().__init__()
58
+
59
+ self.blocks = nn.ModuleList(
60
+ [
61
+ StyleTransformerBlock(
62
+ features=channels + context_embedding_features,
63
+ head_features=head_features,
64
+ num_heads=num_heads,
65
+ multiplier=multiplier,
66
+ style_dim=context_features,
67
+ use_rel_pos=use_rel_pos,
68
+ rel_pos_num_buckets=rel_pos_num_buckets,
69
+ rel_pos_max_distance=rel_pos_max_distance,
70
+ )
71
+ for i in range(num_layers)
72
+ ]
73
+ )
74
+
75
+ self.to_out = nn.Sequential(
76
+ Rearrange("b t c -> b c t"),
77
+ nn.Conv1d(
78
+ in_channels=channels + context_embedding_features,
79
+ out_channels=channels,
80
+ kernel_size=1,
81
+ ),
82
+ )
83
+
84
+ use_context_features = exists(context_features)
85
+ self.use_context_features = use_context_features
86
+ self.use_context_time = use_context_time
87
+
88
+ if use_context_time or use_context_features:
89
+ context_mapping_features = channels + context_embedding_features
90
+
91
+ self.to_mapping = nn.Sequential(
92
+ nn.Linear(context_mapping_features, context_mapping_features),
93
+ nn.GELU(),
94
+ nn.Linear(context_mapping_features, context_mapping_features),
95
+ nn.GELU(),
96
+ )
97
+
98
+ if use_context_time:
99
+ assert exists(context_mapping_features)
100
+ self.to_time = nn.Sequential(
101
+ TimePositionalEmbedding(
102
+ dim=channels, out_features=context_mapping_features
103
+ ),
104
+ nn.GELU(),
105
+ )
106
+
107
+ if use_context_features:
108
+ assert exists(context_features) and exists(context_mapping_features)
109
+ self.to_features = nn.Sequential(
110
+ nn.Linear(
111
+ in_features=context_features, out_features=context_mapping_features
112
+ ),
113
+ nn.GELU(),
114
+ )
115
+
116
+ self.fixed_embedding = FixedEmbedding(
117
+ max_length=embedding_max_length, features=context_embedding_features
118
+ )
119
+
120
+
121
+ def get_mapping(
122
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
123
+ ) -> Optional[Tensor]:
124
+ """Combines context time features and features into mapping"""
125
+ items, mapping = [], None
126
+ # Compute time features
127
+ if self.use_context_time:
128
+ assert_message = "use_context_time=True but no time features provided"
129
+ assert exists(time), assert_message
130
+ items += [self.to_time(time)]
131
+ # Compute features
132
+ if self.use_context_features:
133
+ assert_message = "context_features exists but no features provided"
134
+ assert exists(features), assert_message
135
+ items += [self.to_features(features)]
136
+
137
+ # Compute joint mapping
138
+ if self.use_context_time or self.use_context_features:
139
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
140
+ mapping = self.to_mapping(mapping)
141
+
142
+ return mapping
143
+
144
+ def run(self, x, time, embedding, features):
145
+
146
+ mapping = self.get_mapping(time, features)
147
+ x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
148
+ mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
149
+
150
+ for block in self.blocks:
151
+ x = x + mapping
152
+ x = block(x, features)
153
+
154
+ x = x.mean(axis=1).unsqueeze(1)
155
+ x = self.to_out(x)
156
+ x = x.transpose(-1, -2)
157
+
158
+ return x
159
+
160
+ def forward(self, x: Tensor,
161
+ time: Tensor,
162
+ embedding_mask_proba: float = 0.0,
163
+ embedding: Optional[Tensor] = None,
164
+ features: Optional[Tensor] = None,
165
+ embedding_scale: float = 1.0) -> Tensor:
166
+
167
+ b, device = embedding.shape[0], embedding.device
168
+ fixed_embedding = self.fixed_embedding(embedding)
169
+ if embedding_mask_proba > 0.0:
170
+ # Randomly mask embedding
171
+ batch_mask = rand_bool(
172
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
173
+ )
174
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
175
+
176
+ if embedding_scale != 1.0:
177
+ # Compute both normal and fixed embedding outputs
178
+ out = self.run(x, time, embedding=embedding, features=features)
179
+ out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
180
+ # Scale conditional output using classifier-free guidance
181
+ return out_masked + (out - out_masked) * embedding_scale
182
+ else:
183
+ return self.run(x, time, embedding=embedding, features=features)
184
+
185
+ return x
186
+
187
+
188
+ class StyleTransformerBlock(nn.Module):
189
+ def __init__(
190
+ self,
191
+ features: int,
192
+ num_heads: int,
193
+ head_features: int,
194
+ style_dim: int,
195
+ multiplier: int,
196
+ use_rel_pos: bool,
197
+ rel_pos_num_buckets: Optional[int] = None,
198
+ rel_pos_max_distance: Optional[int] = None,
199
+ context_features: Optional[int] = None,
200
+ ):
201
+ super().__init__()
202
+
203
+ self.use_cross_attention = exists(context_features) and context_features > 0
204
+
205
+ self.attention = StyleAttention(
206
+ features=features,
207
+ style_dim=style_dim,
208
+ num_heads=num_heads,
209
+ head_features=head_features,
210
+ use_rel_pos=use_rel_pos,
211
+ rel_pos_num_buckets=rel_pos_num_buckets,
212
+ rel_pos_max_distance=rel_pos_max_distance,
213
+ )
214
+
215
+ if self.use_cross_attention:
216
+ self.cross_attention = StyleAttention(
217
+ features=features,
218
+ style_dim=style_dim,
219
+ num_heads=num_heads,
220
+ head_features=head_features,
221
+ context_features=context_features,
222
+ use_rel_pos=use_rel_pos,
223
+ rel_pos_num_buckets=rel_pos_num_buckets,
224
+ rel_pos_max_distance=rel_pos_max_distance,
225
+ )
226
+
227
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
228
+
229
+ def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
230
+ x = self.attention(x, s) + x
231
+ if self.use_cross_attention:
232
+ x = self.cross_attention(x, s, context=context) + x
233
+ x = self.feed_forward(x) + x
234
+ return x
235
+
236
+ class StyleAttention(nn.Module):
237
+ def __init__(
238
+ self,
239
+ features: int,
240
+ *,
241
+ style_dim: int,
242
+ head_features: int,
243
+ num_heads: int,
244
+ context_features: Optional[int] = None,
245
+ use_rel_pos: bool,
246
+ rel_pos_num_buckets: Optional[int] = None,
247
+ rel_pos_max_distance: Optional[int] = None,
248
+ ):
249
+ super().__init__()
250
+ self.context_features = context_features
251
+ mid_features = head_features * num_heads
252
+ context_features = default(context_features, features)
253
+
254
+ self.norm = AdaLayerNorm(style_dim, features)
255
+ self.norm_context = AdaLayerNorm(style_dim, context_features)
256
+ self.to_q = nn.Linear(
257
+ in_features=features, out_features=mid_features, bias=False
258
+ )
259
+ self.to_kv = nn.Linear(
260
+ in_features=context_features, out_features=mid_features * 2, bias=False
261
+ )
262
+ self.attention = AttentionBase(
263
+ features,
264
+ num_heads=num_heads,
265
+ head_features=head_features,
266
+ use_rel_pos=use_rel_pos,
267
+ rel_pos_num_buckets=rel_pos_num_buckets,
268
+ rel_pos_max_distance=rel_pos_max_distance,
269
+ )
270
+
271
+ def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
272
+ assert_message = "You must provide a context when using context_features"
273
+ assert not self.context_features or exists(context), assert_message
274
+ # Use context if provided
275
+ context = default(context, x)
276
+ # Normalize then compute q from input and k,v from context
277
+ x, context = self.norm(x, s), self.norm_context(context, s)
278
+
279
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
280
+ # Compute and return attention
281
+ return self.attention(q, k, v)
282
+
283
+ class Transformer1d(nn.Module):
284
+ def __init__(
285
+ self,
286
+ num_layers: int,
287
+ channels: int,
288
+ num_heads: int,
289
+ head_features: int,
290
+ multiplier: int,
291
+ use_context_time: bool = True,
292
+ use_rel_pos: bool = False,
293
+ context_features_multiplier: int = 1,
294
+ rel_pos_num_buckets: Optional[int] = None,
295
+ rel_pos_max_distance: Optional[int] = None,
296
+ context_features: Optional[int] = None,
297
+ context_embedding_features: Optional[int] = None,
298
+ embedding_max_length: int = 512,
299
+ ):
300
+ super().__init__()
301
+
302
+ self.blocks = nn.ModuleList(
303
+ [
304
+ TransformerBlock(
305
+ features=channels + context_embedding_features,
306
+ head_features=head_features,
307
+ num_heads=num_heads,
308
+ multiplier=multiplier,
309
+ use_rel_pos=use_rel_pos,
310
+ rel_pos_num_buckets=rel_pos_num_buckets,
311
+ rel_pos_max_distance=rel_pos_max_distance,
312
+ )
313
+ for i in range(num_layers)
314
+ ]
315
+ )
316
+
317
+ self.to_out = nn.Sequential(
318
+ Rearrange("b t c -> b c t"),
319
+ nn.Conv1d(
320
+ in_channels=channels + context_embedding_features,
321
+ out_channels=channels,
322
+ kernel_size=1,
323
+ ),
324
+ )
325
+
326
+ use_context_features = exists(context_features)
327
+ self.use_context_features = use_context_features
328
+ self.use_context_time = use_context_time
329
+
330
+ if use_context_time or use_context_features:
331
+ context_mapping_features = channels + context_embedding_features
332
+
333
+ self.to_mapping = nn.Sequential(
334
+ nn.Linear(context_mapping_features, context_mapping_features),
335
+ nn.GELU(),
336
+ nn.Linear(context_mapping_features, context_mapping_features),
337
+ nn.GELU(),
338
+ )
339
+
340
+ if use_context_time:
341
+ assert exists(context_mapping_features)
342
+ self.to_time = nn.Sequential(
343
+ TimePositionalEmbedding(
344
+ dim=channels, out_features=context_mapping_features
345
+ ),
346
+ nn.GELU(),
347
+ )
348
+
349
+ if use_context_features:
350
+ assert exists(context_features) and exists(context_mapping_features)
351
+ self.to_features = nn.Sequential(
352
+ nn.Linear(
353
+ in_features=context_features, out_features=context_mapping_features
354
+ ),
355
+ nn.GELU(),
356
+ )
357
+
358
+ self.fixed_embedding = FixedEmbedding(
359
+ max_length=embedding_max_length, features=context_embedding_features
360
+ )
361
+
362
+
363
+ def get_mapping(
364
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
365
+ ) -> Optional[Tensor]:
366
+ """Combines context time features and features into mapping"""
367
+ items, mapping = [], None
368
+ # Compute time features
369
+ if self.use_context_time:
370
+ assert_message = "use_context_time=True but no time features provided"
371
+ assert exists(time), assert_message
372
+ items += [self.to_time(time)]
373
+ # Compute features
374
+ if self.use_context_features:
375
+ assert_message = "context_features exists but no features provided"
376
+ assert exists(features), assert_message
377
+ items += [self.to_features(features)]
378
+
379
+ # Compute joint mapping
380
+ if self.use_context_time or self.use_context_features:
381
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
382
+ mapping = self.to_mapping(mapping)
383
+
384
+ return mapping
385
+
386
+ def run(self, x, time, embedding, features):
387
+
388
+ mapping = self.get_mapping(time, features)
389
+ x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
390
+ mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
391
+
392
+ for block in self.blocks:
393
+ x = x + mapping
394
+ x = block(x)
395
+
396
+ x = x.mean(axis=1).unsqueeze(1)
397
+ x = self.to_out(x)
398
+ x = x.transpose(-1, -2)
399
+
400
+ return x
401
+
402
+ def forward(self, x: Tensor,
403
+ time: Tensor,
404
+ embedding_mask_proba: float = 0.0,
405
+ embedding: Optional[Tensor] = None,
406
+ features: Optional[Tensor] = None,
407
+ embedding_scale: float = 1.0) -> Tensor:
408
+
409
+ b, device = embedding.shape[0], embedding.device
410
+ fixed_embedding = self.fixed_embedding(embedding)
411
+ if embedding_mask_proba > 0.0:
412
+ # Randomly mask embedding
413
+ batch_mask = rand_bool(
414
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
415
+ )
416
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
417
+
418
+ if embedding_scale != 1.0:
419
+ # Compute both normal and fixed embedding outputs
420
+ out = self.run(x, time, embedding=embedding, features=features)
421
+ out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
422
+ # Scale conditional output using classifier-free guidance
423
+ return out_masked + (out - out_masked) * embedding_scale
424
+ else:
425
+ return self.run(x, time, embedding=embedding, features=features)
426
+
427
+ return x
428
+
429
+
430
+ """
431
+ Attention Components
432
+ """
433
+
434
+
435
+ class RelativePositionBias(nn.Module):
436
+ def __init__(self, num_buckets: int, max_distance: int, num_heads: int):
437
+ super().__init__()
438
+ self.num_buckets = num_buckets
439
+ self.max_distance = max_distance
440
+ self.num_heads = num_heads
441
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
442
+
443
+ @staticmethod
444
+ def _relative_position_bucket(
445
+ relative_position: Tensor, num_buckets: int, max_distance: int
446
+ ):
447
+ num_buckets //= 2
448
+ ret = (relative_position >= 0).to(torch.long) * num_buckets
449
+ n = torch.abs(relative_position)
450
+
451
+ max_exact = num_buckets // 2
452
+ is_small = n < max_exact
453
+
454
+ val_if_large = (
455
+ max_exact
456
+ + (
457
+ torch.log(n.float() / max_exact)
458
+ / log(max_distance / max_exact)
459
+ * (num_buckets - max_exact)
460
+ ).long()
461
+ )
462
+ val_if_large = torch.min(
463
+ val_if_large, torch.full_like(val_if_large, num_buckets - 1)
464
+ )
465
+
466
+ ret += torch.where(is_small, n, val_if_large)
467
+ return ret
468
+
469
+ def forward(self, num_queries: int, num_keys: int) -> Tensor:
470
+ i, j, device = num_queries, num_keys, self.relative_attention_bias.weight.device
471
+ q_pos = torch.arange(j - i, j, dtype=torch.long, device=device)
472
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
473
+ rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1")
474
+
475
+ relative_position_bucket = self._relative_position_bucket(
476
+ rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance
477
+ )
478
+
479
+ bias = self.relative_attention_bias(relative_position_bucket)
480
+ bias = rearrange(bias, "m n h -> 1 h m n")
481
+ return bias
482
+
483
+
484
+ def FeedForward(features: int, multiplier: int) -> nn.Module:
485
+ mid_features = features * multiplier
486
+ return nn.Sequential(
487
+ nn.Linear(in_features=features, out_features=mid_features),
488
+ nn.GELU(),
489
+ nn.Linear(in_features=mid_features, out_features=features),
490
+ )
491
+
492
+
493
+ class AttentionBase(nn.Module):
494
+ def __init__(
495
+ self,
496
+ features: int,
497
+ *,
498
+ head_features: int,
499
+ num_heads: int,
500
+ use_rel_pos: bool,
501
+ out_features: Optional[int] = None,
502
+ rel_pos_num_buckets: Optional[int] = None,
503
+ rel_pos_max_distance: Optional[int] = None,
504
+ ):
505
+ super().__init__()
506
+ self.scale = head_features ** -0.5
507
+ self.num_heads = num_heads
508
+ self.use_rel_pos = use_rel_pos
509
+ mid_features = head_features * num_heads
510
+
511
+ if use_rel_pos:
512
+ assert exists(rel_pos_num_buckets) and exists(rel_pos_max_distance)
513
+ self.rel_pos = RelativePositionBias(
514
+ num_buckets=rel_pos_num_buckets,
515
+ max_distance=rel_pos_max_distance,
516
+ num_heads=num_heads,
517
+ )
518
+ if out_features is None:
519
+ out_features = features
520
+
521
+ self.to_out = nn.Linear(in_features=mid_features, out_features=out_features)
522
+
523
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
524
+ # Split heads
525
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
526
+ # Compute similarity matrix
527
+ sim = einsum("... n d, ... m d -> ... n m", q, k)
528
+ sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
529
+ sim = sim * self.scale
530
+ # Get attention matrix with softmax
531
+ attn = sim.softmax(dim=-1)
532
+ # Compute values
533
+ out = einsum("... n m, ... m d -> ... n d", attn, v)
534
+ out = rearrange(out, "b h n d -> b n (h d)")
535
+ return self.to_out(out)
536
+
537
+
538
+ class Attention(nn.Module):
539
+ def __init__(
540
+ self,
541
+ features: int,
542
+ *,
543
+ head_features: int,
544
+ num_heads: int,
545
+ out_features: Optional[int] = None,
546
+ context_features: Optional[int] = None,
547
+ use_rel_pos: bool,
548
+ rel_pos_num_buckets: Optional[int] = None,
549
+ rel_pos_max_distance: Optional[int] = None,
550
+ ):
551
+ super().__init__()
552
+ self.context_features = context_features
553
+ mid_features = head_features * num_heads
554
+ context_features = default(context_features, features)
555
+
556
+ self.norm = nn.LayerNorm(features)
557
+ self.norm_context = nn.LayerNorm(context_features)
558
+ self.to_q = nn.Linear(
559
+ in_features=features, out_features=mid_features, bias=False
560
+ )
561
+ self.to_kv = nn.Linear(
562
+ in_features=context_features, out_features=mid_features * 2, bias=False
563
+ )
564
+
565
+ self.attention = AttentionBase(
566
+ features,
567
+ out_features=out_features,
568
+ num_heads=num_heads,
569
+ head_features=head_features,
570
+ use_rel_pos=use_rel_pos,
571
+ rel_pos_num_buckets=rel_pos_num_buckets,
572
+ rel_pos_max_distance=rel_pos_max_distance,
573
+ )
574
+
575
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
576
+ assert_message = "You must provide a context when using context_features"
577
+ assert not self.context_features or exists(context), assert_message
578
+ # Use context if provided
579
+ context = default(context, x)
580
+ # Normalize then compute q from input and k,v from context
581
+ x, context = self.norm(x), self.norm_context(context)
582
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
583
+ # Compute and return attention
584
+ return self.attention(q, k, v)
585
+
586
+
587
+ """
588
+ Transformer Blocks
589
+ """
590
+
591
+
592
+ class TransformerBlock(nn.Module):
593
+ def __init__(
594
+ self,
595
+ features: int,
596
+ num_heads: int,
597
+ head_features: int,
598
+ multiplier: int,
599
+ use_rel_pos: bool,
600
+ rel_pos_num_buckets: Optional[int] = None,
601
+ rel_pos_max_distance: Optional[int] = None,
602
+ context_features: Optional[int] = None,
603
+ ):
604
+ super().__init__()
605
+
606
+ self.use_cross_attention = exists(context_features) and context_features > 0
607
+
608
+ self.attention = Attention(
609
+ features=features,
610
+ num_heads=num_heads,
611
+ head_features=head_features,
612
+ use_rel_pos=use_rel_pos,
613
+ rel_pos_num_buckets=rel_pos_num_buckets,
614
+ rel_pos_max_distance=rel_pos_max_distance,
615
+ )
616
+
617
+ if self.use_cross_attention:
618
+ self.cross_attention = Attention(
619
+ features=features,
620
+ num_heads=num_heads,
621
+ head_features=head_features,
622
+ context_features=context_features,
623
+ use_rel_pos=use_rel_pos,
624
+ rel_pos_num_buckets=rel_pos_num_buckets,
625
+ rel_pos_max_distance=rel_pos_max_distance,
626
+ )
627
+
628
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
629
+
630
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
631
+ x = self.attention(x) + x
632
+ if self.use_cross_attention:
633
+ x = self.cross_attention(x, context=context) + x
634
+ x = self.feed_forward(x) + x
635
+ return x
636
+
637
+
638
+
639
+ """
640
+ Time Embeddings
641
+ """
642
+
643
+
644
+ class SinusoidalEmbedding(nn.Module):
645
+ def __init__(self, dim: int):
646
+ super().__init__()
647
+ self.dim = dim
648
+
649
+ def forward(self, x: Tensor) -> Tensor:
650
+ device, half_dim = x.device, self.dim // 2
651
+ emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
652
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
653
+ emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
654
+ return torch.cat((emb.sin(), emb.cos()), dim=-1)
655
+
656
+
657
+ class LearnedPositionalEmbedding(nn.Module):
658
+ """Used for continuous time"""
659
+
660
+ def __init__(self, dim: int):
661
+ super().__init__()
662
+ assert (dim % 2) == 0
663
+ half_dim = dim // 2
664
+ self.weights = nn.Parameter(torch.randn(half_dim))
665
+
666
+ def forward(self, x: Tensor) -> Tensor:
667
+ x = rearrange(x, "b -> b 1")
668
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
669
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
670
+ fouriered = torch.cat((x, fouriered), dim=-1)
671
+ return fouriered
672
+
673
+
674
+ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
675
+ return nn.Sequential(
676
+ LearnedPositionalEmbedding(dim),
677
+ nn.Linear(in_features=dim + 1, out_features=out_features),
678
+ )
679
+
680
+ class FixedEmbedding(nn.Module):
681
+ def __init__(self, max_length: int, features: int):
682
+ super().__init__()
683
+ self.max_length = max_length
684
+ self.embedding = nn.Embedding(max_length, features)
685
+
686
+ def forward(self, x: Tensor) -> Tensor:
687
+ batch_size, length, device = *x.shape[0:2], x.device
688
+ assert_message = "Input sequence length must be <= max_length"
689
+ assert length <= self.max_length, assert_message
690
+ position = torch.arange(length, device=device)
691
+ fixed_embedding = self.embedding(position)
692
+ fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
693
+ return fixed_embedding
Modules/diffusion/sampler.py ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import atan, cos, pi, sin, sqrt
2
+ from typing import Any, Callable, List, Optional, Tuple, Type
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange, reduce
8
+ from torch import Tensor
9
+
10
+ from .utils import *
11
+
12
+ """
13
+ Diffusion Training
14
+ """
15
+
16
+ """ Distributions """
17
+
18
+
19
+ class Distribution:
20
+ def __call__(self, num_samples: int, device: torch.device):
21
+ raise NotImplementedError()
22
+
23
+
24
+ class LogNormalDistribution(Distribution):
25
+ def __init__(self, mean: float, std: float):
26
+ self.mean = mean
27
+ self.std = std
28
+
29
+ def __call__(
30
+ self, num_samples: int, device: torch.device = torch.device("cpu")
31
+ ) -> Tensor:
32
+ normal = self.mean + self.std * torch.randn((num_samples,), device=device)
33
+ return normal.exp()
34
+
35
+
36
+ class UniformDistribution(Distribution):
37
+ def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
38
+ return torch.rand(num_samples, device=device)
39
+
40
+
41
+ class VKDistribution(Distribution):
42
+ def __init__(
43
+ self,
44
+ min_value: float = 0.0,
45
+ max_value: float = float("inf"),
46
+ sigma_data: float = 1.0,
47
+ ):
48
+ self.min_value = min_value
49
+ self.max_value = max_value
50
+ self.sigma_data = sigma_data
51
+
52
+ def __call__(
53
+ self, num_samples: int, device: torch.device = torch.device("cpu")
54
+ ) -> Tensor:
55
+ sigma_data = self.sigma_data
56
+ min_cdf = atan(self.min_value / sigma_data) * 2 / pi
57
+ max_cdf = atan(self.max_value / sigma_data) * 2 / pi
58
+ u = (max_cdf - min_cdf) * torch.randn((num_samples,), device=device) + min_cdf
59
+ return torch.tan(u * pi / 2) * sigma_data
60
+
61
+
62
+ """ Diffusion Classes """
63
+
64
+
65
+ def pad_dims(x: Tensor, ndim: int) -> Tensor:
66
+ # Pads additional ndims to the right of the tensor
67
+ return x.view(*x.shape, *((1,) * ndim))
68
+
69
+
70
+ def clip(x: Tensor, dynamic_threshold: float = 0.0):
71
+ if dynamic_threshold == 0.0:
72
+ return x.clamp(-1.0, 1.0)
73
+ else:
74
+ # Dynamic thresholding
75
+ # Find dynamic threshold quantile for each batch
76
+ x_flat = rearrange(x, "b ... -> b (...)")
77
+ scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1)
78
+ # Clamp to a min of 1.0
79
+ scale.clamp_(min=1.0)
80
+ # Clamp all values and scale
81
+ scale = pad_dims(scale, ndim=x.ndim - scale.ndim)
82
+ x = x.clamp(-scale, scale) / scale
83
+ return x
84
+
85
+
86
+ def to_batch(
87
+ batch_size: int,
88
+ device: torch.device,
89
+ x: Optional[float] = None,
90
+ xs: Optional[Tensor] = None,
91
+ ) -> Tensor:
92
+ assert exists(x) ^ exists(xs), "Either x or xs must be provided"
93
+ # If x provided use the same for all batch items
94
+ if exists(x):
95
+ xs = torch.full(size=(batch_size,), fill_value=x).to(device)
96
+ assert exists(xs)
97
+ return xs
98
+
99
+
100
+ class Diffusion(nn.Module):
101
+
102
+ alias: str = ""
103
+
104
+ """Base diffusion class"""
105
+
106
+ def denoise_fn(
107
+ self,
108
+ x_noisy: Tensor,
109
+ sigmas: Optional[Tensor] = None,
110
+ sigma: Optional[float] = None,
111
+ **kwargs,
112
+ ) -> Tensor:
113
+ raise NotImplementedError("Diffusion class missing denoise_fn")
114
+
115
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
116
+ raise NotImplementedError("Diffusion class missing forward function")
117
+
118
+
119
+ class VDiffusion(Diffusion):
120
+
121
+ alias = "v"
122
+
123
+ def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
124
+ super().__init__()
125
+ self.net = net
126
+ self.sigma_distribution = sigma_distribution
127
+
128
+ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
129
+ angle = sigmas * pi / 2
130
+ alpha = torch.cos(angle)
131
+ beta = torch.sin(angle)
132
+ return alpha, beta
133
+
134
+ def denoise_fn(
135
+ self,
136
+ x_noisy: Tensor,
137
+ sigmas: Optional[Tensor] = None,
138
+ sigma: Optional[float] = None,
139
+ **kwargs,
140
+ ) -> Tensor:
141
+ batch_size, device = x_noisy.shape[0], x_noisy.device
142
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
143
+ return self.net(x_noisy, sigmas, **kwargs)
144
+
145
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
146
+ batch_size, device = x.shape[0], x.device
147
+
148
+ # Sample amount of noise to add for each batch element
149
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
150
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
151
+
152
+ # Get noise
153
+ noise = default(noise, lambda: torch.randn_like(x))
154
+
155
+ # Combine input and noise weighted by half-circle
156
+ alpha, beta = self.get_alpha_beta(sigmas_padded)
157
+ x_noisy = x * alpha + noise * beta
158
+ x_target = noise * alpha - x * beta
159
+
160
+ # Denoise and return loss
161
+ x_denoised = self.denoise_fn(x_noisy, sigmas, **kwargs)
162
+ return F.mse_loss(x_denoised, x_target)
163
+
164
+
165
+ class KDiffusion(Diffusion):
166
+ """Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364"""
167
+
168
+ alias = "k"
169
+
170
+ def __init__(
171
+ self,
172
+ net: nn.Module,
173
+ *,
174
+ sigma_distribution: Distribution,
175
+ sigma_data: float, # data distribution standard deviation
176
+ dynamic_threshold: float = 0.0,
177
+ ):
178
+ super().__init__()
179
+ self.net = net
180
+ self.sigma_data = sigma_data
181
+ self.sigma_distribution = sigma_distribution
182
+ self.dynamic_threshold = dynamic_threshold
183
+
184
+ def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
185
+ sigma_data = self.sigma_data
186
+ c_noise = torch.log(sigmas) * 0.25
187
+ sigmas = rearrange(sigmas, "b -> b 1 1")
188
+ c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
189
+ c_out = sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
190
+ c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
191
+ return c_skip, c_out, c_in, c_noise
192
+
193
+ def denoise_fn(
194
+ self,
195
+ x_noisy: Tensor,
196
+ sigmas: Optional[Tensor] = None,
197
+ sigma: Optional[float] = None,
198
+ **kwargs,
199
+ ) -> Tensor:
200
+ batch_size, device = x_noisy.shape[0], x_noisy.device
201
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
202
+
203
+ # Predict network output and add skip connection
204
+ c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas)
205
+ x_pred = self.net(c_in * x_noisy, c_noise, **kwargs)
206
+ x_denoised = c_skip * x_noisy + c_out * x_pred
207
+
208
+ return x_denoised
209
+
210
+ def loss_weight(self, sigmas: Tensor) -> Tensor:
211
+ # Computes weight depending on data distribution
212
+ return (sigmas ** 2 + self.sigma_data ** 2) * (sigmas * self.sigma_data) ** -2
213
+
214
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
215
+ batch_size, device = x.shape[0], x.device
216
+ from einops import rearrange, reduce
217
+
218
+ # Sample amount of noise to add for each batch element
219
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
220
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
221
+
222
+ # Add noise to input
223
+ noise = default(noise, lambda: torch.randn_like(x))
224
+ x_noisy = x + sigmas_padded * noise
225
+
226
+ # Compute denoised values
227
+ x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs)
228
+
229
+ # Compute weighted loss
230
+ losses = F.mse_loss(x_denoised, x, reduction="none")
231
+ losses = reduce(losses, "b ... -> b", "mean")
232
+ losses = losses * self.loss_weight(sigmas)
233
+ loss = losses.mean()
234
+ return loss
235
+
236
+
237
+ class VKDiffusion(Diffusion):
238
+
239
+ alias = "vk"
240
+
241
+ def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
242
+ super().__init__()
243
+ self.net = net
244
+ self.sigma_distribution = sigma_distribution
245
+
246
+ def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
247
+ sigma_data = 1.0
248
+ sigmas = rearrange(sigmas, "b -> b 1 1")
249
+ c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
250
+ c_out = -sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
251
+ c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
252
+ return c_skip, c_out, c_in
253
+
254
+ def sigma_to_t(self, sigmas: Tensor) -> Tensor:
255
+ return sigmas.atan() / pi * 2
256
+
257
+ def t_to_sigma(self, t: Tensor) -> Tensor:
258
+ return (t * pi / 2).tan()
259
+
260
+ def denoise_fn(
261
+ self,
262
+ x_noisy: Tensor,
263
+ sigmas: Optional[Tensor] = None,
264
+ sigma: Optional[float] = None,
265
+ **kwargs,
266
+ ) -> Tensor:
267
+ batch_size, device = x_noisy.shape[0], x_noisy.device
268
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
269
+
270
+ # Predict network output and add skip connection
271
+ c_skip, c_out, c_in = self.get_scale_weights(sigmas)
272
+ x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
273
+ x_denoised = c_skip * x_noisy + c_out * x_pred
274
+ return x_denoised
275
+
276
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
277
+ batch_size, device = x.shape[0], x.device
278
+
279
+ # Sample amount of noise to add for each batch element
280
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
281
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
282
+
283
+ # Add noise to input
284
+ noise = default(noise, lambda: torch.randn_like(x))
285
+ x_noisy = x + sigmas_padded * noise
286
+
287
+ # Compute model output
288
+ c_skip, c_out, c_in = self.get_scale_weights(sigmas)
289
+ x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
290
+
291
+ # Compute v-objective target
292
+ v_target = (x - c_skip * x_noisy) / (c_out + 1e-7)
293
+
294
+ # Compute loss
295
+ loss = F.mse_loss(x_pred, v_target)
296
+ return loss
297
+
298
+
299
+ """
300
+ Diffusion Sampling
301
+ """
302
+
303
+ """ Schedules """
304
+
305
+
306
+ class Schedule(nn.Module):
307
+ """Interface used by different sampling schedules"""
308
+
309
+ def forward(self, num_steps: int, device: torch.device) -> Tensor:
310
+ raise NotImplementedError()
311
+
312
+
313
+ class LinearSchedule(Schedule):
314
+ def forward(self, num_steps: int, device: Any) -> Tensor:
315
+ sigmas = torch.linspace(1, 0, num_steps + 1)[:-1]
316
+ return sigmas
317
+
318
+
319
+ class KarrasSchedule(Schedule):
320
+ """https://arxiv.org/abs/2206.00364 equation 5"""
321
+
322
+ def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0):
323
+ super().__init__()
324
+ self.sigma_min = sigma_min
325
+ self.sigma_max = sigma_max
326
+ self.rho = rho
327
+
328
+ def forward(self, num_steps: int, device: Any) -> Tensor:
329
+ rho_inv = 1.0 / self.rho
330
+ steps = torch.arange(num_steps, device=device, dtype=torch.float32)
331
+ sigmas = (
332
+ self.sigma_max ** rho_inv
333
+ + (steps / (num_steps - 1))
334
+ * (self.sigma_min ** rho_inv - self.sigma_max ** rho_inv)
335
+ ) ** self.rho
336
+ sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
337
+ return sigmas
338
+
339
+
340
+ """ Samplers """
341
+
342
+
343
+ class Sampler(nn.Module):
344
+
345
+ diffusion_types: List[Type[Diffusion]] = []
346
+
347
+ def forward(
348
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
349
+ ) -> Tensor:
350
+ raise NotImplementedError()
351
+
352
+ def inpaint(
353
+ self,
354
+ source: Tensor,
355
+ mask: Tensor,
356
+ fn: Callable,
357
+ sigmas: Tensor,
358
+ num_steps: int,
359
+ num_resamples: int,
360
+ ) -> Tensor:
361
+ raise NotImplementedError("Inpainting not available with current sampler")
362
+
363
+
364
+ class VSampler(Sampler):
365
+
366
+ diffusion_types = [VDiffusion]
367
+
368
+ def get_alpha_beta(self, sigma: float) -> Tuple[float, float]:
369
+ angle = sigma * pi / 2
370
+ alpha = cos(angle)
371
+ beta = sin(angle)
372
+ return alpha, beta
373
+
374
+ def forward(
375
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
376
+ ) -> Tensor:
377
+ x = sigmas[0] * noise
378
+ alpha, beta = self.get_alpha_beta(sigmas[0].item())
379
+
380
+ for i in range(num_steps - 1):
381
+ is_last = i == num_steps - 1
382
+
383
+ x_denoised = fn(x, sigma=sigmas[i])
384
+ x_pred = x * alpha - x_denoised * beta
385
+ x_eps = x * beta + x_denoised * alpha
386
+
387
+ if not is_last:
388
+ alpha, beta = self.get_alpha_beta(sigmas[i + 1].item())
389
+ x = x_pred * alpha + x_eps * beta
390
+
391
+ return x_pred
392
+
393
+
394
+ class KarrasSampler(Sampler):
395
+ """https://arxiv.org/abs/2206.00364 algorithm 1"""
396
+
397
+ diffusion_types = [KDiffusion, VKDiffusion]
398
+
399
+ def __init__(
400
+ self,
401
+ s_tmin: float = 0,
402
+ s_tmax: float = float("inf"),
403
+ s_churn: float = 0.0,
404
+ s_noise: float = 1.0,
405
+ ):
406
+ super().__init__()
407
+ self.s_tmin = s_tmin
408
+ self.s_tmax = s_tmax
409
+ self.s_noise = s_noise
410
+ self.s_churn = s_churn
411
+
412
+ def step(
413
+ self, x: Tensor, fn: Callable, sigma: float, sigma_next: float, gamma: float
414
+ ) -> Tensor:
415
+ """Algorithm 2 (step)"""
416
+ # Select temporarily increased noise level
417
+ sigma_hat = sigma + gamma * sigma
418
+ # Add noise to move from sigma to sigma_hat
419
+ epsilon = self.s_noise * torch.randn_like(x)
420
+ x_hat = x + sqrt(sigma_hat ** 2 - sigma ** 2) * epsilon
421
+ # Evaluate ∂x/∂sigma at sigma_hat
422
+ d = (x_hat - fn(x_hat, sigma=sigma_hat)) / sigma_hat
423
+ # Take euler step from sigma_hat to sigma_next
424
+ x_next = x_hat + (sigma_next - sigma_hat) * d
425
+ # Second order correction
426
+ if sigma_next != 0:
427
+ model_out_next = fn(x_next, sigma=sigma_next)
428
+ d_prime = (x_next - model_out_next) / sigma_next
429
+ x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime)
430
+ return x_next
431
+
432
+ def forward(
433
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
434
+ ) -> Tensor:
435
+ x = sigmas[0] * noise
436
+ # Compute gammas
437
+ gammas = torch.where(
438
+ (sigmas >= self.s_tmin) & (sigmas <= self.s_tmax),
439
+ min(self.s_churn / num_steps, sqrt(2) - 1),
440
+ 0.0,
441
+ )
442
+ # Denoise to sample
443
+ for i in range(num_steps - 1):
444
+ x = self.step(
445
+ x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1], gamma=gammas[i] # type: ignore # noqa
446
+ )
447
+
448
+ return x
449
+
450
+
451
+ class AEulerSampler(Sampler):
452
+
453
+ diffusion_types = [KDiffusion, VKDiffusion]
454
+
455
+ def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float]:
456
+ sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
457
+ sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
458
+ return sigma_up, sigma_down
459
+
460
+ def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
461
+ # Sigma steps
462
+ sigma_up, sigma_down = self.get_sigmas(sigma, sigma_next)
463
+ # Derivative at sigma (∂x/∂sigma)
464
+ d = (x - fn(x, sigma=sigma)) / sigma
465
+ # Euler method
466
+ x_next = x + d * (sigma_down - sigma)
467
+ # Add randomness
468
+ x_next = x_next + torch.randn_like(x) * sigma_up
469
+ return x_next
470
+
471
+ def forward(
472
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
473
+ ) -> Tensor:
474
+ x = sigmas[0] * noise
475
+ # Denoise to sample
476
+ for i in range(num_steps - 1):
477
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
478
+ return x
479
+
480
+
481
+ class ADPM2Sampler(Sampler):
482
+ """https://www.desmos.com/calculator/jbxjlqd9mb"""
483
+
484
+ diffusion_types = [KDiffusion, VKDiffusion]
485
+
486
+ def __init__(self, rho: float = 1.0):
487
+ super().__init__()
488
+ self.rho = rho
489
+
490
+ def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float, float]:
491
+ r = self.rho
492
+ sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
493
+ sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
494
+ sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
495
+ return sigma_up, sigma_down, sigma_mid
496
+
497
+ def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
498
+ # Sigma steps
499
+ sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
500
+ # Derivative at sigma (∂x/∂sigma)
501
+ d = (x - fn(x, sigma=sigma)) / sigma
502
+ # Denoise to midpoint
503
+ x_mid = x + d * (sigma_mid - sigma)
504
+ # Derivative at sigma_mid (∂x_mid/∂sigma_mid)
505
+ d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid
506
+ # Denoise to next
507
+ x = x + d_mid * (sigma_down - sigma)
508
+ # Add randomness
509
+ x_next = x + torch.randn_like(x) * sigma_up
510
+ return x_next
511
+
512
+ def forward(
513
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
514
+ ) -> Tensor:
515
+ x = sigmas[0] * noise
516
+ # Denoise to sample
517
+ for i in range(num_steps - 1):
518
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
519
+ return x
520
+
521
+ def inpaint(
522
+ self,
523
+ source: Tensor,
524
+ mask: Tensor,
525
+ fn: Callable,
526
+ sigmas: Tensor,
527
+ num_steps: int,
528
+ num_resamples: int,
529
+ ) -> Tensor:
530
+ x = sigmas[0] * torch.randn_like(source)
531
+
532
+ for i in range(num_steps - 1):
533
+ # Noise source to current noise level
534
+ source_noisy = source + sigmas[i] * torch.randn_like(source)
535
+ for r in range(num_resamples):
536
+ # Merge noisy source and current then denoise
537
+ x = source_noisy * mask + x * ~mask
538
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
539
+ # Renoise if not last resample step
540
+ if r < num_resamples - 1:
541
+ sigma = sqrt(sigmas[i] ** 2 - sigmas[i + 1] ** 2)
542
+ x = x + sigma * torch.randn_like(x)
543
+
544
+ return source * mask + x * ~mask
545
+
546
+
547
+ """ Main Classes """
548
+
549
+
550
+ class DiffusionSampler(nn.Module):
551
+ def __init__(
552
+ self,
553
+ diffusion: Diffusion,
554
+ *,
555
+ sampler: Sampler,
556
+ sigma_schedule: Schedule,
557
+ num_steps: Optional[int] = None,
558
+ clamp: bool = True,
559
+ ):
560
+ super().__init__()
561
+ self.denoise_fn = diffusion.denoise_fn
562
+ self.sampler = sampler
563
+ self.sigma_schedule = sigma_schedule
564
+ self.num_steps = num_steps
565
+ self.clamp = clamp
566
+
567
+ # Check sampler is compatible with diffusion type
568
+ sampler_class = sampler.__class__.__name__
569
+ diffusion_class = diffusion.__class__.__name__
570
+ message = f"{sampler_class} incompatible with {diffusion_class}"
571
+ assert diffusion.alias in [t.alias for t in sampler.diffusion_types], message
572
+
573
+ def forward(
574
+ self, noise: Tensor, num_steps: Optional[int] = None, **kwargs
575
+ ) -> Tensor:
576
+ device = noise.device
577
+ num_steps = default(num_steps, self.num_steps) # type: ignore
578
+ assert exists(num_steps), "Parameter `num_steps` must be provided"
579
+ # Compute sigmas using schedule
580
+ sigmas = self.sigma_schedule(num_steps, device)
581
+ # Append additional kwargs to denoise function (used e.g. for conditional unet)
582
+ fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
583
+ # Sample using sampler
584
+ x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
585
+ x = x.clamp(-1.0, 1.0) if self.clamp else x
586
+ return x
587
+
588
+
589
+ class DiffusionInpainter(nn.Module):
590
+ def __init__(
591
+ self,
592
+ diffusion: Diffusion,
593
+ *,
594
+ num_steps: int,
595
+ num_resamples: int,
596
+ sampler: Sampler,
597
+ sigma_schedule: Schedule,
598
+ ):
599
+ super().__init__()
600
+ self.denoise_fn = diffusion.denoise_fn
601
+ self.num_steps = num_steps
602
+ self.num_resamples = num_resamples
603
+ self.inpaint_fn = sampler.inpaint
604
+ self.sigma_schedule = sigma_schedule
605
+
606
+ @torch.no_grad()
607
+ def forward(self, inpaint: Tensor, inpaint_mask: Tensor) -> Tensor:
608
+ x = self.inpaint_fn(
609
+ source=inpaint,
610
+ mask=inpaint_mask,
611
+ fn=self.denoise_fn,
612
+ sigmas=self.sigma_schedule(self.num_steps, inpaint.device),
613
+ num_steps=self.num_steps,
614
+ num_resamples=self.num_resamples,
615
+ )
616
+ return x
617
+
618
+
619
+ def sequential_mask(like: Tensor, start: int) -> Tensor:
620
+ length, device = like.shape[2], like.device
621
+ mask = torch.ones_like(like, dtype=torch.bool)
622
+ mask[:, :, start:] = torch.zeros((length - start,), device=device)
623
+ return mask
624
+
625
+
626
+ class SpanBySpanComposer(nn.Module):
627
+ def __init__(
628
+ self,
629
+ inpainter: DiffusionInpainter,
630
+ *,
631
+ num_spans: int,
632
+ ):
633
+ super().__init__()
634
+ self.inpainter = inpainter
635
+ self.num_spans = num_spans
636
+
637
+ def forward(self, start: Tensor, keep_start: bool = False) -> Tensor:
638
+ half_length = start.shape[2] // 2
639
+
640
+ spans = list(start.chunk(chunks=2, dim=-1)) if keep_start else []
641
+ # Inpaint second half from first half
642
+ inpaint = torch.zeros_like(start)
643
+ inpaint[:, :, :half_length] = start[:, :, half_length:]
644
+ inpaint_mask = sequential_mask(like=start, start=half_length)
645
+
646
+ for i in range(self.num_spans):
647
+ # Inpaint second half
648
+ span = self.inpainter(inpaint=inpaint, inpaint_mask=inpaint_mask)
649
+ # Replace first half with generated second half
650
+ second_half = span[:, :, half_length:]
651
+ inpaint[:, :, :half_length] = second_half
652
+ # Save generated span
653
+ spans.append(second_half)
654
+
655
+ return torch.cat(spans, dim=2)
656
+
657
+
658
+ class XDiffusion(nn.Module):
659
+ def __init__(self, type: str, net: nn.Module, **kwargs):
660
+ super().__init__()
661
+
662
+ diffusion_classes = [VDiffusion, KDiffusion, VKDiffusion]
663
+ aliases = [t.alias for t in diffusion_classes] # type: ignore
664
+ message = f"type='{type}' must be one of {*aliases,}"
665
+ assert type in aliases, message
666
+ self.net = net
667
+
668
+ for XDiffusion in diffusion_classes:
669
+ if XDiffusion.alias == type: # type: ignore
670
+ self.diffusion = XDiffusion(net=net, **kwargs)
671
+
672
+ def forward(self, *args, **kwargs) -> Tensor:
673
+ return self.diffusion(*args, **kwargs)
674
+
675
+ def sample(
676
+ self,
677
+ noise: Tensor,
678
+ num_steps: int,
679
+ sigma_schedule: Schedule,
680
+ sampler: Sampler,
681
+ clamp: bool,
682
+ **kwargs,
683
+ ) -> Tensor:
684
+ diffusion_sampler = DiffusionSampler(
685
+ diffusion=self.diffusion,
686
+ sampler=sampler,
687
+ sigma_schedule=sigma_schedule,
688
+ num_steps=num_steps,
689
+ clamp=clamp,
690
+ )
691
+ return diffusion_sampler(noise, **kwargs)
Modules/diffusion/utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ from inspect import isfunction
3
+ from math import ceil, floor, log2, pi
4
+ from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch import Generator, Tensor
10
+ from typing_extensions import TypeGuard
11
+
12
+ T = TypeVar("T")
13
+
14
+
15
+ def exists(val: Optional[T]) -> TypeGuard[T]:
16
+ return val is not None
17
+
18
+
19
+ def iff(condition: bool, value: T) -> Optional[T]:
20
+ return value if condition else None
21
+
22
+
23
+ def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
24
+ return isinstance(obj, list) or isinstance(obj, tuple)
25
+
26
+
27
+ def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
28
+ if exists(val):
29
+ return val
30
+ return d() if isfunction(d) else d
31
+
32
+
33
+ def to_list(val: Union[T, Sequence[T]]) -> List[T]:
34
+ if isinstance(val, tuple):
35
+ return list(val)
36
+ if isinstance(val, list):
37
+ return val
38
+ return [val] # type: ignore
39
+
40
+
41
+ def prod(vals: Sequence[int]) -> int:
42
+ return reduce(lambda x, y: x * y, vals)
43
+
44
+
45
+ def closest_power_2(x: float) -> int:
46
+ exponent = log2(x)
47
+ distance_fn = lambda z: abs(x - 2 ** z) # noqa
48
+ exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
49
+ return 2 ** int(exponent_closest)
50
+
51
+ def rand_bool(shape, proba, device = None):
52
+ if proba == 1:
53
+ return torch.ones(shape, device=device, dtype=torch.bool)
54
+ elif proba == 0:
55
+ return torch.zeros(shape, device=device, dtype=torch.bool)
56
+ else:
57
+ return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
58
+
59
+
60
+ """
61
+ Kwargs Utils
62
+ """
63
+
64
+
65
+ def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
66
+ return_dicts: Tuple[Dict, Dict] = ({}, {})
67
+ for key in d.keys():
68
+ no_prefix = int(not key.startswith(prefix))
69
+ return_dicts[no_prefix][key] = d[key]
70
+ return return_dicts
71
+
72
+
73
+ def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
74
+ kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
75
+ if keep_prefix:
76
+ return kwargs_with_prefix, kwargs
77
+ kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
78
+ return kwargs_no_prefix, kwargs
79
+
80
+
81
+ def prefix_dict(prefix: str, d: Dict) -> Dict:
82
+ return {prefix + str(k): v for k, v in d.items()}
Modules/hifigan.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+ from .utils import init_weights, get_padding
7
+
8
+ import math
9
+ import random
10
+ import numpy as np
11
+
12
+ LRELU_SLOPE = 0.1
13
+
14
+ class AdaIN1d(nn.Module):
15
+ def __init__(self, style_dim, num_features):
16
+ super().__init__()
17
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
18
+ self.fc = nn.Linear(style_dim, num_features*2)
19
+
20
+ def forward(self, x, s):
21
+ h = self.fc(s)
22
+ h = h.view(h.size(0), h.size(1), 1)
23
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
24
+ return (1 + gamma) * self.norm(x) + beta
25
+
26
+ class AdaINResBlock1(torch.nn.Module):
27
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
28
+ super(AdaINResBlock1, self).__init__()
29
+ self.convs1 = nn.ModuleList([
30
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
31
+ padding=get_padding(kernel_size, dilation[0]))),
32
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
33
+ padding=get_padding(kernel_size, dilation[1]))),
34
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
35
+ padding=get_padding(kernel_size, dilation[2])))
36
+ ])
37
+ self.convs1.apply(init_weights)
38
+
39
+ self.convs2 = nn.ModuleList([
40
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
41
+ padding=get_padding(kernel_size, 1))),
42
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
43
+ padding=get_padding(kernel_size, 1))),
44
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
45
+ padding=get_padding(kernel_size, 1)))
46
+ ])
47
+ self.convs2.apply(init_weights)
48
+
49
+ self.adain1 = nn.ModuleList([
50
+ AdaIN1d(style_dim, channels),
51
+ AdaIN1d(style_dim, channels),
52
+ AdaIN1d(style_dim, channels),
53
+ ])
54
+
55
+ self.adain2 = nn.ModuleList([
56
+ AdaIN1d(style_dim, channels),
57
+ AdaIN1d(style_dim, channels),
58
+ AdaIN1d(style_dim, channels),
59
+ ])
60
+
61
+ self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
62
+ self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
63
+
64
+
65
+ def forward(self, x, s):
66
+ for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
67
+ xt = n1(x, s)
68
+ xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
69
+ xt = c1(xt)
70
+ xt = n2(xt, s)
71
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
72
+ xt = c2(xt)
73
+ x = xt + x
74
+ return x
75
+
76
+ def remove_weight_norm(self):
77
+ for l in self.convs1:
78
+ remove_weight_norm(l)
79
+ for l in self.convs2:
80
+ remove_weight_norm(l)
81
+
82
+ class SineGen(torch.nn.Module):
83
+ """ Definition of sine generator
84
+ SineGen(samp_rate, harmonic_num = 0,
85
+ sine_amp = 0.1, noise_std = 0.003,
86
+ voiced_threshold = 0,
87
+ flag_for_pulse=False)
88
+ samp_rate: sampling rate in Hz
89
+ harmonic_num: number of harmonic overtones (default 0)
90
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
91
+ noise_std: std of Gaussian noise (default 0.003)
92
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
93
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
94
+ Note: when flag_for_pulse is True, the first time step of a voiced
95
+ segment is always sin(np.pi) or cos(0)
96
+ """
97
+
98
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
99
+ sine_amp=0.1, noise_std=0.003,
100
+ voiced_threshold=0,
101
+ flag_for_pulse=False):
102
+ super(SineGen, self).__init__()
103
+ self.sine_amp = sine_amp
104
+ self.noise_std = noise_std
105
+ self.harmonic_num = harmonic_num
106
+ self.dim = self.harmonic_num + 1
107
+ self.sampling_rate = samp_rate
108
+ self.voiced_threshold = voiced_threshold
109
+ self.flag_for_pulse = flag_for_pulse
110
+ self.upsample_scale = upsample_scale
111
+
112
+ def _f02uv(self, f0):
113
+ # generate uv signal
114
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
115
+ return uv
116
+
117
+ def _f02sine(self, f0_values):
118
+ """ f0_values: (batchsize, length, dim)
119
+ where dim indicates fundamental tone and overtones
120
+ """
121
+ # convert to F0 in rad. The interger part n can be ignored
122
+ # because 2 * np.pi * n doesn't affect phase
123
+ rad_values = (f0_values / self.sampling_rate) % 1
124
+
125
+ # initial phase noise (no noise for fundamental component)
126
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
127
+ device=f0_values.device)
128
+ rand_ini[:, 0] = 0
129
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
130
+
131
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
132
+ if not self.flag_for_pulse:
133
+ # # for normal case
134
+
135
+ # # To prevent torch.cumsum numerical overflow,
136
+ # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
137
+ # # Buffer tmp_over_one_idx indicates the time step to add -1.
138
+ # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
139
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
140
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
141
+ # cumsum_shift = torch.zeros_like(rad_values)
142
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
143
+
144
+ # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
145
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
146
+ scale_factor=1/self.upsample_scale,
147
+ mode="linear").transpose(1, 2)
148
+
149
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
150
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
151
+ # cumsum_shift = torch.zeros_like(rad_values)
152
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
153
+
154
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
155
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
156
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
157
+ sines = torch.sin(phase)
158
+
159
+ else:
160
+ # If necessary, make sure that the first time step of every
161
+ # voiced segments is sin(pi) or cos(0)
162
+ # This is used for pulse-train generation
163
+
164
+ # identify the last time step in unvoiced segments
165
+ uv = self._f02uv(f0_values)
166
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
167
+ uv_1[:, -1, :] = 1
168
+ u_loc = (uv < 1) * (uv_1 > 0)
169
+
170
+ # get the instantanouse phase
171
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
172
+ # different batch needs to be processed differently
173
+ for idx in range(f0_values.shape[0]):
174
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
175
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
176
+ # stores the accumulation of i.phase within
177
+ # each voiced segments
178
+ tmp_cumsum[idx, :, :] = 0
179
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
180
+
181
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
182
+ # within the previous voiced segment.
183
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
184
+
185
+ # get the sines
186
+ sines = torch.cos(i_phase * 2 * np.pi)
187
+ return sines
188
+
189
+ def forward(self, f0):
190
+ """ sine_tensor, uv = forward(f0)
191
+ input F0: tensor(batchsize=1, length, dim=1)
192
+ f0 for unvoiced steps should be 0
193
+ output sine_tensor: tensor(batchsize=1, length, dim)
194
+ output uv: tensor(batchsize=1, length, 1)
195
+ """
196
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
197
+ device=f0.device)
198
+ # fundamental component
199
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
200
+
201
+ # generate sine waveforms
202
+ sine_waves = self._f02sine(fn) * self.sine_amp
203
+
204
+ # generate uv signal
205
+ # uv = torch.ones(f0.shape)
206
+ # uv = uv * (f0 > self.voiced_threshold)
207
+ uv = self._f02uv(f0)
208
+
209
+ # noise: for unvoiced should be similar to sine_amp
210
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
211
+ # . for voiced regions is self.noise_std
212
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
213
+ noise = noise_amp * torch.randn_like(sine_waves)
214
+
215
+ # first: set the unvoiced part to 0 by uv
216
+ # then: additive noise
217
+ sine_waves = sine_waves * uv + noise
218
+ return sine_waves, uv, noise
219
+
220
+
221
+ class SourceModuleHnNSF(torch.nn.Module):
222
+ """ SourceModule for hn-nsf
223
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
224
+ add_noise_std=0.003, voiced_threshod=0)
225
+ sampling_rate: sampling_rate in Hz
226
+ harmonic_num: number of harmonic above F0 (default: 0)
227
+ sine_amp: amplitude of sine source signal (default: 0.1)
228
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
229
+ note that amplitude of noise in unvoiced is decided
230
+ by sine_amp
231
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
232
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
233
+ F0_sampled (batchsize, length, 1)
234
+ Sine_source (batchsize, length, 1)
235
+ noise_source (batchsize, length 1)
236
+ uv (batchsize, length, 1)
237
+ """
238
+
239
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
240
+ add_noise_std=0.003, voiced_threshod=0):
241
+ super(SourceModuleHnNSF, self).__init__()
242
+
243
+ self.sine_amp = sine_amp
244
+ self.noise_std = add_noise_std
245
+
246
+ # to produce sine waveforms
247
+ self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
248
+ sine_amp, add_noise_std, voiced_threshod)
249
+
250
+ # to merge source harmonics into a single excitation
251
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
252
+ self.l_tanh = torch.nn.Tanh()
253
+
254
+ def forward(self, x):
255
+ """
256
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
257
+ F0_sampled (batchsize, length, 1)
258
+ Sine_source (batchsize, length, 1)
259
+ noise_source (batchsize, length 1)
260
+ """
261
+ # source for harmonic branch
262
+ with torch.no_grad():
263
+ sine_wavs, uv, _ = self.l_sin_gen(x)
264
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
265
+
266
+ # source for noise branch, in the same shape as uv
267
+ noise = torch.randn_like(uv) * self.sine_amp / 3
268
+ return sine_merge, noise, uv
269
+ def padDiff(x):
270
+ return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
271
+
272
+ class Generator(torch.nn.Module):
273
+ def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes):
274
+ super(Generator, self).__init__()
275
+ self.num_kernels = len(resblock_kernel_sizes)
276
+ self.num_upsamples = len(upsample_rates)
277
+ resblock = AdaINResBlock1
278
+
279
+ self.m_source = SourceModuleHnNSF(
280
+ sampling_rate=24000,
281
+ upsample_scale=np.prod(upsample_rates),
282
+ harmonic_num=8, voiced_threshod=10)
283
+
284
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
285
+ self.noise_convs = nn.ModuleList()
286
+ self.ups = nn.ModuleList()
287
+ self.noise_res = nn.ModuleList()
288
+
289
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
290
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
291
+
292
+ self.ups.append(weight_norm(ConvTranspose1d(upsample_initial_channel//(2**i),
293
+ upsample_initial_channel//(2**(i+1)),
294
+ k, u, padding=(u//2 + u%2), output_padding=u%2)))
295
+
296
+ if i + 1 < len(upsample_rates): #
297
+ stride_f0 = np.prod(upsample_rates[i + 1:])
298
+ self.noise_convs.append(Conv1d(
299
+ 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
300
+ self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
301
+ else:
302
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
303
+ self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
304
+
305
+ self.resblocks = nn.ModuleList()
306
+
307
+ self.alphas = nn.ParameterList()
308
+ self.alphas.append(nn.Parameter(torch.ones(1, upsample_initial_channel, 1)))
309
+
310
+ for i in range(len(self.ups)):
311
+ ch = upsample_initial_channel//(2**(i+1))
312
+ self.alphas.append(nn.Parameter(torch.ones(1, ch, 1)))
313
+
314
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
315
+ self.resblocks.append(resblock(ch, k, d, style_dim))
316
+
317
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
318
+ self.ups.apply(init_weights)
319
+ self.conv_post.apply(init_weights)
320
+
321
+ def forward(self, x, s, f0):
322
+
323
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
324
+
325
+ har_source, noi_source, uv = self.m_source(f0)
326
+ har_source = har_source.transpose(1, 2)
327
+
328
+ for i in range(self.num_upsamples):
329
+ x = x + (1 / self.alphas[i]) * (torch.sin(self.alphas[i] * x) ** 2)
330
+ x_source = self.noise_convs[i](har_source)
331
+ x_source = self.noise_res[i](x_source, s)
332
+
333
+ x = self.ups[i](x)
334
+ x = x + x_source
335
+
336
+ xs = None
337
+ for j in range(self.num_kernels):
338
+ if xs is None:
339
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
340
+ else:
341
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
342
+ x = xs / self.num_kernels
343
+ x = x + (1 / self.alphas[i+1]) * (torch.sin(self.alphas[i+1] * x) ** 2)
344
+ x = self.conv_post(x)
345
+ x = torch.tanh(x)
346
+
347
+ return x
348
+
349
+ def remove_weight_norm(self):
350
+ print('Removing weight norm...')
351
+ for l in self.ups:
352
+ remove_weight_norm(l)
353
+ for l in self.resblocks:
354
+ l.remove_weight_norm()
355
+ remove_weight_norm(self.conv_pre)
356
+ remove_weight_norm(self.conv_post)
357
+
358
+
359
+ class AdainResBlk1d(nn.Module):
360
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
361
+ upsample='none', dropout_p=0.0):
362
+ super().__init__()
363
+ self.actv = actv
364
+ self.upsample_type = upsample
365
+ self.upsample = UpSample1d(upsample)
366
+ self.learned_sc = dim_in != dim_out
367
+ self._build_weights(dim_in, dim_out, style_dim)
368
+ self.dropout = nn.Dropout(dropout_p)
369
+
370
+ if upsample == 'none':
371
+ self.pool = nn.Identity()
372
+ else:
373
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
374
+
375
+
376
+ def _build_weights(self, dim_in, dim_out, style_dim):
377
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
378
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
379
+ self.norm1 = AdaIN1d(style_dim, dim_in)
380
+ self.norm2 = AdaIN1d(style_dim, dim_out)
381
+ if self.learned_sc:
382
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
383
+
384
+ def _shortcut(self, x):
385
+ x = self.upsample(x)
386
+ if self.learned_sc:
387
+ x = self.conv1x1(x)
388
+ return x
389
+
390
+ def _residual(self, x, s):
391
+ x = self.norm1(x, s)
392
+ x = self.actv(x)
393
+ x = self.pool(x)
394
+ x = self.conv1(self.dropout(x))
395
+ x = self.norm2(x, s)
396
+ x = self.actv(x)
397
+ x = self.conv2(self.dropout(x))
398
+ return x
399
+
400
+ def forward(self, x, s):
401
+ out = self._residual(x, s)
402
+ out = (out + self._shortcut(x)) / math.sqrt(2)
403
+ return out
404
+
405
+ class UpSample1d(nn.Module):
406
+ def __init__(self, layer_type):
407
+ super().__init__()
408
+ self.layer_type = layer_type
409
+
410
+ def forward(self, x):
411
+ if self.layer_type == 'none':
412
+ return x
413
+ else:
414
+ return F.interpolate(x, scale_factor=2, mode='nearest')
415
+
416
+ class Decoder(nn.Module):
417
+ def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
418
+ resblock_kernel_sizes = [3,7,11],
419
+ upsample_rates = [10,5,3,2],
420
+ upsample_initial_channel=512,
421
+ resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
422
+ upsample_kernel_sizes=[20,10,6,4]):
423
+ super().__init__()
424
+
425
+ self.decode = nn.ModuleList()
426
+
427
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
428
+
429
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
430
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
431
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
432
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
433
+
434
+ self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
435
+
436
+ self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
437
+
438
+ self.asr_res = nn.Sequential(
439
+ weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
440
+ )
441
+
442
+
443
+ self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes)
444
+
445
+
446
+ def forward(self, asr, F0_curve, N, s):
447
+ if self.training:
448
+ downlist = [0, 3, 7]
449
+ F0_down = downlist[random.randint(0, 2)]
450
+ downlist = [0, 3, 7, 15]
451
+ N_down = downlist[random.randint(0, 3)]
452
+ if F0_down:
453
+ F0_curve = nn.functional.conv1d(F0_curve.unsqueeze(1), torch.ones(1, 1, F0_down).to('cuda'), padding=F0_down//2).squeeze(1) / F0_down
454
+ if N_down:
455
+ N = nn.functional.conv1d(N.unsqueeze(1), torch.ones(1, 1, N_down).to('cuda'), padding=N_down//2).squeeze(1) / N_down
456
+
457
+
458
+ F0 = self.F0_conv(F0_curve.unsqueeze(1))
459
+ N = self.N_conv(N.unsqueeze(1))
460
+
461
+ x = torch.cat([asr, F0, N], axis=1)
462
+ x = self.encode(x, s)
463
+
464
+ asr_res = self.asr_res(asr)
465
+
466
+ res = True
467
+ for block in self.decode:
468
+ if res:
469
+ x = torch.cat([x, asr_res, F0, N], axis=1)
470
+ x = block(x, s)
471
+ if block.upsample_type != "none":
472
+ res = False
473
+
474
+ x = self.generator(x, s, F0_curve)
475
+ return x
476
+
477
+
Modules/utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def init_weights(m, mean=0.0, std=0.01):
2
+ classname = m.__class__.__name__
3
+ if classname.find("Conv") != -1:
4
+ m.weight.data.normal_(mean, std)
5
+
6
+
7
+ def apply_weight_norm(m):
8
+ classname = m.__class__.__name__
9
+ if classname.find("Conv") != -1:
10
+ weight_norm(m)
11
+
12
+
13
+ def get_padding(kernel_size, dilation=1):
14
+ return int((kernel_size*dilation - dilation)/2)
Utils/ASR/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
Utils/ASR/config.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "logs/20201006"
2
+ save_freq: 5
3
+ device: "cuda"
4
+ epochs: 180
5
+ batch_size: 64
6
+ pretrained_model: ""
7
+ train_data: "ASRDataset/train_list.txt"
8
+ val_data: "ASRDataset/val_list.txt"
9
+
10
+ dataset_params:
11
+ data_augmentation: false
12
+
13
+ preprocess_parasm:
14
+ sr: 24000
15
+ spect_params:
16
+ n_fft: 2048
17
+ win_length: 1200
18
+ hop_length: 300
19
+ mel_params:
20
+ n_mels: 80
21
+
22
+ model_params:
23
+ input_dim: 80
24
+ hidden_dim: 256
25
+ n_token: 178
26
+ token_embedding_dim: 512
27
+
28
+ optimizer_params:
29
+ lr: 0.0005
Utils/ASR/epoch_00080.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fedd55a1234b0c56e1e8b509c74edf3a5e2f27106a66038a4a946047a775bd6c
3
+ size 94552811
Utils/ASR/layers.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from typing import Optional, Any
5
+ from torch import Tensor
6
+ import torch.nn.functional as F
7
+ import torchaudio
8
+ import torchaudio.functional as audio_F
9
+
10
+ import random
11
+ random.seed(0)
12
+
13
+
14
+ def _get_activation_fn(activ):
15
+ if activ == 'relu':
16
+ return nn.ReLU()
17
+ elif activ == 'lrelu':
18
+ return nn.LeakyReLU(0.2)
19
+ elif activ == 'swish':
20
+ return lambda x: x*torch.sigmoid(x)
21
+ else:
22
+ raise RuntimeError('Unexpected activ type %s, expected [relu, lrelu, swish]' % activ)
23
+
24
+ class LinearNorm(torch.nn.Module):
25
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
26
+ super(LinearNorm, self).__init__()
27
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
28
+
29
+ torch.nn.init.xavier_uniform_(
30
+ self.linear_layer.weight,
31
+ gain=torch.nn.init.calculate_gain(w_init_gain))
32
+
33
+ def forward(self, x):
34
+ return self.linear_layer(x)
35
+
36
+
37
+ class ConvNorm(torch.nn.Module):
38
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
39
+ padding=None, dilation=1, bias=True, w_init_gain='linear', param=None):
40
+ super(ConvNorm, self).__init__()
41
+ if padding is None:
42
+ assert(kernel_size % 2 == 1)
43
+ padding = int(dilation * (kernel_size - 1) / 2)
44
+
45
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
46
+ kernel_size=kernel_size, stride=stride,
47
+ padding=padding, dilation=dilation,
48
+ bias=bias)
49
+
50
+ torch.nn.init.xavier_uniform_(
51
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
52
+
53
+ def forward(self, signal):
54
+ conv_signal = self.conv(signal)
55
+ return conv_signal
56
+
57
+ class CausualConv(nn.Module):
58
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=1, dilation=1, bias=True, w_init_gain='linear', param=None):
59
+ super(CausualConv, self).__init__()
60
+ if padding is None:
61
+ assert(kernel_size % 2 == 1)
62
+ padding = int(dilation * (kernel_size - 1) / 2) * 2
63
+ else:
64
+ self.padding = padding * 2
65
+ self.conv = nn.Conv1d(in_channels, out_channels,
66
+ kernel_size=kernel_size, stride=stride,
67
+ padding=self.padding,
68
+ dilation=dilation,
69
+ bias=bias)
70
+
71
+ torch.nn.init.xavier_uniform_(
72
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
73
+
74
+ def forward(self, x):
75
+ x = self.conv(x)
76
+ x = x[:, :, :-self.padding]
77
+ return x
78
+
79
+ class CausualBlock(nn.Module):
80
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='lrelu'):
81
+ super(CausualBlock, self).__init__()
82
+ self.blocks = nn.ModuleList([
83
+ self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
84
+ for i in range(n_conv)])
85
+
86
+ def forward(self, x):
87
+ for block in self.blocks:
88
+ res = x
89
+ x = block(x)
90
+ x += res
91
+ return x
92
+
93
+ def _get_conv(self, hidden_dim, dilation, activ='lrelu', dropout_p=0.2):
94
+ layers = [
95
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
96
+ _get_activation_fn(activ),
97
+ nn.BatchNorm1d(hidden_dim),
98
+ nn.Dropout(p=dropout_p),
99
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
100
+ _get_activation_fn(activ),
101
+ nn.Dropout(p=dropout_p)
102
+ ]
103
+ return nn.Sequential(*layers)
104
+
105
+ class ConvBlock(nn.Module):
106
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'):
107
+ super().__init__()
108
+ self._n_groups = 8
109
+ self.blocks = nn.ModuleList([
110
+ self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
111
+ for i in range(n_conv)])
112
+
113
+
114
+ def forward(self, x):
115
+ for block in self.blocks:
116
+ res = x
117
+ x = block(x)
118
+ x += res
119
+ return x
120
+
121
+ def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2):
122
+ layers = [
123
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
124
+ _get_activation_fn(activ),
125
+ nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
126
+ nn.Dropout(p=dropout_p),
127
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
128
+ _get_activation_fn(activ),
129
+ nn.Dropout(p=dropout_p)
130
+ ]
131
+ return nn.Sequential(*layers)
132
+
133
+ class LocationLayer(nn.Module):
134
+ def __init__(self, attention_n_filters, attention_kernel_size,
135
+ attention_dim):
136
+ super(LocationLayer, self).__init__()
137
+ padding = int((attention_kernel_size - 1) / 2)
138
+ self.location_conv = ConvNorm(2, attention_n_filters,
139
+ kernel_size=attention_kernel_size,
140
+ padding=padding, bias=False, stride=1,
141
+ dilation=1)
142
+ self.location_dense = LinearNorm(attention_n_filters, attention_dim,
143
+ bias=False, w_init_gain='tanh')
144
+
145
+ def forward(self, attention_weights_cat):
146
+ processed_attention = self.location_conv(attention_weights_cat)
147
+ processed_attention = processed_attention.transpose(1, 2)
148
+ processed_attention = self.location_dense(processed_attention)
149
+ return processed_attention
150
+
151
+
152
+ class Attention(nn.Module):
153
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
154
+ attention_location_n_filters, attention_location_kernel_size):
155
+ super(Attention, self).__init__()
156
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
157
+ bias=False, w_init_gain='tanh')
158
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
159
+ w_init_gain='tanh')
160
+ self.v = LinearNorm(attention_dim, 1, bias=False)
161
+ self.location_layer = LocationLayer(attention_location_n_filters,
162
+ attention_location_kernel_size,
163
+ attention_dim)
164
+ self.score_mask_value = -float("inf")
165
+
166
+ def get_alignment_energies(self, query, processed_memory,
167
+ attention_weights_cat):
168
+ """
169
+ PARAMS
170
+ ------
171
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
172
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
173
+ attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
174
+ RETURNS
175
+ -------
176
+ alignment (batch, max_time)
177
+ """
178
+
179
+ processed_query = self.query_layer(query.unsqueeze(1))
180
+ processed_attention_weights = self.location_layer(attention_weights_cat)
181
+ energies = self.v(torch.tanh(
182
+ processed_query + processed_attention_weights + processed_memory))
183
+
184
+ energies = energies.squeeze(-1)
185
+ return energies
186
+
187
+ def forward(self, attention_hidden_state, memory, processed_memory,
188
+ attention_weights_cat, mask):
189
+ """
190
+ PARAMS
191
+ ------
192
+ attention_hidden_state: attention rnn last output
193
+ memory: encoder outputs
194
+ processed_memory: processed encoder outputs
195
+ attention_weights_cat: previous and cummulative attention weights
196
+ mask: binary mask for padded data
197
+ """
198
+ alignment = self.get_alignment_energies(
199
+ attention_hidden_state, processed_memory, attention_weights_cat)
200
+
201
+ if mask is not None:
202
+ alignment.data.masked_fill_(mask, self.score_mask_value)
203
+
204
+ attention_weights = F.softmax(alignment, dim=1)
205
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
206
+ attention_context = attention_context.squeeze(1)
207
+
208
+ return attention_context, attention_weights
209
+
210
+
211
+ class ForwardAttentionV2(nn.Module):
212
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
213
+ attention_location_n_filters, attention_location_kernel_size):
214
+ super(ForwardAttentionV2, self).__init__()
215
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
216
+ bias=False, w_init_gain='tanh')
217
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
218
+ w_init_gain='tanh')
219
+ self.v = LinearNorm(attention_dim, 1, bias=False)
220
+ self.location_layer = LocationLayer(attention_location_n_filters,
221
+ attention_location_kernel_size,
222
+ attention_dim)
223
+ self.score_mask_value = -float(1e20)
224
+
225
+ def get_alignment_energies(self, query, processed_memory,
226
+ attention_weights_cat):
227
+ """
228
+ PARAMS
229
+ ------
230
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
231
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
232
+ attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
233
+ RETURNS
234
+ -------
235
+ alignment (batch, max_time)
236
+ """
237
+
238
+ processed_query = self.query_layer(query.unsqueeze(1))
239
+ processed_attention_weights = self.location_layer(attention_weights_cat)
240
+ energies = self.v(torch.tanh(
241
+ processed_query + processed_attention_weights + processed_memory))
242
+
243
+ energies = energies.squeeze(-1)
244
+ return energies
245
+
246
+ def forward(self, attention_hidden_state, memory, processed_memory,
247
+ attention_weights_cat, mask, log_alpha):
248
+ """
249
+ PARAMS
250
+ ------
251
+ attention_hidden_state: attention rnn last output
252
+ memory: encoder outputs
253
+ processed_memory: processed encoder outputs
254
+ attention_weights_cat: previous and cummulative attention weights
255
+ mask: binary mask for padded data
256
+ """
257
+ log_energy = self.get_alignment_energies(
258
+ attention_hidden_state, processed_memory, attention_weights_cat)
259
+
260
+ #log_energy =
261
+
262
+ if mask is not None:
263
+ log_energy.data.masked_fill_(mask, self.score_mask_value)
264
+
265
+ #attention_weights = F.softmax(alignment, dim=1)
266
+
267
+ #content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
268
+ #log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
269
+
270
+ #log_total_score = log_alpha + content_score
271
+
272
+ #previous_attention_weights = attention_weights_cat[:,0,:]
273
+
274
+ log_alpha_shift_padded = []
275
+ max_time = log_energy.size(1)
276
+ for sft in range(2):
277
+ shifted = log_alpha[:,:max_time-sft]
278
+ shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value)
279
+ log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
280
+
281
+ biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2)
282
+
283
+ log_alpha_new = biased + log_energy
284
+
285
+ attention_weights = F.softmax(log_alpha_new, dim=1)
286
+
287
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
288
+ attention_context = attention_context.squeeze(1)
289
+
290
+ return attention_context, attention_weights, log_alpha_new
291
+
292
+
293
+ class PhaseShuffle2d(nn.Module):
294
+ def __init__(self, n=2):
295
+ super(PhaseShuffle2d, self).__init__()
296
+ self.n = n
297
+ self.random = random.Random(1)
298
+
299
+ def forward(self, x, move=None):
300
+ # x.size = (B, C, M, L)
301
+ if move is None:
302
+ move = self.random.randint(-self.n, self.n)
303
+
304
+ if move == 0:
305
+ return x
306
+ else:
307
+ left = x[:, :, :, :move]
308
+ right = x[:, :, :, move:]
309
+ shuffled = torch.cat([right, left], dim=3)
310
+ return shuffled
311
+
312
+ class PhaseShuffle1d(nn.Module):
313
+ def __init__(self, n=2):
314
+ super(PhaseShuffle1d, self).__init__()
315
+ self.n = n
316
+ self.random = random.Random(1)
317
+
318
+ def forward(self, x, move=None):
319
+ # x.size = (B, C, M, L)
320
+ if move is None:
321
+ move = self.random.randint(-self.n, self.n)
322
+
323
+ if move == 0:
324
+ return x
325
+ else:
326
+ left = x[:, :, :move]
327
+ right = x[:, :, move:]
328
+ shuffled = torch.cat([right, left], dim=2)
329
+
330
+ return shuffled
331
+
332
+ class MFCC(nn.Module):
333
+ def __init__(self, n_mfcc=40, n_mels=80):
334
+ super(MFCC, self).__init__()
335
+ self.n_mfcc = n_mfcc
336
+ self.n_mels = n_mels
337
+ self.norm = 'ortho'
338
+ dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
339
+ self.register_buffer('dct_mat', dct_mat)
340
+
341
+ def forward(self, mel_specgram):
342
+ if len(mel_specgram.shape) == 2:
343
+ mel_specgram = mel_specgram.unsqueeze(0)
344
+ unsqueezed = True
345
+ else:
346
+ unsqueezed = False
347
+ # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
348
+ # -> (channel, time, n_mfcc).tranpose(...)
349
+ mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
350
+
351
+ # unpack batch
352
+ if unsqueezed:
353
+ mfcc = mfcc.squeeze(0)
354
+ return mfcc
Utils/ASR/models.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import TransformerEncoder
5
+ import torch.nn.functional as F
6
+ from .layers import MFCC, Attention, LinearNorm, ConvNorm, ConvBlock
7
+
8
+ class ASRCNN(nn.Module):
9
+ def __init__(self,
10
+ input_dim=80,
11
+ hidden_dim=256,
12
+ n_token=35,
13
+ n_layers=6,
14
+ token_embedding_dim=256,
15
+
16
+ ):
17
+ super().__init__()
18
+ self.n_token = n_token
19
+ self.n_down = 1
20
+ self.to_mfcc = MFCC()
21
+ self.init_cnn = ConvNorm(input_dim//2, hidden_dim, kernel_size=7, padding=3, stride=2)
22
+ self.cnns = nn.Sequential(
23
+ *[nn.Sequential(
24
+ ConvBlock(hidden_dim),
25
+ nn.GroupNorm(num_groups=1, num_channels=hidden_dim)
26
+ ) for n in range(n_layers)])
27
+ self.projection = ConvNorm(hidden_dim, hidden_dim // 2)
28
+ self.ctc_linear = nn.Sequential(
29
+ LinearNorm(hidden_dim//2, hidden_dim),
30
+ nn.ReLU(),
31
+ LinearNorm(hidden_dim, n_token))
32
+ self.asr_s2s = ASRS2S(
33
+ embedding_dim=token_embedding_dim,
34
+ hidden_dim=hidden_dim//2,
35
+ n_token=n_token)
36
+
37
+ def forward(self, x, src_key_padding_mask=None, text_input=None):
38
+ x = self.to_mfcc(x)
39
+ x = self.init_cnn(x)
40
+ x = self.cnns(x)
41
+ x = self.projection(x)
42
+ x = x.transpose(1, 2)
43
+ ctc_logit = self.ctc_linear(x)
44
+ if text_input is not None:
45
+ _, s2s_logit, s2s_attn = self.asr_s2s(x, src_key_padding_mask, text_input)
46
+ return ctc_logit, s2s_logit, s2s_attn
47
+ else:
48
+ return ctc_logit
49
+
50
+ def get_feature(self, x):
51
+ x = self.to_mfcc(x.squeeze(1))
52
+ x = self.init_cnn(x)
53
+ x = self.cnns(x)
54
+ x = self.projection(x)
55
+ return x
56
+
57
+ def length_to_mask(self, lengths):
58
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
59
+ mask = torch.gt(mask+1, lengths.unsqueeze(1)).to(lengths.device)
60
+ return mask
61
+
62
+ def get_future_mask(self, out_length, unmask_future_steps=0):
63
+ """
64
+ Args:
65
+ out_length (int): returned mask shape is (out_length, out_length).
66
+ unmask_futre_steps (int): unmasking future step size.
67
+ Return:
68
+ mask (torch.BoolTensor): mask future timesteps mask[i, j] = True if i > j + unmask_future_steps else False
69
+ """
70
+ index_tensor = torch.arange(out_length).unsqueeze(0).expand(out_length, -1)
71
+ mask = torch.gt(index_tensor, index_tensor.T + unmask_future_steps)
72
+ return mask
73
+
74
+ class ASRS2S(nn.Module):
75
+ def __init__(self,
76
+ embedding_dim=256,
77
+ hidden_dim=512,
78
+ n_location_filters=32,
79
+ location_kernel_size=63,
80
+ n_token=40):
81
+ super(ASRS2S, self).__init__()
82
+ self.embedding = nn.Embedding(n_token, embedding_dim)
83
+ val_range = math.sqrt(6 / hidden_dim)
84
+ self.embedding.weight.data.uniform_(-val_range, val_range)
85
+
86
+ self.decoder_rnn_dim = hidden_dim
87
+ self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token)
88
+ self.attention_layer = Attention(
89
+ self.decoder_rnn_dim,
90
+ hidden_dim,
91
+ hidden_dim,
92
+ n_location_filters,
93
+ location_kernel_size
94
+ )
95
+ self.decoder_rnn = nn.LSTMCell(self.decoder_rnn_dim + embedding_dim, self.decoder_rnn_dim)
96
+ self.project_to_hidden = nn.Sequential(
97
+ LinearNorm(self.decoder_rnn_dim * 2, hidden_dim),
98
+ nn.Tanh())
99
+ self.sos = 1
100
+ self.eos = 2
101
+
102
+ def initialize_decoder_states(self, memory, mask):
103
+ """
104
+ moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
105
+ """
106
+ B, L, H = memory.shape
107
+ self.decoder_hidden = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
108
+ self.decoder_cell = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
109
+ self.attention_weights = torch.zeros((B, L)).type_as(memory)
110
+ self.attention_weights_cum = torch.zeros((B, L)).type_as(memory)
111
+ self.attention_context = torch.zeros((B, H)).type_as(memory)
112
+ self.memory = memory
113
+ self.processed_memory = self.attention_layer.memory_layer(memory)
114
+ self.mask = mask
115
+ self.unk_index = 3
116
+ self.random_mask = 0.1
117
+
118
+ def forward(self, memory, memory_mask, text_input):
119
+ """
120
+ moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
121
+ moemory_mask.shape = (B, L, )
122
+ texts_input.shape = (B, T)
123
+ """
124
+ self.initialize_decoder_states(memory, memory_mask)
125
+ # text random mask
126
+ random_mask = (torch.rand(text_input.shape) < self.random_mask).to(text_input.device)
127
+ _text_input = text_input.clone()
128
+ _text_input.masked_fill_(random_mask, self.unk_index)
129
+ decoder_inputs = self.embedding(_text_input).transpose(0, 1) # -> [T, B, channel]
130
+ start_embedding = self.embedding(
131
+ torch.LongTensor([self.sos]*decoder_inputs.size(1)).to(decoder_inputs.device))
132
+ decoder_inputs = torch.cat((start_embedding.unsqueeze(0), decoder_inputs), dim=0)
133
+
134
+ hidden_outputs, logit_outputs, alignments = [], [], []
135
+ while len(hidden_outputs) < decoder_inputs.size(0):
136
+
137
+ decoder_input = decoder_inputs[len(hidden_outputs)]
138
+ hidden, logit, attention_weights = self.decode(decoder_input)
139
+ hidden_outputs += [hidden]
140
+ logit_outputs += [logit]
141
+ alignments += [attention_weights]
142
+
143
+ hidden_outputs, logit_outputs, alignments = \
144
+ self.parse_decoder_outputs(
145
+ hidden_outputs, logit_outputs, alignments)
146
+
147
+ return hidden_outputs, logit_outputs, alignments
148
+
149
+
150
+ def decode(self, decoder_input):
151
+
152
+ cell_input = torch.cat((decoder_input, self.attention_context), -1)
153
+ self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
154
+ cell_input,
155
+ (self.decoder_hidden, self.decoder_cell))
156
+
157
+ attention_weights_cat = torch.cat(
158
+ (self.attention_weights.unsqueeze(1),
159
+ self.attention_weights_cum.unsqueeze(1)),dim=1)
160
+
161
+ self.attention_context, self.attention_weights = self.attention_layer(
162
+ self.decoder_hidden,
163
+ self.memory,
164
+ self.processed_memory,
165
+ attention_weights_cat,
166
+ self.mask)
167
+
168
+ self.attention_weights_cum += self.attention_weights
169
+
170
+ hidden_and_context = torch.cat((self.decoder_hidden, self.attention_context), -1)
171
+ hidden = self.project_to_hidden(hidden_and_context)
172
+
173
+ # dropout to increasing g
174
+ logit = self.project_to_n_symbols(F.dropout(hidden, 0.5, self.training))
175
+
176
+ return hidden, logit, self.attention_weights
177
+
178
+ def parse_decoder_outputs(self, hidden, logit, alignments):
179
+
180
+ # -> [B, T_out + 1, max_time]
181
+ alignments = torch.stack(alignments).transpose(0,1)
182
+ # [T_out + 1, B, n_symbols] -> [B, T_out + 1, n_symbols]
183
+ logit = torch.stack(logit).transpose(0, 1).contiguous()
184
+ hidden = torch.stack(hidden).transpose(0, 1).contiguous()
185
+
186
+ return hidden, logit, alignments
Utils/JDC/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
Utils/JDC/bst.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54dc94364b97e18ac1dfa6287714ed121248cfaac4cfd39d061c6e0a089ef169
3
+ size 21029926
Utils/JDC/model.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of model from:
3
+ Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
4
+ Convolutional Recurrent Neural Networks" (2019)
5
+ Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
6
+ """
7
+ import torch
8
+ from torch import nn
9
+
10
+ class JDCNet(nn.Module):
11
+ """
12
+ Joint Detection and Classification Network model for singing voice melody.
13
+ """
14
+ def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01):
15
+ super().__init__()
16
+ self.num_class = num_class
17
+
18
+ # input = (b, 1, 31, 513), b = batch size
19
+ self.conv_block = nn.Sequential(
20
+ nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False), # out: (b, 64, 31, 513)
21
+ nn.BatchNorm2d(num_features=64),
22
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
23
+ nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513)
24
+ )
25
+
26
+ # res blocks
27
+ self.res_block1 = ResBlock(in_channels=64, out_channels=128) # (b, 128, 31, 128)
28
+ self.res_block2 = ResBlock(in_channels=128, out_channels=192) # (b, 192, 31, 32)
29
+ self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8)
30
+
31
+ # pool block
32
+ self.pool_block = nn.Sequential(
33
+ nn.BatchNorm2d(num_features=256),
34
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
35
+ nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2)
36
+ nn.Dropout(p=0.2),
37
+ )
38
+
39
+ # maxpool layers (for auxiliary network inputs)
40
+ # in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
41
+ self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40))
42
+ # in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
43
+ self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20))
44
+ # in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
45
+ self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10))
46
+
47
+ # in = (b, 640, 31, 2), out = (b, 256, 31, 2)
48
+ self.detector_conv = nn.Sequential(
49
+ nn.Conv2d(640, 256, 1, bias=False),
50
+ nn.BatchNorm2d(256),
51
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
52
+ nn.Dropout(p=0.2),
53
+ )
54
+
55
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
56
+ self.bilstm_classifier = nn.LSTM(
57
+ input_size=512, hidden_size=256,
58
+ batch_first=True, bidirectional=True) # (b, 31, 512)
59
+
60
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
61
+ self.bilstm_detector = nn.LSTM(
62
+ input_size=512, hidden_size=256,
63
+ batch_first=True, bidirectional=True) # (b, 31, 512)
64
+
65
+ # input: (b * 31, 512)
66
+ self.classifier = nn.Linear(in_features=512, out_features=self.num_class) # (b * 31, num_class)
67
+
68
+ # input: (b * 31, 512)
69
+ self.detector = nn.Linear(in_features=512, out_features=2) # (b * 31, 2) - binary classifier
70
+
71
+ # initialize weights
72
+ self.apply(self.init_weights)
73
+
74
+ def get_feature_GAN(self, x):
75
+ seq_len = x.shape[-2]
76
+ x = x.float().transpose(-1, -2)
77
+
78
+ convblock_out = self.conv_block(x)
79
+
80
+ resblock1_out = self.res_block1(convblock_out)
81
+ resblock2_out = self.res_block2(resblock1_out)
82
+ resblock3_out = self.res_block3(resblock2_out)
83
+ poolblock_out = self.pool_block[0](resblock3_out)
84
+ poolblock_out = self.pool_block[1](poolblock_out)
85
+
86
+ return poolblock_out.transpose(-1, -2)
87
+
88
+ def get_feature(self, x):
89
+ seq_len = x.shape[-2]
90
+ x = x.float().transpose(-1, -2)
91
+
92
+ convblock_out = self.conv_block(x)
93
+
94
+ resblock1_out = self.res_block1(convblock_out)
95
+ resblock2_out = self.res_block2(resblock1_out)
96
+ resblock3_out = self.res_block3(resblock2_out)
97
+ poolblock_out = self.pool_block[0](resblock3_out)
98
+ poolblock_out = self.pool_block[1](poolblock_out)
99
+
100
+ return self.pool_block[2](poolblock_out)
101
+
102
+ def forward(self, x):
103
+ """
104
+ Returns:
105
+ classification_prediction, detection_prediction
106
+ sizes: (b, 31, 722), (b, 31, 2)
107
+ """
108
+ ###############################
109
+ # forward pass for classifier #
110
+ ###############################
111
+ seq_len = x.shape[-1]
112
+ x = x.float().transpose(-1, -2)
113
+
114
+ convblock_out = self.conv_block(x)
115
+
116
+ resblock1_out = self.res_block1(convblock_out)
117
+ resblock2_out = self.res_block2(resblock1_out)
118
+ resblock3_out = self.res_block3(resblock2_out)
119
+
120
+
121
+ poolblock_out = self.pool_block[0](resblock3_out)
122
+ poolblock_out = self.pool_block[1](poolblock_out)
123
+ GAN_feature = poolblock_out.transpose(-1, -2)
124
+ poolblock_out = self.pool_block[2](poolblock_out)
125
+
126
+ # (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
127
+ classifier_out = poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512))
128
+ classifier_out, _ = self.bilstm_classifier(classifier_out) # ignore the hidden states
129
+
130
+ classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512)
131
+ classifier_out = self.classifier(classifier_out)
132
+ classifier_out = classifier_out.view((-1, seq_len, self.num_class)) # (b, 31, num_class)
133
+
134
+ # sizes: (b, 31, 722), (b, 31, 2)
135
+ # classifier output consists of predicted pitch classes per frame
136
+ # detector output consists of: (isvoice, notvoice) estimates per frame
137
+ return torch.abs(classifier_out.squeeze()), GAN_feature, poolblock_out
138
+
139
+ @staticmethod
140
+ def init_weights(m):
141
+ if isinstance(m, nn.Linear):
142
+ nn.init.kaiming_uniform_(m.weight)
143
+ if m.bias is not None:
144
+ nn.init.constant_(m.bias, 0)
145
+ elif isinstance(m, nn.Conv2d):
146
+ nn.init.xavier_normal_(m.weight)
147
+ elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell):
148
+ for p in m.parameters():
149
+ if p.data is None:
150
+ continue
151
+
152
+ if len(p.shape) >= 2:
153
+ nn.init.orthogonal_(p.data)
154
+ else:
155
+ nn.init.normal_(p.data)
156
+
157
+
158
+ class ResBlock(nn.Module):
159
+ def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01):
160
+ super().__init__()
161
+ self.downsample = in_channels != out_channels
162
+
163
+ # BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
164
+ self.pre_conv = nn.Sequential(
165
+ nn.BatchNorm2d(num_features=in_channels),
166
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
167
+ nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only
168
+ )
169
+
170
+ # conv layers
171
+ self.conv = nn.Sequential(
172
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
173
+ kernel_size=3, padding=1, bias=False),
174
+ nn.BatchNorm2d(out_channels),
175
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
176
+ nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
177
+ )
178
+
179
+ # 1 x 1 convolution layer to match the feature dimensions
180
+ self.conv1by1 = None
181
+ if self.downsample:
182
+ self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
183
+
184
+ def forward(self, x):
185
+ x = self.pre_conv(x)
186
+ if self.downsample:
187
+ x = self.conv(x) + self.conv1by1(x)
188
+ else:
189
+ x = self.conv(x) + x
190
+ return x
Utils/PLBERT/config.yml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Checkpoint"
2
+ mixed_precision: "fp16"
3
+ data_folder: "wikipedia_20220301.en.processed"
4
+ batch_size: 192
5
+ save_interval: 5000
6
+ log_interval: 10
7
+ num_process: 1 # number of GPUs
8
+ num_steps: 1000000
9
+
10
+ dataset_params:
11
+ tokenizer: "transfo-xl-wt103"
12
+ token_separator: " " # token used for phoneme separator (space)
13
+ token_mask: "M" # token used for phoneme mask (M)
14
+ word_separator: 3039 # token used for word separator (<formula>)
15
+ token_maps: "token_maps.pkl" # token map path
16
+
17
+ max_mel_length: 512 # max phoneme length
18
+
19
+ word_mask_prob: 0.15 # probability to mask the entire word
20
+ phoneme_mask_prob: 0.1 # probability to mask each phoneme
21
+ replace_prob: 0.2 # probablity to replace phonemes
22
+
23
+ model_params:
24
+ vocab_size: 178
25
+ hidden_size: 768
26
+ num_attention_heads: 12
27
+ intermediate_size: 2048
28
+ max_position_embeddings: 512
29
+ num_hidden_layers: 12
30
+ dropout: 0.1
Utils/PLBERT/step_1000000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0714ff85804db43e06b3b0ac5749bf90cf206257c6c5916e8a98c5933b4c21e0
3
+ size 25185187
Utils/PLBERT/util.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ from transformers import AlbertConfig, AlbertModel
5
+
6
+ class CustomAlbert(AlbertModel):
7
+ def forward(self, *args, **kwargs):
8
+ # Call the original forward method
9
+ outputs = super().forward(*args, **kwargs)
10
+
11
+ # Only return the last_hidden_state
12
+ return outputs.last_hidden_state
13
+
14
+
15
+ def load_plbert(log_dir):
16
+ config_path = os.path.join(log_dir, "config.yml")
17
+ plbert_config = yaml.safe_load(open(config_path))
18
+
19
+ albert_base_configuration = AlbertConfig(**plbert_config['model_params'])
20
+ bert = CustomAlbert(albert_base_configuration)
21
+
22
+ files = os.listdir(log_dir)
23
+ ckpts = []
24
+ for f in os.listdir(log_dir):
25
+ if f.startswith("step_"): ckpts.append(f)
26
+
27
+ iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
28
+ iters = sorted(iters)[-1]
29
+
30
+ checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".pth", map_location='cpu')
31
+ state_dict = checkpoint['net']
32
+ from collections import OrderedDict
33
+ new_state_dict = OrderedDict()
34
+ for k, v in state_dict.items():
35
+ name = k[7:] # remove `module.`
36
+ if name.startswith('encoder.'):
37
+ name = name[8:] # remove `encoder.`
38
+ new_state_dict[name] = v
39
+ del new_state_dict["embeddings.position_ids"]
40
+ bert.load_state_dict(new_state_dict, strict=False)
41
+
42
+ return bert
Utils/Utils2/ASR/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
Utils/Utils2/ASR/config.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "logs/20201006"
2
+ save_freq: 5
3
+ device: "cuda"
4
+ epochs: 180
5
+ batch_size: 64
6
+ pretrained_model: ""
7
+ train_data: "ASRDataset/train_list.txt"
8
+ val_data: "ASRDataset/val_list.txt"
9
+
10
+ dataset_params:
11
+ data_augmentation: false
12
+
13
+ preprocess_parasm:
14
+ sr: 24000
15
+ spect_params:
16
+ n_fft: 2048
17
+ win_length: 1200
18
+ hop_length: 300
19
+ mel_params:
20
+ n_mels: 80
21
+
22
+ model_params:
23
+ input_dim: 80
24
+ hidden_dim: 256
25
+ n_token: 178
26
+ token_embedding_dim: 512
27
+
28
+ optimizer_params:
29
+ lr: 0.0005
Utils/Utils2/ASR/epoch_00080.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fedd55a1234b0c56e1e8b509c74edf3a5e2f27106a66038a4a946047a775bd6c
3
+ size 94552811
Utils/Utils2/ASR/layers.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from typing import Optional, Any
5
+ from torch import Tensor
6
+ import torch.nn.functional as F
7
+ import torchaudio
8
+ import torchaudio.functional as audio_F
9
+
10
+ import random
11
+ random.seed(0)
12
+
13
+
14
+ def _get_activation_fn(activ):
15
+ if activ == 'relu':
16
+ return nn.ReLU()
17
+ elif activ == 'lrelu':
18
+ return nn.LeakyReLU(0.2)
19
+ elif activ == 'swish':
20
+ return lambda x: x*torch.sigmoid(x)
21
+ else:
22
+ raise RuntimeError('Unexpected activ type %s, expected [relu, lrelu, swish]' % activ)
23
+
24
+ class LinearNorm(torch.nn.Module):
25
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
26
+ super(LinearNorm, self).__init__()
27
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
28
+
29
+ torch.nn.init.xavier_uniform_(
30
+ self.linear_layer.weight,
31
+ gain=torch.nn.init.calculate_gain(w_init_gain))
32
+
33
+ def forward(self, x):
34
+ return self.linear_layer(x)
35
+
36
+
37
+ class ConvNorm(torch.nn.Module):
38
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
39
+ padding=None, dilation=1, bias=True, w_init_gain='linear', param=None):
40
+ super(ConvNorm, self).__init__()
41
+ if padding is None:
42
+ assert(kernel_size % 2 == 1)
43
+ padding = int(dilation * (kernel_size - 1) / 2)
44
+
45
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
46
+ kernel_size=kernel_size, stride=stride,
47
+ padding=padding, dilation=dilation,
48
+ bias=bias)
49
+
50
+ torch.nn.init.xavier_uniform_(
51
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
52
+
53
+ def forward(self, signal):
54
+ conv_signal = self.conv(signal)
55
+ return conv_signal
56
+
57
+ class CausualConv(nn.Module):
58
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=1, dilation=1, bias=True, w_init_gain='linear', param=None):
59
+ super(CausualConv, self).__init__()
60
+ if padding is None:
61
+ assert(kernel_size % 2 == 1)
62
+ padding = int(dilation * (kernel_size - 1) / 2) * 2
63
+ else:
64
+ self.padding = padding * 2
65
+ self.conv = nn.Conv1d(in_channels, out_channels,
66
+ kernel_size=kernel_size, stride=stride,
67
+ padding=self.padding,
68
+ dilation=dilation,
69
+ bias=bias)
70
+
71
+ torch.nn.init.xavier_uniform_(
72
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
73
+
74
+ def forward(self, x):
75
+ x = self.conv(x)
76
+ x = x[:, :, :-self.padding]
77
+ return x
78
+
79
+ class CausualBlock(nn.Module):
80
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='lrelu'):
81
+ super(CausualBlock, self).__init__()
82
+ self.blocks = nn.ModuleList([
83
+ self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
84
+ for i in range(n_conv)])
85
+
86
+ def forward(self, x):
87
+ for block in self.blocks:
88
+ res = x
89
+ x = block(x)
90
+ x += res
91
+ return x
92
+
93
+ def _get_conv(self, hidden_dim, dilation, activ='lrelu', dropout_p=0.2):
94
+ layers = [
95
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
96
+ _get_activation_fn(activ),
97
+ nn.BatchNorm1d(hidden_dim),
98
+ nn.Dropout(p=dropout_p),
99
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
100
+ _get_activation_fn(activ),
101
+ nn.Dropout(p=dropout_p)
102
+ ]
103
+ return nn.Sequential(*layers)
104
+
105
+ class ConvBlock(nn.Module):
106
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'):
107
+ super().__init__()
108
+ self._n_groups = 8
109
+ self.blocks = nn.ModuleList([
110
+ self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
111
+ for i in range(n_conv)])
112
+
113
+
114
+ def forward(self, x):
115
+ for block in self.blocks:
116
+ res = x
117
+ x = block(x)
118
+ x += res
119
+ return x
120
+
121
+ def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2):
122
+ layers = [
123
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
124
+ _get_activation_fn(activ),
125
+ nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
126
+ nn.Dropout(p=dropout_p),
127
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
128
+ _get_activation_fn(activ),
129
+ nn.Dropout(p=dropout_p)
130
+ ]
131
+ return nn.Sequential(*layers)
132
+
133
+ class LocationLayer(nn.Module):
134
+ def __init__(self, attention_n_filters, attention_kernel_size,
135
+ attention_dim):
136
+ super(LocationLayer, self).__init__()
137
+ padding = int((attention_kernel_size - 1) / 2)
138
+ self.location_conv = ConvNorm(2, attention_n_filters,
139
+ kernel_size=attention_kernel_size,
140
+ padding=padding, bias=False, stride=1,
141
+ dilation=1)
142
+ self.location_dense = LinearNorm(attention_n_filters, attention_dim,
143
+ bias=False, w_init_gain='tanh')
144
+
145
+ def forward(self, attention_weights_cat):
146
+ processed_attention = self.location_conv(attention_weights_cat)
147
+ processed_attention = processed_attention.transpose(1, 2)
148
+ processed_attention = self.location_dense(processed_attention)
149
+ return processed_attention
150
+
151
+
152
+ class Attention(nn.Module):
153
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
154
+ attention_location_n_filters, attention_location_kernel_size):
155
+ super(Attention, self).__init__()
156
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
157
+ bias=False, w_init_gain='tanh')
158
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
159
+ w_init_gain='tanh')
160
+ self.v = LinearNorm(attention_dim, 1, bias=False)
161
+ self.location_layer = LocationLayer(attention_location_n_filters,
162
+ attention_location_kernel_size,
163
+ attention_dim)
164
+ self.score_mask_value = -float("inf")
165
+
166
+ def get_alignment_energies(self, query, processed_memory,
167
+ attention_weights_cat):
168
+ """
169
+ PARAMS
170
+ ------
171
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
172
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
173
+ attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
174
+ RETURNS
175
+ -------
176
+ alignment (batch, max_time)
177
+ """
178
+
179
+ processed_query = self.query_layer(query.unsqueeze(1))
180
+ processed_attention_weights = self.location_layer(attention_weights_cat)
181
+ energies = self.v(torch.tanh(
182
+ processed_query + processed_attention_weights + processed_memory))
183
+
184
+ energies = energies.squeeze(-1)
185
+ return energies
186
+
187
+ def forward(self, attention_hidden_state, memory, processed_memory,
188
+ attention_weights_cat, mask):
189
+ """
190
+ PARAMS
191
+ ------
192
+ attention_hidden_state: attention rnn last output
193
+ memory: encoder outputs
194
+ processed_memory: processed encoder outputs
195
+ attention_weights_cat: previous and cummulative attention weights
196
+ mask: binary mask for padded data
197
+ """
198
+ alignment = self.get_alignment_energies(
199
+ attention_hidden_state, processed_memory, attention_weights_cat)
200
+
201
+ if mask is not None:
202
+ alignment.data.masked_fill_(mask, self.score_mask_value)
203
+
204
+ attention_weights = F.softmax(alignment, dim=1)
205
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
206
+ attention_context = attention_context.squeeze(1)
207
+
208
+ return attention_context, attention_weights
209
+
210
+
211
+ class ForwardAttentionV2(nn.Module):
212
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
213
+ attention_location_n_filters, attention_location_kernel_size):
214
+ super(ForwardAttentionV2, self).__init__()
215
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
216
+ bias=False, w_init_gain='tanh')
217
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
218
+ w_init_gain='tanh')
219
+ self.v = LinearNorm(attention_dim, 1, bias=False)
220
+ self.location_layer = LocationLayer(attention_location_n_filters,
221
+ attention_location_kernel_size,
222
+ attention_dim)
223
+ self.score_mask_value = -float(1e20)
224
+
225
+ def get_alignment_energies(self, query, processed_memory,
226
+ attention_weights_cat):
227
+ """
228
+ PARAMS
229
+ ------
230
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
231
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
232
+ attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
233
+ RETURNS
234
+ -------
235
+ alignment (batch, max_time)
236
+ """
237
+
238
+ processed_query = self.query_layer(query.unsqueeze(1))
239
+ processed_attention_weights = self.location_layer(attention_weights_cat)
240
+ energies = self.v(torch.tanh(
241
+ processed_query + processed_attention_weights + processed_memory))
242
+
243
+ energies = energies.squeeze(-1)
244
+ return energies
245
+
246
+ def forward(self, attention_hidden_state, memory, processed_memory,
247
+ attention_weights_cat, mask, log_alpha):
248
+ """
249
+ PARAMS
250
+ ------
251
+ attention_hidden_state: attention rnn last output
252
+ memory: encoder outputs
253
+ processed_memory: processed encoder outputs
254
+ attention_weights_cat: previous and cummulative attention weights
255
+ mask: binary mask for padded data
256
+ """
257
+ log_energy = self.get_alignment_energies(
258
+ attention_hidden_state, processed_memory, attention_weights_cat)
259
+
260
+ #log_energy =
261
+
262
+ if mask is not None:
263
+ log_energy.data.masked_fill_(mask, self.score_mask_value)
264
+
265
+ #attention_weights = F.softmax(alignment, dim=1)
266
+
267
+ #content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
268
+ #log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
269
+
270
+ #log_total_score = log_alpha + content_score
271
+
272
+ #previous_attention_weights = attention_weights_cat[:,0,:]
273
+
274
+ log_alpha_shift_padded = []
275
+ max_time = log_energy.size(1)
276
+ for sft in range(2):
277
+ shifted = log_alpha[:,:max_time-sft]
278
+ shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value)
279
+ log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
280
+
281
+ biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2)
282
+
283
+ log_alpha_new = biased + log_energy
284
+
285
+ attention_weights = F.softmax(log_alpha_new, dim=1)
286
+
287
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
288
+ attention_context = attention_context.squeeze(1)
289
+
290
+ return attention_context, attention_weights, log_alpha_new
291
+
292
+
293
+ class PhaseShuffle2d(nn.Module):
294
+ def __init__(self, n=2):
295
+ super(PhaseShuffle2d, self).__init__()
296
+ self.n = n
297
+ self.random = random.Random(1)
298
+
299
+ def forward(self, x, move=None):
300
+ # x.size = (B, C, M, L)
301
+ if move is None:
302
+ move = self.random.randint(-self.n, self.n)
303
+
304
+ if move == 0:
305
+ return x
306
+ else:
307
+ left = x[:, :, :, :move]
308
+ right = x[:, :, :, move:]
309
+ shuffled = torch.cat([right, left], dim=3)
310
+ return shuffled
311
+
312
+ class PhaseShuffle1d(nn.Module):
313
+ def __init__(self, n=2):
314
+ super(PhaseShuffle1d, self).__init__()
315
+ self.n = n
316
+ self.random = random.Random(1)
317
+
318
+ def forward(self, x, move=None):
319
+ # x.size = (B, C, M, L)
320
+ if move is None:
321
+ move = self.random.randint(-self.n, self.n)
322
+
323
+ if move == 0:
324
+ return x
325
+ else:
326
+ left = x[:, :, :move]
327
+ right = x[:, :, move:]
328
+ shuffled = torch.cat([right, left], dim=2)
329
+
330
+ return shuffled
331
+
332
+ class MFCC(nn.Module):
333
+ def __init__(self, n_mfcc=40, n_mels=80):
334
+ super(MFCC, self).__init__()
335
+ self.n_mfcc = n_mfcc
336
+ self.n_mels = n_mels
337
+ self.norm = 'ortho'
338
+ dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
339
+ self.register_buffer('dct_mat', dct_mat)
340
+
341
+ def forward(self, mel_specgram):
342
+ if len(mel_specgram.shape) == 2:
343
+ mel_specgram = mel_specgram.unsqueeze(0)
344
+ unsqueezed = True
345
+ else:
346
+ unsqueezed = False
347
+ # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
348
+ # -> (channel, time, n_mfcc).tranpose(...)
349
+ mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
350
+
351
+ # unpack batch
352
+ if unsqueezed:
353
+ mfcc = mfcc.squeeze(0)
354
+ return mfcc
Utils/Utils2/ASR/models.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import TransformerEncoder
5
+ import torch.nn.functional as F
6
+ from .layers import MFCC, Attention, LinearNorm, ConvNorm, ConvBlock
7
+
8
+ class ASRCNN(nn.Module):
9
+ def __init__(self,
10
+ input_dim=80,
11
+ hidden_dim=256,
12
+ n_token=35,
13
+ n_layers=6,
14
+ token_embedding_dim=256,
15
+
16
+ ):
17
+ super().__init__()
18
+ self.n_token = n_token
19
+ self.n_down = 1
20
+ self.to_mfcc = MFCC()
21
+ self.init_cnn = ConvNorm(input_dim//2, hidden_dim, kernel_size=7, padding=3, stride=2)
22
+ self.cnns = nn.Sequential(
23
+ *[nn.Sequential(
24
+ ConvBlock(hidden_dim),
25
+ nn.GroupNorm(num_groups=1, num_channels=hidden_dim)
26
+ ) for n in range(n_layers)])
27
+ self.projection = ConvNorm(hidden_dim, hidden_dim // 2)
28
+ self.ctc_linear = nn.Sequential(
29
+ LinearNorm(hidden_dim//2, hidden_dim),
30
+ nn.ReLU(),
31
+ LinearNorm(hidden_dim, n_token))
32
+ self.asr_s2s = ASRS2S(
33
+ embedding_dim=token_embedding_dim,
34
+ hidden_dim=hidden_dim//2,
35
+ n_token=n_token)
36
+
37
+ def forward(self, x, src_key_padding_mask=None, text_input=None):
38
+ x = self.to_mfcc(x)
39
+ x = self.init_cnn(x)
40
+ x = self.cnns(x)
41
+ x = self.projection(x)
42
+ x = x.transpose(1, 2)
43
+ ctc_logit = self.ctc_linear(x)
44
+ if text_input is not None:
45
+ _, s2s_logit, s2s_attn = self.asr_s2s(x, src_key_padding_mask, text_input)
46
+ return ctc_logit, s2s_logit, s2s_attn
47
+ else:
48
+ return ctc_logit
49
+
50
+ def get_feature(self, x):
51
+ x = self.to_mfcc(x.squeeze(1))
52
+ x = self.init_cnn(x)
53
+ x = self.cnns(x)
54
+ x = self.projection(x)
55
+ return x
56
+
57
+ def length_to_mask(self, lengths):
58
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
59
+ mask = torch.gt(mask+1, lengths.unsqueeze(1)).to(lengths.device)
60
+ return mask
61
+
62
+ def get_future_mask(self, out_length, unmask_future_steps=0):
63
+ """
64
+ Args:
65
+ out_length (int): returned mask shape is (out_length, out_length).
66
+ unmask_futre_steps (int): unmasking future step size.
67
+ Return:
68
+ mask (torch.BoolTensor): mask future timesteps mask[i, j] = True if i > j + unmask_future_steps else False
69
+ """
70
+ index_tensor = torch.arange(out_length).unsqueeze(0).expand(out_length, -1)
71
+ mask = torch.gt(index_tensor, index_tensor.T + unmask_future_steps)
72
+ return mask
73
+
74
+ class ASRS2S(nn.Module):
75
+ def __init__(self,
76
+ embedding_dim=256,
77
+ hidden_dim=512,
78
+ n_location_filters=32,
79
+ location_kernel_size=63,
80
+ n_token=40):
81
+ super(ASRS2S, self).__init__()
82
+ self.embedding = nn.Embedding(n_token, embedding_dim)
83
+ val_range = math.sqrt(6 / hidden_dim)
84
+ self.embedding.weight.data.uniform_(-val_range, val_range)
85
+
86
+ self.decoder_rnn_dim = hidden_dim
87
+ self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token)
88
+ self.attention_layer = Attention(
89
+ self.decoder_rnn_dim,
90
+ hidden_dim,
91
+ hidden_dim,
92
+ n_location_filters,
93
+ location_kernel_size
94
+ )
95
+ self.decoder_rnn = nn.LSTMCell(self.decoder_rnn_dim + embedding_dim, self.decoder_rnn_dim)
96
+ self.project_to_hidden = nn.Sequential(
97
+ LinearNorm(self.decoder_rnn_dim * 2, hidden_dim),
98
+ nn.Tanh())
99
+ self.sos = 1
100
+ self.eos = 2
101
+
102
+ def initialize_decoder_states(self, memory, mask):
103
+ """
104
+ moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
105
+ """
106
+ B, L, H = memory.shape
107
+ self.decoder_hidden = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
108
+ self.decoder_cell = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
109
+ self.attention_weights = torch.zeros((B, L)).type_as(memory)
110
+ self.attention_weights_cum = torch.zeros((B, L)).type_as(memory)
111
+ self.attention_context = torch.zeros((B, H)).type_as(memory)
112
+ self.memory = memory
113
+ self.processed_memory = self.attention_layer.memory_layer(memory)
114
+ self.mask = mask
115
+ self.unk_index = 3
116
+ self.random_mask = 0.1
117
+
118
+ def forward(self, memory, memory_mask, text_input):
119
+ """
120
+ moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
121
+ moemory_mask.shape = (B, L, )
122
+ texts_input.shape = (B, T)
123
+ """
124
+ self.initialize_decoder_states(memory, memory_mask)
125
+ # text random mask
126
+ random_mask = (torch.rand(text_input.shape) < self.random_mask).to(text_input.device)
127
+ _text_input = text_input.clone()
128
+ _text_input.masked_fill_(random_mask, self.unk_index)
129
+ decoder_inputs = self.embedding(_text_input).transpose(0, 1) # -> [T, B, channel]
130
+ start_embedding = self.embedding(
131
+ torch.LongTensor([self.sos]*decoder_inputs.size(1)).to(decoder_inputs.device))
132
+ decoder_inputs = torch.cat((start_embedding.unsqueeze(0), decoder_inputs), dim=0)
133
+
134
+ hidden_outputs, logit_outputs, alignments = [], [], []
135
+ while len(hidden_outputs) < decoder_inputs.size(0):
136
+
137
+ decoder_input = decoder_inputs[len(hidden_outputs)]
138
+ hidden, logit, attention_weights = self.decode(decoder_input)
139
+ hidden_outputs += [hidden]
140
+ logit_outputs += [logit]
141
+ alignments += [attention_weights]
142
+
143
+ hidden_outputs, logit_outputs, alignments = \
144
+ self.parse_decoder_outputs(
145
+ hidden_outputs, logit_outputs, alignments)
146
+
147
+ return hidden_outputs, logit_outputs, alignments
148
+
149
+
150
+ def decode(self, decoder_input):
151
+
152
+ cell_input = torch.cat((decoder_input, self.attention_context), -1)
153
+ self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
154
+ cell_input,
155
+ (self.decoder_hidden, self.decoder_cell))
156
+
157
+ attention_weights_cat = torch.cat(
158
+ (self.attention_weights.unsqueeze(1),
159
+ self.attention_weights_cum.unsqueeze(1)),dim=1)
160
+
161
+ self.attention_context, self.attention_weights = self.attention_layer(
162
+ self.decoder_hidden,
163
+ self.memory,
164
+ self.processed_memory,
165
+ attention_weights_cat,
166
+ self.mask)
167
+
168
+ self.attention_weights_cum += self.attention_weights
169
+
170
+ hidden_and_context = torch.cat((self.decoder_hidden, self.attention_context), -1)
171
+ hidden = self.project_to_hidden(hidden_and_context)
172
+
173
+ # dropout to increasing g
174
+ logit = self.project_to_n_symbols(F.dropout(hidden, 0.5, self.training))
175
+
176
+ return hidden, logit, self.attention_weights
177
+
178
+ def parse_decoder_outputs(self, hidden, logit, alignments):
179
+
180
+ # -> [B, T_out + 1, max_time]
181
+ alignments = torch.stack(alignments).transpose(0,1)
182
+ # [T_out + 1, B, n_symbols] -> [B, T_out + 1, n_symbols]
183
+ logit = torch.stack(logit).transpose(0, 1).contiguous()
184
+ hidden = torch.stack(hidden).transpose(0, 1).contiguous()
185
+
186
+ return hidden, logit, alignments
Utils/Utils2/JDC/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
Utils/Utils2/JDC/bst.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54dc94364b97e18ac1dfa6287714ed121248cfaac4cfd39d061c6e0a089ef169
3
+ size 21029926
Utils/Utils2/JDC/model.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of model from:
3
+ Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
4
+ Convolutional Recurrent Neural Networks" (2019)
5
+ Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
6
+ """
7
+ import torch
8
+ from torch import nn
9
+
10
+ class JDCNet(nn.Module):
11
+ """
12
+ Joint Detection and Classification Network model for singing voice melody.
13
+ """
14
+ def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01):
15
+ super().__init__()
16
+ self.num_class = num_class
17
+
18
+ # input = (b, 1, 31, 513), b = batch size
19
+ self.conv_block = nn.Sequential(
20
+ nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False), # out: (b, 64, 31, 513)
21
+ nn.BatchNorm2d(num_features=64),
22
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
23
+ nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513)
24
+ )
25
+
26
+ # res blocks
27
+ self.res_block1 = ResBlock(in_channels=64, out_channels=128) # (b, 128, 31, 128)
28
+ self.res_block2 = ResBlock(in_channels=128, out_channels=192) # (b, 192, 31, 32)
29
+ self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8)
30
+
31
+ # pool block
32
+ self.pool_block = nn.Sequential(
33
+ nn.BatchNorm2d(num_features=256),
34
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
35
+ nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2)
36
+ nn.Dropout(p=0.2),
37
+ )
38
+
39
+ # maxpool layers (for auxiliary network inputs)
40
+ # in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
41
+ self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40))
42
+ # in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
43
+ self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20))
44
+ # in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
45
+ self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10))
46
+
47
+ # in = (b, 640, 31, 2), out = (b, 256, 31, 2)
48
+ self.detector_conv = nn.Sequential(
49
+ nn.Conv2d(640, 256, 1, bias=False),
50
+ nn.BatchNorm2d(256),
51
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
52
+ nn.Dropout(p=0.2),
53
+ )
54
+
55
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
56
+ self.bilstm_classifier = nn.LSTM(
57
+ input_size=512, hidden_size=256,
58
+ batch_first=True, bidirectional=True) # (b, 31, 512)
59
+
60
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
61
+ self.bilstm_detector = nn.LSTM(
62
+ input_size=512, hidden_size=256,
63
+ batch_first=True, bidirectional=True) # (b, 31, 512)
64
+
65
+ # input: (b * 31, 512)
66
+ self.classifier = nn.Linear(in_features=512, out_features=self.num_class) # (b * 31, num_class)
67
+
68
+ # input: (b * 31, 512)
69
+ self.detector = nn.Linear(in_features=512, out_features=2) # (b * 31, 2) - binary classifier
70
+
71
+ # initialize weights
72
+ self.apply(self.init_weights)
73
+
74
+ def get_feature_GAN(self, x):
75
+ seq_len = x.shape[-2]
76
+ x = x.float().transpose(-1, -2)
77
+
78
+ convblock_out = self.conv_block(x)
79
+
80
+ resblock1_out = self.res_block1(convblock_out)
81
+ resblock2_out = self.res_block2(resblock1_out)
82
+ resblock3_out = self.res_block3(resblock2_out)
83
+ poolblock_out = self.pool_block[0](resblock3_out)
84
+ poolblock_out = self.pool_block[1](poolblock_out)
85
+
86
+ return poolblock_out.transpose(-1, -2)
87
+
88
+ def get_feature(self, x):
89
+ seq_len = x.shape[-2]
90
+ x = x.float().transpose(-1, -2)
91
+
92
+ convblock_out = self.conv_block(x)
93
+
94
+ resblock1_out = self.res_block1(convblock_out)
95
+ resblock2_out = self.res_block2(resblock1_out)
96
+ resblock3_out = self.res_block3(resblock2_out)
97
+ poolblock_out = self.pool_block[0](resblock3_out)
98
+ poolblock_out = self.pool_block[1](poolblock_out)
99
+
100
+ return self.pool_block[2](poolblock_out)
101
+
102
+ def forward(self, x):
103
+ """
104
+ Returns:
105
+ classification_prediction, detection_prediction
106
+ sizes: (b, 31, 722), (b, 31, 2)
107
+ """
108
+ ###############################
109
+ # forward pass for classifier #
110
+ ###############################
111
+ seq_len = x.shape[-1]
112
+ x = x.float().transpose(-1, -2)
113
+
114
+ convblock_out = self.conv_block(x)
115
+
116
+ resblock1_out = self.res_block1(convblock_out)
117
+ resblock2_out = self.res_block2(resblock1_out)
118
+ resblock3_out = self.res_block3(resblock2_out)
119
+
120
+
121
+ poolblock_out = self.pool_block[0](resblock3_out)
122
+ poolblock_out = self.pool_block[1](poolblock_out)
123
+ GAN_feature = poolblock_out.transpose(-1, -2)
124
+ poolblock_out = self.pool_block[2](poolblock_out)
125
+
126
+ # (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
127
+ classifier_out = poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512))
128
+ classifier_out, _ = self.bilstm_classifier(classifier_out) # ignore the hidden states
129
+
130
+ classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512)
131
+ classifier_out = self.classifier(classifier_out)
132
+ classifier_out = classifier_out.view((-1, seq_len, self.num_class)) # (b, 31, num_class)
133
+
134
+ # sizes: (b, 31, 722), (b, 31, 2)
135
+ # classifier output consists of predicted pitch classes per frame
136
+ # detector output consists of: (isvoice, notvoice) estimates per frame
137
+ return torch.abs(classifier_out.squeeze()), GAN_feature, poolblock_out
138
+
139
+ @staticmethod
140
+ def init_weights(m):
141
+ if isinstance(m, nn.Linear):
142
+ nn.init.kaiming_uniform_(m.weight)
143
+ if m.bias is not None:
144
+ nn.init.constant_(m.bias, 0)
145
+ elif isinstance(m, nn.Conv2d):
146
+ nn.init.xavier_normal_(m.weight)
147
+ elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell):
148
+ for p in m.parameters():
149
+ if p.data is None:
150
+ continue
151
+
152
+ if len(p.shape) >= 2:
153
+ nn.init.orthogonal_(p.data)
154
+ else:
155
+ nn.init.normal_(p.data)
156
+
157
+
158
+ class ResBlock(nn.Module):
159
+ def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01):
160
+ super().__init__()
161
+ self.downsample = in_channels != out_channels
162
+
163
+ # BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
164
+ self.pre_conv = nn.Sequential(
165
+ nn.BatchNorm2d(num_features=in_channels),
166
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
167
+ nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only
168
+ )
169
+
170
+ # conv layers
171
+ self.conv = nn.Sequential(
172
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
173
+ kernel_size=3, padding=1, bias=False),
174
+ nn.BatchNorm2d(out_channels),
175
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
176
+ nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
177
+ )
178
+
179
+ # 1 x 1 convolution layer to match the feature dimensions
180
+ self.conv1by1 = None
181
+ if self.downsample:
182
+ self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
183
+
184
+ def forward(self, x):
185
+ x = self.pre_conv(x)
186
+ if self.downsample:
187
+ x = self.conv(x) + self.conv1by1(x)
188
+ else:
189
+ x = self.conv(x) + x
190
+ return x
Utils/Utils2/PLBERT/config.yml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Checkpoint"
2
+ mixed_precision: "fp16"
3
+ data_folder: "wikipedia_20220301.en.processed"
4
+ batch_size: 192
5
+ save_interval: 5000
6
+ log_interval: 10
7
+ num_process: 1 # number of GPUs
8
+ num_steps: 1000000
9
+
10
+ dataset_params:
11
+ tokenizer: "transfo-xl-wt103"
12
+ token_separator: " " # token used for phoneme separator (space)
13
+ token_mask: "M" # token used for phoneme mask (M)
14
+ word_separator: 3039 # token used for word separator (<formula>)
15
+ token_maps: "token_maps.pkl" # token map path
16
+
17
+ max_mel_length: 512 # max phoneme length
18
+
19
+ word_mask_prob: 0.15 # probability to mask the entire word
20
+ phoneme_mask_prob: 0.1 # probability to mask each phoneme
21
+ replace_prob: 0.2 # probablity to replace phonemes
22
+
23
+ model_params:
24
+ vocab_size: 178
25
+ hidden_size: 768
26
+ num_attention_heads: 12
27
+ intermediate_size: 2048
28
+ max_position_embeddings: 512
29
+ num_hidden_layers: 12
30
+ dropout: 0.1
Utils/Utils2/PLBERT/step_1000000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0714ff85804db43e06b3b0ac5749bf90cf206257c6c5916e8a98c5933b4c21e0
3
+ size 25185187
Utils/Utils2/PLBERT/util.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ from transformers import AlbertConfig, AlbertModel
5
+
6
+ class CustomAlbert(AlbertModel):
7
+ def forward(self, *args, **kwargs):
8
+ # Call the original forward method
9
+ outputs = super().forward(*args, **kwargs)
10
+
11
+ # Only return the last_hidden_state
12
+ return outputs.last_hidden_state
13
+
14
+
15
+ def load_plbert(log_dir):
16
+ config_path = os.path.join(log_dir, "config.yml")
17
+ plbert_config = yaml.safe_load(open(config_path))
18
+
19
+ albert_base_configuration = AlbertConfig(**plbert_config['model_params'])
20
+ bert = CustomAlbert(albert_base_configuration)
21
+
22
+ files = os.listdir(log_dir)
23
+ ckpts = []
24
+ for f in os.listdir(log_dir):
25
+ if f.startswith("step_"): ckpts.append(f)
26
+
27
+ iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
28
+ iters = sorted(iters)[-1]
29
+
30
+ checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location='cpu')
31
+ state_dict = checkpoint['net']
32
+ from collections import OrderedDict
33
+ new_state_dict = OrderedDict()
34
+ for k, v in state_dict.items():
35
+ name = k[7:] # remove `module.`
36
+ if name.startswith('encoder.'):
37
+ name = name[8:] # remove `encoder.`
38
+ new_state_dict[name] = v
39
+ del new_state_dict["embeddings.position_ids"]
40
+ bert.load_state_dict(new_state_dict, strict=False)
41
+
42
+ return bert
Utils/Utils2/config.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {ASR_config: Utils/ASR/config.yml, ASR_path: Utils/ASR/epoch_00080.pth, F0_path: Utils/JDC/bst.t7,
2
+ PLBERT_dir: Utils/PLBERT/, batch_size: 8, data_params: {OOD_data: Data/OOD_texts.txt,
3
+ min_length: 50, root_path: '', train_data: Data/train_list.txt, val_data: Data/val_list.txt},
4
+ device: cuda, epochs_1st: 40, epochs_2nd: 25, first_stage_path: first_stage.pth,
5
+ load_only_params: false, log_dir: Models/LibriTTS, log_interval: 10, loss_params: {
6
+ TMA_epoch: 4, diff_epoch: 0, joint_epoch: 0, lambda_F0: 1.0, lambda_ce: 20.0,
7
+ lambda_diff: 1.0, lambda_dur: 1.0, lambda_gen: 1.0, lambda_mel: 5.0, lambda_mono: 1.0,
8
+ lambda_norm: 1.0, lambda_s2s: 1.0, lambda_slm: 1.0, lambda_sty: 1.0}, max_len: 300,
9
+ model_params: {decoder: {resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3,
10
+ 5]], resblock_kernel_sizes: [3, 7, 11], type: hifigan, upsample_initial_channel: 512,
11
+ upsample_kernel_sizes: [20, 10, 6, 4], upsample_rates: [10, 5, 3, 2]}, diffusion: {
12
+ dist: {estimate_sigma_data: true, mean: -3.0, sigma_data: 0.19926648961191362,
13
+ std: 1.0}, embedding_mask_proba: 0.1, transformer: {head_features: 64, multiplier: 2,
14
+ num_heads: 8, num_layers: 3}}, dim_in: 64, dropout: 0.2, hidden_dim: 512,
15
+ max_conv_dim: 512, max_dur: 50, multispeaker: true, n_layer: 3, n_mels: 80, n_token: 178,
16
+ slm: {hidden: 768, initial_channel: 64, model: microsoft/wavlm-base-plus, nlayers: 13,
17
+ sr: 16000}, style_dim: 128}, optimizer_params: {bert_lr: 1.0e-05, ft_lr: 1.0e-05,
18
+ lr: 0.0001}, preprocess_params: {spect_params: {hop_length: 300, n_fft: 2048,
19
+ win_length: 1200}, sr: 24000}, pretrained_model: Models/LibriTTS/epoch_2nd_00002.pth,
20
+ save_freq: 1, second_stage_load_pretrained: true, slmadv_params: {batch_percentage: 0.5,
21
+ iter: 20, max_len: 500, min_len: 400, scale: 0.01, sig: 1.5, thresh: 5}}
Utils/Utils2/engineer_style_vectors_v2.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from pathlib import Path
3
+ import shutil
4
+ import csv
5
+ import io
6
+ import os
7
+ import typing
8
+ import wave
9
+ import sys
10
+ from mimic3_tts.__main__ import (CommandLineInterfaceState,
11
+ get_args,
12
+ initialize_args,
13
+ initialize_tts,
14
+ # print_voices,
15
+ # process_lines,
16
+ shutdown_tts,
17
+ OutputNaming,
18
+ process_line)
19
+
20
+
21
+ def process_lines(state: CommandLineInterfaceState, wav_path=None):
22
+ '''MIMIC3 INTERNAL CALL that yields the sigh sound'''
23
+
24
+ args = state.args
25
+
26
+ result_idx = 0
27
+ print(f'why waitings in the for loop LIN {state.texts=}\n')
28
+ for line in state.texts:
29
+ print(f'LIN {line=}\n') # prints \n so is empty not getting the predifne text of state.texts
30
+ line_voice: typing.Optional[str] = None
31
+ line_id = ""
32
+ line = line.strip()
33
+ # if not line:
34
+ # continue
35
+
36
+ if args.output_naming == OutputNaming.ID:
37
+ # Line has the format id|text instead of just text
38
+ with io.StringIO(line) as line_io:
39
+ reader = csv.reader(line_io, delimiter=args.csv_delimiter)
40
+ row = next(reader)
41
+ line_id, line = row[0], row[-1]
42
+ if args.csv_voice:
43
+ line_voice = row[1]
44
+
45
+ process_line(line, state, line_id=line_id, line_voice=line_voice)
46
+ result_idx += 1
47
+
48
+ print('\nARRive at All Audio writing\n\n\n\n')
49
+ # -------------------------------------------------------------------------
50
+
51
+ # Write combined audio to stdout
52
+ if state.all_audio:
53
+ # _LOGGER.debug("Writing WAV audio to stdout")
54
+
55
+ if sys.stdout.isatty() and (not state.args.stdout):
56
+ with io.BytesIO() as wav_io:
57
+ wav_file_play: wave.Wave_write = wave.open(wav_io, "wb")
58
+ with wav_file_play:
59
+ wav_file_play.setframerate(state.sample_rate_hz)
60
+ wav_file_play.setsampwidth(state.sample_width_bytes)
61
+ wav_file_play.setnchannels(state.num_channels)
62
+ wav_file_play.writeframes(state.all_audio)
63
+
64
+ # play_wav_bytes(state.args, wav_io.getvalue())
65
+ # wav_path = '_direct_call_2.wav'
66
+ with open(wav_path, 'wb') as wav_file:
67
+ wav_file.write(wav_io.getvalue())
68
+ wav_file.seek(0)
69
+
70
+ # -----------------------------------------------------------------------------
71
+ # cat _tmp_ssml.txt | mimic3 --cuda --ssml --noise-w 0.90001 --length-scale 0.91 --noise-scale 0.04 > noise_w=0.90_en_happy_2.wav
72
+ # ======================================================================
73
+ out_dir = 'assets/'
74
+ reference_wav_directory = 'assets/wavs/style_vector_v2/'
75
+ Path(reference_wav_directory).mkdir(parents=True, exist_ok=True)
76
+ Path(out_dir).mkdir(parents=True, exist_ok=True)
77
+
78
+ wav_dir = 'assets/wavs/'
79
+ Path(wav_dir).mkdir(parents=True, exist_ok=True)
80
+ N_PIX = 11
81
+
82
+
83
+ # =======================================================================
84
+ # S T A R T G E N E R A T E png/wav
85
+ # =======================================================================
86
+
87
+ NOISE_SCALE = .667
88
+ NOISE_W = .9001 #.8 #.90001 # default .8 in __main__.py @ L697 IGNORED DUE TO ARTEfACTS - FOR NOW USE default
89
+
90
+ a = [
91
+ 'p239',
92
+ 'p236',
93
+ 'p264',
94
+ 'p250',
95
+ 'p259',
96
+ 'p247',
97
+ 'p261',
98
+ 'p263',
99
+ 'p283',
100
+ 'p274',
101
+ 'p286',
102
+ 'p276',
103
+ 'p270',
104
+ 'p281',
105
+ 'p277',
106
+ 'p231',
107
+ 'p238',
108
+ 'p271',
109
+ 'p257',
110
+ 'p273',
111
+ 'p284',
112
+ 'p329',
113
+ 'p361',
114
+ 'p287',
115
+ 'p360',
116
+ 'p374',
117
+ 'p376',
118
+ 'p310',
119
+ 'p304',
120
+ 'p340',
121
+ 'p347',
122
+ 'p330',
123
+ 'p308',
124
+ 'p314',
125
+ 'p317',
126
+ 'p339',
127
+ 'p311',
128
+ 'p294',
129
+ 'p305',
130
+ 'p266',
131
+ 'p335',
132
+ 'p334',
133
+ 'p318',
134
+ 'p323',
135
+ 'p351',
136
+ 'p333',
137
+ 'p313',
138
+ 'p316',
139
+ 'p244',
140
+ 'p307',
141
+ 'p363',
142
+ 'p336',
143
+ 'p312',
144
+ 'p267',
145
+ 'p297',
146
+ 'p275',
147
+ 'p295',
148
+ 'p288',
149
+ 'p258',
150
+ 'p301',
151
+ 'p232',
152
+ 'p292',
153
+ 'p272',
154
+ 'p278',
155
+ 'p280',
156
+ 'p341',
157
+ 'p268',
158
+ 'p298',
159
+ 'p299',
160
+ 'p279',
161
+ 'p285',
162
+ 'p326',
163
+ 'p300',
164
+ 's5',
165
+ 'p230',
166
+ 'p254',
167
+ 'p269',
168
+ 'p293',
169
+ 'p252',
170
+ 'p345',
171
+ 'p262',
172
+ 'p243',
173
+ 'p227',
174
+ 'p343',
175
+ 'p255',
176
+ 'p229',
177
+ 'p240',
178
+ 'p248',
179
+ 'p253',
180
+ 'p233',
181
+ 'p228',
182
+ 'p251',
183
+ 'p282',
184
+ 'p246',
185
+ 'p234',
186
+ 'p226',
187
+ 'p260',
188
+ 'p245',
189
+ 'p241',
190
+ 'p303',
191
+ 'p265',
192
+ 'p306',
193
+ 'p237',
194
+ 'p249',
195
+ 'p256',
196
+ 'p302',
197
+ 'p364',
198
+ 'p225',
199
+ 'p362']
200
+
201
+ print(len(a))
202
+
203
+ b = []
204
+
205
+ for row in a:
206
+ b.append(f'en_US/vctk_low#{row}')
207
+
208
+ # print(b)
209
+
210
+ # 00000000 arctic
211
+
212
+
213
+ a = [
214
+ 'awb' # comma
215
+ 'rms',
216
+ 'slt',
217
+ 'ksp',
218
+ 'clb',
219
+ 'aew',
220
+ 'bdl',
221
+ 'lnh',
222
+ 'jmk',
223
+ 'rxr',
224
+ 'fem',
225
+ 'ljm',
226
+ 'slp',
227
+ 'ahw',
228
+ 'axb',
229
+ 'aup',
230
+ 'eey',
231
+ 'gka',
232
+ ]
233
+
234
+
235
+ for row in a:
236
+ b.append(f'en_US/cmu-arctic_low#{row}')
237
+
238
+ # HIFItts
239
+
240
+ a = ['9017',
241
+ '6097',
242
+ '92']
243
+
244
+ for row in a:
245
+ b.append(f'en_US/hifi-tts_low#{row}')
246
+
247
+ a = [
248
+ 'elliot_miller',
249
+ 'judy_bieber',
250
+ 'mary_ann']
251
+
252
+ for row in a:
253
+ b.append(f'en_US/m-ailabs_low#{row}')
254
+
255
+ # LJspeech - single speaker
256
+
257
+ b.append(f'en_US/ljspeech_low')
258
+
259
+ # en_UK apope - only speaker
260
+
261
+ b.append(f'en_UK/apope_low')
262
+
263
+ all_names = b
264
+
265
+
266
+ VOICES = {}
267
+ for _id, _voice in enumerate(all_names):
268
+
269
+ # If GitHub Quota exceded copy mimic-voices from local copies
270
+ #
271
+ # https://github.com/MycroftAI/mimic3-voices
272
+ #
273
+ home_voice_dir = f'/home/audeering.local/dkounadis/.local/share/mycroft/mimic3/voices/{_voice.split("#")[0]}/'
274
+ Path(home_voice_dir).mkdir(parents=True, exist_ok=True)
275
+ speaker_free_voice_name = _voice.split("#")[0] if '#' in _voice else _voice
276
+ if not os.path.isfile(home_voice_dir + 'generator.onnx'):
277
+ shutil.copyfile(
278
+ f'/data/dkounadis/cache/mimic3-voices/voices/{speaker_free_voice_name}/generator.onnx',
279
+ home_voice_dir + 'generator.onnx') # 'en_US incl. voice
280
+
281
+ prepare_file = _voice.replace('/', '_').replace('#', '_').replace('_low', '')
282
+ if 'cmu-arctic' in prepare_file:
283
+ prepare_file = prepare_file.replace('cmu-arctic', 'cmu_arctic') + '.wav'
284
+ else:
285
+ prepare_file = prepare_file + '.wav' # [...cmu-arctic...](....cmu_arctic....wav)
286
+
287
+ file_true = prepare_file.split('.wav')[0] + '_true_.wav'
288
+ file_false = prepare_file.split('.wav')[0] + '_false_.wav'
289
+ print(prepare_file, file_false, file_true)
290
+
291
+
292
+ reference_wav = reference_wav_directory + prepare_file
293
+ rate = 4 # high speed sounds nice if used as speaker-reference audio for StyleTTS2
294
+ _ssml = (
295
+ '<speak>'
296
+ '<prosody volume=\'64\'>'
297
+ f'<prosody rate=\'{rate}\'>'
298
+ f'<voice name=\'{_voice}\'>'
299
+ '<s>'
300
+ 'Sweet dreams are made of this, .. !!! # I travel the world and the seven seas.'
301
+ '</s>'
302
+ '</voice>'
303
+ '</prosody>'
304
+ '</prosody>'
305
+ '</speak>'
306
+ )
307
+ with open('_tmp_ssml.txt', 'w') as f:
308
+ f.write(_ssml)
309
+
310
+
311
+ # ps = subprocess.Popen(f'cat _tmp_ssml.txt | mimic3 --ssml > {reference_wav}', shell=True)
312
+ # ps.wait() # using ps to call mimic3 because samples dont have time to be written in stdout buffer
313
+ args = get_args()
314
+ args.ssml = True
315
+ args.text = [_ssml] #['aa', 'bb'] #txt
316
+ args.interactive = False
317
+ # args.output_naming = OutputNaming.TIME
318
+
319
+ state = CommandLineInterfaceState(args=args)
320
+ initialize_args(state)
321
+ initialize_tts(state)
322
+ # args.texts = [txt] #['aa', 'bb'] #txt
323
+ # state.stdout = '.' #None #'makeme.wav'
324
+ # state.output_dir = '.noopy'
325
+ # state.interactive = False
326
+ # state.output_naming = OutputNaming.TIME
327
+ # # state.ssml = 1234546575
328
+ # state.stdout = True
329
+ # state.tts = True
330
+ process_lines(state, wav_path=reference_wav)
331
+ shutdown_tts(state)
Utils/config.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {ASR_config: Utils/ASR/config.yml, ASR_path: Utils/ASR/epoch_00080.pth, F0_path: Utils/JDC/bst.t7,
2
+ PLBERT_dir: Utils/PLBERT/, batch_size: 8, data_params: {OOD_data: Data/OOD_texts.txt,
3
+ min_length: 50, root_path: '', train_data: Data/train_list.txt, val_data: Data/val_list.txt},
4
+ device: cuda, epochs_1st: 40, epochs_2nd: 25, first_stage_path: first_stage.pth,
5
+ load_only_params: false, log_dir: Models/LibriTTS, log_interval: 10, loss_params: {
6
+ TMA_epoch: 4, diff_epoch: 0, joint_epoch: 0, lambda_F0: 1.0, lambda_ce: 20.0,
7
+ lambda_diff: 1.0, lambda_dur: 1.0, lambda_gen: 1.0, lambda_mel: 5.0, lambda_mono: 1.0,
8
+ lambda_norm: 1.0, lambda_s2s: 1.0, lambda_slm: 1.0, lambda_sty: 1.0}, max_len: 300,
9
+ model_params: {decoder: {resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3,
10
+ 5]], resblock_kernel_sizes: [3, 7, 11], type: hifigan, upsample_initial_channel: 512,
11
+ upsample_kernel_sizes: [20, 10, 6, 4], upsample_rates: [10, 5, 3, 2]}, diffusion: {
12
+ dist: {estimate_sigma_data: true, mean: -3.0, sigma_data: 0.19926648961191362,
13
+ std: 1.0}, embedding_mask_proba: 0.1, transformer: {head_features: 64, multiplier: 2,
14
+ num_heads: 8, num_layers: 3}}, dim_in: 64, dropout: 0.2, hidden_dim: 512,
15
+ max_conv_dim: 512, max_dur: 50, multispeaker: true, n_layer: 3, n_mels: 80, n_token: 178,
16
+ slm: {hidden: 768, initial_channel: 64, model: microsoft/wavlm-base-plus, nlayers: 13,
17
+ sr: 16000}, style_dim: 128}, optimizer_params: {bert_lr: 1.0e-05, ft_lr: 1.0e-05,
18
+ lr: 0.0001}, preprocess_params: {spect_params: {hop_length: 300, n_fft: 2048,
19
+ win_length: 1200}, sr: 24000}, pretrained_model: Models/LibriTTS/epoch_2nd_00002.pth,
20
+ save_freq: 1, second_stage_load_pretrained: true, slmadv_params: {batch_percentage: 0.5,
21
+ iter: 20, max_len: 500, min_len: 400, scale: 0.01, sig: 1.5, thresh: 5}}
Utils/engineer_style_vectors_v2.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from pathlib import Path
3
+ import shutil
4
+ import csv
5
+ import io
6
+ import os
7
+ import typing
8
+ import wave
9
+ import sys
10
+ from mimic3_tts.__main__ import (CommandLineInterfaceState,
11
+ get_args,
12
+ initialize_args,
13
+ initialize_tts,
14
+ # print_voices,
15
+ # process_lines,
16
+ shutdown_tts,
17
+ OutputNaming,
18
+ process_line)
19
+
20
+
21
+ def process_lines(state: CommandLineInterfaceState, wav_path=None):
22
+ '''MIMIC3 INTERNAL CALL that yields the sigh sound'''
23
+
24
+ args = state.args
25
+
26
+ result_idx = 0
27
+ print(f'why waitings in the for loop LIN {state.texts=}\n')
28
+ for line in state.texts:
29
+ print(f'LIN {line=}\n') # prints \n so is empty not getting the predifne text of state.texts
30
+ line_voice: typing.Optional[str] = None
31
+ line_id = ""
32
+ line = line.strip()
33
+ # if not line:
34
+ # continue
35
+
36
+ if args.output_naming == OutputNaming.ID:
37
+ # Line has the format id|text instead of just text
38
+ with io.StringIO(line) as line_io:
39
+ reader = csv.reader(line_io, delimiter=args.csv_delimiter)
40
+ row = next(reader)
41
+ line_id, line = row[0], row[-1]
42
+ if args.csv_voice:
43
+ line_voice = row[1]
44
+
45
+ process_line(line, state, line_id=line_id, line_voice=line_voice)
46
+ result_idx += 1
47
+
48
+ print('\nARRive at All Audio writing\n\n\n\n')
49
+ # -------------------------------------------------------------------------
50
+
51
+ # Write combined audio to stdout
52
+ if state.all_audio:
53
+ # _LOGGER.debug("Writing WAV audio to stdout")
54
+
55
+ if sys.stdout.isatty() and (not state.args.stdout):
56
+ with io.BytesIO() as wav_io:
57
+ wav_file_play: wave.Wave_write = wave.open(wav_io, "wb")
58
+ with wav_file_play:
59
+ wav_file_play.setframerate(state.sample_rate_hz)
60
+ wav_file_play.setsampwidth(state.sample_width_bytes)
61
+ wav_file_play.setnchannels(state.num_channels)
62
+ wav_file_play.writeframes(state.all_audio)
63
+
64
+ # play_wav_bytes(state.args, wav_io.getvalue())
65
+ # wav_path = '_direct_call_2.wav'
66
+ with open(wav_path, 'wb') as wav_file:
67
+ wav_file.write(wav_io.getvalue())
68
+ wav_file.seek(0)
69
+
70
+ # -----------------------------------------------------------------------------
71
+ # cat _tmp_ssml.txt | mimic3 --cuda --ssml --noise-w 0.90001 --length-scale 0.91 --noise-scale 0.04 > noise_w=0.90_en_happy_2.wav
72
+ # ======================================================================
73
+ out_dir = 'assets/'
74
+ reference_wav_directory = 'assets/wavs/style_vector_v2/'
75
+ Path(reference_wav_directory).mkdir(parents=True, exist_ok=True)
76
+ Path(out_dir).mkdir(parents=True, exist_ok=True)
77
+
78
+ wav_dir = 'assets/wavs/'
79
+ Path(wav_dir).mkdir(parents=True, exist_ok=True)
80
+ N_PIX = 11
81
+
82
+
83
+ # =======================================================================
84
+ # S T A R T G E N E R A T E png/wav
85
+ # =======================================================================
86
+
87
+ NOISE_SCALE = .667
88
+ NOISE_W = .9001 #.8 #.90001 # default .8 in __main__.py @ L697 IGNORED DUE TO ARTEfACTS - FOR NOW USE default
89
+
90
+ a = [
91
+ 'p239',
92
+ 'p236',
93
+ 'p264',
94
+ 'p250',
95
+ 'p259',
96
+ 'p247',
97
+ 'p261',
98
+ 'p263',
99
+ 'p283',
100
+ 'p274',
101
+ 'p286',
102
+ 'p276',
103
+ 'p270',
104
+ 'p281',
105
+ 'p277',
106
+ 'p231',
107
+ 'p238',
108
+ 'p271',
109
+ 'p257',
110
+ 'p273',
111
+ 'p284',
112
+ 'p329',
113
+ 'p361',
114
+ 'p287',
115
+ 'p360',
116
+ 'p374',
117
+ 'p376',
118
+ 'p310',
119
+ 'p304',
120
+ 'p340',
121
+ 'p347',
122
+ 'p330',
123
+ 'p308',
124
+ 'p314',
125
+ 'p317',
126
+ 'p339',
127
+ 'p311',
128
+ 'p294',
129
+ 'p305',
130
+ 'p266',
131
+ 'p335',
132
+ 'p334',
133
+ 'p318',
134
+ 'p323',
135
+ 'p351',
136
+ 'p333',
137
+ 'p313',
138
+ 'p316',
139
+ 'p244',
140
+ 'p307',
141
+ 'p363',
142
+ 'p336',
143
+ 'p312',
144
+ 'p267',
145
+ 'p297',
146
+ 'p275',
147
+ 'p295',
148
+ 'p288',
149
+ 'p258',
150
+ 'p301',
151
+ 'p232',
152
+ 'p292',
153
+ 'p272',
154
+ 'p278',
155
+ 'p280',
156
+ 'p341',
157
+ 'p268',
158
+ 'p298',
159
+ 'p299',
160
+ 'p279',
161
+ 'p285',
162
+ 'p326',
163
+ 'p300',
164
+ 's5',
165
+ 'p230',
166
+ 'p254',
167
+ 'p269',
168
+ 'p293',
169
+ 'p252',
170
+ 'p345',
171
+ 'p262',
172
+ 'p243',
173
+ 'p227',
174
+ 'p343',
175
+ 'p255',
176
+ 'p229',
177
+ 'p240',
178
+ 'p248',
179
+ 'p253',
180
+ 'p233',
181
+ 'p228',
182
+ 'p251',
183
+ 'p282',
184
+ 'p246',
185
+ 'p234',
186
+ 'p226',
187
+ 'p260',
188
+ 'p245',
189
+ 'p241',
190
+ 'p303',
191
+ 'p265',
192
+ 'p306',
193
+ 'p237',
194
+ 'p249',
195
+ 'p256',
196
+ 'p302',
197
+ 'p364',
198
+ 'p225',
199
+ 'p362']
200
+
201
+ print(len(a))
202
+
203
+ b = []
204
+
205
+ for row in a:
206
+ b.append(f'en_US/vctk_low#{row}')
207
+
208
+ # print(b)
209
+
210
+ # 00000000 arctic
211
+
212
+
213
+ a = [
214
+ 'awb' # comma
215
+ 'rms',
216
+ 'slt',
217
+ 'ksp',
218
+ 'clb',
219
+ 'aew',
220
+ 'bdl',
221
+ 'lnh',
222
+ 'jmk',
223
+ 'rxr',
224
+ 'fem',
225
+ 'ljm',
226
+ 'slp',
227
+ 'ahw',
228
+ 'axb',
229
+ 'aup',
230
+ 'eey',
231
+ 'gka',
232
+ ]
233
+
234
+
235
+ for row in a:
236
+ b.append(f'en_US/cmu-arctic_low#{row}')
237
+
238
+ # HIFItts
239
+
240
+ a = ['9017',
241
+ '6097',
242
+ '92']
243
+
244
+ for row in a:
245
+ b.append(f'en_US/hifi-tts_low#{row}')
246
+
247
+ a = [
248
+ 'elliot_miller',
249
+ 'judy_bieber',
250
+ 'mary_ann']
251
+
252
+ for row in a:
253
+ b.append(f'en_US/m-ailabs_low#{row}')
254
+
255
+ # LJspeech - single speaker
256
+
257
+ b.append(f'en_US/ljspeech_low')
258
+
259
+ # en_UK apope - only speaker
260
+
261
+ b.append(f'en_UK/apope_low')
262
+
263
+ all_names = b
264
+
265
+
266
+ VOICES = {}
267
+ for _id, _voice in enumerate(all_names):
268
+
269
+ # If GitHub Quota exceded copy mimic-voices from local copies
270
+ #
271
+ # https://github.com/MycroftAI/mimic3-voices
272
+ #
273
+ home_voice_dir = f'/home/audeering.local/dkounadis/.local/share/mycroft/mimic3/voices/{_voice.split("#")[0]}/'
274
+ Path(home_voice_dir).mkdir(parents=True, exist_ok=True)
275
+ speaker_free_voice_name = _voice.split("#")[0] if '#' in _voice else _voice
276
+ if not os.path.isfile(home_voice_dir + 'generator.onnx'):
277
+ shutil.copyfile(
278
+ f'/data/dkounadis/cache/mimic3-voices/voices/{speaker_free_voice_name}/generator.onnx',
279
+ home_voice_dir + 'generator.onnx') # 'en_US incl. voice
280
+
281
+ prepare_file = _voice.replace('/', '_').replace('#', '_').replace('_low', '')
282
+ if 'cmu-arctic' in prepare_file:
283
+ prepare_file = prepare_file.replace('cmu-arctic', 'cmu_arctic') + '.wav'
284
+ else:
285
+ prepare_file = prepare_file + '.wav' # [...cmu-arctic...](....cmu_arctic....wav)
286
+
287
+ file_true = prepare_file.split('.wav')[0] + '_true_.wav'
288
+ file_false = prepare_file.split('.wav')[0] + '_false_.wav'
289
+ print(prepare_file, file_false, file_true)
290
+
291
+
292
+ reference_wav = reference_wav_directory + prepare_file
293
+ rate = 4 # high speed sounds nice if used as speaker-reference audio for StyleTTS2
294
+ _ssml = (
295
+ '<speak>'
296
+ '<prosody volume=\'64\'>'
297
+ f'<prosody rate=\'{rate}\'>'
298
+ f'<voice name=\'{_voice}\'>'
299
+ '<s>'
300
+ 'Sweet dreams are made of this, .. !!! # I travel the world and the seven seas.'
301
+ '</s>'
302
+ '</voice>'
303
+ '</prosody>'
304
+ '</prosody>'
305
+ '</speak>'
306
+ )
307
+ with open('_tmp_ssml.txt', 'w') as f:
308
+ f.write(_ssml)
309
+
310
+
311
+ # ps = subprocess.Popen(f'cat _tmp_ssml.txt | mimic3 --ssml > {reference_wav}', shell=True)
312
+ # ps.wait() # using ps to call mimic3 because samples dont have time to be written in stdout buffer
313
+ args = get_args()
314
+ args.ssml = True
315
+ args.text = [_ssml] #['aa', 'bb'] #txt
316
+ args.interactive = False
317
+ # args.output_naming = OutputNaming.TIME
318
+
319
+ state = CommandLineInterfaceState(args=args)
320
+ initialize_args(state)
321
+ initialize_tts(state)
322
+ # args.texts = [txt] #['aa', 'bb'] #txt
323
+ # state.stdout = '.' #None #'makeme.wav'
324
+ # state.output_dir = '.noopy'
325
+ # state.interactive = False
326
+ # state.output_naming = OutputNaming.TIME
327
+ # # state.ssml = 1234546575
328
+ # state.stdout = True
329
+ # state.tts = True
330
+ process_lines(state, wav_path=reference_wav)
331
+ shutdown_tts(state)
models.py CHANGED
@@ -517,7 +517,7 @@ def load_F0_models(path):
517
 
518
  F0_model = JDCNet(num_class=1, seq_len=192)
519
  print(path, 'WHAT ARE YOU TRYING TO LOAD F0 L520')
520
- path.replace('.t7', '.pth')
521
  params = torch.load(path, map_location='cpu')['net']
522
  F0_model.load_state_dict(params)
523
  _ = F0_model.train()
 
517
 
518
  F0_model = JDCNet(num_class=1, seq_len=192)
519
  print(path, 'WHAT ARE YOU TRYING TO LOAD F0 L520')
520
+ path = path.replace('.t7', '.pth')
521
  params = torch.load(path, map_location='cpu')['net']
522
  F0_model.load_state_dict(params)
523
  _ = F0_model.train()