Blane187 commited on
Commit
c3b58fa
1 Parent(s): 236b4c8

Upload 39 files

Browse files
requirements.txt CHANGED
@@ -1,16 +1,15 @@
1
  joblib>=1.1.0
2
- numba==0.56.4
3
  numpy==1.23.5
4
  scipy
5
  librosa==0.9.1
6
- llvmlite==0.39.0
7
- fairseq==0.12.2
8
- faiss-cpu==1.7.3
9
- gradio==3.34.0
10
  Cython
11
  pydub>=0.25.1
12
  soundfile>=0.12.1
13
- ffmpeg-python>=0.2.0
14
  tensorboardX
15
  Jinja2>=3.1.2
16
  json5
@@ -41,8 +40,8 @@ httpx
41
  onnxruntime; sys_platform == 'darwin'
42
  onnxruntime-gpu; sys_platform != 'darwin'
43
  torchcrepe==0.0.20
44
- fastapi==0.88
45
  torchfcpe
46
- ffmpy==0.3.1
47
  python-dotenv>=1.0.0
48
  av
 
 
1
  joblib>=1.1.0
2
+ numba
3
  numpy==1.23.5
4
  scipy
5
  librosa==0.9.1
6
+ llvmlite
7
+ fairseq
8
+ faiss-cpu
9
+ gradio
10
  Cython
11
  pydub>=0.25.1
12
  soundfile>=0.12.1
 
13
  tensorboardX
14
  Jinja2>=3.1.2
15
  json5
 
40
  onnxruntime; sys_platform == 'darwin'
41
  onnxruntime-gpu; sys_platform != 'darwin'
42
  torchcrepe==0.0.20
43
+ fastapi
44
  torchfcpe
 
45
  python-dotenv>=1.0.0
46
  av
