Plachta commited on
Commit
a83a03b
β€’
1 Parent(s): a1e9282

Update modules/length_regulator.py

Browse files
Files changed (1) hide show
  1. modules/length_regulator.py +141 -118
modules/length_regulator.py CHANGED
@@ -1,118 +1,141 @@
1
- from typing import Tuple
2
- import torch
3
- import torch.nn as nn
4
- from torch.nn import functional as F
5
- from modules.commons import sequence_mask
6
- import numpy as np
7
-
8
- # f0_bin = 256
9
- f0_max = 1100.0
10
- f0_min = 50.0
11
- f0_mel_min = 1127 * np.log(1 + f0_min / 700)
12
- f0_mel_max = 1127 * np.log(1 + f0_max / 700)
13
-
14
- def f0_to_coarse(f0, f0_bin):
15
- f0_mel = 1127 * (1 + f0 / 700).log()
16
- a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
17
- b = f0_mel_min * a - 1.
18
- f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
19
- # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
20
- f0_coarse = torch.round(f0_mel).long()
21
- f0_coarse = f0_coarse * (f0_coarse > 0)
22
- f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
23
- f0_coarse = f0_coarse * (f0_coarse < f0_bin)
24
- f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1))
25
- return f0_coarse
26
-
27
- class InterpolateRegulator(nn.Module):
28
- def __init__(
29
- self,
30
- channels: int,
31
- sampling_ratios: Tuple,
32
- is_discrete: bool = False,
33
- codebook_size: int = 1024, # for discrete only
34
- out_channels: int = None,
35
- groups: int = 1,
36
- token_dropout_prob: float = 0.5, # randomly drop out input tokens
37
- token_dropout_range: float = 0.5, # randomly drop out input tokens
38
- n_codebooks: int = 1, # number of codebooks
39
- quantizer_dropout: float = 0.0, # dropout for quantizer
40
- f0_condition: bool = False,
41
- n_f0_bins: int = 512,
42
- ):
43
- super().__init__()
44
- self.sampling_ratios = sampling_ratios
45
- out_channels = out_channels or channels
46
- model = nn.ModuleList([])
47
- if len(sampling_ratios) > 0:
48
- for _ in sampling_ratios:
49
- module = nn.Conv1d(channels, channels, 3, 1, 1)
50
- norm = nn.GroupNorm(groups, channels)
51
- act = nn.Mish()
52
- model.extend([module, norm, act])
53
- model.append(
54
- nn.Conv1d(channels, out_channels, 1, 1)
55
- )
56
- self.model = nn.Sequential(*model)
57
- self.embedding = nn.Embedding(codebook_size, channels)
58
- self.is_discrete = is_discrete
59
-
60
- self.mask_token = nn.Parameter(torch.zeros(1, channels))
61
-
62
- self.n_codebooks = n_codebooks
63
- if n_codebooks > 1:
64
- self.extra_codebooks = nn.ModuleList([
65
- nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1)
66
- ])
67
- self.token_dropout_prob = token_dropout_prob
68
- self.token_dropout_range = token_dropout_range
69
- self.quantizer_dropout = quantizer_dropout
70
-
71
- if f0_condition:
72
- self.f0_embedding = nn.Embedding(n_f0_bins, channels)
73
- self.f0_condition = f0_condition
74
- self.n_f0_bins = n_f0_bins
75
- self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
76
- self.f0_mask = nn.Parameter(torch.zeros(1, channels))
77
- else:
78
- self.f0_condition = False
79
-
80
- def forward(self, x, ylens=None, n_quantizers=None, f0=None):
81
- # apply token drop
82
- if self.training:
83
- n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks
84
- dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],))
85
- n_dropout = int(x.shape[0] * self.quantizer_dropout)
86
- n_quantizers[:n_dropout] = dropout[:n_dropout]
87
- n_quantizers = n_quantizers.to(x.device)
88
- # decide whether to drop for each sample in batch
89
- else:
90
- n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers)
91
- if self.is_discrete:
92
- if self.n_codebooks > 1:
93
- assert len(x.size()) == 3
94
- x_emb = self.embedding(x[:, 0])
95
- for i, emb in enumerate(self.extra_codebooks):
96
- x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1])
97
- x = x_emb
98
- elif self.n_codebooks == 1:
99
- if len(x.size()) == 2:
100
- x = self.embedding(x)
101
- else:
102
- x = self.embedding(x[:, 0])
103
- # x in (B, T, D)
104
- mask = sequence_mask(ylens).unsqueeze(-1)
105
- x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
106
- if self.f0_condition:
107
- if f0 is None:
108
- x = x + self.f0_mask.unsqueeze(-1)
109
- else:
110
- # quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
111
- quantized_f0 = f0_to_coarse(f0, self.n_f0_bins)
112
- quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long()
113
- f0_emb = self.f0_embedding(quantized_f0)
114
- f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
115
- x = x + f0_emb
116
- out = self.model(x).transpose(1, 2).contiguous()
117
- olens = ylens
118
- return out * mask, olens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from modules.commons import sequence_mask
6
+ import numpy as np
7
+ from dac.nn.quantize import VectorQuantize
8
+
9
+ # f0_bin = 256
10
+ f0_max = 1100.0
11
+ f0_min = 50.0
12
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
13
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
14
+
15
+ def f0_to_coarse(f0, f0_bin):
16
+ f0_mel = 1127 * (1 + f0 / 700).log()
17
+ a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
18
+ b = f0_mel_min * a - 1.
19
+ f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
20
+ # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
21
+ f0_coarse = torch.round(f0_mel).long()
22
+ f0_coarse = f0_coarse * (f0_coarse > 0)
23
+ f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
24
+ f0_coarse = f0_coarse * (f0_coarse < f0_bin)
25
+ f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1))
26
+ return f0_coarse
27
+
28
+ class InterpolateRegulator(nn.Module):
29
+ def __init__(
30
+ self,
31
+ channels: int,
32
+ sampling_ratios: Tuple,
33
+ is_discrete: bool = False,
34
+ in_channels: int = None, # only applies to continuous input
35
+ vector_quantize: bool = False, # whether to use vector quantization, only applies to continuous input
36
+ codebook_size: int = 1024, # for discrete only
37
+ out_channels: int = None,
38
+ groups: int = 1,
39
+ n_codebooks: int = 1, # number of codebooks
40
+ quantizer_dropout: float = 0.0, # dropout for quantizer
41
+ f0_condition: bool = False,
42
+ n_f0_bins: int = 512,
43
+ ):
44
+ super().__init__()
45
+ self.sampling_ratios = sampling_ratios
46
+ out_channels = out_channels or channels
47
+ model = nn.ModuleList([])
48
+ if len(sampling_ratios) > 0:
49
+ self.interpolate = True
50
+ for _ in sampling_ratios:
51
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
52
+ norm = nn.GroupNorm(groups, channels)
53
+ act = nn.Mish()
54
+ model.extend([module, norm, act])
55
+ else:
56
+ self.interpolate = False
57
+ model.append(
58
+ nn.Conv1d(channels, out_channels, 1, 1)
59
+ )
60
+ self.model = nn.Sequential(*model)
61
+ self.embedding = nn.Embedding(codebook_size, channels)
62
+ self.is_discrete = is_discrete
63
+
64
+ self.mask_token = nn.Parameter(torch.zeros(1, channels))
65
+
66
+ self.n_codebooks = n_codebooks
67
+ if n_codebooks > 1:
68
+ self.extra_codebooks = nn.ModuleList([
69
+ nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1)
70
+ ])
71
+ self.extra_codebook_mask_tokens = nn.ParameterList([
72
+ nn.Parameter(torch.zeros(1, channels)) for _ in range(n_codebooks - 1)
73
+ ])
74
+ self.quantizer_dropout = quantizer_dropout
75
+
76
+ if f0_condition:
77
+ self.f0_embedding = nn.Embedding(n_f0_bins, channels)
78
+ self.f0_condition = f0_condition
79
+ self.n_f0_bins = n_f0_bins
80
+ self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
81
+ self.f0_mask = nn.Parameter(torch.zeros(1, channels))
82
+ else:
83
+ self.f0_condition = False
84
+
85
+ if not is_discrete:
86
+ self.content_in_proj = nn.Linear(in_channels, channels)
87
+ if vector_quantize:
88
+ self.vq = VectorQuantize(channels, codebook_size, 8)
89
+
90
+ def forward(self, x, ylens=None, n_quantizers=None, f0=None):
91
+ # apply token drop
92
+ if self.training:
93
+ n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks
94
+ dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],))
95
+ n_dropout = int(x.shape[0] * self.quantizer_dropout)
96
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
97
+ n_quantizers = n_quantizers.to(x.device)
98
+ # decide whether to drop for each sample in batch
99
+ else:
100
+ n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers)
101
+ if self.is_discrete:
102
+ if self.n_codebooks > 1:
103
+ assert len(x.size()) == 3
104
+ x_emb = self.embedding(x[:, 0])
105
+ for i, emb in enumerate(self.extra_codebooks):
106
+ x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1])
107
+ # add mask token if not using this codebook
108
+ # x_emb = x_emb + (n_quantizers <= i+1)[..., None, None] * self.extra_codebook_mask_tokens[i]
109
+ x = x_emb
110
+ elif self.n_codebooks == 1:
111
+ if len(x.size()) == 2:
112
+ x = self.embedding(x)
113
+ else:
114
+ x = self.embedding(x[:, 0])
115
+ else:
116
+ x = self.content_in_proj(x)
117
+ # x in (B, T, D)
118
+ mask = sequence_mask(ylens).unsqueeze(-1)
119
+ if self.interpolate:
120
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
121
+ else:
122
+ x = x.transpose(1, 2).contiguous()
123
+ mask = mask[:, :x.size(2), :]
124
+ ylens = ylens.clamp(max=x.size(2)).long()
125
+ if self.f0_condition:
126
+ if f0 is None:
127
+ x = x + self.f0_mask.unsqueeze(-1)
128
+ else:
129
+ quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
130
+ #quantized_f0 = f0_to_coarse(f0, self.n_f0_bins)
131
+ #quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long()
132
+ f0_emb = self.f0_embedding(quantized_f0)
133
+ f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
134
+ x = x + f0_emb
135
+ out = self.model(x).transpose(1, 2).contiguous()
136
+ if hasattr(self, 'vq'):
137
+ out_q, commitment_loss, codebook_loss, codes, out, = self.vq(out.transpose(1, 2))
138
+ out_q = out_q.transpose(1, 2)
139
+ return out_q * mask, ylens, codes, commitment_loss, codebook_loss
140
+ olens = ylens
141
+ return out * mask, olens, None, None, None