47
+ pybase16384
rvc/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from . import ipex
2
+ import sys
3
+
4
+ del sys.modules["rvc.ipex"]
rvc/f0/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .f0 import F0Predictor
2
+
3
+ from .crepe import CRePE
4
+ from .dio import Dio
5
+ from .fcpe import FCPE
6
+ from .harvest import Harvest
7
+ from .pm import PM
8
+ from .rmvpe import RMVPE
9
+
10
+ __all__ = ["F0Predictor", "CRePE", "Dio", "FCPE", "Harvest", "PM", "RMVPE"]
rvc/f0/crepe.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchcrepe
6
+
7
+ from .f0 import F0Predictor
8
+
9
+
10
+ class CRePE(F0Predictor):
11
+ def __init__(
12
+ self,
13
+ hop_length=512,
14
+ f0_min=50,
15
+ f0_max=1100,
16
+ sampling_rate=44100,
17
+ device="cpu",
18
+ ):
19
+ if "privateuseone" in str(device):
20
+ device = "cpu"
21
+ super().__init__(
22
+ hop_length,
23
+ f0_min,
24
+ f0_max,
25
+ sampling_rate,
26
+ device,
27
+ )
28
+
29
+ def compute_f0(
30
+ self,
31
+ wav: np.ndarray,
32
+ p_len: Optional[int] = None,
33
+ filter_radius: Optional[Union[int, float]] = None,
34
+ ):
35
+ if p_len is None:
36
+ p_len = wav.shape[0] // self.hop_length
37
+ if not torch.is_tensor(wav):
38
+ wav = torch.from_numpy(wav)
39
+ # Pick a batch size that doesn't cause memory errors on your gpu
40
+ batch_size = 512
41
+ # Compute pitch using device 'device'
42
+ f0, pd = torchcrepe.predict(
43
+ wav.float().to(self.device).unsqueeze(dim=0),
44
+ self.sampling_rate,
45
+ self.hop_length,
46
+ self.f0_min,
47
+ self.f0_max,
48
+ batch_size=batch_size,
49
+ device=self.device,
50
+ return_periodicity=True,
51
+ )
52
+ pd = torchcrepe.filter.median(pd, 3)
53
+ f0 = torchcrepe.filter.mean(f0, 3)
54
+ f0[pd < 0.1] = 0
55
+ f0 = f0[0].cpu().numpy()
56
+ return self._interpolate_f0(self._resize_f0(f0, p_len))[0]
rvc/f0/deepunet.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class ConvBlockRes(nn.Module):
8
+ def __init__(
9
+ self,
10
+ in_channels: int,
11
+ out_channels: int,
12
+ momentum: float = 0.01,
13
+ ):
14
+ super(ConvBlockRes, self).__init__()
15
+ self.conv = nn.Sequential(
16
+ nn.Conv2d(
17
+ in_channels=in_channels,
18
+ out_channels=out_channels,
19
+ kernel_size=(3, 3),
20
+ stride=(1, 1),
21
+ padding=(1, 1),
22
+ bias=False,
23
+ ),
24
+ nn.BatchNorm2d(out_channels, momentum=momentum),
25
+ nn.ReLU(),
26
+ nn.Conv2d(
27
+ in_channels=out_channels,
28
+ out_channels=out_channels,
29
+ kernel_size=(3, 3),
30
+ stride=(1, 1),
31
+ padding=(1, 1),
32
+ bias=False,
33
+ ),
34
+ nn.BatchNorm2d(out_channels, momentum=momentum),
35
+ nn.ReLU(),
36
+ )
37
+ # self.shortcut:Optional[nn.Module] = None
38
+ if in_channels != out_channels:
39
+ self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
40
+
41
+ def forward(self, x: torch.Tensor):
42
+ if not hasattr(self, "shortcut"):
43
+ return self.conv(x) + x
44
+ else:
45
+ return self.conv(x) + self.shortcut(x)
46
+
47
+
48
+ class Encoder(nn.Module):
49
+ def __init__(
50
+ self,
51
+ in_channels: int,
52
+ in_size: int,
53
+ n_encoders: int,
54
+ kernel_size: Tuple[int, int],
55
+ n_blocks: int,
56
+ out_channels=16,
57
+ momentum=0.01,
58
+ ):
59
+ super(Encoder, self).__init__()
60
+ self.n_encoders = n_encoders
61
+
62
+ self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
63
+ self.layers = nn.ModuleList()
64
+ for _ in range(self.n_encoders):
65
+ self.layers.append(
66
+ ResEncoderBlock(
67
+ in_channels, out_channels, kernel_size, n_blocks, momentum=momentum
68
+ )
69
+ )
70
+ in_channels = out_channels
71
+ out_channels *= 2
72
+ in_size //= 2
73
+ self.out_size = in_size
74
+ self.out_channel = out_channels
75
+
76
+ def __call__(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
77
+ return super().__call__(x)
78
+
79
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
80
+ concat_tensors: List[torch.Tensor] = []
81
+ x = self.bn(x)
82
+ for layer in self.layers:
83
+ t, x = layer(x)
84
+ concat_tensors.append(t)
85
+ return x, concat_tensors
86
+
87
+
88
+ class ResEncoderBlock(nn.Module):
89
+ def __init__(
90
+ self,
91
+ in_channels: int,
92
+ out_channels: int,
93
+ kernel_size: Tuple[int, int],
94
+ n_blocks=1,
95
+ momentum=0.01,
96
+ ):
97
+ super(ResEncoderBlock, self).__init__()
98
+ self.n_blocks = n_blocks
99
+ self.kernel_size = kernel_size
100
+
101
+ self.conv = nn.ModuleList()
102
+ self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
103
+ for _ in range(n_blocks - 1):
104
+ self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
105
+
106
+ if self.kernel_size is not None:
107
+ self.pool = nn.AvgPool2d(kernel_size=kernel_size)
108
+
109
+ def forward(
110
+ self,
111
+ x: torch.Tensor,
112
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
113
+ for conv in self.conv:
114
+ x = conv(x)
115
+ if self.kernel_size is not None:
116
+ return x, self.pool(x)
117
+ return x
118
+
119
+
120
+ class Intermediate(nn.Module):
121
+ def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
122
+ super(Intermediate, self).__init__()
123
+
124
+ self.layers = nn.ModuleList()
125
+ self.layers.append(
126
+ ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)
127
+ )
128
+ for _ in range(n_inters - 1):
129
+ self.layers.append(
130
+ ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)
131
+ )
132
+
133
+ def forward(self, x):
134
+ for layer in self.layers:
135
+ x = layer(x)
136
+ return x
137
+
138
+
139
+ class ResDecoderBlock(nn.Module):
140
+ def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
141
+ super(ResDecoderBlock, self).__init__()
142
+ out_padding = (0, 1) if stride == (1, 2) else (1, 1)
143
+
144
+ self.conv1 = nn.Sequential(
145
+ nn.ConvTranspose2d(
146
+ in_channels=in_channels,
147
+ out_channels=out_channels,
148
+ kernel_size=(3, 3),
149
+ stride=stride,
150
+ padding=(1, 1),
151
+ output_padding=out_padding,
152
+ bias=False,
153
+ ),
154
+ nn.BatchNorm2d(out_channels, momentum=momentum),
155
+ nn.ReLU(),
156
+ )
157
+ self.conv2 = nn.ModuleList()
158
+ self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
159
+ for _ in range(n_blocks - 1):
160
+ self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
161
+
162
+ def forward(self, x, concat_tensor):
163
+ x = self.conv1(x)
164
+ x = torch.cat((x, concat_tensor), dim=1)
165
+ for conv2 in self.conv2:
166
+ x = conv2(x)
167
+ return x
168
+
169
+
170
+ class Decoder(nn.Module):
171
+ def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
172
+ super(Decoder, self).__init__()
173
+
174
+ self.layers = nn.ModuleList()
175
+ self.n_decoders = n_decoders
176
+ for _ in range(self.n_decoders):
177
+ out_channels = in_channels // 2
178
+ self.layers.append(
179
+ ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
180
+ )
181
+ in_channels = out_channels
182
+
183
+ def forward(self, x: torch.Tensor, concat_tensors: List[torch.Tensor]):
184
+ for i, layer in enumerate(self.layers):
185
+ x = layer(x, concat_tensors[-1 - i])
186
+ return x
187
+
188
+
189
+ class DeepUnet(nn.Module):
190
+ def __init__(
191
+ self,
192
+ kernel_size: Tuple[int, int],
193
+ n_blocks: int,
194
+ en_de_layers=5,
195
+ inter_layers=4,
196
+ in_channels=1,
197
+ en_out_channels=16,
198
+ ):
199
+ super(DeepUnet, self).__init__()
200
+ self.encoder = Encoder(
201
+ in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
202
+ )
203
+ self.intermediate = Intermediate(
204
+ self.encoder.out_channel // 2,
205
+ self.encoder.out_channel,
206
+ inter_layers,
207
+ n_blocks,
208
+ )
209
+ self.decoder = Decoder(
210
+ self.encoder.out_channel, en_de_layers, kernel_size, n_blocks
211
+ )
212
+
213
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
214
+ x, concat_tensors = self.encoder(x)
215
+ x = self.intermediate(x)
216
+ x = self.decoder(x, concat_tensors)
217
+ return x
rvc/f0/dio.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Union
2
+
3
+ import numpy as np
4
+ import pyworld
5
+
6
+ from .f0 import F0Predictor
7
+
8
+
9
+ class Dio(F0Predictor):
10
+ def __init__(self, hop_length=512, f0_min=50, f0_max=1100, sampling_rate=44100):
11
+ super().__init__(hop_length, f0_min, f0_max, sampling_rate)
12
+
13
+ def compute_f0(
14
+ self,
15
+ wav: np.ndarray,
16
+ p_len: Optional[int] = None,
17
+ filter_radius: Optional[Union[int, float]] = None,
18
+ ):
19
+ if p_len is None:
20
+ p_len = wav.shape[0] // self.hop_length
21
+ f0, t = pyworld.dio(
22
+ wav.astype(np.double),
23
+ fs=self.sampling_rate,
24
+ f0_floor=self.f0_min,
25
+ f0_ceil=self.f0_max,
26
+ frame_period=1000 * self.hop_length / self.sampling_rate,
27
+ )
28
+ f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate)
29
+ for index, pitch in enumerate(f0):
30
+ f0[index] = round(pitch, 1)
31
+ return self._interpolate_f0(self._resize_f0(f0, p_len))[0]
rvc/f0/e2e.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch.nn as nn
4
+
5
+ from .deepunet import DeepUnet
6
+
7
+
8
+ class E2E(nn.Module):
9
+ def __init__(
10
+ self,
11
+ n_blocks: int,
12
+ n_gru: int,
13
+ kernel_size: Tuple[int, int],
14
+ en_de_layers=5,
15
+ inter_layers=4,
16
+ in_channels=1,
17
+ en_out_channels=16,
18
+ ):
19
+ super(E2E, self).__init__()
20
+
21
+ self.unet = DeepUnet(
22
+ kernel_size,
23
+ n_blocks,
24
+ en_de_layers,
25
+ inter_layers,
26
+ in_channels,
27
+ en_out_channels,
28
+ )
29
+ self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
30
+ if n_gru:
31
+ self.fc = nn.Sequential(
32
+ self.BiGRU(3 * 128, 256, n_gru),
33
+ nn.Linear(512, 360),
34
+ nn.Dropout(0.25),
35
+ nn.Sigmoid(),
36
+ )
37
+ else:
38
+ self.fc = nn.Sequential(
39
+ nn.Linear(3 * nn.N_MELS, nn.N_CLASS),
40
+ nn.Dropout(0.25),
41
+ nn.Sigmoid(),
42
+ )
43
+
44
+ def forward(self, mel):
45
+ mel = mel.transpose(-1, -2).unsqueeze(1)
46
+ x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
47
+ x = self.fc(x)
48
+ return x
49
+
50
+ class BiGRU(nn.Module):
51
+ def __init__(
52
+ self,
53
+ input_features: int,
54
+ hidden_features: int,
55
+ num_layers: int,
56
+ ):
57
+ super().__init__()
58
+ self.gru = nn.GRU(
59
+ input_features,
60
+ hidden_features,
61
+ num_layers=num_layers,
62
+ batch_first=True,
63
+ bidirectional=True,
64
+ )
65
+
66
+ def forward(self, x):
67
+ return self.gru(x)[0]
rvc/f0/f0.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ import numpy as np
5
+
6
+
7
+ class F0Predictor(object):
8
+ def __init__(
9
+ self,
10
+ hop_length=512,
11
+ f0_min=50,
12
+ f0_max=1100,
13
+ sampling_rate=44100,
14
+ device: Optional[str] = None,
15
+ ):
16
+ self.hop_length = hop_length
17
+ self.f0_min = f0_min
18
+ self.f0_max = f0_max
19
+ self.sampling_rate = sampling_rate
20
+ if device is None:
21
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
22
+ self.device = device
23
+
24
+ def compute_f0(
25
+ self,
26
+ wav: np.ndarray,
27
+ p_len: Optional[int] = None,
28
+ filter_radius: Optional[Union[int, float]] = None,
29
+ ): ...
30
+
31
+ def _interpolate_f0(self, f0: np.ndarray):
32
+ """
33
+ 对F0进行插值处理
34
+ """
35
+
36
+ data = np.reshape(f0, (f0.size, 1))
37
+
38
+ vuv_vector = np.zeros((data.size, 1), dtype=np.float32)
39
+ vuv_vector[data > 0.0] = 1.0
40
+ vuv_vector[data <= 0.0] = 0.0
41
+
42
+ ip_data = data
43
+
44
+ frame_number = data.size
45
+ last_value = 0.0
46
+ for i in range(frame_number):
47
+ if data[i] <= 0.0:
48
+ j = i + 1
49
+ for j in range(i + 1, frame_number):
50
+ if data[j] > 0.0:
51
+ break
52
+ if j < frame_number - 1:
53
+ if last_value > 0.0:
54
+ step = (data[j] - data[i - 1]) / float(j - i)
55
+ for k in range(i, j):
56
+ ip_data[k] = data[i - 1] + step * (k - i + 1)
57
+ else:
58
+ for k in range(i, j):
59
+ ip_data[k] = data[j]
60
+ else:
61
+ for k in range(i, frame_number):
62
+ ip_data[k] = last_value
63
+ else:
64
+ ip_data[i] = data[i] # 这里可能存在一个没有必要的拷贝
65
+ last_value = data[i]
66
+
67
+ return ip_data[:, 0], vuv_vector[:, 0]
68
+
69
+ def _resize_f0(self, x: np.ndarray, target_len: int):
70
+ source = np.array(x)
71
+ source[source < 0.001] = np.nan
72
+ target = np.interp(
73
+ np.arange(0, len(source) * target_len, len(source)) / target_len,
74
+ np.arange(0, len(source)),
75
+ source,
76
+ )
77
+ res = np.nan_to_num(target)
78
+ return res
rvc/f0/fcpe.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from .f0 import F0Predictor
7
+
8
+
9
+ class FCPE(F0Predictor):
10
+ def __init__(
11
+ self,
12
+ hop_length=512,
13
+ f0_min=50,
14
+ f0_max=1100,
15
+ sampling_rate=44100,
16
+ device="cpu",
17
+ ):
18
+ super().__init__(
19
+ hop_length,
20
+ f0_min,
21
+ f0_max,
22
+ sampling_rate,
23
+ device,
24
+ )
25
+
26
+ from torchfcpe import (
27
+ spawn_bundled_infer_model,
28
+ ) # must be imported at here, or it will cause fairseq crash on training
29
+
30
+ self.model = spawn_bundled_infer_model(self.device)
31
+
32
+ def compute_f0(
33
+ self,
34
+ wav: np.ndarray,
35
+ p_len: Optional[int] = None,
36
+ filter_radius: Optional[Union[int, float]] = 0.006,
37
+ ):
38
+ if p_len is None:
39
+ p_len = wav.shape[0] // self.hop_length
40
+ if not torch.is_tensor(wav):
41
+ wav = torch.from_numpy(wav)
42
+ f0 = (
43
+ self.model.infer(
44
+ wav.float().to(self.device).unsqueeze(0),
45
+ sr=self.sampling_rate,
46
+ decoder_mode="local_argmax",
47
+ threshold=filter_radius,
48
+ )
49
+ .squeeze()
50
+ .cpu()
51
+ .numpy()
52
+ )
53
+ return self._interpolate_f0(self._resize_f0(f0, p_len))[0]
rvc/f0/harvest.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Union
2
+
3
+ import numpy as np
4
+ import pyworld
5
+ from scipy import signal
6
+
7
+ from .f0 import F0Predictor
8
+
9
+
10
+ class Harvest(F0Predictor):
11
+ def __init__(self, hop_length=512, f0_min=50, f0_max=1100, sampling_rate=44100):
12
+ super().__init__(hop_length, f0_min, f0_max, sampling_rate)
13
+
14
+ def compute_f0(
15
+ self,
16
+ wav: np.ndarray,
17
+ p_len: Optional[int] = None,
18
+ filter_radius: Optional[Union[int, float]] = None,
19
+ ):
20
+ if p_len is None:
21
+ p_len = wav.shape[0] // self.hop_length
22
+ f0, t = pyworld.harvest(
23
+ wav.astype(np.double),
24
+ fs=self.sampling_rate,
25
+ f0_ceil=self.f0_max,
26
+ f0_floor=self.f0_min,
27
+ frame_period=1000 * self.hop_length / self.sampling_rate,
28
+ )
29
+ f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate)
30
+ if filter_radius is not None and filter_radius > 2:
31
+ f0 = signal.medfilt(f0, filter_radius)
32
+ return self._interpolate_f0(self._resize_f0(f0, p_len))[0]
rvc/f0/mel.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import numpy as np
5
+ from librosa.filters import mel
6
+
7
+ from .stft import STFT
8
+
9
+
10
+ class MelSpectrogram(torch.nn.Module):
11
+ def __init__(
12
+ self,
13
+ is_half: bool,
14
+ n_mel_channels: int,
15
+ sampling_rate: int,
16
+ win_length: int,
17
+ hop_length: int,
18
+ n_fft: Optional[int] = None,
19
+ mel_fmin: int = 0,
20
+ mel_fmax: int = None,
21
+ clamp: float = 1e-5,
22
+ device=torch.device("cpu"),
23
+ ):
24
+ super().__init__()
25
+ if n_fft is None:
26
+ n_fft = win_length
27
+ mel_basis = mel(
28
+ sr=sampling_rate,
29
+ n_fft=n_fft,
30
+ n_mels=n_mel_channels,
31
+ fmin=mel_fmin,
32
+ fmax=mel_fmax,
33
+ htk=True,
34
+ )
35
+ mel_basis = torch.from_numpy(mel_basis).float()
36
+ self.register_buffer("mel_basis", mel_basis)
37
+ self.n_fft = n_fft
38
+ self.hop_length = hop_length
39
+ self.win_length = win_length
40
+ self.clamp = clamp
41
+ self.is_half = is_half
42
+
43
+ self.stft = STFT(
44
+ filter_length=n_fft,
45
+ hop_length=hop_length,
46
+ win_length=win_length,
47
+ window="hann",
48
+ use_torch_stft="privateuseone" not in str(device),
49
+ ).to(device)
50
+
51
+ def forward(
52
+ self,
53
+ audio: torch.Tensor,
54
+ keyshift=0,
55
+ speed=1,
56
+ center=True,
57
+ ):
58
+ factor = 2 ** (keyshift / 12)
59
+ win_length_new = int(np.round(self.win_length * factor))
60
+ magnitude = self.stft(audio, keyshift, speed, center)
61
+ if keyshift != 0:
62
+ size = self.n_fft // 2 + 1
63
+ resize = magnitude.size(1)
64
+ if resize < size:
65
+ magnitude = torch.nn.functional.pad(magnitude, (0, 0, 0, size - resize))
66
+ magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
67
+ mel_output = torch.matmul(self.mel_basis, magnitude)
68
+ if self.is_half:
69
+ mel_output = mel_output.half()
70
+ log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
71
+ return log_mel_spec
rvc/f0/models.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def get_rmvpe(
5
+ model_path="assets/rmvpe/rmvpe.pt", device=torch.device("cpu"), is_half=False
6
+ ):
7
+ from rvc.f0.e2e import E2E
8
+
9
+ model = E2E(4, 1, (2, 2))
10
+ ckpt = torch.load(model_path, map_location=device)
11
+ model.load_state_dict(ckpt)
12
+ model.eval()
13
+ if is_half:
14
+ model = model.half()
15
+ model = model.to(device)
16
+ return model
rvc/f0/pm.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional
2
+
3
+ import numpy as np
4
+ import parselmouth
5
+
6
+ from .f0 import F0Predictor
7
+
8
+
9
+ class PM(F0Predictor):
10
+ def __init__(self, hop_length=512, f0_min=50, f0_max=1100, sampling_rate=44100):
11
+ super().__init__(hop_length, f0_min, f0_max, sampling_rate)
12
+
13
+ def compute_f0(
14
+ self,
15
+ wav: np.ndarray,
16
+ p_len: Optional[int] = None,
17
+ filter_radius: Optional[int] = None,
18
+ ):
19
+ x = wav
20
+ if p_len is None:
21
+ p_len = x.shape[0] // self.hop_length
22
+ else:
23
+ assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error"
24
+ time_step = self.hop_length / self.sampling_rate * 1000
25
+ f0 = (
26
+ parselmouth.Sound(x, self.sampling_rate)
27
+ .to_pitch_ac(
28
+ time_step=time_step / 1000,
29
+ voicing_threshold=0.6,
30
+ pitch_floor=self.f0_min,
31
+ pitch_ceiling=self.f0_max,
32
+ )
33
+ .selected_array["frequency"]
34
+ )
35
+
36
+ pad_size = (p_len - len(f0) + 1) // 2
37
+ if pad_size > 0 or p_len - len(f0) - pad_size > 0:
38
+ f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
39
+ return self._interpolate_f0(f0)[0]
rvc/f0/rmvpe.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import os
3
+ from typing import Any, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from rvc.jit import load_inputs, get_jit_model, export_jit_model, save_pickle
10
+
11
+ from .mel import MelSpectrogram
12
+ from .f0 import F0Predictor
13
+ from .models import get_rmvpe
14
+
15
+
16
+ def rmvpe_jit_export(
17
+ model_path: str,
18
+ mode: str = "script",
19
+ inputs_path: str = None,
20
+ save_path: str = None,
21
+ device=torch.device("cpu"),
22
+ is_half=False,
23
+ ):
24
+ if not save_path:
25
+ save_path = model_path.rstrip(".pth")
26
+ save_path += ".half.jit" if is_half else ".jit"
27
+ if "cuda" in str(device) and ":" not in str(device):
28
+ device = torch.device("cuda:0")
29
+
30
+ model = get_rmvpe(model_path, device, is_half)
31
+ inputs = None
32
+ if mode == "trace":
33
+ inputs = load_inputs(inputs_path, device, is_half)
34
+ ckpt = export_jit_model(model, mode, inputs, device, is_half)
35
+ ckpt["device"] = str(device)
36
+ save_pickle(ckpt, save_path)
37
+ return ckpt
38
+
39
+
40
+ class RMVPE(F0Predictor):
41
+ def __init__(
42
+ self,
43
+ model_path: str,
44
+ is_half: bool,
45
+ device: str,
46
+ use_jit=False,
47
+ ):
48
+ hop_length = 160
49
+ f0_min = 30
50
+ f0_max = 8000
51
+ sampling_rate = 16000
52
+
53
+ super().__init__(
54
+ hop_length,
55
+ f0_min,
56
+ f0_max,
57
+ sampling_rate,
58
+ device,
59
+ )
60
+
61
+ self.is_half = is_half
62
+ cents_mapping = 20 * np.arange(360) + 1997.3794084376191
63
+ self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
64
+
65
+ self.mel_extractor = MelSpectrogram(
66
+ is_half=is_half,
67
+ n_mel_channels=128,
68
+ sampling_rate=sampling_rate,
69
+ win_length=1024,
70
+ hop_length=hop_length,
71
+ mel_fmin=f0_min,
72
+ mel_fmax=f0_max,
73
+ device=self.device,
74
+ ).to(self.device)
75
+
76
+ if "privateuseone" in str(self.device):
77
+ import onnxruntime as ort
78
+
79
+ self.model = ort.InferenceSession(
80
+ "%s/rmvpe.onnx" % os.environ["rmvpe_root"],
81
+ providers=["DmlExecutionProvider"],
82
+ )
83
+ else:
84
+
85
+ def rmvpe_jit_model():
86
+ ckpt = get_jit_model(model_path, is_half, self.device, rmvpe_jit_export)
87
+ model = torch.jit.load(BytesIO(ckpt["model"]), map_location=self.device)
88
+ model = model.to(self.device)
89
+ return model
90
+
91
+ if use_jit and not (is_half and "cpu" in str(self.device)):
92
+ self.model = rmvpe_jit_model()
93
+ else:
94
+ self.model = get_rmvpe(model_path, self.device, is_half)
95
+
96
+ def compute_f0(
97
+ self,
98
+ wav: np.ndarray,
99
+ p_len: Optional[int] = None,
100
+ filter_radius: Optional[Union[int, float]] = None,
101
+ ):
102
+ if p_len is None:
103
+ p_len = wav.shape[0] // self.hop_length
104
+ if not torch.is_tensor(wav):
105
+ wav = torch.from_numpy(wav)
106
+ mel = self.mel_extractor(wav.float().to(self.device).unsqueeze(0), center=True)
107
+ hidden = self._mel2hidden(mel)
108
+ if "privateuseone" not in str(self.device):
109
+ hidden = hidden.squeeze(0).cpu().numpy()
110
+ else:
111
+ hidden = hidden[0]
112
+ if self.is_half == True:
113
+ hidden = hidden.astype("float32")
114
+
115
+ f0 = self._decode(hidden, thred=filter_radius)
116
+
117
+ return self._interpolate_f0(self._resize_f0(f0, p_len))[0]
118
+
119
+ def _to_local_average_cents(self, salience, threshold=0.05):
120
+ center = np.argmax(salience, axis=1) # 帧长#index
121
+ salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368
122
+ center += 4
123
+ todo_salience = []
124
+ todo_cents_mapping = []
125
+ starts = center - 4
126
+ ends = center + 5
127
+ for idx in range(salience.shape[0]):
128
+ todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
129
+ todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
130
+ todo_salience = np.array(todo_salience) # 帧长,9
131
+ todo_cents_mapping = np.array(todo_cents_mapping) # 帧长,9
132
+ product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
133
+ weight_sum = np.sum(todo_salience, 1) # 帧长
134
+ devided = product_sum / weight_sum # 帧长
135
+ maxx = np.max(salience, axis=1) # 帧长
136
+ devided[maxx <= threshold] = 0
137
+ return devided
138
+
139
+ def _mel2hidden(self, mel):
140
+ with torch.no_grad():
141
+ n_frames = mel.shape[-1]
142
+ n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
143
+ if n_pad > 0:
144
+ mel = F.pad(mel, (0, n_pad), mode="constant")
145
+ if "privateuseone" in str(self.device):
146
+ onnx_input_name = self.model.get_inputs()[0].name
147
+ onnx_outputs_names = self.model.get_outputs()[0].name
148
+ hidden = self.model.run(
149
+ [onnx_outputs_names],
150
+ input_feed={onnx_input_name: mel.cpu().numpy()},
151
+ )[0]
152
+ else:
153
+ mel = mel.half() if self.is_half else mel.float()
154
+ hidden = self.model(mel)
155
+ return hidden[:, :n_frames]
156
+
157
+ def _decode(self, hidden, thred=0.03):
158
+ cents_pred = self._to_local_average_cents(hidden, threshold=thred)
159
+ f0 = 10 * (2 ** (cents_pred / 1200))
160
+ f0[f0 == 10] = 0
161
+ # f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
162
+ return f0
rvc/f0/stft.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from librosa.util import pad_center
7
+ from scipy.signal import get_window
8
+
9
+
10
+ class STFT(torch.nn.Module):
11
+ def __init__(
12
+ self,
13
+ filter_length=1024,
14
+ hop_length=512,
15
+ win_length: Optional[int] = None,
16
+ window="hann",
17
+ use_torch_stft=True,
18
+ ):
19
+ """
20
+ This module implements an STFT using 1D convolution and 1D transpose convolutions.
21
+ This is a bit tricky so there are some cases that probably won't work as working
22
+ out the same sizes before and after in all overlap add setups is tough. Right now,
23
+ this code should work with hop lengths that are half the filter length (50% overlap
24
+ between frames).
25
+
26
+ Keyword Arguments:
27
+ filter_length {int} -- Length of filters used (default: {1024})
28
+ hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512})
29
+ win_length {[type]} -- Length of the window function applied to each frame (if not specified, it
30
+ equals the filter length). (default: {None})
31
+ window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris)
32
+ (default: {'hann'})
33
+ """
34
+ super(STFT, self).__init__()
35
+ self.filter_length = filter_length
36
+ self.hop_length = hop_length
37
+ self.pad_amount = int(self.filter_length / 2)
38
+ self.win_length = win_length
39
+ self.hann_window = {}
40
+ self.use_torch_stft = use_torch_stft
41
+
42
+ if use_torch_stft:
43
+ return
44
+
45
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
46
+
47
+ cutoff = int((self.filter_length / 2 + 1))
48
+ fourier_basis = np.vstack(
49
+ [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
50
+ )
51
+ forward_basis = torch.FloatTensor(fourier_basis)
52
+ inverse_basis = torch.FloatTensor(np.linalg.pinv(fourier_basis))
53
+
54
+ if win_length is None or not win_length:
55
+ win_length = filter_length
56
+ assert filter_length >= win_length
57
+
58
+ # get window and zero center pad it to filter_length
59
+ fft_window = get_window(window, win_length, fftbins=True)
60
+ fft_window = pad_center(fft_window, size=filter_length)
61
+ fft_window = torch.from_numpy(fft_window).float()
62
+
63
+ # window the bases
64
+ forward_basis *= fft_window
65
+ inverse_basis = (inverse_basis.T * fft_window).T
66
+
67
+ self.register_buffer("forward_basis", forward_basis.float())
68
+ self.register_buffer("inverse_basis", inverse_basis.float())
69
+ self.register_buffer("fft_window", fft_window.float())
70
+
71
+ def __call__(
72
+ self,
73
+ input_data: torch.Tensor,
74
+ keyshift: int = 0,
75
+ speed: int = 1,
76
+ center: bool = True,
77
+ ) -> torch.Tensor:
78
+ return super().__call__(input_data, keyshift, speed, center)
79
+
80
+ def transform(
81
+ self,
82
+ input_data: torch.Tensor,
83
+ return_phase=False,
84
+ ) -> Tuple[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]]:
85
+ """Take input data (audio) to STFT domain.
86
+
87
+ Arguments:
88
+ input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
89
+
90
+ Returns:
91
+ magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
92
+ num_frequencies, num_frames)
93
+ phase {tensor} -- Phase of STFT with shape (num_batch,
94
+ num_frequencies, num_frames)
95
+ """
96
+ input_data = F.pad(
97
+ input_data,
98
+ (self.pad_amount, self.pad_amount),
99
+ mode="reflect",
100
+ )
101
+ forward_transform = input_data.unfold(
102
+ 1, self.filter_length, self.hop_length
103
+ ).permute(0, 2, 1)
104
+ forward_transform = torch.matmul(self.forward_basis, forward_transform)
105
+ cutoff = int((self.filter_length / 2) + 1)
106
+ real_part = forward_transform[:, :cutoff, :]
107
+ imag_part = forward_transform[:, cutoff:, :]
108
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
109
+ if return_phase:
110
+ phase = torch.atan2(imag_part.data, real_part.data)
111
+ return magnitude, phase
112
+ else:
113
+ return magnitude
114
+
115
+ def inverse(
116
+ self,
117
+ magnitude: torch.Tensor,
118
+ phase: torch.Tensor,
119
+ ) -> torch.Tensor:
120
+ """Call the inverse STFT (iSTFT), given magnitude and phase tensors produced
121
+ by the ```transform``` function.
122
+
123
+ Arguments:
124
+ magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
125
+ num_frequencies, num_frames)
126
+ phase {tensor} -- Phase of STFT with shape (num_batch,
127
+ num_frequencies, num_frames)
128
+
129
+ Returns:
130
+ inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of
131
+ shape (num_batch, num_samples)
132
+ """
133
+ cat = torch.cat(
134
+ [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
135
+ )
136
+ fold = torch.nn.Fold(
137
+ output_size=(1, (cat.size(-1) - 1) * self.hop_length + self.filter_length),
138
+ kernel_size=(1, self.filter_length),
139
+ stride=(1, self.hop_length),
140
+ )
141
+ inverse_transform = torch.matmul(self.inverse_basis, cat)
142
+ inverse_transform: torch.Tensor = fold(inverse_transform)[
143
+ :, 0, 0, self.pad_amount : -self.pad_amount
144
+ ]
145
+ window_square_sum = (
146
+ self.fft_window.pow(2).repeat(cat.size(-1), 1).T.unsqueeze(0)
147
+ )
148
+ window_square_sum = fold(window_square_sum)[
149
+ :, 0, 0, self.pad_amount : -self.pad_amount
150
+ ]
151
+ inverse_transform /= window_square_sum
152
+ return inverse_transform
153
+
154
+ def forward(
155
+ self,
156
+ input_data: torch.Tensor,
157
+ keyshift: int = 0,
158
+ speed: int = 1,
159
+ center: bool = True,
160
+ ) -> torch.Tensor:
161
+ factor = 2 ** (keyshift / 12)
162
+ n_fft_new = int(np.round(self.filter_length * factor))
163
+ win_length_new = int(np.round(self.win_length * factor))
164
+ hop_length_new = int(np.round(self.hop_length * speed))
165
+ if self.use_torch_stft:
166
+ keyshift_key = str(keyshift) + "_" + str(input_data.device)
167
+ if keyshift_key not in self.hann_window:
168
+ self.hann_window[keyshift_key] = torch.hann_window(
169
+ self.win_length,
170
+ ).to(input_data.device)
171
+ fft = torch.stft(
172
+ input_data,
173
+ n_fft=n_fft_new,
174
+ hop_length=hop_length_new,
175
+ win_length=win_length_new,
176
+ window=self.hann_window[keyshift_key],
177
+ center=center,
178
+ return_complex=True,
179
+ )
180
+ return torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
181
+ return self.transform(input_data)
182
+ """Take input data (audio) to STFT domain and then back to audio.
183
+
184
+ Arguments:
185
+ input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
186
+
187
+ Returns:
188
+ reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of
189
+ shape (num_batch, num_samples)
190
+ reconstruction = self.inverse(
191
+ self.transform(input_data, return_phase=True),
192
+ )
193
+ return reconstruction
194
+ """
rvc/hubert.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from typing import Optional, Tuple
4
+
5
+ from fairseq.checkpoint_utils import load_model_ensemble_and_task
6
+ from fairseq.utils import index_put
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ # @torch.jit.script
13
+ def pad_to_multiple(x, multiple, dim=-1, value=0):
14
+ # Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41
15
+ if x is None:
16
+ return None, 0
17
+ tsz = x.size(dim)
18
+ m = tsz / multiple
19
+ remainder = math.ceil(m) * multiple - tsz
20
+ if int(tsz % multiple) == 0:
21
+ return x, 0
22
+ pad_offset = (0,) * (-1 - dim) * 2
23
+
24
+ return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
25
+
26
+
27
+ def extract_features(
28
+ self,
29
+ x,
30
+ padding_mask=None,
31
+ tgt_layer=None,
32
+ min_layer=0,
33
+ ):
34
+ if padding_mask is not None:
35
+ x = index_put(x, padding_mask, 0)
36
+
37
+ x_conv = self.pos_conv(x.transpose(1, 2))
38
+ x_conv = x_conv.transpose(1, 2)
39
+ x = x + x_conv
40
+
41
+ if not self.layer_norm_first:
42
+ x = self.layer_norm(x)
43
+
44
+ # pad to the sequence length dimension
45
+ x, pad_length = pad_to_multiple(x, self.required_seq_len_multiple, dim=-2, value=0)
46
+ if pad_length > 0 and padding_mask is None:
47
+ padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
48
+ padding_mask[:, -pad_length:] = True
49
+ else:
50
+ padding_mask, _ = pad_to_multiple(
51
+ padding_mask, self.required_seq_len_multiple, dim=-1, value=True
52
+ )
53
+ x = F.dropout(x, p=self.dropout, training=self.training)
54
+
55
+ # B x T x C -> T x B x C
56
+ x = x.transpose(0, 1)
57
+
58
+ layer_results = []
59
+ r = None
60
+ for i, layer in enumerate(self.layers):
61
+ dropout_probability = np.random.random() if self.layerdrop > 0 else 1
62
+ if not self.training or (dropout_probability > self.layerdrop):
63
+ x, (z, lr) = layer(
64
+ x, self_attn_padding_mask=padding_mask, need_weights=False
65
+ )
66
+ if i >= min_layer:
67
+ layer_results.append((x, z, lr))
68
+ if i == tgt_layer:
69
+ r = x
70
+ break
71
+
72
+ if r is not None:
73
+ x = r
74
+
75
+ # T x B x C -> B x T x C
76
+ x = x.transpose(0, 1)
77
+
78
+ # undo paddding
79
+ if pad_length > 0:
80
+ x = x[:, :-pad_length]
81
+
82
+ def undo_pad(a, b, c):
83
+ return (
84
+ a[:-pad_length],
85
+ b[:-pad_length] if b is not None else b,
86
+ c[:-pad_length],
87
+ )
88
+
89
+ layer_results = [undo_pad(*u) for u in layer_results]
90
+
91
+ return x, layer_results
92
+
93
+
94
+ def compute_mask_indices(
95
+ shape: Tuple[int, int],
96
+ padding_mask: Optional[torch.Tensor],
97
+ mask_prob: float,
98
+ mask_length: int,
99
+ mask_type: str = "static",
100
+ mask_other: float = 0.0,
101
+ min_masks: int = 0,
102
+ no_overlap: bool = False,
103
+ min_space: int = 0,
104
+ require_same_masks: bool = True,
105
+ mask_dropout: float = 0.0,
106
+ ) -> torch.Tensor:
107
+ """
108
+ Computes random mask spans for a given shape
109
+
110
+ Args:
111
+ shape: the the shape for which to compute masks.
112
+ should be of size 2 where first element is batch size and 2nd is timesteps
113
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
114
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
115
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
116
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
117
+ mask_type: how to compute mask lengths
118
+ static = fixed size
119
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
120
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
121
+ poisson = sample from possion distribution with lambda = mask length
122
+ min_masks: minimum number of masked spans
123
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
124
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
125
+ require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
126
+ mask_dropout: randomly dropout this percentage of masks in each example
127
+ """
128
+
129
+ bsz, all_sz = shape
130
+ mask = torch.full((bsz, all_sz), False)
131
+
132
+ all_num_mask = int(
133
+ # add a random number for probabilistic rounding
134
+ mask_prob * all_sz / float(mask_length)
135
+ + torch.rand([1]).item()
136
+ )
137
+
138
+ all_num_mask = max(min_masks, all_num_mask)
139
+
140
+ mask_idcs = []
141
+ for i in range(bsz):
142
+ if padding_mask is not None:
143
+ sz = all_sz - padding_mask[i].long().sum().item()
144
+ num_mask = int(mask_prob * sz / float(mask_length) + np.random.rand())
145
+ num_mask = max(min_masks, num_mask)
146
+ else:
147
+ sz = all_sz
148
+ num_mask = all_num_mask
149
+
150
+ if mask_type == "static":
151
+ lengths = torch.full([num_mask], mask_length)
152
+ elif mask_type == "uniform":
153
+ lengths = torch.randint(mask_other, mask_length * 2 + 1, size=[num_mask])
154
+ elif mask_type == "normal":
155
+ lengths = torch.normal(mask_length, mask_other, size=[num_mask])
156
+ lengths = [max(1, int(round(x))) for x in lengths]
157
+ else:
158
+ raise Exception("unknown mask selection " + mask_type)
159
+
160
+ if sum(lengths) == 0:
161
+ lengths[0] = min(mask_length, sz - 1)
162
+
163
+ if no_overlap:
164
+ mask_idc = []
165
+
166
+ def arrange(s, e, length, keep_length):
167
+ span_start = torch.randint(low=s, high=e - length, size=[1]).item()
168
+ mask_idc.extend(span_start + i for i in range(length))
169
+
170
+ new_parts = []
171
+ if span_start - s - min_space >= keep_length:
172
+ new_parts.append((s, span_start - min_space + 1))
173
+ if e - span_start - length - min_space > keep_length:
174
+ new_parts.append((span_start + length + min_space, e))
175
+ return new_parts
176
+
177
+ parts = [(0, sz)]
178
+ min_length = min(lengths)
179
+ for length in sorted(lengths, reverse=True):
180
+ t = [e - s if e - s >= length + min_space else 0 for s, e in parts]
181
+ lens = torch.asarray(t, dtype=torch.int)
182
+ l_sum = torch.sum(lens)
183
+ if l_sum == 0:
184
+ break
185
+ probs = lens / torch.sum(lens)
186
+ c = torch.multinomial(probs.float(), len(parts)).item()
187
+ s, e = parts.pop(c)
188
+ parts.extend(arrange(s, e, length, min_length))
189
+ mask_idc = torch.asarray(mask_idc)
190
+ else:
191
+ min_len = min(lengths)
192
+ if sz - min_len <= num_mask:
193
+ min_len = sz - num_mask - 1
194
+ mask_idc = torch.asarray(
195
+ random.sample([i for i in range(sz - min_len)], num_mask)
196
+ )
197
+ mask_idc = torch.asarray(
198
+ [
199
+ mask_idc[j] + offset
200
+ for j in range(len(mask_idc))
201
+ for offset in range(lengths[j])
202
+ ]
203
+ )
204
+
205
+ mask_idcs.append(torch.unique(mask_idc[mask_idc < sz]))
206
+
207
+ min_len = min([len(m) for m in mask_idcs])
208
+ for i, mask_idc in enumerate(mask_idcs):
209
+ if isinstance(mask_idc, torch.Tensor):
210
+ mask_idc = torch.asarray(mask_idc, dtype=torch.float)
211
+ if len(mask_idc) > min_len and require_same_masks:
212
+ mask_idc = torch.asarray(
213
+ random.sample([i for i in range(mask_idc)], min_len)
214
+ )
215
+ if mask_dropout > 0:
216
+ num_holes = int(round(len(mask_idc) * mask_dropout))
217
+ mask_idc = torch.asarray(
218
+ random.sample([i for i in range(mask_idc)], len(mask_idc) - num_holes)
219
+ )
220
+
221
+ mask[i, mask_idc.int()] = True
222
+
223
+ return mask
224
+
225
+
226
+ def apply_mask(self, x, padding_mask, target_list):
227
+ B, T, C = x.shape
228
+ torch.zeros_like(x)
229
+ if self.mask_prob > 0:
230
+ mask_indices = compute_mask_indices(
231
+ (B, T),
232
+ padding_mask,
233
+ self.mask_prob,
234
+ self.mask_length,
235
+ self.mask_selection,
236
+ self.mask_other,
237
+ min_masks=2,
238
+ no_overlap=self.no_mask_overlap,
239
+ min_space=self.mask_min_space,
240
+ )
241
+ mask_indices = mask_indices.to(x.device)
242
+ x[mask_indices] = self.mask_emb
243
+ else:
244
+ mask_indices = None
245
+
246
+ if self.mask_channel_prob > 0:
247
+ mask_channel_indices = compute_mask_indices(
248
+ (B, C),
249
+ None,
250
+ self.mask_channel_prob,
251
+ self.mask_channel_length,
252
+ self.mask_channel_selection,
253
+ self.mask_channel_other,
254
+ no_overlap=self.no_mask_channel_overlap,
255
+ min_space=self.mask_channel_min_space,
256
+ )
257
+ mask_channel_indices = (
258
+ mask_channel_indices.to(x.device).unsqueeze(1).expand(-1, T, -1)
259
+ )
260
+ x[mask_channel_indices] = 0
261
+
262
+ return x, mask_indices
263
+
264
+
265
+ def get_hubert(model_path="assets/hubert/hubert_base.pt", device=torch.device("cpu")):
266
+ models, _, _ = load_model_ensemble_and_task(
267
+ [model_path],
268
+ suffix="",
269
+ )
270
+ hubert_model = models[0]
271
+ hubert_model = hubert_model.to(device)
272
+
273
+ def _apply_mask(x, padding_mask, target_list):
274
+ return apply_mask(hubert_model, x, padding_mask, target_list)
275
+
276
+ hubert_model.apply_mask = _apply_mask
277
+
278
+ def _extract_features(
279
+ x,
280
+ padding_mask=None,
281
+ tgt_layer=None,
282
+ min_layer=0,
283
+ ):
284
+ return extract_features(
285
+ hubert_model.encoder,
286
+ x,
287
+ padding_mask=padding_mask,
288
+ tgt_layer=tgt_layer,
289
+ min_layer=min_layer,
290
+ )
291
+
292
+ hubert_model.encoder.extract_features = _extract_features
293
+
294
+ hubert_model._forward = hubert_model.forward
295
+
296
+ def hubert_extract_features(
297
+ self,
298
+ source: torch.Tensor,
299
+ padding_mask: Optional[torch.Tensor] = None,
300
+ mask: bool = False,
301
+ ret_conv: bool = False,
302
+ output_layer: Optional[int] = None,
303
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
304
+ res = self._forward(
305
+ source,
306
+ padding_mask=padding_mask,
307
+ mask=mask,
308
+ features_only=True,
309
+ output_layer=output_layer,
310
+ )
311
+ feature = res["features"] if ret_conv else res["x"]
312
+ return feature, res["padding_mask"]
313
+
314
+ def _hubert_extract_features(
315
+ source: torch.Tensor,
316
+ padding_mask: Optional[torch.Tensor] = None,
317
+ mask: bool = False,
318
+ ret_conv: bool = False,
319
+ output_layer: Optional[int] = None,
320
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
321
+ return hubert_extract_features(
322
+ hubert_model, source, padding_mask, mask, ret_conv, output_layer
323
+ )
324
+
325
+ hubert_model.extract_features = _hubert_extract_features
326
+
327
+ def infer(source, padding_mask, output_layer: torch.Tensor):
328
+ output_layer = output_layer.item()
329
+ logits = hubert_model.extract_features(
330
+ source=source, padding_mask=padding_mask, output_layer=output_layer
331
+ )
332
+ feats = hubert_model.final_proj(logits[0]) if output_layer == 9 else logits[0]
333
+ return feats
334
+
335
+ hubert_model.infer = infer
336
+ # hubert_model.forward=infer
337
+ # hubert_model.forward
338
+
339
+ return hubert_model
rvc/ipex/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import torch
3
+
4
+ if torch.xpu.is_available():
5
+ from .init import ipex_init
6
+
7
+ ipex_init()
8
+ from .gradscaler import gradscaler_init
9
+ except Exception: # pylint: disable=broad-exception-caught
10
+ pass
rvc/ipex/attention.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
3
+
4
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long
5
+
6
+ original_torch_bmm = torch.bmm
7
+
8
+
9
+ def torch_bmm(input, mat2, *, out=None):
10
+ if input.dtype != mat2.dtype:
11
+ mat2 = mat2.to(input.dtype)
12
+
13
+ # ARC GPUs can't allocate more than 4GB to a single block, Slice it:
14
+ batch_size_attention, input_tokens, mat2_shape = (
15
+ input.shape[0],
16
+ input.shape[1],
17
+ mat2.shape[2],
18
+ )
19
+ block_multiply = input.element_size()
20
+ slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply
21
+ block_size = batch_size_attention * slice_block_size
22
+
23
+ split_slice_size = batch_size_attention
24
+ if block_size > 4:
25
+ do_split = True
26
+ # Find something divisible with the input_tokens
27
+ while (split_slice_size * slice_block_size) > 4:
28
+ split_slice_size = split_slice_size // 2
29
+ if split_slice_size <= 1:
30
+ split_slice_size = 1
31
+ break
32
+ else:
33
+ do_split = False
34
+
35
+ split_2_slice_size = input_tokens
36
+ if split_slice_size * slice_block_size > 4:
37
+ slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply
38
+ do_split_2 = True
39
+ # Find something divisible with the input_tokens
40
+ while (split_2_slice_size * slice_block_size2) > 4:
41
+ split_2_slice_size = split_2_slice_size // 2
42
+ if split_2_slice_size <= 1:
43
+ split_2_slice_size = 1
44
+ break
45
+ else:
46
+ do_split_2 = False
47
+
48
+ if do_split:
49
+ hidden_states = torch.zeros(
50
+ input.shape[0],
51
+ input.shape[1],
52
+ mat2.shape[2],
53
+ device=input.device,
54
+ dtype=input.dtype,
55
+ )
56
+ for i in range(batch_size_attention // split_slice_size):
57
+ start_idx = i * split_slice_size
58
+ end_idx = (i + 1) * split_slice_size
59
+ if do_split_2:
60
+ for i2 in range(
61
+ input_tokens // split_2_slice_size
62
+ ): # pylint: disable=invalid-name
63
+ start_idx_2 = i2 * split_2_slice_size
64
+ end_idx_2 = (i2 + 1) * split_2_slice_size
65
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = (
66
+ original_torch_bmm(
67
+ input[start_idx:end_idx, start_idx_2:end_idx_2],
68
+ mat2[start_idx:end_idx, start_idx_2:end_idx_2],
69
+ out=out,
70
+ )
71
+ )
72
+ else:
73
+ hidden_states[start_idx:end_idx] = original_torch_bmm(
74
+ input[start_idx:end_idx], mat2[start_idx:end_idx], out=out
75
+ )
76
+ else:
77
+ return original_torch_bmm(input, mat2, out=out)
78
+ return hidden_states
79
+
80
+
81
+ original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
82
+
83
+
84
+ def scaled_dot_product_attention(
85
+ query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
86
+ ):
87
+ # ARC GPUs can't allocate more than 4GB to a single block, Slice it:
88
+ if len(query.shape) == 3:
89
+ batch_size_attention, query_tokens, shape_four = query.shape
90
+ shape_one = 1
91
+ no_shape_one = True
92
+ else:
93
+ shape_one, batch_size_attention, query_tokens, shape_four = query.shape
94
+ no_shape_one = False
95
+
96
+ block_multiply = query.element_size()
97
+ slice_block_size = (
98
+ shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply
99
+ )
100
+ block_size = batch_size_attention * slice_block_size
101
+
102
+ split_slice_size = batch_size_attention
103
+ if block_size > 4:
104
+ do_split = True
105
+ # Find something divisible with the shape_one
106
+ while (split_slice_size * slice_block_size) > 4:
107
+ split_slice_size = split_slice_size // 2
108
+ if split_slice_size <= 1:
109
+ split_slice_size = 1
110
+ break
111
+ else:
112
+ do_split = False
113
+
114
+ split_2_slice_size = query_tokens
115
+ if split_slice_size * slice_block_size > 4:
116
+ slice_block_size2 = (
117
+ shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply
118
+ )
119
+ do_split_2 = True
120
+ # Find something divisible with the batch_size_attention
121
+ while (split_2_slice_size * slice_block_size2) > 4:
122
+ split_2_slice_size = split_2_slice_size // 2
123
+ if split_2_slice_size <= 1:
124
+ split_2_slice_size = 1
125
+ break
126
+ else:
127
+ do_split_2 = False
128
+
129
+ if do_split:
130
+ hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
131
+ for i in range(batch_size_attention // split_slice_size):
132
+ start_idx = i * split_slice_size
133
+ end_idx = (i + 1) * split_slice_size
134
+ if do_split_2:
135
+ for i2 in range(
136
+ query_tokens // split_2_slice_size
137
+ ): # pylint: disable=invalid-name
138
+ start_idx_2 = i2 * split_2_slice_size
139
+ end_idx_2 = (i2 + 1) * split_2_slice_size
140
+ if no_shape_one:
141
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = (
142
+ original_scaled_dot_product_attention(
143
+ query[start_idx:end_idx, start_idx_2:end_idx_2],
144
+ key[start_idx:end_idx, start_idx_2:end_idx_2],
145
+ value[start_idx:end_idx, start_idx_2:end_idx_2],
146
+ attn_mask=(
147
+ attn_mask[start_idx:end_idx, start_idx_2:end_idx_2]
148
+ if attn_mask is not None
149
+ else attn_mask
150
+ ),
151
+ dropout_p=dropout_p,
152
+ is_causal=is_causal,
153
+ )
154
+ )
155
+ else:
156
+ hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = (
157
+ original_scaled_dot_product_attention(
158
+ query[:, start_idx:end_idx, start_idx_2:end_idx_2],
159
+ key[:, start_idx:end_idx, start_idx_2:end_idx_2],
160
+ value[:, start_idx:end_idx, start_idx_2:end_idx_2],
161
+ attn_mask=(
162
+ attn_mask[
163
+ :, start_idx:end_idx, start_idx_2:end_idx_2
164
+ ]
165
+ if attn_mask is not None
166
+ else attn_mask
167
+ ),
168
+ dropout_p=dropout_p,
169
+ is_causal=is_causal,
170
+ )
171
+ )
172
+ else:
173
+ if no_shape_one:
174
+ hidden_states[start_idx:end_idx] = (
175
+ original_scaled_dot_product_attention(
176
+ query[start_idx:end_idx],
177
+ key[start_idx:end_idx],
178
+ value[start_idx:end_idx],
179
+ attn_mask=(
180
+ attn_mask[start_idx:end_idx]
181
+ if attn_mask is not None
182
+ else attn_mask
183
+ ),
184
+ dropout_p=dropout_p,
185
+ is_causal=is_causal,
186
+ )
187
+ )
188
+ else:
189
+ hidden_states[:, start_idx:end_idx] = (
190
+ original_scaled_dot_product_attention(
191
+ query[:, start_idx:end_idx],
192
+ key[:, start_idx:end_idx],
193
+ value[:, start_idx:end_idx],
194
+ attn_mask=(
195
+ attn_mask[:, start_idx:end_idx]
196
+ if attn_mask is not None
197
+ else attn_mask
198
+ ),
199
+ dropout_p=dropout_p,
200
+ is_causal=is_causal,
201
+ )
202
+ )
203
+ else:
204
+ return original_scaled_dot_product_attention(
205
+ query,
206
+ key,
207
+ value,
208
+ attn_mask=attn_mask,
209
+ dropout_p=dropout_p,
210
+ is_causal=is_causal,
211
+ )
212
+ return hidden_states
213
+
214
+
215
+ def attention_init():
216
+ # ARC GPUs can't allocate more than 4GB to a single block:
217
+ torch.bmm = torch_bmm
218
+ torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
rvc/ipex/gradscaler.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+
3
+ import torch
4
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
5
+ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
6
+
7
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long
8
+
9
+ OptState = ipex.cpu.autocast._grad_scaler.OptState
10
+ _MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
11
+ _refresh_per_optimizer_state = (
12
+ ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
13
+ )
14
+
15
+
16
+ def _unscale_grads_(
17
+ self, optimizer, inv_scale, found_inf, allow_fp16
18
+ ): # pylint: disable=unused-argument
19
+ per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
20
+ per_device_found_inf = _MultiDeviceReplicator(found_inf)
21
+
22
+ # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
23
+ # There could be hundreds of grads, so we'd like to iterate through them just once.
24
+ # However, we don't know their devices or dtypes in advance.
25
+
26
+ # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
27
+ # Google says mypy struggles with defaultdicts type annotations.
28
+ per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
29
+ # sync grad to master weight
30
+ if hasattr(optimizer, "sync_grad"):
31
+ optimizer.sync_grad()
32
+ with torch.no_grad():
33
+ for group in optimizer.param_groups:
34
+ for param in group["params"]:
35
+ if param.grad is None:
36
+ continue
37
+ if (not allow_fp16) and param.grad.dtype == torch.float16:
38
+ raise ValueError("Attempting to unscale FP16 gradients.")
39
+ if param.grad.is_sparse:
40
+ # is_coalesced() == False means the sparse grad has values with duplicate indices.
41
+ # coalesce() deduplicates indices and adds all values that have the same index.
42
+ # For scaled fp16 values, there's a good chance coalescing will cause overflow,
43
+ # so we should check the coalesced _values().
44
+ if param.grad.dtype is torch.float16:
45
+ param.grad = param.grad.coalesce()
46
+ to_unscale = param.grad._values()
47
+ else:
48
+ to_unscale = param.grad
49
+
50
+ # -: is there a way to split by device and dtype without appending in the inner loop?
51
+ to_unscale = to_unscale.to("cpu")
52
+ per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(
53
+ to_unscale
54
+ )
55
+
56
+ for _, per_dtype_grads in per_device_and_dtype_grads.items():
57
+ for grads in per_dtype_grads.values():
58
+ core._amp_foreach_non_finite_check_and_unscale_(
59
+ grads,
60
+ per_device_found_inf.get("cpu"),
61
+ per_device_inv_scale.get("cpu"),
62
+ )
63
+
64
+ return per_device_found_inf._per_device_tensors
65
+
66
+
67
+ def unscale_(self, optimizer):
68
+ """
69
+ Divides ("unscales") the optimizer's gradient tensors by the scale factor.
70
+ :meth:`unscale_` is optional, serving cases where you need to
71
+ :ref:`modify or inspect gradients<working-with-unscaled-gradients>`
72
+ between the backward pass(es) and :meth:`step`.
73
+ If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
74
+ Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
75
+ ...
76
+ scaler.scale(loss).backward()
77
+ scaler.unscale_(optimizer)
78
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
79
+ scaler.step(optimizer)
80
+ scaler.update()
81
+ Args:
82
+ optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
83
+ .. warning::
84
+ :meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
85
+ and only after all gradients for that optimizer's assigned parameters have been accumulated.
86
+ Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
87
+ .. warning::
88
+ :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
89
+ """
90
+ if not self._enabled:
91
+ return
92
+
93
+ self._check_scale_growth_tracker("unscale_")
94
+
95
+ optimizer_state = self._per_optimizer_states[id(optimizer)]
96
+
97
+ if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
98
+ raise RuntimeError(
99
+ "unscale_() has already been called on this optimizer since the last update()."
100
+ )
101
+ elif optimizer_state["stage"] is OptState.STEPPED:
102
+ raise RuntimeError("unscale_() is being called after step().")
103
+
104
+ # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
105
+ assert self._scale is not None
106
+ inv_scale = (
107
+ self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
108
+ )
109
+ found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device)
110
+
111
+ optimizer_state["found_inf_per_device"] = self._unscale_grads_(
112
+ optimizer, inv_scale, found_inf, False
113
+ )
114
+ optimizer_state["stage"] = OptState.UNSCALED
115
+
116
+
117
+ def update(self, new_scale=None):
118
+ """
119
+ Updates the scale factor.
120
+ If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
121
+ to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
122
+ the scale is multiplied by ``growth_factor`` to increase it.
123
+ Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
124
+ used directly, it's used to fill GradScaler's internal scale tensor. So if
125
+ ``new_scale`` was a tensor, later in-place changes to that tensor will not further
126
+ affect the scale GradScaler uses internally.)
127
+ Args:
128
+ new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor.
129
+ .. warning::
130
+ :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
131
+ been invoked for all optimizers used this iteration.
132
+ """
133
+ if not self._enabled:
134
+ return
135
+
136
+ _scale, _growth_tracker = self._check_scale_growth_tracker("update")
137
+
138
+ if new_scale is not None:
139
+ # Accept a new user-defined scale.
140
+ if isinstance(new_scale, float):
141
+ self._scale.fill_(new_scale) # type: ignore[union-attr]
142
+ else:
143
+ reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
144
+ assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined]
145
+ assert new_scale.numel() == 1, reason
146
+ assert new_scale.requires_grad is False, reason
147
+ self._scale.copy_(new_scale) # type: ignore[union-attr]
148
+ else:
149
+ # Consume shared inf/nan data collected from optimizers to update the scale.
150
+ # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
151
+ found_infs = [
152
+ found_inf.to(device="cpu", non_blocking=True)
153
+ for state in self._per_optimizer_states.values()
154
+ for found_inf in state["found_inf_per_device"].values()
155
+ ]
156
+
157
+ assert len(found_infs) > 0, "No inf checks were recorded prior to update."
158
+
159
+ found_inf_combined = found_infs[0]
160
+ if len(found_infs) > 1:
161
+ for i in range(1, len(found_infs)):
162
+ found_inf_combined += found_infs[i]
163
+
164
+ to_device = _scale.device
165
+ _scale = _scale.to("cpu")
166
+ _growth_tracker = _growth_tracker.to("cpu")
167
+
168
+ core._amp_update_scale_(
169
+ _scale,
170
+ _growth_tracker,
171
+ found_inf_combined,
172
+ self._growth_factor,
173
+ self._backoff_factor,
174
+ self._growth_interval,
175
+ )
176
+
177
+ _scale = _scale.to(to_device)
178
+ _growth_tracker = _growth_tracker.to(to_device)
179
+ # To prepare for next iteration, clear the data collected from optimizers this iteration.
180
+ self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
181
+
182
+
183
+ def gradscaler_init():
184
+ torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
185
+ torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
186
+ torch.xpu.amp.GradScaler.unscale_ = unscale_
187
+ torch.xpu.amp.GradScaler.update = update
188
+ return torch.xpu.amp.GradScaler
rvc/ipex/hijacks.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import importlib
3
+
4
+ import torch
5
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
6
+
7
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
8
+
9
+
10
+ class CondFunc: # pylint: disable=missing-class-docstring
11
+ def __new__(cls, orig_func, sub_func, cond_func):
12
+ self = super(CondFunc, cls).__new__(cls)
13
+ if isinstance(orig_func, str):
14
+ func_path = orig_func.split(".")
15
+ for i in range(len(func_path) - 1, -1, -1):
16
+ try:
17
+ resolved_obj = importlib.import_module(".".join(func_path[:i]))
18
+ break
19
+ except ImportError:
20
+ pass
21
+ for attr_name in func_path[i:-1]:
22
+ resolved_obj = getattr(resolved_obj, attr_name)
23
+ orig_func = getattr(resolved_obj, func_path[-1])
24
+ setattr(
25
+ resolved_obj,
26
+ func_path[-1],
27
+ lambda *args, **kwargs: self(*args, **kwargs),
28
+ )
29
+ self.__init__(orig_func, sub_func, cond_func)
30
+ return lambda *args, **kwargs: self(*args, **kwargs)
31
+
32
+ def __init__(self, orig_func, sub_func, cond_func):
33
+ self.__orig_func = orig_func
34
+ self.__sub_func = sub_func
35
+ self.__cond_func = cond_func
36
+
37
+ def __call__(self, *args, **kwargs):
38
+ if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
39
+ return self.__sub_func(self.__orig_func, *args, **kwargs)
40
+ else:
41
+ return self.__orig_func(*args, **kwargs)
42
+
43
+
44
+ _utils = torch.utils.data._utils
45
+
46
+
47
+ def _shutdown_workers(self):
48
+ if (
49
+ torch.utils.data._utils is None
50
+ or torch.utils.data._utils.python_exit_status is True
51
+ or torch.utils.data._utils.python_exit_status is None
52
+ ):
53
+ return
54
+ if hasattr(self, "_shutdown") and not self._shutdown:
55
+ self._shutdown = True
56
+ try:
57
+ if hasattr(self, "_pin_memory_thread"):
58
+ self._pin_memory_thread_done_event.set()
59
+ self._worker_result_queue.put((None, None))
60
+ self._pin_memory_thread.join()
61
+ self._worker_result_queue.cancel_join_thread()
62
+ self._worker_result_queue.close()
63
+ self._workers_done_event.set()
64
+ for worker_id in range(len(self._workers)):
65
+ if self._persistent_workers or self._workers_status[worker_id]:
66
+ self._mark_worker_as_unavailable(worker_id, shutdown=True)
67
+ for w in self._workers: # pylint: disable=invalid-name
68
+ w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL)
69
+ for q in self._index_queues: # pylint: disable=invalid-name
70
+ q.cancel_join_thread()
71
+ q.close()
72
+ finally:
73
+ if self._worker_pids_set:
74
+ torch.utils.data._utils.signal_handling._remove_worker_pids(id(self))
75
+ self._worker_pids_set = False
76
+ for w in self._workers: # pylint: disable=invalid-name
77
+ if w.is_alive():
78
+ w.terminate()
79
+
80
+
81
+ class DummyDataParallel(
82
+ torch.nn.Module
83
+ ): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
84
+ def __new__(
85
+ cls, module, device_ids=None, output_device=None, dim=0
86
+ ): # pylint: disable=unused-argument
87
+ if isinstance(device_ids, list) and len(device_ids) > 1:
88
+ print("IPEX backend doesn't support DataParallel on multiple XPU devices")
89
+ return module.to("xpu")
90
+
91
+
92
+ def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
93
+ return contextlib.nullcontext()
94
+
95
+
96
+ def check_device(device):
97
+ return bool(
98
+ (isinstance(device, torch.device) and device.type == "cuda")
99
+ or (isinstance(device, str) and "cuda" in device)
100
+ or isinstance(device, int)
101
+ )
102
+
103
+
104
+ def return_xpu(device):
105
+ return (
106
+ f"xpu:{device[-1]}"
107
+ if isinstance(device, str) and ":" in device
108
+ else (
109
+ f"xpu:{device}"
110
+ if isinstance(device, int)
111
+ else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
112
+ )
113
+ )
114
+
115
+
116
+ def ipex_no_cuda(orig_func, *args, **kwargs):
117
+ torch.cuda.is_available = lambda: False
118
+ orig_func(*args, **kwargs)
119
+ torch.cuda.is_available = torch.xpu.is_available
120
+
121
+
122
+ original_autocast = torch.autocast
123
+
124
+
125
+ def ipex_autocast(*args, **kwargs):
126
+ if len(args) > 0 and args[0] == "cuda":
127
+ return original_autocast("xpu", *args[1:], **kwargs)
128
+ else:
129
+ return original_autocast(*args, **kwargs)
130
+
131
+
132
+ original_torch_cat = torch.cat
133
+
134
+
135
+ def torch_cat(tensor, *args, **kwargs):
136
+ if len(tensor) == 3 and (
137
+ tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype
138
+ ):
139
+ return original_torch_cat(
140
+ [tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)],
141
+ *args,
142
+ **kwargs,
143
+ )
144
+ else:
145
+ return original_torch_cat(tensor, *args, **kwargs)
146
+
147
+
148
+ original_interpolate = torch.nn.functional.interpolate
149
+
150
+
151
+ def interpolate(
152
+ tensor,
153
+ size=None,
154
+ scale_factor=None,
155
+ mode="nearest",
156
+ align_corners=None,
157
+ recompute_scale_factor=None,
158
+ antialias=False,
159
+ ): # pylint: disable=too-many-arguments
160
+ if antialias or align_corners is not None:
161
+ return_device = tensor.device
162
+ return_dtype = tensor.dtype
163
+ return original_interpolate(
164
+ tensor.to("cpu", dtype=torch.float32),
165
+ size=size,
166
+ scale_factor=scale_factor,
167
+ mode=mode,
168
+ align_corners=align_corners,
169
+ recompute_scale_factor=recompute_scale_factor,
170
+ antialias=antialias,
171
+ ).to(return_device, dtype=return_dtype)
172
+ else:
173
+ return original_interpolate(
174
+ tensor,
175
+ size=size,
176
+ scale_factor=scale_factor,
177
+ mode=mode,
178
+ align_corners=align_corners,
179
+ recompute_scale_factor=recompute_scale_factor,
180
+ antialias=antialias,
181
+ )
182
+
183
+
184
+ original_linalg_solve = torch.linalg.solve
185
+
186
+
187
+ def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
188
+ if A.device != torch.device("cpu") or B.device != torch.device("cpu"):
189
+ return_device = A.device
190
+ return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(
191
+ return_device
192
+ )
193
+ else:
194
+ return original_linalg_solve(A, B, *args, **kwargs)
195
+
196
+
197
+ def ipex_hijacks():
198
+ CondFunc(
199
+ "torch.Tensor.to",
200
+ lambda orig_func, self, device=None, *args, **kwargs: orig_func(
201
+ self, return_xpu(device), *args, **kwargs
202
+ ),
203
+ lambda orig_func, self, device=None, *args, **kwargs: check_device(device),
204
+ )
205
+ CondFunc(
206
+ "torch.Tensor.cuda",
207
+ lambda orig_func, self, device=None, *args, **kwargs: orig_func(
208
+ self, return_xpu(device), *args, **kwargs
209
+ ),
210
+ lambda orig_func, self, device=None, *args, **kwargs: check_device(device),
211
+ )
212
+ CondFunc(
213
+ "torch.empty",
214
+ lambda orig_func, *args, device=None, **kwargs: orig_func(
215
+ *args, device=return_xpu(device), **kwargs
216
+ ),
217
+ lambda orig_func, *args, device=None, **kwargs: check_device(device),
218
+ )
219
+ CondFunc(
220
+ "torch.load",
221
+ lambda orig_func, *args, map_location=None, **kwargs: orig_func(
222
+ *args, return_xpu(map_location), **kwargs
223
+ ),
224
+ lambda orig_func, *args, map_location=None, **kwargs: map_location is None
225
+ or check_device(map_location),
226
+ )
227
+ CondFunc(
228
+ "torch.randn",
229
+ lambda orig_func, *args, device=None, **kwargs: orig_func(
230
+ *args, device=return_xpu(device), **kwargs
231
+ ),
232
+ lambda orig_func, *args, device=None, **kwargs: check_device(device),
233
+ )
234
+ CondFunc(
235
+ "torch.ones",
236
+ lambda orig_func, *args, device=None, **kwargs: orig_func(
237
+ *args, device=return_xpu(device), **kwargs
238
+ ),
239
+ lambda orig_func, *args, device=None, **kwargs: check_device(device),
240
+ )
241
+ CondFunc(
242
+ "torch.zeros",
243
+ lambda orig_func, *args, device=None, **kwargs: orig_func(
244
+ *args, device=return_xpu(device), **kwargs
245
+ ),
246
+ lambda orig_func, *args, device=None, **kwargs: check_device(device),
247
+ )
248
+ CondFunc(
249
+ "torch.tensor",
250
+ lambda orig_func, *args, device=None, **kwargs: orig_func(
251
+ *args, device=return_xpu(device), **kwargs
252
+ ),
253
+ lambda orig_func, *args, device=None, **kwargs: check_device(device),
254
+ )
255
+ CondFunc(
256
+ "torch.linspace",
257
+ lambda orig_func, *args, device=None, **kwargs: orig_func(
258
+ *args, device=return_xpu(device), **kwargs
259
+ ),
260
+ lambda orig_func, *args, device=None, **kwargs: check_device(device),
261
+ )
262
+
263
+ CondFunc(
264
+ "torch.Generator",
265
+ lambda orig_func, device=None: torch.xpu.Generator(device),
266
+ lambda orig_func, device=None: device is not None
267
+ and device != torch.device("cpu")
268
+ and device != "cpu",
269
+ )
270
+
271
+ CondFunc(
272
+ "torch.batch_norm",
273
+ lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(
274
+ input,
275
+ (
276
+ weight
277
+ if weight is not None
278
+ else torch.ones(input.size()[1], device=input.device)
279
+ ),
280
+ (
281
+ bias
282
+ if bias is not None
283
+ else torch.zeros(input.size()[1], device=input.device)
284
+ ),
285
+ *args,
286
+ **kwargs,
287
+ ),
288
+ lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"),
289
+ )
290
+ CondFunc(
291
+ "torch.instance_norm",
292
+ lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(
293
+ input,
294
+ (
295
+ weight
296
+ if weight is not None
297
+ else torch.ones(input.size()[1], device=input.device)
298
+ ),
299
+ (
300
+ bias
301
+ if bias is not None
302
+ else torch.zeros(input.size()[1], device=input.device)
303
+ ),
304
+ *args,
305
+ **kwargs,
306
+ ),
307
+ lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"),
308
+ )
309
+
310
+ # Functions with dtype errors:
311
+ CondFunc(
312
+ "torch.nn.modules.GroupNorm.forward",
313
+ lambda orig_func, self, input: orig_func(
314
+ self, input.to(self.weight.data.dtype)
315
+ ),
316
+ lambda orig_func, self, input: input.dtype != self.weight.data.dtype,
317
+ )
318
+ CondFunc(
319
+ "torch.nn.modules.linear.Linear.forward",
320
+ lambda orig_func, self, input: orig_func(
321
+ self, input.to(self.weight.data.dtype)
322
+ ),
323
+ lambda orig_func, self, input: input.dtype != self.weight.data.dtype,
324
+ )
325
+ CondFunc(
326
+ "torch.nn.modules.conv.Conv2d.forward",
327
+ lambda orig_func, self, input: orig_func(
328
+ self, input.to(self.weight.data.dtype)
329
+ ),
330
+ lambda orig_func, self, input: input.dtype != self.weight.data.dtype,
331
+ )
332
+ CondFunc(
333
+ "torch.nn.functional.layer_norm",
334
+ lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: orig_func(
335
+ input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs
336
+ ),
337
+ lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: weight
338
+ is not None
339
+ and input.dtype != weight.data.dtype,
340
+ )
341
+
342
+ # Diffusers Float64 (ARC GPUs doesn't support double or Float64):
343
+ if not torch.xpu.has_fp64_dtype():
344
+ CondFunc(
345
+ "torch.from_numpy",
346
+ lambda orig_func, ndarray: orig_func(ndarray.astype("float32")),
347
+ lambda orig_func, ndarray: ndarray.dtype == float,
348
+ )
349
+
350
+ # Broken functions when torch.cuda.is_available is True:
351
+ CondFunc(
352
+ "torch.utils.data.dataloader._BaseDataLoaderIter.__init__",
353
+ lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
354
+ lambda orig_func, *args, **kwargs: True,
355
+ )
356
+
357
+ # Functions that make compile mad with CondFunc:
358
+ torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = (
359
+ _shutdown_workers
360
+ )
361
+ torch.nn.DataParallel = DummyDataParallel
362
+ torch.autocast = ipex_autocast
363
+ torch.cat = torch_cat
364
+ torch.linalg.solve = linalg_solve
365
+ torch.nn.functional.interpolate = interpolate
366
+ torch.backends.cuda.sdp_kernel = return_null_context
rvc/ipex/init.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import contextlib
4
+
5
+ import torch
6
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
7
+
8
+ from .hijacks import ipex_hijacks
9
+ from .attention import attention_init
10
+
11
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long
12
+
13
+
14
+ def ipex_init(): # pylint: disable=too-many-statements
15
+ try:
16
+ # Replace cuda with xpu:
17
+ torch.cuda.current_device = torch.xpu.current_device
18
+ torch.cuda.current_stream = torch.xpu.current_stream
19
+ torch.cuda.device = torch.xpu.device
20
+ torch.cuda.device_count = torch.xpu.device_count
21
+ torch.cuda.device_of = torch.xpu.device_of
22
+ torch.cuda.get_device_name = torch.xpu.get_device_name
23
+ torch.cuda.get_device_properties = torch.xpu.get_device_properties
24
+ torch.cuda.init = torch.xpu.init
25
+ torch.cuda.is_available = torch.xpu.is_available
26
+ torch.cuda.is_initialized = torch.xpu.is_initialized
27
+ torch.cuda.is_current_stream_capturing = lambda: False
28
+ torch.cuda.set_device = torch.xpu.set_device
29
+ torch.cuda.stream = torch.xpu.stream
30
+ torch.cuda.synchronize = torch.xpu.synchronize
31
+ torch.cuda.Event = torch.xpu.Event
32
+ torch.cuda.Stream = torch.xpu.Stream
33
+ torch.cuda.FloatTensor = torch.xpu.FloatTensor
34
+ torch.Tensor.cuda = torch.Tensor.xpu
35
+ torch.Tensor.is_cuda = torch.Tensor.is_xpu
36
+ torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
37
+ torch.cuda._initialized = torch.xpu.lazy_init._initialized
38
+ torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
39
+ torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
40
+ torch.cuda._tls = torch.xpu.lazy_init._tls
41
+ torch.cuda.threading = torch.xpu.lazy_init.threading
42
+ torch.cuda.traceback = torch.xpu.lazy_init.traceback
43
+ torch.cuda.Optional = torch.xpu.Optional
44
+ torch.cuda.__cached__ = torch.xpu.__cached__
45
+ torch.cuda.__loader__ = torch.xpu.__loader__
46
+ torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
47
+ torch.cuda.Tuple = torch.xpu.Tuple
48
+ torch.cuda.streams = torch.xpu.streams
49
+ torch.cuda._lazy_new = torch.xpu._lazy_new
50
+ torch.cuda.FloatStorage = torch.xpu.FloatStorage
51
+ torch.cuda.Any = torch.xpu.Any
52
+ torch.cuda.__doc__ = torch.xpu.__doc__
53
+ torch.cuda.default_generators = torch.xpu.default_generators
54
+ torch.cuda.HalfTensor = torch.xpu.HalfTensor
55
+ torch.cuda._get_device_index = torch.xpu._get_device_index
56
+ torch.cuda.__path__ = torch.xpu.__path__
57
+ torch.cuda.Device = torch.xpu.Device
58
+ torch.cuda.IntTensor = torch.xpu.IntTensor
59
+ torch.cuda.ByteStorage = torch.xpu.ByteStorage
60
+ torch.cuda.set_stream = torch.xpu.set_stream
61
+ torch.cuda.BoolStorage = torch.xpu.BoolStorage
62
+ torch.cuda.os = torch.xpu.os
63
+ torch.cuda.torch = torch.xpu.torch
64
+ torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
65
+ torch.cuda.Union = torch.xpu.Union
66
+ torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
67
+ torch.cuda.ShortTensor = torch.xpu.ShortTensor
68
+ torch.cuda.LongTensor = torch.xpu.LongTensor
69
+ torch.cuda.IntStorage = torch.xpu.IntStorage
70
+ torch.cuda.LongStorage = torch.xpu.LongStorage
71
+ torch.cuda.__annotations__ = torch.xpu.__annotations__
72
+ torch.cuda.__package__ = torch.xpu.__package__
73
+ torch.cuda.__builtins__ = torch.xpu.__builtins__
74
+ torch.cuda.CharTensor = torch.xpu.CharTensor
75
+ torch.cuda.List = torch.xpu.List
76
+ torch.cuda._lazy_init = torch.xpu._lazy_init
77
+ torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
78
+ torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
79
+ torch.cuda.ByteTensor = torch.xpu.ByteTensor
80
+ torch.cuda.StreamContext = torch.xpu.StreamContext
81
+ torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
82
+ torch.cuda.ShortStorage = torch.xpu.ShortStorage
83
+ torch.cuda._lazy_call = torch.xpu._lazy_call
84
+ torch.cuda.HalfStorage = torch.xpu.HalfStorage
85
+ torch.cuda.random = torch.xpu.random
86
+ torch.cuda._device = torch.xpu._device
87
+ torch.cuda.classproperty = torch.xpu.classproperty
88
+ torch.cuda.__name__ = torch.xpu.__name__
89
+ torch.cuda._device_t = torch.xpu._device_t
90
+ torch.cuda.warnings = torch.xpu.warnings
91
+ torch.cuda.__spec__ = torch.xpu.__spec__
92
+ torch.cuda.BoolTensor = torch.xpu.BoolTensor
93
+ torch.cuda.CharStorage = torch.xpu.CharStorage
94
+ torch.cuda.__file__ = torch.xpu.__file__
95
+ torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
96
+ # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
97
+
98
+ # Memory:
99
+ torch.cuda.memory = torch.xpu.memory
100
+ if "linux" in sys.platform and "WSL2" in os.popen("uname -a").read():
101
+ torch.xpu.empty_cache = lambda: None
102
+ torch.cuda.empty_cache = torch.xpu.empty_cache
103
+ torch.cuda.memory_stats = torch.xpu.memory_stats
104
+ torch.cuda.memory_summary = torch.xpu.memory_summary
105
+ torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
106
+ torch.cuda.memory_allocated = torch.xpu.memory_allocated
107
+ torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
108
+ torch.cuda.memory_reserved = torch.xpu.memory_reserved
109
+ torch.cuda.memory_cached = torch.xpu.memory_reserved
110
+ torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
111
+ torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
112
+ torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
113
+ torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
114
+ torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
115
+ torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
116
+ torch.cuda.reset_accumulated_memory_stats = (
117
+ torch.xpu.reset_accumulated_memory_stats
118
+ )
119
+
120
+ # RNG:
121
+ torch.cuda.get_rng_state = torch.xpu.get_rng_state
122
+ torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
123
+ torch.cuda.set_rng_state = torch.xpu.set_rng_state
124
+ torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
125
+ torch.cuda.manual_seed = torch.xpu.manual_seed
126
+ torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
127
+ torch.cuda.seed = torch.xpu.seed
128
+ torch.cuda.seed_all = torch.xpu.seed_all
129
+ torch.cuda.initial_seed = torch.xpu.initial_seed
130
+
131
+ # AMP:
132
+ torch.cuda.amp = torch.xpu.amp
133
+ if not hasattr(torch.cuda.amp, "common"):
134
+ torch.cuda.amp.common = contextlib.nullcontext()
135
+ torch.cuda.amp.common.amp_definitely_not_available = lambda: False
136
+ try:
137
+ torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
138
+ except Exception: # pylint: disable=broad-exception-caught
139
+ try:
140
+ from .gradscaler import (
141
+ gradscaler_init,
142
+ ) # pylint: disable=import-outside-toplevel, import-error
143
+
144
+ gradscaler_init()
145
+ torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
146
+ except Exception: # pylint: disable=broad-exception-caught
147
+ torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
148
+
149
+ # C
150
+ torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
151
+ ipex._C._DeviceProperties.major = 2023
152
+ ipex._C._DeviceProperties.minor = 2
153
+
154
+ # Fix functions with ipex:
155
+ torch.cuda.mem_get_info = lambda device=None: [
156
+ (
157
+ torch.xpu.get_device_properties(device).total_memory
158
+ - torch.xpu.memory_allocated(device)
159
+ ),
160
+ torch.xpu.get_device_properties(device).total_memory,
161
+ ]
162
+ torch._utils._get_available_device_type = lambda: "xpu"
163
+ torch.has_cuda = True
164
+ torch.cuda.has_half = True
165
+ torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
166
+ torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
167
+ torch.version.cuda = "11.7"
168
+ torch.cuda.get_device_capability = lambda *args, **kwargs: [11, 7]
169
+ torch.cuda.get_device_properties.major = 11
170
+ torch.cuda.get_device_properties.minor = 7
171
+ torch.cuda.ipc_collect = lambda *args, **kwargs: None
172
+ torch.cuda.utilization = lambda *args, **kwargs: 0
173
+ if hasattr(torch.xpu, "getDeviceIdListForCard"):
174
+ torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard
175
+ torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard
176
+ else:
177
+ torch.cuda.getDeviceIdListForCard = torch.xpu.get_device_id_list_per_card
178
+ torch.cuda.get_device_id_list_per_card = (
179
+ torch.xpu.get_device_id_list_per_card
180
+ )
181
+
182
+ ipex_hijacks()
183
+ attention_init()
184
+ try:
185
+ from .diffusers import ipex_diffusers
186
+
187
+ ipex_diffusers()
188
+ except Exception: # pylint: disable=broad-exception-caught
189
+ pass
190
+ except Exception as e:
191
+ return False, e
192
+ return True, None
rvc/jit/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .jit import load_inputs, get_jit_model, export_jit_model, save_pickle
rvc/jit/jit.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from io import BytesIO
3
+ from collections import OrderedDict
4
+ import os
5
+
6
+ import torch
7
+
8
+
9
+ def load_pickle(path: str):
10
+ with open(path, "rb") as f:
11
+ return pickle.load(f)
12
+
13
+
14
+ def save_pickle(ckpt: dict, save_path: str):
15
+ with open(save_path, "wb") as f:
16
+ pickle.dump(ckpt, f)
17
+
18
+
19
+ def load_inputs(path: torch.serialization.FILE_LIKE, device: str, is_half=False):
20
+ parm = torch.load(path, map_location=torch.device("cpu"))
21
+ for key in parm.keys():
22
+ parm[key] = parm[key].to(device)
23
+ if is_half and parm[key].dtype == torch.float32:
24
+ parm[key] = parm[key].half()
25
+ elif not is_half and parm[key].dtype == torch.float16:
26
+ parm[key] = parm[key].float()
27
+ return parm
28
+
29
+
30
+ def export_jit_model(
31
+ model: torch.nn.Module,
32
+ mode: str = "trace",
33
+ inputs: dict = None,
34
+ device=torch.device("cpu"),
35
+ is_half: bool = False,
36
+ ) -> dict:
37
+ model = model.half() if is_half else model.float()
38
+ model.eval()
39
+ if mode == "trace":
40
+ assert inputs is not None
41
+ model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
42
+ elif mode == "script":
43
+ model_jit = torch.jit.script(model)
44
+ model_jit.to(device)
45
+ model_jit = model_jit.half() if is_half else model_jit.float()
46
+ buffer = BytesIO()
47
+ # model_jit=model_jit.cpu()
48
+ torch.jit.save(model_jit, buffer)
49
+ del model_jit
50
+ cpt = OrderedDict()
51
+ cpt["model"] = buffer.getvalue()
52
+ cpt["is_half"] = is_half
53
+ return cpt
54
+
55
+
56
+ def get_jit_model(model_path: str, is_half: bool, device: str, exporter):
57
+ jit_model_path = model_path.rstrip(".pth")
58
+ jit_model_path += ".half.jit" if is_half else ".jit"
59
+ ckpt = None
60
+
61
+ if os.path.exists(jit_model_path):
62
+ ckpt = load_pickle(jit_model_path)
63
+ model_device = ckpt["device"]
64
+ if model_device != str(device):
65
+ del ckpt
66
+ ckpt = None
67
+
68
+ if ckpt is None:
69
+ ckpt = exporter(
70
+ model_path=model_path,
71
+ mode="script",
72
+ inputs_path=None,
73
+ save_path=jit_model_path,
74
+ device=device,
75
+ is_half=is_half,
76
+ )
77
+
78
+ return ckpt
rvc/layers/__init__.py ADDED
File without changes
rvc/layers/attentions.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+
9
+ class MultiHeadAttention(nn.Module):
10
+ def __init__(
11
+ self,
12
+ channels: int,
13
+ out_channels: int,
14
+ n_heads: int,
15
+ p_dropout: float = 0.0,
16
+ window_size: Optional[int] = None,
17
+ heads_share: bool = True,
18
+ block_length: Optional[int] = None,
19
+ proximal_bias: bool = False,
20
+ proximal_init: bool = False,
21
+ ):
22
+ super(MultiHeadAttention, self).__init__()
23
+ assert channels % n_heads == 0
24
+
25
+ self.channels = channels
26
+ self.out_channels = out_channels
27
+ self.n_heads = n_heads
28
+ self.p_dropout = p_dropout
29
+ self.window_size = window_size
30
+ self.heads_share = heads_share
31
+ self.block_length = block_length
32
+ self.proximal_bias = proximal_bias
33
+ self.proximal_init = proximal_init
34
+ self.attn = None
35
+
36
+ self.k_channels = channels // n_heads
37
+ self.conv_q = nn.Conv1d(channels, channels, 1)
38
+ self.conv_k = nn.Conv1d(channels, channels, 1)
39
+ self.conv_v = nn.Conv1d(channels, channels, 1)
40
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
41
+ self.drop = nn.Dropout(p_dropout)
42
+
43
+ if window_size is not None:
44
+ n_heads_rel = 1 if heads_share else n_heads
45
+ rel_stddev = self.k_channels**-0.5
46
+ self.emb_rel_k = nn.Parameter(
47
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
48
+ * rel_stddev
49
+ )
50
+ self.emb_rel_v = nn.Parameter(
51
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
52
+ * rel_stddev
53
+ )
54
+
55
+ nn.init.xavier_uniform_(self.conv_q.weight)
56
+ nn.init.xavier_uniform_(self.conv_k.weight)
57
+ nn.init.xavier_uniform_(self.conv_v.weight)
58
+ if proximal_init:
59
+ with torch.no_grad():
60
+ self.conv_k.weight.copy_(self.conv_q.weight)
61
+ self.conv_k.bias.copy_(self.conv_q.bias)
62
+
63
+ def __call__(
64
+ self,
65
+ x: torch.Tensor,
66
+ c: torch.Tensor,
67
+ attn_mask: Optional[torch.Tensor] = None,
68
+ ) -> torch.Tensor:
69
+ return super().__call__(x, c, attn_mask=attn_mask)
70
+
71
+ def forward(
72
+ self,
73
+ x: torch.Tensor,
74
+ c: torch.Tensor,
75
+ attn_mask: Optional[torch.Tensor] = None,
76
+ ) -> torch.Tensor:
77
+ q = self.conv_q(x)
78
+ k = self.conv_k(c)
79
+ v = self.conv_v(c)
80
+
81
+ x, _ = self._attention(q, k, v, mask=attn_mask)
82
+
83
+ x = self.conv_o(x)
84
+ return x
85
+
86
+ def _attention(
87
+ self,
88
+ query: torch.Tensor,
89
+ key: torch.Tensor,
90
+ value: torch.Tensor,
91
+ mask: Optional[torch.Tensor] = None,
92
+ ):
93
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
94
+ b, d, t_s = key.size()
95
+ t_t = query.size(2)
96
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
97
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
98
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
99
+
100
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
101
+ if self.window_size is not None:
102
+ assert (
103
+ t_s == t_t
104
+ ), "Relative attention is only available for self-attention."
105
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
106
+ rel_logits = self._matmul_with_relative_keys(
107
+ query / math.sqrt(self.k_channels), key_relative_embeddings
108
+ )
109
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
110
+ scores = scores + scores_local
111
+ if self.proximal_bias:
112
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
113
+ scores = scores + self._attention_bias_proximal(t_s).to(
114
+ device=scores.device, dtype=scores.dtype
115
+ )
116
+ if mask is not None:
117
+ scores = scores.masked_fill(mask == 0, -1e4)
118
+ if self.block_length is not None:
119
+ assert (
120
+ t_s == t_t
121
+ ), "Local attention is only available for self-attention."
122
+ block_mask = (
123
+ torch.ones_like(scores)
124
+ .triu(-self.block_length)
125
+ .tril(self.block_length)
126
+ )
127
+ scores = scores.masked_fill(block_mask == 0, -1e4)
128
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
129
+ p_attn = self.drop(p_attn)
130
+ output = torch.matmul(p_attn, value)
131
+ if self.window_size is not None:
132
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
133
+ value_relative_embeddings = self._get_relative_embeddings(
134
+ self.emb_rel_v, t_s
135
+ )
136
+ output = output + self._matmul_with_relative_values(
137
+ relative_weights, value_relative_embeddings
138
+ )
139
+ output = (
140
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
141
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
142
+ return output, p_attn
143
+
144
+ def _matmul_with_relative_values(self, x, y):
145
+ """
146
+ x: [b, h, l, m]
147
+ y: [h or 1, m, d]
148
+ ret: [b, h, l, d]
149
+ """
150
+ ret = torch.matmul(x, y.unsqueeze(0))
151
+ return ret
152
+
153
+ def _matmul_with_relative_keys(self, x, y):
154
+ """
155
+ x: [b, h, l, d]
156
+ y: [h or 1, m, d]
157
+ ret: [b, h, l, m]
158
+ """
159
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
160
+ return ret
161
+
162
+ def _get_relative_embeddings(self, relative_embeddings, length: int):
163
+ # max_relative_position = 2 * self.window_size + 1
164
+ # Pad first before slice to avoid using cond ops.
165
+ pad_length: int = max(length - (self.window_size + 1), 0)
166
+ slice_start_position = max((self.window_size + 1) - length, 0)
167
+ slice_end_position = slice_start_position + 2 * length - 1
168
+ if pad_length > 0:
169
+ padded_relative_embeddings = F.pad(
170
+ relative_embeddings,
171
+ [0, 0, pad_length, pad_length, 0, 0],
172
+ )
173
+ else:
174
+ padded_relative_embeddings = relative_embeddings
175
+ used_relative_embeddings = padded_relative_embeddings[
176
+ :, slice_start_position:slice_end_position
177
+ ]
178
+ return used_relative_embeddings
179
+
180
+ def _relative_position_to_absolute_position(self, x):
181
+ """
182
+ x: [b, h, l, 2*l-1]
183
+ ret: [b, h, l, l]
184
+ """
185
+ batch, heads, length, _ = x.size()
186
+ # Concat columns of pad to shift from relative to absolute indexing.
187
+ x = F.pad(
188
+ x,
189
+ [0, 1, 0, 0, 0, 0, 0, 0],
190
+ )
191
+
192
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
193
+ x_flat = x.view([batch, heads, length * 2 * length])
194
+ x_flat = F.pad(x_flat, [0, length - 1, 0, 0, 0, 0])
195
+
196
+ # Reshape and slice out the padded elements.
197
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
198
+ :, :, :length, length - 1 :
199
+ ]
200
+ return x_final
201
+
202
+ def _absolute_position_to_relative_position(self, x):
203
+ """
204
+ x: [b, h, l, l]
205
+ ret: [b, h, l, 2*l-1]
206
+ """
207
+ batch, heads, length, _ = x.size()
208
+ # padd along column
209
+ x = F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0])
210
+ x_flat = x.view([batch, heads, (length**2) + (length * (length - 1))])
211
+ # add 0's in the beginning that will skew the elements after reshape
212
+ x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0])
213
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
214
+ return x_final
215
+
216
+ def _attention_bias_proximal(self, length: int):
217
+ """Bias for self-attention to encourage attention to close positions.
218
+ Args:
219
+ length: an integer scalar.
220
+ Returns:
221
+ a Tensor with shape [1, 1, length, length]
222
+ """
223
+ r = torch.arange(length, dtype=torch.float32)
224
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
225
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
226
+
227
+
228
+ class FFN(nn.Module):
229
+ """
230
+ Feed-Forward Network
231
+ """
232
+
233
+ def __init__(
234
+ self,
235
+ in_channels: int,
236
+ out_channels: int,
237
+ filter_channels: int,
238
+ kernel_size: int,
239
+ p_dropout: float = 0.0,
240
+ activation: Optional[str] = None,
241
+ causal: bool = False,
242
+ ):
243
+ super(FFN, self).__init__()
244
+ self.in_channels = in_channels
245
+ self.out_channels = out_channels
246
+ self.filter_channels = filter_channels
247
+ self.kernel_size = kernel_size
248
+ self.p_dropout = p_dropout
249
+ self.activation = activation
250
+ self.causal = causal
251
+ self.is_activation = True if activation == "gelu" else False
252
+
253
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
254
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
255
+ self.drop = nn.Dropout(p_dropout)
256
+
257
+ def __call__(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
258
+ return super().__call__(x, x_mask)
259
+
260
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
261
+ x = self.conv_1(self._padding(x, x_mask))
262
+ if self.is_activation:
263
+ x = x * torch.sigmoid(1.702 * x)
264
+ else:
265
+ x = torch.relu(x)
266
+ x = self.drop(x)
267
+
268
+ x = self.conv_2(self._padding(x, x_mask))
269
+ return x * x_mask
270
+
271
+ def _padding(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
272
+ if self.causal:
273
+ return self._causal_padding(x * x_mask)
274
+ return self._same_padding(x * x_mask)
275
+
276
+ def _causal_padding(self, x):
277
+ if self.kernel_size == 1:
278
+ return x
279
+ pad_l: int = self.kernel_size - 1
280
+ pad_r: int = 0
281
+ # padding = [[0, 0], [0, 0], [pad_l, pad_r]]
282
+ x = F.pad(x, [pad_l, pad_r, 0, 0, 0, 0])
283
+ return x
284
+
285
+ def _same_padding(self, x):
286
+ if self.kernel_size == 1:
287
+ return x
288
+ pad_l: int = (self.kernel_size - 1) // 2
289
+ pad_r: int = self.kernel_size // 2
290
+ # padding = [[0, 0], [0, 0], [pad_l, pad_r]]
291
+ x = F.pad(x, [pad_l, pad_r, 0, 0, 0, 0])
292
+ return x
rvc/layers/discriminators.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import Conv1d, Conv2d
6
+ from torch.nn import functional as F
7
+ from torch.nn.utils import spectral_norm, weight_norm
8
+
9
+ from .residuals import LRELU_SLOPE
10
+ from .utils import get_padding
11
+
12
+
13
+ class MultiPeriodDiscriminator(torch.nn.Module):
14
+ """
15
+ version: 'v1' or 'v2'
16
+ """
17
+
18
+ def __init__(
19
+ self, version: str, use_spectral_norm: bool = False, has_xpu: bool = False
20
+ ):
21
+ super(MultiPeriodDiscriminator, self).__init__()
22
+ periods = (
23
+ (2, 3, 5, 7, 11, 17) if version == "v1" else (2, 3, 5, 7, 11, 17, 23, 37)
24
+ )
25
+
26
+ self.discriminators = nn.ModuleList(
27
+ [
28
+ DiscriminatorS(use_spectral_norm=use_spectral_norm),
29
+ *(
30
+ DiscriminatorP(
31
+ i, use_spectral_norm=use_spectral_norm, has_xpu=has_xpu
32
+ )
33
+ for i in periods
34
+ ),
35
+ ]
36
+ )
37
+
38
+ def __call__(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
39
+ List[torch.Tensor],
40
+ List[torch.Tensor],
41
+ List[List[torch.Tensor]],
42
+ List[List[torch.Tensor]],
43
+ ]:
44
+ return super().__call__(y, y_hat)
45
+
46
+ def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
47
+ List[torch.Tensor],
48
+ List[torch.Tensor],
49
+ List[List[torch.Tensor]],
50
+ List[List[torch.Tensor]],
51
+ ]:
52
+ y_d_rs = []
53
+ y_d_gs = []
54
+ fmap_rs = []
55
+ fmap_gs = []
56
+
57
+ for d in self.discriminators:
58
+ y_d_r, fmap_r = d(y)
59
+ y_d_g, fmap_g = d(y_hat)
60
+ y_d_rs.append(y_d_r)
61
+ y_d_gs.append(y_d_g)
62
+ fmap_rs.append(fmap_r)
63
+ fmap_gs.append(fmap_g)
64
+
65
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
66
+
67
+
68
+ class DiscriminatorS(torch.nn.Module):
69
+ def __init__(self, use_spectral_norm: bool = False):
70
+ super(DiscriminatorS, self).__init__()
71
+ norm_f = spectral_norm if use_spectral_norm else weight_norm
72
+
73
+ self.convs = nn.ModuleList(
74
+ [
75
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
76
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
77
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
78
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
79
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
80
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
81
+ ]
82
+ )
83
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
84
+
85
+ def __call__(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
86
+ return super().__call__(x)
87
+
88
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
89
+ fmap = []
90
+
91
+ for l in self.convs:
92
+ x = l(x)
93
+ x = F.leaky_relu(x, LRELU_SLOPE)
94
+ fmap.append(x)
95
+
96
+ x = self.conv_post(x)
97
+ fmap.append(x)
98
+ x = torch.flatten(x, 1, -1)
99
+
100
+ return x, fmap
101
+
102
+
103
+ class DiscriminatorP(torch.nn.Module):
104
+ def __init__(
105
+ self,
106
+ period: int,
107
+ kernel_size: int = 5,
108
+ stride: int = 3,
109
+ use_spectral_norm: bool = False,
110
+ has_xpu: bool = False,
111
+ ):
112
+ super(DiscriminatorP, self).__init__()
113
+ self.period = period
114
+ self.has_xpu = has_xpu
115
+ norm_f = spectral_norm if use_spectral_norm else weight_norm
116
+ sequence = (1, 32, 128, 512, 1024)
117
+ convs_padding = (get_padding(kernel_size, 1), 0)
118
+
119
+ self.convs = nn.ModuleList()
120
+ for i in range(len(sequence) - 1):
121
+ self.convs.append(
122
+ norm_f(
123
+ Conv2d(
124
+ sequence[i],
125
+ sequence[i + 1],
126
+ (kernel_size, 1),
127
+ (stride, 1),
128
+ padding=convs_padding,
129
+ )
130
+ )
131
+ )
132
+ self.convs.append(
133
+ norm_f(
134
+ Conv2d(
135
+ 1024,
136
+ 1024,
137
+ (kernel_size, 1),
138
+ 1,
139
+ padding=convs_padding,
140
+ )
141
+ )
142
+ )
143
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
144
+
145
+ def __call__(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
146
+ return super().__call__(x)
147
+
148
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
149
+ fmap = []
150
+
151
+ # 1d to 2d
152
+ b, c, t = x.shape
153
+ if t % self.period != 0: # pad first
154
+ n_pad = self.period - (t % self.period)
155
+ if self.has_xpu and x.dtype == torch.bfloat16:
156
+ x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to(
157
+ dtype=torch.bfloat16
158
+ )
159
+ else:
160
+ x = F.pad(x, (0, n_pad), "reflect")
161
+ t = t + n_pad
162
+ x = x.view(b, c, t // self.period, self.period)
163
+
164
+ for l in self.convs:
165
+ x = l(x)
166
+ x = F.leaky_relu(x, LRELU_SLOPE)
167
+ fmap.append(x)
168
+ x = self.conv_post(x)
169
+ fmap.append(x)
170
+ x = torch.flatten(x, 1, -1)
171
+
172
+ return x, fmap
rvc/layers/encoders.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple, Optional
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from .attentions import MultiHeadAttention, FFN
8
+ from .norms import LayerNorm, WN
9
+ from .utils import sequence_mask
10
+
11
+
12
+ class Encoder(nn.Module):
13
+ def __init__(
14
+ self,
15
+ hidden_channels: int,
16
+ filter_channels: int,
17
+ n_heads: int,
18
+ n_layers: int,
19
+ kernel_size: int = 1,
20
+ p_dropout: float = 0.0,
21
+ window_size: int = 10,
22
+ ):
23
+ super(Encoder, self).__init__()
24
+
25
+ self.hidden_channels = hidden_channels
26
+ self.filter_channels = filter_channels
27
+ self.n_heads = n_heads
28
+ self.n_layers = n_layers
29
+ self.kernel_size = kernel_size
30
+ self.p_dropout = p_dropout
31
+ self.window_size = window_size
32
+
33
+ self.drop = nn.Dropout(p_dropout)
34
+ self.attn_layers = nn.ModuleList()
35
+ self.norm_layers_1 = nn.ModuleList()
36
+ self.ffn_layers = nn.ModuleList()
37
+ self.norm_layers_2 = nn.ModuleList()
38
+
39
+ for _ in range(self.n_layers):
40
+ self.attn_layers.append(
41
+ MultiHeadAttention(
42
+ hidden_channels,
43
+ hidden_channels,
44
+ n_heads,
45
+ p_dropout=p_dropout,
46
+ window_size=window_size,
47
+ )
48
+ )
49
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
50
+ self.ffn_layers.append(
51
+ FFN(
52
+ hidden_channels,
53
+ hidden_channels,
54
+ filter_channels,
55
+ kernel_size,
56
+ p_dropout=p_dropout,
57
+ )
58
+ )
59
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
60
+
61
+ def __call__(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
62
+ return super().__call__(x, x_mask)
63
+
64
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
65
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
66
+ x = x * x_mask
67
+ for attn, norm1, ffn, norm2 in zip(
68
+ self.attn_layers,
69
+ self.norm_layers_1,
70
+ self.ffn_layers,
71
+ self.norm_layers_2,
72
+ ):
73
+ y = attn(x, x, attn_mask)
74
+ y = self.drop(y)
75
+ x = norm1(x + y)
76
+
77
+ y = ffn(x, x_mask)
78
+ y = self.drop(y)
79
+ x = norm2(x + y)
80
+ x = x * x_mask
81
+ return x
82
+
83
+
84
+ class TextEncoder(nn.Module):
85
+ def __init__(
86
+ self,
87
+ in_channels: int,
88
+ out_channels: int,
89
+ hidden_channels: int,
90
+ filter_channels: int,
91
+ n_heads: int,
92
+ n_layers: int,
93
+ kernel_size: int,
94
+ p_dropout: float,
95
+ f0: bool = True,
96
+ ):
97
+ super(TextEncoder, self).__init__()
98
+
99
+ self.out_channels = out_channels
100
+ self.hidden_channels = hidden_channels
101
+ self.filter_channels = filter_channels
102
+ self.n_heads = n_heads
103
+ self.n_layers = n_layers
104
+ self.kernel_size = kernel_size
105
+ self.p_dropout = float(p_dropout)
106
+
107
+ self.emb_phone = nn.Linear(in_channels, hidden_channels)
108
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
109
+ if f0 == True:
110
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
111
+ self.encoder = Encoder(
112
+ hidden_channels,
113
+ filter_channels,
114
+ n_heads,
115
+ n_layers,
116
+ kernel_size,
117
+ float(p_dropout),
118
+ )
119
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
120
+
121
+ def __call__(
122
+ self,
123
+ phone: torch.Tensor,
124
+ pitch: torch.Tensor,
125
+ lengths: torch.Tensor,
126
+ skip_head: Optional[int] = None,
127
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
128
+ return super().__call__(
129
+ phone,
130
+ pitch,
131
+ lengths,
132
+ skip_head=skip_head,
133
+ )
134
+
135
+ def forward(
136
+ self,
137
+ phone: torch.Tensor,
138
+ pitch: torch.Tensor,
139
+ lengths: torch.Tensor,
140
+ skip_head: Optional[int] = None,
141
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
142
+ x = self.emb_phone(phone)
143
+ if pitch is not None:
144
+ x += self.emb_pitch(pitch)
145
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
146
+ x = self.lrelu(x)
147
+ x = torch.transpose(x, 1, -1) # [b, h, t]
148
+ x_mask = torch.unsqueeze(
149
+ sequence_mask(lengths, x.size(2)),
150
+ 1,
151
+ ).to(x.dtype)
152
+ x = self.encoder(x * x_mask, x_mask)
153
+ if skip_head is not None:
154
+ head = int(skip_head)
155
+ x = x[:, :, head:]
156
+ x_mask = x_mask[:, :, head:]
157
+ stats: torch.Tensor = self.proj(x) * x_mask
158
+ m, logs = torch.split(stats, self.out_channels, dim=1)
159
+ return m, logs, x_mask
160
+
161
+
162
+ class PosteriorEncoder(nn.Module):
163
+ def __init__(
164
+ self,
165
+ in_channels: int,
166
+ out_channels: int,
167
+ hidden_channels: int,
168
+ kernel_size: int,
169
+ dilation_rate: int,
170
+ n_layers: int,
171
+ gin_channels=0,
172
+ ):
173
+ super(PosteriorEncoder, self).__init__()
174
+ self.in_channels = in_channels
175
+ self.out_channels = out_channels
176
+ self.hidden_channels = hidden_channels
177
+ self.kernel_size = kernel_size
178
+ self.dilation_rate = dilation_rate
179
+ self.n_layers = n_layers
180
+ self.gin_channels = gin_channels
181
+
182
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
183
+ self.enc = WN(
184
+ hidden_channels,
185
+ kernel_size,
186
+ dilation_rate,
187
+ n_layers,
188
+ gin_channels=gin_channels,
189
+ )
190
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
191
+
192
+ def __call__(
193
+ self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
194
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
195
+ return super().__call__(x, x_lengths, g=g)
196
+
197
+ def forward(
198
+ self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
199
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
200
+ x_mask = torch.unsqueeze(
201
+ sequence_mask(x_lengths, x.size(2)),
202
+ 1,
203
+ ).to(x.dtype)
204
+ x = self.pre(x) * x_mask
205
+ x = self.enc(x, x_mask, g=g)
206
+ stats = self.proj(x) * x_mask
207
+ m, logs = torch.split(stats, self.out_channels, dim=1)
208
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
209
+ return z, m, logs, x_mask
210
+
211
+ def remove_weight_norm(self):
212
+ self.enc.remove_weight_norm()
213
+
214
+ def __prepare_scriptable__(self):
215
+ for hook in self.enc._forward_pre_hooks.values():
216
+ if (
217
+ hook.__module__ == "torch.nn.utils.weight_norm"
218
+ and hook.__class__.__name__ == "WeightNorm"
219
+ ):
220
+ torch.nn.utils.remove_weight_norm(self.enc)
221
+ return self
rvc/layers/generators.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import Conv1d, ConvTranspose1d
6
+ from torch.nn import functional as F
7
+ from torch.nn.utils import remove_weight_norm, weight_norm
8
+
9
+ from .residuals import ResBlock1, ResBlock2, LRELU_SLOPE
10
+ from .utils import call_weight_data_normal_if_Conv
11
+
12
+
13
+ class Generator(torch.nn.Module):
14
+ def __init__(
15
+ self,
16
+ initial_channel: int,
17
+ resblock: str,
18
+ resblock_kernel_sizes: List[int],
19
+ resblock_dilation_sizes: List[List[int]],
20
+ upsample_rates: List[int],
21
+ upsample_initial_channel: int,
22
+ upsample_kernel_sizes: List[int],
23
+ gin_channels: int = 0,
24
+ ):
25
+ super(Generator, self).__init__()
26
+ self.num_kernels = len(resblock_kernel_sizes)
27
+ self.num_upsamples = len(upsample_rates)
28
+
29
+ self.conv_pre = Conv1d(
30
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
31
+ )
32
+
33
+ self.ups = nn.ModuleList()
34
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
35
+ self.ups.append(
36
+ weight_norm(
37
+ ConvTranspose1d(
38
+ upsample_initial_channel // (2**i),
39
+ upsample_initial_channel // (2 ** (i + 1)),
40
+ k,
41
+ u,
42
+ padding=(k - u) // 2,
43
+ )
44
+ )
45
+ )
46
+
47
+ self.resblocks = nn.ModuleList()
48
+ resblock_module = ResBlock1 if resblock == "1" else ResBlock2
49
+ for i in range(len(self.ups)):
50
+ ch = upsample_initial_channel // (2 ** (i + 1))
51
+ for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
52
+ self.resblocks.append(resblock_module(ch, k, d))
53
+
54
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
55
+ self.ups.apply(call_weight_data_normal_if_Conv)
56
+
57
+ if gin_channels != 0:
58
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
59
+
60
+ def __call__(
61
+ self,
62
+ x: torch.Tensor,
63
+ g: Optional[torch.Tensor] = None,
64
+ n_res: Optional[int] = None,
65
+ ) -> torch.Tensor:
66
+ return super().__call__(x, g=g, n_res=n_res)
67
+
68
+ def forward(
69
+ self,
70
+ x: torch.Tensor,
71
+ g: Optional[torch.Tensor] = None,
72
+ n_res: Optional[int] = None,
73
+ ):
74
+ if n_res is not None:
75
+ n = int(n_res)
76
+ if n != x.shape[-1]:
77
+ x = F.interpolate(x, size=n, mode="linear")
78
+
79
+ x = self.conv_pre(x)
80
+ if g is not None:
81
+ x = x + self.cond(g)
82
+
83
+ for i in range(self.num_upsamples):
84
+ x = F.leaky_relu(x, LRELU_SLOPE)
85
+ x = self.ups[i](x)
86
+ n = i * self.num_kernels
87
+ xs = self.resblocks[n](x)
88
+ for j in range(1, self.num_kernels):
89
+ xs += self.resblocks[n + j](x)
90
+ x = xs / self.num_kernels
91
+
92
+ x = F.leaky_relu(x)
93
+ x = self.conv_post(x)
94
+ x = torch.tanh(x)
95
+
96
+ return x
97
+
98
+ def __prepare_scriptable__(self):
99
+ for l in self.ups:
100
+ for hook in l._forward_pre_hooks.values():
101
+ # The hook we want to remove is an instance of WeightNorm class, so
102
+ # normally we would do `if isinstance(...)` but this class is not accessible
103
+ # because of shadowing, so we check the module name directly.
104
+ # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
105
+ if (
106
+ hook.__module__ == "torch.nn.utils.weight_norm"
107
+ and hook.__class__.__name__ == "WeightNorm"
108
+ ):
109
+ torch.nn.utils.remove_weight_norm(l)
110
+
111
+ for l in self.resblocks:
112
+ for hook in l._forward_pre_hooks.values():
113
+ if (
114
+ hook.__module__ == "torch.nn.utils.weight_norm"
115
+ and hook.__class__.__name__ == "WeightNorm"
116
+ ):
117
+ torch.nn.utils.remove_weight_norm(l)
118
+ return self
119
+
120
+ def remove_weight_norm(self):
121
+ for l in self.ups:
122
+ remove_weight_norm(l)
123
+ for l in self.resblocks:
124
+ l.remove_weight_norm()
125
+
126
+
127
+ class SineGenerator(torch.nn.Module):
128
+ """Definition of sine generator
129
+ SineGenerator(samp_rate, harmonic_num = 0,
130
+ sine_amp = 0.1, noise_std = 0.003,
131
+ voiced_threshold = 0,
132
+ flag_for_pulse=False)
133
+ samp_rate: sampling rate in Hz
134
+ harmonic_num: number of harmonic overtones (default 0)
135
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
136
+ noise_std: std of Gaussian noise (default 0.003)
137
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
138
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
139
+ Note: when flag_for_pulse is True, the first time step of a voiced
140
+ segment is always sin(torch.pi) or cos(0)
141
+ """
142
+
143
+ def __init__(
144
+ self,
145
+ samp_rate: int,
146
+ harmonic_num: int = 0,
147
+ sine_amp: float = 0.1,
148
+ noise_std: float = 0.003,
149
+ voiced_threshold: int = 0,
150
+ ):
151
+ super(SineGenerator, self).__init__()
152
+ self.sine_amp = sine_amp
153
+ self.noise_std = noise_std
154
+ self.harmonic_num = harmonic_num
155
+ self.dim = harmonic_num + 1
156
+ self.sampling_rate = samp_rate
157
+ self.voiced_threshold = voiced_threshold
158
+
159
+ def __call__(
160
+ self, f0: torch.Tensor, upp: int
161
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
162
+ return super().__call__(f0, upp)
163
+
164
+ def forward(
165
+ self, f0: torch.Tensor, upp: int
166
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
167
+ """sine_tensor, uv = forward(f0)
168
+ input F0: tensor(batchsize=1, length, dim=1)
169
+ f0 for unvoiced steps should be 0
170
+ output sine_tensor: tensor(batchsize=1, length, dim)
171
+ output uv: tensor(batchsize=1, length, 1)
172
+ """
173
+ with torch.no_grad():
174
+ f0 = f0[:, None].transpose(1, 2)
175
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
176
+ # fundamental component
177
+ f0_buf[:, :, 0] = f0[:, :, 0]
178
+ for idx in range(self.harmonic_num):
179
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
180
+ idx + 2
181
+ ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
182
+ rad_values = (
183
+ f0_buf / self.sampling_rate
184
+ ) % 1 ###%1意味着n_har的乘积无法后处理优化
185
+ rand_ini = torch.rand(
186
+ f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
187
+ )
188
+ rand_ini[:, 0] = 0
189
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
190
+ tmp_over_one = torch.cumsum(
191
+ rad_values, 1
192
+ ) # % 1 #####%1意味着后面的cumsum无法再优化
193
+ tmp_over_one *= upp
194
+ tmp_over_one: torch.Tensor = F.interpolate(
195
+ tmp_over_one.transpose(2, 1),
196
+ scale_factor=float(upp),
197
+ mode="linear",
198
+ align_corners=True,
199
+ ).transpose(2, 1)
200
+ rad_values: torch.Tensor = F.interpolate(
201
+ rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest"
202
+ ).transpose(
203
+ 2, 1
204
+ ) #######
205
+ tmp_over_one %= 1
206
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
207
+ cumsum_shift = torch.zeros_like(rad_values)
208
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
209
+ sine_waves = torch.sin(
210
+ torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi
211
+ )
212
+ sine_waves = sine_waves * self.sine_amp
213
+ uv = self._f02uv(f0)
214
+ uv: torch.Tensor = F.interpolate(
215
+ uv.transpose(2, 1), scale_factor=float(upp), mode="nearest"
216
+ ).transpose(2, 1)
217
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
218
+ noise = noise_amp * torch.randn_like(sine_waves)
219
+ sine_waves = sine_waves * uv + noise
220
+ return sine_waves, uv, noise
221
+
222
+ def _f02uv(self, f0):
223
+ # generate uv signal
224
+ uv = torch.ones_like(f0)
225
+ uv = uv * (f0 > self.voiced_threshold)
226
+ if uv.device.type == "privateuseone": # for DirectML
227
+ uv = uv.float()
228
+ return uv
rvc/layers/norms.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from .utils import activate_add_tanh_sigmoid_multiply
8
+
9
+
10
+ class LayerNorm(nn.Module):
11
+ def __init__(self, channels: int, eps: float = 1e-5):
12
+ super(LayerNorm, self).__init__()
13
+ self.channels = channels
14
+ self.eps = eps
15
+
16
+ self.gamma = nn.Parameter(torch.ones(channels))
17
+ self.beta = nn.Parameter(torch.zeros(channels))
18
+
19
+ def forward(self, x: torch.Tensor):
20
+ x = x.transpose(1, -1)
21
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
22
+ return x.transpose(1, -1)
23
+
24
+
25
+ class WN(torch.nn.Module):
26
+ def __init__(
27
+ self,
28
+ hidden_channels: int,
29
+ kernel_size: int,
30
+ dilation_rate: int,
31
+ n_layers: int,
32
+ gin_channels: int = 0,
33
+ p_dropout: int = 0,
34
+ ):
35
+ super(WN, self).__init__()
36
+ assert kernel_size % 2 == 1
37
+ self.hidden_channels = hidden_channels
38
+ self.kernel_size = (kernel_size,)
39
+ self.dilation_rate = dilation_rate
40
+ self.n_layers = n_layers
41
+ self.gin_channels = gin_channels
42
+ self.p_dropout = float(p_dropout)
43
+
44
+ self.in_layers = torch.nn.ModuleList()
45
+ self.res_skip_layers = torch.nn.ModuleList()
46
+ self.drop = nn.Dropout(float(p_dropout))
47
+
48
+ if gin_channels != 0:
49
+ cond_layer = torch.nn.Conv1d(
50
+ gin_channels, 2 * hidden_channels * n_layers, 1
51
+ )
52
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
53
+
54
+ for i in range(n_layers):
55
+ dilation = dilation_rate**i
56
+ padding = int((kernel_size * dilation - dilation) / 2)
57
+ in_layer = torch.nn.Conv1d(
58
+ hidden_channels,
59
+ 2 * hidden_channels,
60
+ kernel_size,
61
+ dilation=dilation,
62
+ padding=padding,
63
+ )
64
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
65
+ self.in_layers.append(in_layer)
66
+
67
+ # last one is not necessary
68
+ if i < n_layers - 1:
69
+ res_skip_channels = 2 * hidden_channels
70
+ else:
71
+ res_skip_channels = hidden_channels
72
+
73
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
74
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
75
+ self.res_skip_layers.append(res_skip_layer)
76
+
77
+ def __call__(
78
+ self,
79
+ x: torch.Tensor,
80
+ x_mask: torch.Tensor,
81
+ g: Optional[torch.Tensor] = None,
82
+ ) -> torch.Tensor:
83
+ return super().__call__(x, x_mask, g=g)
84
+
85
+ def forward(
86
+ self,
87
+ x: torch.Tensor,
88
+ x_mask: torch.Tensor,
89
+ g: Optional[torch.Tensor] = None,
90
+ ) -> torch.Tensor:
91
+ output = torch.zeros_like(x)
92
+
93
+ if g is not None:
94
+ g = self.cond_layer(g)
95
+
96
+ for i, (in_layer, res_skip_layer) in enumerate(
97
+ zip(self.in_layers, self.res_skip_layers)
98
+ ):
99
+ x_in: torch.Tensor = in_layer(x)
100
+ if g is not None:
101
+ cond_offset = i * 2 * self.hidden_channels
102
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
103
+ else:
104
+ g_l = torch.zeros_like(x_in)
105
+
106
+ acts = activate_add_tanh_sigmoid_multiply(x_in, g_l, self.hidden_channels)
107
+ acts: torch.Tensor = self.drop(acts)
108
+
109
+ res_skip_acts: torch.Tensor = res_skip_layer(acts)
110
+ if i < self.n_layers - 1:
111
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
112
+ x = (x + res_acts) * x_mask
113
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
114
+ else:
115
+ output = output + res_skip_acts
116
+ return output * x_mask
117
+
118
+ def remove_weight_norm(self):
119
+ if self.gin_channels != 0:
120
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
121
+ for l in self.in_layers:
122
+ torch.nn.utils.remove_weight_norm(l)
123
+ for l in self.res_skip_layers:
124
+ torch.nn.utils.remove_weight_norm(l)
125
+
126
+ def __prepare_scriptable__(self):
127
+ if self.gin_channels != 0:
128
+ for hook in self.cond_layer._forward_pre_hooks.values():
129
+ if (
130
+ hook.__module__ == "torch.nn.utils.weight_norm"
131
+ and hook.__class__.__name__ == "WeightNorm"
132
+ ):
133
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
134
+ for l in self.in_layers:
135
+ for hook in l._forward_pre_hooks.values():
136
+ if (
137
+ hook.__module__ == "torch.nn.utils.weight_norm"
138
+ and hook.__class__.__name__ == "WeightNorm"
139
+ ):
140
+ torch.nn.utils.remove_weight_norm(l)
141
+ for l in self.res_skip_layers:
142
+ for hook in l._forward_pre_hooks.values():
143
+ if (
144
+ hook.__module__ == "torch.nn.utils.weight_norm"
145
+ and hook.__class__.__name__ == "WeightNorm"
146
+ ):
147
+ torch.nn.utils.remove_weight_norm(l)
148
+ return self
rvc/layers/nsf.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import math
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import Conv1d, ConvTranspose1d
7
+ from torch.nn import functional as F
8
+ from torch.nn.utils import remove_weight_norm, weight_norm
9
+
10
+ from .generators import SineGenerator
11
+ from .residuals import ResBlock1, ResBlock2, LRELU_SLOPE
12
+ from .utils import call_weight_data_normal_if_Conv
13
+
14
+
15
+ class SourceModuleHnNSF(torch.nn.Module):
16
+ """SourceModule for hn-nsf
17
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
18
+ add_noise_std=0.003, voiced_threshod=0)
19
+ sampling_rate: sampling_rate in Hz
20
+ harmonic_num: number of harmonic above F0 (default: 0)
21
+ sine_amp: amplitude of sine source signal (default: 0.1)
22
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
23
+ note that amplitude of noise in unvoiced is decided
24
+ by sine_amp
25
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
26
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
27
+ F0_sampled (batchsize, length, 1)
28
+ Sine_source (batchsize, length, 1)
29
+ noise_source (batchsize, length 1)
30
+ uv (batchsize, length, 1)
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ sampling_rate: int,
36
+ harmonic_num: int = 0,
37
+ sine_amp: float = 0.1,
38
+ add_noise_std: float = 0.003,
39
+ voiced_threshod: int = 0,
40
+ ):
41
+ super(SourceModuleHnNSF, self).__init__()
42
+
43
+ self.sine_amp = sine_amp
44
+ self.noise_std = add_noise_std
45
+ # to produce sine waveforms
46
+ self.l_sin_gen = SineGenerator(
47
+ sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
48
+ )
49
+ # to merge source harmonics into a single excitation
50
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
51
+ self.l_tanh = torch.nn.Tanh()
52
+
53
+ def __call__(self, x: torch.Tensor, upp: int = 1) -> torch.Tensor:
54
+ return super().__call__(x, upp=upp)
55
+
56
+ def forward(self, x: torch.Tensor, upp: int = 1) -> torch.Tensor:
57
+ sine_wavs, _, _ = self.l_sin_gen(x, upp)
58
+ sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype)
59
+ sine_merge: torch.Tensor = self.l_tanh(self.l_linear(sine_wavs))
60
+ return sine_merge # , None, None # noise, uv
61
+
62
+
63
+ class NSFGenerator(torch.nn.Module):
64
+ def __init__(
65
+ self,
66
+ initial_channel: int,
67
+ resblock: str,
68
+ resblock_kernel_sizes: List[int],
69
+ resblock_dilation_sizes: List[List[int]],
70
+ upsample_rates: List[int],
71
+ upsample_initial_channel: int,
72
+ upsample_kernel_sizes: List[int],
73
+ gin_channels: int,
74
+ sr: int,
75
+ ):
76
+ super(NSFGenerator, self).__init__()
77
+ self.num_kernels = len(resblock_kernel_sizes)
78
+ self.num_upsamples = len(upsample_rates)
79
+
80
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=math.prod(upsample_rates))
81
+ self.m_source = SourceModuleHnNSF(sampling_rate=sr, harmonic_num=0)
82
+ self.noise_convs = nn.ModuleList()
83
+ self.conv_pre = Conv1d(
84
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
85
+ )
86
+ resblock = ResBlock1 if resblock == "1" else ResBlock2
87
+
88
+ self.ups = nn.ModuleList()
89
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
90
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
91
+ self.ups.append(
92
+ weight_norm(
93
+ ConvTranspose1d(
94
+ upsample_initial_channel // (2**i),
95
+ upsample_initial_channel // (2 ** (i + 1)),
96
+ k,
97
+ u,
98
+ padding=(k - u) // 2,
99
+ )
100
+ )
101
+ )
102
+ if i + 1 < len(upsample_rates):
103
+ stride_f0 = math.prod(upsample_rates[i + 1 :])
104
+ self.noise_convs.append(
105
+ Conv1d(
106
+ 1,
107
+ c_cur,
108
+ kernel_size=stride_f0 * 2,
109
+ stride=stride_f0,
110
+ padding=stride_f0 // 2,
111
+ )
112
+ )
113
+ else:
114
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
115
+
116
+ self.resblocks = nn.ModuleList()
117
+ for i in range(len(self.ups)):
118
+ ch: int = upsample_initial_channel // (2 ** (i + 1))
119
+ for j, (k, d) in enumerate(
120
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
121
+ ):
122
+ self.resblocks.append(resblock(ch, k, d))
123
+
124
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
125
+ self.ups.apply(call_weight_data_normal_if_Conv)
126
+
127
+ if gin_channels != 0:
128
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
129
+
130
+ self.upp = math.prod(upsample_rates)
131
+
132
+ self.lrelu_slope = LRELU_SLOPE
133
+
134
+ def __call__(
135
+ self,
136
+ x: torch.Tensor,
137
+ f0: torch.Tensor,
138
+ g: Optional[torch.Tensor] = None,
139
+ n_res: Optional[int] = None,
140
+ ) -> torch.Tensor:
141
+ return super().__call__(x, f0, g=g, n_res=n_res)
142
+
143
+ def forward(
144
+ self,
145
+ x: torch.Tensor,
146
+ f0: torch.Tensor,
147
+ g: Optional[torch.Tensor] = None,
148
+ n_res: Optional[int] = None,
149
+ ) -> torch.Tensor:
150
+ har_source = self.m_source(f0, self.upp)
151
+ har_source = har_source.transpose(1, 2)
152
+
153
+ if n_res is not None:
154
+ n_res = int(n_res)
155
+ if n_res * self.upp != har_source.shape[-1]:
156
+ har_source = F.interpolate(
157
+ har_source, size=n_res * self.upp, mode="linear"
158
+ )
159
+ if n_res != x.shape[-1]:
160
+ x = F.interpolate(x, size=n_res, mode="linear")
161
+
162
+ x = self.conv_pre(x)
163
+ if g is not None:
164
+ x = x + self.cond(g)
165
+ # torch.jit.script() does not support direct indexing of torch modules
166
+ # That's why I wrote this
167
+ for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
168
+ if i < self.num_upsamples:
169
+ x = F.leaky_relu(x, self.lrelu_slope)
170
+ x = ups(x)
171
+ x_source = noise_convs(har_source)
172
+ x = x + x_source
173
+ xs: Optional[torch.Tensor] = None
174
+ l = [i * self.num_kernels + j for j in range(self.num_kernels)]
175
+ for j, resblock in enumerate(self.resblocks):
176
+ if j in l:
177
+ if xs is None:
178
+ xs = resblock(x)
179
+ else:
180
+ xs += resblock(x)
181
+ # This assertion cannot be ignored! \
182
+ # If ignored, it will cause torch.jit.script() compilation errors
183
+ assert isinstance(xs, torch.Tensor)
184
+ x = xs / self.num_kernels
185
+ x = F.leaky_relu(x)
186
+ x = self.conv_post(x)
187
+ x = torch.tanh(x)
188
+
189
+ return x
190
+
191
+ def remove_weight_norm(self):
192
+ for l in self.ups:
193
+ remove_weight_norm(l)
194
+ for l in self.resblocks:
195
+ l.remove_weight_norm()
196
+
197
+ def __prepare_scriptable__(self):
198
+ for l in self.ups:
199
+ for hook in l._forward_pre_hooks.values():
200
+ # The hook we want to remove is an instance of WeightNorm class, so
201
+ # normally we would do `if isinstance(...)` but this class is not accessible
202
+ # because of shadowing, so we check the module name directly.
203
+ # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
204
+ if (
205
+ hook.__module__ == "torch.nn.utils.weight_norm"
206
+ and hook.__class__.__name__ == "WeightNorm"
207
+ ):
208
+ torch.nn.utils.remove_weight_norm(l)
209
+ for l in self.resblocks:
210
+ for hook in self.resblocks._forward_pre_hooks.values():
211
+ if (
212
+ hook.__module__ == "torch.nn.utils.weight_norm"
213
+ and hook.__class__.__name__ == "WeightNorm"
214
+ ):
215
+ torch.nn.utils.remove_weight_norm(l)
216
+ return self
rvc/layers/residuals.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import Conv1d
6
+ from torch.nn import functional as F
7
+ from torch.nn.utils import remove_weight_norm, weight_norm
8
+
9
+ from .norms import WN
10
+ from .utils import (
11
+ get_padding,
12
+ call_weight_data_normal_if_Conv,
13
+ )
14
+
15
+ LRELU_SLOPE = 0.1
16
+
17
+
18
+ class ResBlock1(torch.nn.Module):
19
+ def __init__(
20
+ self,
21
+ channels: int,
22
+ kernel_size: int = 3,
23
+ dilation: List[int] = (1, 3, 5),
24
+ ):
25
+ super(ResBlock1, self).__init__()
26
+
27
+ self.convs1 = nn.ModuleList()
28
+ for d in dilation:
29
+ self.convs1.append(
30
+ weight_norm(
31
+ Conv1d(
32
+ channels,
33
+ channels,
34
+ kernel_size,
35
+ 1,
36
+ dilation=d,
37
+ padding=get_padding(kernel_size, d),
38
+ )
39
+ ),
40
+ )
41
+ self.convs1.apply(call_weight_data_normal_if_Conv)
42
+
43
+ self.convs2 = nn.ModuleList()
44
+ for _ in dilation:
45
+ self.convs2.append(
46
+ weight_norm(
47
+ Conv1d(
48
+ channels,
49
+ channels,
50
+ kernel_size,
51
+ 1,
52
+ dilation=1,
53
+ padding=get_padding(kernel_size, 1),
54
+ )
55
+ ),
56
+ )
57
+ self.convs2.apply(call_weight_data_normal_if_Conv)
58
+ self.lrelu_slope = LRELU_SLOPE
59
+
60
+ def __call__(
61
+ self,
62
+ x: torch.Tensor,
63
+ x_mask: Optional[torch.Tensor] = None,
64
+ ) -> torch.Tensor:
65
+ return super().__call__(x, x_mask=x_mask)
66
+
67
+ def forward(
68
+ self,
69
+ x: torch.Tensor,
70
+ x_mask: Optional[torch.Tensor] = None,
71
+ ) -> torch.Tensor:
72
+ for c1, c2 in zip(self.convs1, self.convs2):
73
+ xt = F.leaky_relu(x, self.lrelu_slope)
74
+ if x_mask is not None:
75
+ xt = xt * x_mask
76
+ xt = c1(xt)
77
+ xt = F.leaky_relu(xt, self.lrelu_slope)
78
+ if x_mask is not None:
79
+ xt = xt * x_mask
80
+ xt = c2(xt)
81
+ x = xt + x
82
+ if x_mask is not None:
83
+ x = x * x_mask
84
+ return x
85
+
86
+ def remove_weight_norm(self):
87
+ for l in self.convs1:
88
+ remove_weight_norm(l)
89
+ for l in self.convs2:
90
+ remove_weight_norm(l)
91
+
92
+ def __prepare_scriptable__(self):
93
+ for l in self.convs1:
94
+ for hook in l._forward_pre_hooks.values():
95
+ if (
96
+ hook.__module__ == "torch.nn.utils.weight_norm"
97
+ and hook.__class__.__name__ == "WeightNorm"
98
+ ):
99
+ torch.nn.utils.remove_weight_norm(l)
100
+ for l in self.convs2:
101
+ for hook in l._forward_pre_hooks.values():
102
+ if (
103
+ hook.__module__ == "torch.nn.utils.weight_norm"
104
+ and hook.__class__.__name__ == "WeightNorm"
105
+ ):
106
+ torch.nn.utils.remove_weight_norm(l)
107
+ return self
108
+
109
+
110
+ class ResBlock2(torch.nn.Module):
111
+ """
112
+ Actually this module is not used currently
113
+ because all configs specified "resblock": "1"
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ channels: int,
119
+ kernel_size=3,
120
+ dilation: List[int] = (1, 3),
121
+ ):
122
+ super(ResBlock2, self).__init__()
123
+ self.convs = nn.ModuleList()
124
+ for d in dilation:
125
+ self.convs.append(
126
+ weight_norm(
127
+ Conv1d(
128
+ channels,
129
+ channels,
130
+ kernel_size,
131
+ 1,
132
+ dilation=d,
133
+ padding=get_padding(kernel_size, d),
134
+ )
135
+ ),
136
+ )
137
+ self.convs.apply(call_weight_data_normal_if_Conv)
138
+ self.lrelu_slope = LRELU_SLOPE
139
+
140
+ def __call__(
141
+ self,
142
+ x: torch.Tensor,
143
+ x_mask: Optional[torch.Tensor] = None,
144
+ ) -> torch.Tensor:
145
+ return super().__call__(x, x_mask=x_mask)
146
+
147
+ def forward(
148
+ self,
149
+ x: torch.Tensor,
150
+ x_mask: Optional[torch.Tensor] = None,
151
+ ) -> torch.Tensor:
152
+ for c in self.convs:
153
+ xt = F.leaky_relu(x, self.lrelu_slope)
154
+ if x_mask is not None:
155
+ xt = xt * x_mask
156
+ xt = c(xt)
157
+ x = xt + x
158
+ if x_mask is not None:
159
+ x = x * x_mask
160
+ return x
161
+
162
+ def remove_weight_norm(self):
163
+ for l in self.convs:
164
+ remove_weight_norm(l)
165
+
166
+ def __prepare_scriptable__(self):
167
+ for l in self.convs:
168
+ for hook in l._forward_pre_hooks.values():
169
+ if (
170
+ hook.__module__ == "torch.nn.utils.weight_norm"
171
+ and hook.__class__.__name__ == "WeightNorm"
172
+ ):
173
+ torch.nn.utils.remove_weight_norm(l)
174
+ return self
175
+
176
+
177
+ class ResidualCouplingLayer(nn.Module):
178
+ def __init__(
179
+ self,
180
+ channels: int,
181
+ hidden_channels: int,
182
+ kernel_size: int,
183
+ dilation_rate: int,
184
+ n_layers: int,
185
+ p_dropout: int = 0,
186
+ gin_channels: int = 0,
187
+ mean_only: bool = False,
188
+ ):
189
+ assert channels % 2 == 0, "channels should be divisible by 2"
190
+ super(ResidualCouplingLayer, self).__init__()
191
+ self.channels = channels
192
+ self.hidden_channels = hidden_channels
193
+ self.kernel_size = kernel_size
194
+ self.dilation_rate = dilation_rate
195
+ self.n_layers = n_layers
196
+ self.half_channels = channels // 2
197
+ self.mean_only = mean_only
198
+
199
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
200
+ self.enc = WN(
201
+ hidden_channels,
202
+ kernel_size,
203
+ dilation_rate,
204
+ n_layers,
205
+ p_dropout=float(p_dropout),
206
+ gin_channels=gin_channels,
207
+ )
208
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
209
+ self.post.weight.data.zero_()
210
+ self.post.bias.data.zero_()
211
+
212
+ def __call__(
213
+ self,
214
+ x: torch.Tensor,
215
+ x_mask: torch.Tensor,
216
+ g: Optional[torch.Tensor] = None,
217
+ reverse: bool = False,
218
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
219
+ return super().__call__(x, x_mask, g=g, reverse=reverse)
220
+
221
+ def forward(
222
+ self,
223
+ x: torch.Tensor,
224
+ x_mask: torch.Tensor,
225
+ g: Optional[torch.Tensor] = None,
226
+ reverse: bool = False,
227
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
228
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
229
+ h = self.pre(x0) * x_mask
230
+ h = self.enc(h, x_mask, g=g)
231
+ stats = self.post(h) * x_mask
232
+ if not self.mean_only:
233
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
234
+ else:
235
+ m = stats
236
+ logs = torch.zeros_like(m)
237
+
238
+ if not reverse:
239
+ x1 = m + x1 * torch.exp(logs) * x_mask
240
+ x = torch.cat([x0, x1], 1)
241
+ logdet = torch.sum(logs, [1, 2])
242
+ return x, logdet
243
+
244
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
245
+ x = torch.cat([x0, x1], 1)
246
+ return x, torch.zeros([1])
247
+
248
+ def remove_weight_norm(self):
249
+ self.enc.remove_weight_norm()
250
+
251
+ def __prepare_scriptable__(self):
252
+ for hook in self.enc._forward_pre_hooks.values():
253
+ if (
254
+ hook.__module__ == "torch.nn.utils.weight_norm"
255
+ and hook.__class__.__name__ == "WeightNorm"
256
+ ):
257
+ torch.nn.utils.remove_weight_norm(self.enc)
258
+ return self
259
+
260
+
261
+ class ResidualCouplingBlock(nn.Module):
262
+ class Flip(nn.Module):
263
+ """
264
+ torch.jit.script() Compiled functions
265
+ can't take variable number of arguments or
266
+ use keyword-only arguments with defaults
267
+ """
268
+
269
+ def forward(
270
+ self,
271
+ x: torch.Tensor,
272
+ x_mask: torch.Tensor,
273
+ g: Optional[torch.Tensor] = None,
274
+ reverse: bool = False,
275
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
276
+ x = torch.flip(x, [1])
277
+ if not reverse:
278
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
279
+ return x, logdet
280
+ else:
281
+ return x, torch.zeros([1], device=x.device)
282
+
283
+ def __init__(
284
+ self,
285
+ channels: int,
286
+ hidden_channels: int,
287
+ kernel_size: int,
288
+ dilation_rate: int,
289
+ n_layers: int,
290
+ n_flows: int = 4,
291
+ gin_channels: int = 0,
292
+ ):
293
+ super(ResidualCouplingBlock, self).__init__()
294
+ self.channels = channels
295
+ self.hidden_channels = hidden_channels
296
+ self.kernel_size = kernel_size
297
+ self.dilation_rate = dilation_rate
298
+ self.n_layers = n_layers
299
+ self.n_flows = n_flows
300
+ self.gin_channels = gin_channels
301
+
302
+ self.flows = nn.ModuleList()
303
+ for _ in range(n_flows):
304
+ self.flows.append(
305
+ ResidualCouplingLayer(
306
+ channels,
307
+ hidden_channels,
308
+ kernel_size,
309
+ dilation_rate,
310
+ n_layers,
311
+ gin_channels=gin_channels,
312
+ mean_only=True,
313
+ )
314
+ )
315
+ self.flows.append(self.Flip())
316
+
317
+ def __call__(
318
+ self,
319
+ x: torch.Tensor,
320
+ x_mask: torch.Tensor,
321
+ g: Optional[torch.Tensor] = None,
322
+ reverse: bool = False,
323
+ ) -> torch.Tensor:
324
+ return super().__call__(x, x_mask, g=g, reverse=reverse)
325
+
326
+ def forward(
327
+ self,
328
+ x: torch.Tensor,
329
+ x_mask: torch.Tensor,
330
+ g: Optional[torch.Tensor] = None,
331
+ reverse: bool = False,
332
+ ) -> torch.Tensor:
333
+ if not reverse:
334
+ for flow in self.flows:
335
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
336
+ else:
337
+ for flow in reversed(self.flows):
338
+ x, _ = flow.forward(x, x_mask, g=g, reverse=reverse)
339
+ return x
340
+
341
+ def remove_weight_norm(self):
342
+ for i in range(self.n_flows):
343
+ self.flows[i * 2].remove_weight_norm()
344
+
345
+ def __prepare_scriptable__(self):
346
+ for i in range(self.n_flows):
347
+ for hook in self.flows[i * 2]._forward_pre_hooks.values():
348
+ if (
349
+ hook.__module__ == "torch.nn.utils.weight_norm"
350
+ and hook.__class__.__name__ == "WeightNorm"
351
+ ):
352
+ torch.nn.utils.remove_weight_norm(self.flows[i * 2])
353
+ return self
rvc/layers/synthesizers.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ from .encoders import TextEncoder, PosteriorEncoder
8
+ from .generators import Generator
9
+ from .nsf import NSFGenerator
10
+ from .residuals import ResidualCouplingBlock
11
+ from .utils import (
12
+ slice_on_last_dim,
13
+ rand_slice_segments_on_last_dim,
14
+ )
15
+
16
+
17
+ class SynthesizerTrnMsNSFsid(nn.Module):
18
+ def __init__(
19
+ self,
20
+ spec_channels: int,
21
+ segment_size: int,
22
+ inter_channels: int,
23
+ hidden_channels: int,
24
+ filter_channels: int,
25
+ n_heads: int,
26
+ n_layers: int,
27
+ kernel_size: int,
28
+ p_dropout: int,
29
+ resblock: str,
30
+ resblock_kernel_sizes: List[int],
31
+ resblock_dilation_sizes: List[List[int]],
32
+ upsample_rates: List[int],
33
+ upsample_initial_channel: int,
34
+ upsample_kernel_sizes: List[int],
35
+ spk_embed_dim: int,
36
+ gin_channels: int,
37
+ sr: Optional[Union[str, int]],
38
+ encoder_dim: int,
39
+ use_f0: bool,
40
+ ):
41
+ super().__init__()
42
+ if isinstance(sr, str):
43
+ sr = {
44
+ "32k": 32000,
45
+ "40k": 40000,
46
+ "48k": 48000,
47
+ }[sr]
48
+ self.spec_channels = spec_channels
49
+ self.inter_channels = inter_channels
50
+ self.hidden_channels = hidden_channels
51
+ self.filter_channels = filter_channels
52
+ self.n_heads = n_heads
53
+ self.n_layers = n_layers
54
+ self.kernel_size = kernel_size
55
+ self.p_dropout = float(p_dropout)
56
+ self.resblock = resblock
57
+ self.resblock_kernel_sizes = resblock_kernel_sizes
58
+ self.resblock_dilation_sizes = resblock_dilation_sizes
59
+ self.upsample_rates = upsample_rates
60
+ self.upsample_initial_channel = upsample_initial_channel
61
+ self.upsample_kernel_sizes = upsample_kernel_sizes
62
+ self.segment_size = segment_size
63
+ self.gin_channels = gin_channels
64
+ self.spk_embed_dim = spk_embed_dim
65
+
66
+ self.enc_p = TextEncoder(
67
+ encoder_dim,
68
+ inter_channels,
69
+ hidden_channels,
70
+ filter_channels,
71
+ n_heads,
72
+ n_layers,
73
+ kernel_size,
74
+ float(p_dropout),
75
+ f0=use_f0,
76
+ )
77
+ if use_f0:
78
+ self.dec = NSFGenerator(
79
+ inter_channels,
80
+ resblock,
81
+ resblock_kernel_sizes,
82
+ resblock_dilation_sizes,
83
+ upsample_rates,
84
+ upsample_initial_channel,
85
+ upsample_kernel_sizes,
86
+ gin_channels=gin_channels,
87
+ sr=sr,
88
+ )
89
+ else:
90
+ self.dec = Generator(
91
+ inter_channels,
92
+ resblock,
93
+ resblock_kernel_sizes,
94
+ resblock_dilation_sizes,
95
+ upsample_rates,
96
+ upsample_initial_channel,
97
+ upsample_kernel_sizes,
98
+ gin_channels=gin_channels,
99
+ )
100
+ self.enc_q = PosteriorEncoder(
101
+ spec_channels,
102
+ inter_channels,
103
+ hidden_channels,
104
+ 5,
105
+ 1,
106
+ 16,
107
+ gin_channels=gin_channels,
108
+ )
109
+ self.flow = ResidualCouplingBlock(
110
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
111
+ )
112
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
113
+
114
+ def remove_weight_norm(self):
115
+ self.dec.remove_weight_norm()
116
+ self.flow.remove_weight_norm()
117
+ if hasattr(self, "enc_q"):
118
+ self.enc_q.remove_weight_norm()
119
+
120
+ def __prepare_scriptable__(self):
121
+ for hook in self.dec._forward_pre_hooks.values():
122
+ # The hook we want to remove is an instance of WeightNorm class, so
123
+ # normally we would do `if isinstance(...)` but this class is not accessible
124
+ # because of shadowing, so we check the module name directly.
125
+ # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
126
+ if (
127
+ hook.__module__ == "torch.nn.utils.weight_norm"
128
+ and hook.__class__.__name__ == "WeightNorm"
129
+ ):
130
+ torch.nn.utils.remove_weight_norm(self.dec)
131
+ for hook in self.flow._forward_pre_hooks.values():
132
+ if (
133
+ hook.__module__ == "torch.nn.utils.weight_norm"
134
+ and hook.__class__.__name__ == "WeightNorm"
135
+ ):
136
+ torch.nn.utils.remove_weight_norm(self.flow)
137
+ if hasattr(self, "enc_q"):
138
+ for hook in self.enc_q._forward_pre_hooks.values():
139
+ if (
140
+ hook.__module__ == "torch.nn.utils.weight_norm"
141
+ and hook.__class__.__name__ == "WeightNorm"
142
+ ):
143
+ torch.nn.utils.remove_weight_norm(self.enc_q)
144
+ return self
145
+
146
+ @torch.jit.ignore
147
+ def forward(
148
+ self,
149
+ phone: torch.Tensor,
150
+ phone_lengths: torch.Tensor,
151
+ y: torch.Tensor,
152
+ y_lengths: torch.Tensor,
153
+ ds: Optional[torch.Tensor] = None,
154
+ pitch: Optional[torch.Tensor] = None,
155
+ pitchf: Optional[torch.Tensor] = None,
156
+ ): # 这里ds是id,[bs,1]
157
+ # print(1,pitch.shape)#[bs,t]
158
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
159
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
160
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
161
+ z_p = self.flow(z, y_mask, g=g)
162
+ z_slice, ids_slice = rand_slice_segments_on_last_dim(
163
+ z, y_lengths, self.segment_size
164
+ )
165
+ if pitchf is not None:
166
+ pitchf = slice_on_last_dim(pitchf, ids_slice, self.segment_size)
167
+ o = self.dec(z_slice, pitchf, g=g)
168
+ else:
169
+ o = self.dec(z_slice, g=g)
170
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
171
+
172
+ @torch.jit.export
173
+ def infer(
174
+ self,
175
+ phone: torch.Tensor,
176
+ phone_lengths: torch.Tensor,
177
+ sid: torch.Tensor,
178
+ pitch: Optional[torch.Tensor] = None,
179
+ pitchf: Optional[torch.Tensor] = None, # nsff0
180
+ skip_head: Optional[int] = None,
181
+ return_length: Optional[int] = None,
182
+ return_length2: Optional[int] = None,
183
+ ):
184
+ g = self.emb_g(sid).unsqueeze(-1)
185
+ if skip_head is not None and return_length is not None:
186
+ head = int(skip_head)
187
+ length = int(return_length)
188
+ flow_head = head - 24
189
+ if flow_head < 0:
190
+ flow_head = 0
191
+ dec_head = head - flow_head
192
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths, flow_head)
193
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
194
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
195
+ z = z[:, :, dec_head : dec_head + length]
196
+ x_mask = x_mask[:, :, dec_head : dec_head + length]
197
+ if pitchf is not None:
198
+ pitchf = pitchf[:, head : head + length]
199
+ else:
200
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
201
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
202
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
203
+ del z_p, m_p, logs_p
204
+ if pitchf is not None:
205
+ o = self.dec(
206
+ z * x_mask,
207
+ pitchf,
208
+ g=g,
209
+ n_res=return_length2,
210
+ )
211
+ else:
212
+ o = self.dec(z * x_mask, g=g, n_res=return_length2)
213
+ del x_mask, z
214
+ return o # , x_mask, (z, z_p, m_p, logs_p)
215
+
216
+
217
+ class SynthesizerTrnMs256NSFsid(SynthesizerTrnMsNSFsid):
218
+ def __init__(
219
+ self,
220
+ spec_channels: int,
221
+ segment_size: int,
222
+ inter_channels: int,
223
+ hidden_channels: int,
224
+ filter_channels: int,
225
+ n_heads: int,
226
+ n_layers: int,
227
+ kernel_size: int,
228
+ p_dropout: int,
229
+ resblock: str,
230
+ resblock_kernel_sizes: List[int],
231
+ resblock_dilation_sizes: List[List[int]],
232
+ upsample_rates: List[int],
233
+ upsample_initial_channel: int,
234
+ upsample_kernel_sizes: List[int],
235
+ spk_embed_dim: int,
236
+ gin_channels: int,
237
+ sr: Union[str, int],
238
+ ):
239
+ super().__init__(
240
+ spec_channels,
241
+ segment_size,
242
+ inter_channels,
243
+ hidden_channels,
244
+ filter_channels,
245
+ n_heads,
246
+ n_layers,
247
+ kernel_size,
248
+ p_dropout,
249
+ resblock,
250
+ resblock_kernel_sizes,
251
+ resblock_dilation_sizes,
252
+ upsample_rates,
253
+ upsample_initial_channel,
254
+ upsample_kernel_sizes,
255
+ spk_embed_dim,
256
+ gin_channels,
257
+ sr,
258
+ 256,
259
+ True,
260
+ )
261
+
262
+
263
+ class SynthesizerTrnMs768NSFsid(SynthesizerTrnMsNSFsid):
264
+ def __init__(
265
+ self,
266
+ spec_channels: int,
267
+ segment_size: int,
268
+ inter_channels: int,
269
+ hidden_channels: int,
270
+ filter_channels: int,
271
+ n_heads: int,
272
+ n_layers: int,
273
+ kernel_size: int,
274
+ p_dropout: int,
275
+ resblock: str,
276
+ resblock_kernel_sizes: List[int],
277
+ resblock_dilation_sizes: List[List[int]],
278
+ upsample_rates: List[int],
279
+ upsample_initial_channel: int,
280
+ upsample_kernel_sizes: List[int],
281
+ spk_embed_dim: int,
282
+ gin_channels: int,
283
+ sr: Union[str, int],
284
+ ):
285
+ super().__init__(
286
+ spec_channels,
287
+ segment_size,
288
+ inter_channels,
289
+ hidden_channels,
290
+ filter_channels,
291
+ n_heads,
292
+ n_layers,
293
+ kernel_size,
294
+ p_dropout,
295
+ resblock,
296
+ resblock_kernel_sizes,
297
+ resblock_dilation_sizes,
298
+ upsample_rates,
299
+ upsample_initial_channel,
300
+ upsample_kernel_sizes,
301
+ spk_embed_dim,
302
+ gin_channels,
303
+ sr,
304
+ 768,
305
+ True,
306
+ )
307
+
308
+
309
+ class SynthesizerTrnMs256NSFsid_nono(SynthesizerTrnMsNSFsid):
310
+ def __init__(
311
+ self,
312
+ spec_channels: int,
313
+ segment_size: int,
314
+ inter_channels: int,
315
+ hidden_channels: int,
316
+ filter_channels: int,
317
+ n_heads: int,
318
+ n_layers: int,
319
+ kernel_size: int,
320
+ p_dropout: int,
321
+ resblock: str,
322
+ resblock_kernel_sizes: List[int],
323
+ resblock_dilation_sizes: List[List[int]],
324
+ upsample_rates: List[int],
325
+ upsample_initial_channel: int,
326
+ upsample_kernel_sizes: List[int],
327
+ spk_embed_dim: int,
328
+ gin_channels: int,
329
+ sr=None,
330
+ ):
331
+ super().__init__(
332
+ spec_channels,
333
+ segment_size,
334
+ inter_channels,
335
+ hidden_channels,
336
+ filter_channels,
337
+ n_heads,
338
+ n_layers,
339
+ kernel_size,
340
+ p_dropout,
341
+ resblock,
342
+ resblock_kernel_sizes,
343
+ resblock_dilation_sizes,
344
+ upsample_rates,
345
+ upsample_initial_channel,
346
+ upsample_kernel_sizes,
347
+ spk_embed_dim,
348
+ gin_channels,
349
+ 256,
350
+ False,
351
+ )
352
+
353
+
354
+ class SynthesizerTrnMs768NSFsid_nono(SynthesizerTrnMsNSFsid):
355
+ def __init__(
356
+ self,
357
+ spec_channels: int,
358
+ segment_size: int,
359
+ inter_channels: int,
360
+ hidden_channels: int,
361
+ filter_channels: int,
362
+ n_heads: int,
363
+ n_layers: int,
364
+ kernel_size: int,
365
+ p_dropout: int,
366
+ resblock: str,
367
+ resblock_kernel_sizes: List[int],
368
+ resblock_dilation_sizes: List[List[int]],
369
+ upsample_rates: List[int],
370
+ upsample_initial_channel: int,
371
+ upsample_kernel_sizes: List[int],
372
+ spk_embed_dim: int,
373
+ gin_channels: int,
374
+ sr=None,
375
+ ):
376
+ super().__init__(
377
+ spec_channels,
378
+ segment_size,
379
+ inter_channels,
380
+ hidden_channels,
381
+ filter_channels,
382
+ n_heads,
383
+ n_layers,
384
+ kernel_size,
385
+ p_dropout,
386
+ resblock,
387
+ resblock_kernel_sizes,
388
+ resblock_dilation_sizes,
389
+ upsample_rates,
390
+ upsample_initial_channel,
391
+ upsample_kernel_sizes,
392
+ spk_embed_dim,
393
+ gin_channels,
394
+ 768,
395
+ False,
396
+ )
rvc/layers/transforms.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch.nn import functional as F
6
+
7
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
8
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
+ DEFAULT_MIN_DERIVATIVE = 1e-3
10
+
11
+
12
+ def piecewise_rational_quadratic_transform(
13
+ inputs: torch.Tensor,
14
+ unnormalized_widths: torch.Tensor,
15
+ unnormalized_heights: torch.Tensor,
16
+ unnormalized_derivatives: torch.Tensor,
17
+ inverse: bool = False,
18
+ tails: Optional[str] = None,
19
+ tail_bound: float = 1.0,
20
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
23
+ ):
24
+ if tails is None:
25
+ spline_fn = rational_quadratic_spline
26
+ spline_kwargs = {}
27
+ else:
28
+ spline_fn = unconstrained_rational_quadratic_spline
29
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
30
+
31
+ outputs, logabsdet = spline_fn(
32
+ inputs=inputs,
33
+ unnormalized_widths=unnormalized_widths,
34
+ unnormalized_heights=unnormalized_heights,
35
+ unnormalized_derivatives=unnormalized_derivatives,
36
+ inverse=inverse,
37
+ min_bin_width=min_bin_width,
38
+ min_bin_height=min_bin_height,
39
+ min_derivative=min_derivative,
40
+ **spline_kwargs
41
+ )
42
+ return outputs, logabsdet
43
+
44
+
45
+ def searchsorted(bin_locations, inputs, eps=1e-6):
46
+ bin_locations[..., -1] += eps
47
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48
+
49
+
50
+ def unconstrained_rational_quadratic_spline(
51
+ inputs: torch.Tensor,
52
+ unnormalized_widths: torch.Tensor,
53
+ unnormalized_heights: torch.Tensor,
54
+ unnormalized_derivatives: torch.Tensor,
55
+ inverse: bool = False,
56
+ tails: str = "linear",
57
+ tail_bound: float = 1.0,
58
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
61
+ ):
62
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63
+ outside_interval_mask = ~inside_interval_mask
64
+
65
+ outputs = torch.zeros_like(inputs)
66
+ logabsdet = torch.zeros_like(inputs)
67
+
68
+ if tails == "linear":
69
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70
+ constant = np.log(np.exp(1 - min_derivative) - 1)
71
+ unnormalized_derivatives[..., 0] = constant
72
+ unnormalized_derivatives[..., -1] = constant
73
+
74
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
75
+ logabsdet[outside_interval_mask] = 0
76
+ else:
77
+ raise RuntimeError("{} tails are not implemented.".format(tails))
78
+
79
+ (
80
+ outputs[inside_interval_mask],
81
+ logabsdet[inside_interval_mask],
82
+ ) = rational_quadratic_spline(
83
+ inputs=inputs[inside_interval_mask],
84
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
+ inverse=inverse,
88
+ left=-tail_bound,
89
+ right=tail_bound,
90
+ bottom=-tail_bound,
91
+ top=tail_bound,
92
+ min_bin_width=min_bin_width,
93
+ min_bin_height=min_bin_height,
94
+ min_derivative=min_derivative,
95
+ )
96
+
97
+ return outputs, logabsdet
98
+
99
+
100
+ def rational_quadratic_spline(
101
+ inputs: torch.Tensor,
102
+ unnormalized_widths: torch.Tensor,
103
+ unnormalized_heights: torch.Tensor,
104
+ unnormalized_derivatives: torch.Tensor,
105
+ inverse: bool = False,
106
+ left=0.0,
107
+ right=1.0,
108
+ bottom=0.0,
109
+ top=1.0,
110
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
111
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
112
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
113
+ ):
114
+ if torch.min(inputs) < left or torch.max(inputs) > right:
115
+ raise ValueError("Input to a transform is not within its domain")
116
+
117
+ num_bins = unnormalized_widths.shape[-1]
118
+
119
+ if min_bin_width * num_bins > 1.0:
120
+ raise ValueError("Minimal bin width too large for the number of bins")
121
+ if min_bin_height * num_bins > 1.0:
122
+ raise ValueError("Minimal bin height too large for the number of bins")
123
+
124
+ widths = F.softmax(unnormalized_widths, dim=-1)
125
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
126
+ cumwidths = torch.cumsum(widths, dim=-1)
127
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
128
+ cumwidths = (right - left) * cumwidths + left
129
+ cumwidths[..., 0] = left
130
+ cumwidths[..., -1] = right
131
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
132
+
133
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
134
+
135
+ heights = F.softmax(unnormalized_heights, dim=-1)
136
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
137
+ cumheights = torch.cumsum(heights, dim=-1)
138
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
139
+ cumheights = (top - bottom) * cumheights + bottom
140
+ cumheights[..., 0] = bottom
141
+ cumheights[..., -1] = top
142
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
143
+
144
+ if inverse:
145
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
146
+ else:
147
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
148
+
149
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
150
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
151
+
152
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
153
+ delta = heights / widths
154
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
155
+
156
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
157
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
158
+
159
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
160
+
161
+ if inverse:
162
+ a = (inputs - input_cumheights) * (
163
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
164
+ ) + input_heights * (input_delta - input_derivatives)
165
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
166
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
167
+ )
168
+ c = -input_delta * (inputs - input_cumheights)
169
+
170
+ discriminant = b.pow(2) - 4 * a * c
171
+ assert (discriminant >= 0).all()
172
+
173
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
174
+ outputs = root * input_bin_widths + input_cumwidths
175
+
176
+ theta_one_minus_theta = root * (1 - root)
177
+ denominator = input_delta + (
178
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
179
+ * theta_one_minus_theta
180
+ )
181
+ derivative_numerator = input_delta.pow(2) * (
182
+ input_derivatives_plus_one * root.pow(2)
183
+ + 2 * input_delta * theta_one_minus_theta
184
+ + input_derivatives * (1 - root).pow(2)
185
+ )
186
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
187
+
188
+ return outputs, -logabsdet
189
+ else:
190
+ theta = (inputs - input_cumwidths) / input_bin_widths
191
+ theta_one_minus_theta = theta * (1 - theta)
192
+
193
+ numerator = input_heights * (
194
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
195
+ )
196
+ denominator = input_delta + (
197
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
198
+ * theta_one_minus_theta
199
+ )
200
+ outputs = input_cumheights + numerator / denominator
201
+
202
+ derivative_numerator = input_delta.pow(2) * (
203
+ input_derivatives_plus_one * theta.pow(2)
204
+ + 2 * input_delta * theta_one_minus_theta
205
+ + input_derivatives * (1 - theta).pow(2)
206
+ )
207
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
208
+
209
+ return outputs, logabsdet
rvc/layers/utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Iterator
2
+
3
+ import torch
4
+
5
+
6
+ def call_weight_data_normal_if_Conv(m: torch.nn.Module):
7
+ classname = m.__class__.__name__
8
+ if classname.find("Conv") != -1:
9
+ mean = 0.0
10
+ std = 0.01
11
+ m.weight.data.normal_(mean, std)
12
+
13
+
14
+ def get_padding(kernel_size: int, dilation=1) -> int:
15
+ return int((kernel_size * dilation - dilation) / 2)
16
+
17
+
18
+ def slice_on_last_dim(
19
+ x: torch.Tensor,
20
+ start_indices: List[int],
21
+ segment_size=4,
22
+ ) -> torch.Tensor:
23
+ new_shape = [*x.shape]
24
+ new_shape[-1] = segment_size
25
+ ret = torch.empty(new_shape, device=x.device)
26
+ for i in range(x.size(0)):
27
+ idx_str = start_indices[i]
28
+ idx_end = idx_str + segment_size
29
+ ret[i, ..., :] = x[i, ..., idx_str:idx_end]
30
+ return ret
31
+
32
+
33
+ def rand_slice_segments_on_last_dim(
34
+ x: torch.Tensor,
35
+ x_lengths: int = None,
36
+ segment_size=4,
37
+ ) -> Tuple[torch.Tensor, List[int]]:
38
+ b, _, t = x.size()
39
+ if x_lengths is None:
40
+ x_lengths = t
41
+ ids_str_max = x_lengths - segment_size + 1
42
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
43
+ ret = slice_on_last_dim(x, ids_str, segment_size)
44
+ return ret, ids_str
45
+
46
+
47
+ @torch.jit.script
48
+ def activate_add_tanh_sigmoid_multiply(
49
+ input_a: torch.Tensor, input_b: torch.Tensor, n_channels: int
50
+ ) -> torch.Tensor:
51
+ in_act = input_a + input_b
52
+ t_act = torch.tanh(in_act[:, :n_channels, :])
53
+ s_act = torch.sigmoid(in_act[:, n_channels:, :])
54
+ acts = t_act * s_act
55
+ return acts
56
+
57
+
58
+ def sequence_mask(
59
+ length: torch.Tensor,
60
+ max_length: Optional[int] = None,
61
+ ) -> torch.BoolTensor:
62
+ if max_length is None:
63
+ max_length = int(length.max())
64
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
65
+ return x.unsqueeze(0) < length.unsqueeze(1)
66
+
67
+
68
+ def total_grad_norm(
69
+ parameters: Iterator[torch.nn.Parameter],
70
+ norm_type: float = 2.0,
71
+ ) -> float:
72
+ norm_type = float(norm_type)
73
+ total_norm = 0.0
74
+
75
+ for p in parameters:
76
+ if p.grad is None:
77
+ continue
78
+ param_norm = p.grad.data.norm(norm_type)
79
+ total_norm += float(param_norm.item()) ** norm_type
80
+ total_norm = total_norm ** (1.0 / norm_type)
81
+
82
+ return total_norm
rvc/onnx/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .infer import RVC
2
+ from .exporter import export_onnx
rvc/onnx/exporter.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .synthesizer import SynthesizerTrnMsNSFsid
4
+
5
+
6
+ def export_onnx(from_cpkt_pth: str, to_onnx_pth: str) -> str:
7
+ cpt = torch.load(from_cpkt_pth, map_location="cpu")
8
+ cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
9
+ vec_channels = 256 if cpt.get("version", "v1") == "v1" else 768
10
+
11
+ test_phone = torch.rand(1, 200, vec_channels) # hidden unit
12
+ test_phone_lengths = torch.tensor([200]).long() # hidden unit 长度(貌似没啥用)
13
+ test_pitch = torch.randint(size=(1, 200), low=5, high=255) # 基频(单位赫兹)
14
+ test_pitchf = torch.rand(1, 200) # nsf基频
15
+ test_ds = torch.LongTensor([0]) # 说话人ID
16
+ test_rnd = torch.rand(1, 192, 200) # 噪声(加入随机因子)
17
+
18
+ device = "cpu" # 导出时设备(不影响使用模型)
19
+
20
+ net_g = SynthesizerTrnMsNSFsid(
21
+ *cpt["config"], encoder_dim=vec_channels
22
+ ) # fp32导出(C++要支持fp16必须手动将内存重新排列所以暂时不用fp16)
23
+ net_g.load_state_dict(cpt["weight"], strict=False)
24
+ input_names = ["phone", "phone_lengths", "pitch", "pitchf", "ds", "rnd"]
25
+ output_names = [
26
+ "audio",
27
+ ]
28
+ # net_g.construct_spkmixmap() #多角色混合轨道导出
29
+ torch.onnx.export(
30
+ net_g,
31
+ (
32
+ test_phone.to(device),
33
+ test_phone_lengths.to(device),
34
+ test_pitch.to(device),
35
+ test_pitchf.to(device),
36
+ test_ds.to(device),
37
+ test_rnd.to(device),
38
+ ),
39
+ to_onnx_pth,
40
+ dynamic_axes={
41
+ "phone": [1],
42
+ "pitch": [1],
43
+ "pitchf": [1],
44
+ "rnd": [2],
45
+ },
46
+ do_constant_folding=False,
47
+ opset_version=17,
48
+ verbose=False,
49
+ input_names=input_names,
50
+ output_names=output_names,
51
+ )
52
+ return "Finished"
rvc/onnx/infer.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+ import os
3
+
4
+ import librosa
5
+ import numpy as np
6
+ import onnxruntime
7
+
8
+ from rvc.f0 import (
9
+ PM,
10
+ Harvest,
11
+ Dio,
12
+ F0Predictor,
13
+ )
14
+
15
+
16
+ class Model:
17
+ def __init__(
18
+ self,
19
+ path: typing.Union[str, bytes, os.PathLike],
20
+ device: typing.Literal["cpu", "cuda", "dml"] = "cpu",
21
+ ):
22
+ if device == "cpu":
23
+ providers = ["CPUExecutionProvider"]
24
+ elif device == "cuda":
25
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
26
+ elif device == "dml":
27
+ providers = ["DmlExecutionProvider"]
28
+ else:
29
+ raise RuntimeError("Unsportted Device")
30
+ self.model = onnxruntime.InferenceSession(path, providers=providers)
31
+
32
+
33
+ class ContentVec(Model):
34
+ def __init__(
35
+ self,
36
+ vec_path: typing.Union[str, bytes, os.PathLike],
37
+ device: typing.Literal["cpu", "cuda", "dml"] = "cpu",
38
+ ):
39
+ super().__init__(vec_path, device)
40
+
41
+ def __call__(self, wav: np.ndarray[typing.Any, np.dtype]):
42
+ return self.forward(wav)
43
+
44
+ def forward(self, wav: np.ndarray[typing.Any, np.dtype]):
45
+ if wav.ndim == 2: # double channels
46
+ wav = wav.mean(-1)
47
+ assert wav.ndim == 1, wav.ndim
48
+ wav = np.expand_dims(np.expand_dims(wav, 0), 0)
49
+ onnx_input = {self.model.get_inputs()[0].name: wav}
50
+ logits = self.model.run(None, onnx_input)[0]
51
+ return logits.transpose(0, 2, 1)
52
+
53
+
54
+ predictors: typing.Dict[str, F0Predictor] = {
55
+ "pm": PM,
56
+ "harvest": Harvest,
57
+ "dio": Dio,
58
+ }
59
+
60
+
61
+ def get_f0_predictor(
62
+ f0_method: str, hop_length: int, sampling_rate: int
63
+ ) -> F0Predictor:
64
+ return predictors[f0_method](hop_length=hop_length, sampling_rate=sampling_rate)
65
+
66
+
67
+ class RVC(Model):
68
+ def __init__(
69
+ self,
70
+ model_path: typing.Union[str, bytes, os.PathLike],
71
+ hop_len=512,
72
+ vec_path: typing.Union[str, bytes, os.PathLike] = "vec-768-layer-12.onnx",
73
+ device: typing.Literal["cpu", "cuda", "dml"] = "cpu",
74
+ ):
75
+ super().__init__(model_path, device)
76
+ self.vec_model = ContentVec(vec_path, device)
77
+ self.hop_len = hop_len
78
+
79
+ def infer(
80
+ self,
81
+ wav: np.ndarray[typing.Any, np.dtype],
82
+ wav_sr: int,
83
+ model_sr: int = 40000,
84
+ sid: int = 0,
85
+ f0_method="dio",
86
+ f0_up_key=0,
87
+ ) -> np.ndarray[typing.Any, np.dtype[np.int16]]:
88
+ f0_min = 50
89
+ f0_max = 1100
90
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
91
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
92
+ f0_predictor = get_f0_predictor(
93
+ f0_method,
94
+ self.hop_len,
95
+ model_sr,
96
+ )
97
+ org_length = len(wav)
98
+ if org_length / wav_sr > 50.0:
99
+ raise RuntimeError("wav max length exceeded")
100
+
101
+ hubert = self.vec_model(librosa.resample(wav, orig_sr=wav_sr, target_sr=16000))
102
+ hubert = np.repeat(hubert, 2, axis=2).transpose(0, 2, 1).astype(np.float32)
103
+ hubert_length = hubert.shape[1]
104
+
105
+ pitchf = f0_predictor.compute_f0(wav, hubert_length)
106
+ pitchf = pitchf * 2 ** (f0_up_key / 12)
107
+ pitch = pitchf.copy()
108
+ f0_mel = 1127 * np.log(1 + pitch / 700)
109
+ f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (
110
+ f0_mel_max - f0_mel_min
111
+ ) + 1
112
+ f0_mel[f0_mel <= 1] = 1
113
+ f0_mel[f0_mel > 255] = 255
114
+ pitch = np.rint(f0_mel).astype(np.int64)
115
+
116
+ pitchf = pitchf.reshape(1, len(pitchf)).astype(np.float32)
117
+ pitch = pitch.reshape(1, len(pitch))
118
+ ds = np.array([sid]).astype(np.int64)
119
+
120
+ rnd = np.random.randn(1, 192, hubert_length).astype(np.float32)
121
+ hubert_length = np.array([hubert_length]).astype(np.int64)
122
+
123
+ out_wav = self.forward(hubert, hubert_length, pitch, pitchf, ds, rnd).squeeze()
124
+
125
+ out_wav = np.pad(out_wav, (0, 2 * self.hop_len), "constant")
126
+
127
+ return out_wav[0:org_length]
128
+
129
+ def forward(
130
+ self,
131
+ hubert: np.ndarray[typing.Any, np.dtype[np.float32]],
132
+ hubert_length: int,
133
+ pitch: np.ndarray[typing.Any, np.dtype[np.int64]],
134
+ pitchf: np.ndarray[typing.Any, np.dtype[np.float32]],
135
+ ds: np.ndarray[typing.Any, np.dtype[np.int64]],
136
+ rnd: np.ndarray[typing.Any, np.dtype[np.float32]],
137
+ ) -> np.ndarray[typing.Any, np.dtype[np.int16]]:
138
+ onnx_input = {
139
+ self.model.get_inputs()[0].name: hubert,
140
+ self.model.get_inputs()[1].name: hubert_length,
141
+ self.model.get_inputs()[2].name: pitch,
142
+ self.model.get_inputs()[3].name: pitchf,
143
+ self.model.get_inputs()[4].name: ds,
144
+ self.model.get_inputs()[5].name: rnd,
145
+ }
146
+ return (self.model.run(None, onnx_input)[0] * 32767).astype(np.int16)
rvc/onnx/synthesizer.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ import torch
4
+
5
+ from rvc.layers.synthesizers import SynthesizerTrnMsNSFsid as SynthesizerBase
6
+
7
+
8
+ class SynthesizerTrnMsNSFsid(SynthesizerBase):
9
+ def __init__(
10
+ self,
11
+ spec_channels: int,
12
+ segment_size: int,
13
+ inter_channels: int,
14
+ hidden_channels: int,
15
+ filter_channels: int,
16
+ n_heads: int,
17
+ n_layers: int,
18
+ kernel_size: int,
19
+ p_dropout: int,
20
+ resblock: str,
21
+ resblock_kernel_sizes: List[int],
22
+ resblock_dilation_sizes: List[List[int]],
23
+ upsample_rates: List[int],
24
+ upsample_initial_channel: int,
25
+ upsample_kernel_sizes: List[int],
26
+ spk_embed_dim: int,
27
+ gin_channels: int,
28
+ sr: Optional[Union[str, int]],
29
+ encoder_dim: int,
30
+ ):
31
+ super().__init__(
32
+ spec_channels,
33
+ segment_size,
34
+ inter_channels,
35
+ hidden_channels,
36
+ filter_channels,
37
+ n_heads,
38
+ n_layers,
39
+ kernel_size,
40
+ p_dropout,
41
+ resblock,
42
+ resblock_kernel_sizes,
43
+ resblock_dilation_sizes,
44
+ upsample_rates,
45
+ upsample_initial_channel,
46
+ upsample_kernel_sizes,
47
+ spk_embed_dim,
48
+ gin_channels,
49
+ sr,
50
+ encoder_dim,
51
+ True,
52
+ )
53
+ self.speaker_map = None
54
+
55
+ def remove_weight_norm(self):
56
+ self.dec.remove_weight_norm()
57
+ self.flow.remove_weight_norm()
58
+ self.enc_q.remove_weight_norm()
59
+
60
+ def construct_spkmixmap(self):
61
+ self.speaker_map = torch.zeros((self.n_speaker, 1, 1, self.gin_channels))
62
+ for i in range(self.n_speaker):
63
+ self.speaker_map[i] = self.emb_g(torch.LongTensor([[i]]))
64
+ self.speaker_map = self.speaker_map.unsqueeze(0)
65
+
66
+ def forward(self, phone, phone_lengths, pitch, nsff0, g, rnd, max_len=None):
67
+ if self.speaker_map is not None: # [N, S] * [S, B, 1, H]
68
+ g = g.reshape((g.shape[0], g.shape[1], 1, 1, 1)) # [N, S, B, 1, 1]
69
+ g = g * self.speaker_map # [N, S, B, 1, H]
70
+ g = torch.sum(g, dim=1) # [N, 1, B, 1, H]
71
+ g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N]
72
+ else:
73
+ g = g.unsqueeze(0)
74
+ g = self.emb_g(g).transpose(1, 2)
75
+
76
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
77
+ z_p = (m_p + torch.exp(logs_p) * rnd) * x_mask
78
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
79
+ o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
80
+ return o
rvc/synthesizer.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+
5
+ from .layers.synthesizers import SynthesizerTrnMsNSFsid
6
+ from .jit import load_inputs, export_jit_model, save_pickle
7
+
8
+
9
+ def get_synthesizer(cpt: OrderedDict, device=torch.device("cpu")):
10
+ cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
11
+ if_f0 = cpt.get("f0", 1)
12
+ version = cpt.get("version", "v1")
13
+ if version == "v1":
14
+ encoder_dim = 256
15
+ elif version == "v2":
16
+ encoder_dim = 768
17
+ net_g = SynthesizerTrnMsNSFsid(
18
+ *cpt["config"],
19
+ encoder_dim=encoder_dim,
20
+ use_f0=if_f0 == 1,
21
+ )
22
+ del net_g.enc_q
23
+ net_g.load_state_dict(cpt["weight"], strict=False)
24
+ net_g = net_g.float()
25
+ net_g.eval().to(device)
26
+ net_g.remove_weight_norm()
27
+ return net_g, cpt
28
+
29
+
30
+ def load_synthesizer(
31
+ pth_path: torch.serialization.FILE_LIKE, device=torch.device("cpu")
32
+ ):
33
+ return get_synthesizer(
34
+ torch.load(pth_path, map_location=torch.device("cpu")),
35
+ device,
36
+ )
37
+
38
+
39
+ def synthesizer_jit_export(
40
+ model_path: str,
41
+ mode: str = "script",
42
+ inputs_path: str = None,
43
+ save_path: str = None,
44
+ device=torch.device("cpu"),
45
+ is_half=False,
46
+ ):
47
+ if not save_path:
48
+ save_path = model_path.rstrip(".pth")
49
+ save_path += ".half.jit" if is_half else ".jit"
50
+ if "cuda" in str(device) and ":" not in str(device):
51
+ device = torch.device("cuda:0")
52
+ from rvc.synthesizer import load_synthesizer
53
+
54
+ model, cpt = load_synthesizer(model_path, device)
55
+ assert isinstance(cpt, dict)
56
+ model.forward = model.infer
57
+ inputs = None
58
+ if mode == "trace":
59
+ inputs = load_inputs(inputs_path, device, is_half)
60
+ ckpt = export_jit_model(model, mode, inputs, device, is_half)
61
+ cpt.pop("weight")
62
+ cpt["model"] = ckpt["model"]
63
+ cpt["device"] = device
64
+ save_pickle(cpt, save_path)
65
+ return cpt