Spaces:
Runtime error
Runtime error
SayaSS
commited on
Commit
•
c17721a
1
Parent(s):
c2dde5f
update
Browse files- Eng_docs.md +0 -109
- app.py +7 -12
- data_utils.py +0 -142
- flask_api.py +0 -56
- inference/__pycache__/__init__.cpython-38.pyc +0 -0
- inference/__pycache__/infer_tool.cpython-38.pyc +0 -0
- inference/__pycache__/slicer.cpython-38.pyc +0 -0
- inference/infer_tool.py +62 -22
- modules/__pycache__/__init__.cpython-38.pyc +0 -0
- modules/__pycache__/attentions.cpython-38.pyc +0 -0
- modules/__pycache__/commons.cpython-38.pyc +0 -0
- modules/__pycache__/modules.cpython-38.pyc +0 -0
- preprocess_flist_config.py +0 -67
- preprocess_hubert_f0.py +0 -62
- resample.py +0 -48
- spec_gen.py +0 -22
- train.py +0 -297
- utils.py +3 -9
Eng_docs.md
DELETED
@@ -1,109 +0,0 @@
|
|
1 |
-
# SoftVC VITS Singing Voice Conversion
|
2 |
-
|
3 |
-
## Updates
|
4 |
-
> According to incomplete statistics, it seems that training with multiple speakers may lead to **worsened leaking of voice timbre**. It is not recommended to train models with more than 5 speakers. The current suggestion is to try to train models with only a single speaker if you want to achieve a voice timbre that is more similar to the target.
|
5 |
-
> Fixed the issue with unwanted staccato, improving audio quality by a decent amount.\
|
6 |
-
> The 2.0 version has been moved to the 2.0 branch.\
|
7 |
-
> Version 3.0 uses the code structure of FreeVC, which isn't compatible with older versions.\
|
8 |
-
> Compared to [DiffSVC](https://github.com/prophesier/diff-svc) , diffsvc performs much better when the training data is of extremely high quality, but this repository may perform better on datasets with lower quality. Additionally, this repository is much faster in terms of inference speed compared to diffsvc.
|
9 |
-
|
10 |
-
## Model Overview
|
11 |
-
A singing voice coversion (SVC) model, using the SoftVC encoder to extract features from the input audio, sent into VITS along with the F0 to replace the original input to acheive a voice conversion effect. Additionally, changing the vocoder to [NSF HiFiGAN](https://github.com/openvpi/DiffSinger/tree/refactor/modules/nsf_hifigan) to fix the issue with unwanted staccato.
|
12 |
-
|
13 |
-
## Notice
|
14 |
-
+ The current branch is the 32kHz version, which requires less vram during inferencing, as well as faster inferencing speeds, and datasets for said branch take up less disk space. Thus the 32 kHz branch is recommended for use.
|
15 |
-
+ If you want to train 48 kHz variant models, switch to the [main branch](https://github.com/innnky/so-vits-svc/tree/main).
|
16 |
-
|
17 |
-
|
18 |
-
## Required models
|
19 |
-
+ soft vc hubert:[hubert-soft-0d54a1f4.pt](https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt)
|
20 |
-
+ Place under `hubert`.
|
21 |
-
+ Pretrained models [G_0.pth](https://huggingface.co/innnky/sovits_pretrained/resolve/main/G_0.pth) and [D_0.pth](https://huggingface.co/innnky/sovits_pretrained/resolve/main/D_0.pth)
|
22 |
-
+ Place under `logs/32k`.
|
23 |
-
+ Pretrained models are required, because from experiments, training from scratch can be rather unpredictable to say the least, and training with a pretrained model can greatly improve training speeds.
|
24 |
-
+ The pretrained model includes云灏, 即霜, 辉宇·星AI, 派蒙, and 绫地宁宁, covering the common ranges of both male and female voices, and so it can be seen as a rather universal pretrained model.
|
25 |
-
+ The pretrained model exludes the `optimizer speaker_embedding` section, rendering it only usable for pretraining and incapable of inferencing with.
|
26 |
-
```shell
|
27 |
-
# For simple downloading.
|
28 |
-
# hubert
|
29 |
-
wget -P hubert/ https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt
|
30 |
-
# G&D pretrained models
|
31 |
-
wget -P logs/32k/ https://huggingface.co/innnky/sovits_pretrained/resolve/main/G_0.pth
|
32 |
-
wget -P logs/32k/ https://huggingface.co/innnky/sovits_pretrained/resolve/main/D_0.pth
|
33 |
-
|
34 |
-
```
|
35 |
-
|
36 |
-
## Colab notebook script for dataset creation and training.
|
37 |
-
[colab training notebook](https://colab.research.google.com/drive/1rCUOOVG7-XQlVZuWRAj5IpGrMM8t07pE?usp=sharing)
|
38 |
-
|
39 |
-
## Dataset preparation
|
40 |
-
All that is required is that the data be put under the `dataset_raw` folder in the structure format provided below.
|
41 |
-
```shell
|
42 |
-
dataset_raw
|
43 |
-
├───speaker0
|
44 |
-
│ ├───xxx1-xxx1.wav
|
45 |
-
│ ├───...
|
46 |
-
│ └───Lxx-0xx8.wav
|
47 |
-
└───speaker1
|
48 |
-
├───xx2-0xxx2.wav
|
49 |
-
├───...
|
50 |
-
└───xxx7-xxx007.wav
|
51 |
-
```
|
52 |
-
|
53 |
-
## Data pre-processing.
|
54 |
-
1. Resample to 32khz
|
55 |
-
|
56 |
-
```shell
|
57 |
-
python resample.py
|
58 |
-
```
|
59 |
-
2. Automatically sort out training set, validation set, test set, and automatically generate configuration files.
|
60 |
-
```shell
|
61 |
-
python preprocess_flist_config.py
|
62 |
-
# Notice.
|
63 |
-
# The n_speakers value in the config will be set automatically according to the amount of speakers in the dataset.
|
64 |
-
# To reserve space for additionally added speakers in the dataset, the n_speakers value will be be set to twice the actual amount.
|
65 |
-
# If you want even more space for adding more data, you can edit the n_speakers value in the config after runing this step.
|
66 |
-
# This can not be changed after training starts.
|
67 |
-
```
|
68 |
-
3. Generate hubert and F0 features/
|
69 |
-
```shell
|
70 |
-
python preprocess_hubert_f0.py
|
71 |
-
```
|
72 |
-
After running the step above, the `dataset` folder will contain all the pre-processed data, you can delete the `dataset_raw` folder after that.
|
73 |
-
|
74 |
-
## Training.
|
75 |
-
```shell
|
76 |
-
python train.py -c configs/config.json -m 32k
|
77 |
-
```
|
78 |
-
|
79 |
-
## Inferencing.
|
80 |
-
|
81 |
-
Use [inference_main.py](inference_main.py)
|
82 |
-
+ Edit `model_path` to your newest checkpoint.
|
83 |
-
+ Place the input audio under the `raw` folder.
|
84 |
-
+ Change `clean_names` to the output file name.
|
85 |
-
+ Use `trans` to edit the pitch shifting amount (semitones).
|
86 |
-
+ Change `spk_list` to the speaker name.
|
87 |
-
|
88 |
-
## Onnx Exporting.
|
89 |
-
### **When exporting Onnx, please make sure you re-clone the whole repository!!!**
|
90 |
-
Use [onnx_export.py](onnx_export.py)
|
91 |
-
+ Create a new folder called `checkpoints`.
|
92 |
-
+ Create a project folder in `checkpoints` folder with the desired name for your project, let's use `myproject` as example. Folder structure looks like `./checkpoints/myproject`.
|
93 |
-
+ Rename your model to `model.pth`, rename your config file to `config.json` then move them into `myproject` folder.
|
94 |
-
+ Modify [onnx_export.py](onnx_export.py) where `path = "NyaruTaffy"`, change `NyaruTaffy` to your project name, here it will be `path = "myproject"`.
|
95 |
-
+ Run [onnx_export.py](onnx_export.py)
|
96 |
-
+ Once it finished, a `model.onnx` will be generated in `myproject` folder, that's the model you just exported.
|
97 |
-
+ Notice: if you want to export a 48K model, please follow the instruction below or use `model_onnx_48k.py` directly.
|
98 |
-
+ Open [model_onnx.py](model_onnx.py) and change `hps={"sampling_rate": 32000...}` to `hps={"sampling_rate": 48000}` in class `SynthesizerTrn`.
|
99 |
-
+ Open [nvSTFT](/vdecoder/hifigan/nvSTFT.py) and replace all `32000` with `48000`
|
100 |
-
### Onnx Model UI Support
|
101 |
-
+ [MoeSS](https://github.com/NaruseMioShirakana/MoeSS)
|
102 |
-
+ All training function and transformation are removed, only if they are all removed you are actually using Onnx.
|
103 |
-
|
104 |
-
## Gradio (WebUI)
|
105 |
-
Use [sovits_gradio.py](sovits_gradio.py) to run Gradio WebUI
|
106 |
-
+ Create a new folder called `checkpoints`.
|
107 |
-
+ Create a project folder in `checkpoints` folder with the desired name for your project, let's use `myproject` as example. Folder structure looks like `./checkpoints/myproject`.
|
108 |
-
+ Rename your model to `model.pth`, rename your config file to `config.json` then move them into `myproject` folder.
|
109 |
-
+ Run [sovits_gradio.py](sovits_gradio.py)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -31,20 +31,15 @@ def create_vc_fn(model, sid):
|
|
31 |
if input_audio is None:
|
32 |
return "You need to upload an audio", None
|
33 |
sampling_rate, audio = input_audio
|
34 |
-
# print(audio.shape,sampling_rate)
|
35 |
duration = audio.shape[0] / sampling_rate
|
36 |
-
if duration >
|
37 |
-
return "Please upload an audio file that is less than
|
38 |
audio = (audio / np.iinfo(audio.dtype).max).astype(np.float32)
|
39 |
if len(audio.shape) > 1:
|
40 |
audio = librosa.to_mono(audio.transpose(1, 0))
|
41 |
-
if sampling_rate !=
|
42 |
-
audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=
|
43 |
-
|
44 |
-
soundfile.write(out_wav_path, audio, 16000, format="wav")
|
45 |
-
out_audio, out_sr = model.infer(sid, vc_transform, out_wav_path,
|
46 |
-
auto_predict_f0=auto_f0,
|
47 |
-
)
|
48 |
return "Success", (44100, out_audio.cpu().numpy())
|
49 |
return vc_fn
|
50 |
|
@@ -64,11 +59,11 @@ if __name__ == '__main__':
|
|
64 |
models.append((name, cover, create_vc_fn(model, name)))
|
65 |
with gr.Blocks() as app:
|
66 |
gr.Markdown(
|
67 |
-
"# <center> Sovits
|
68 |
"## <center> The input audio should be clean and pure voice without background music.\n"
|
69 |
"![visitor badge](https://visitor-badge.glitch.me/badge?page_id=sayashi.Sovits-Umamusume)\n\n"
|
70 |
"[Open In Colab](https://colab.research.google.com/drive/1wfsBbMzmtLflOJeqc5ZnJiLY7L239hJW?usp=share_link)"
|
71 |
-
"
|
72 |
"[Original Repo](https://github.com/innnky/so-vits-svc/tree/4.0)"
|
73 |
)
|
74 |
with gr.Tabs():
|
|
|
31 |
if input_audio is None:
|
32 |
return "You need to upload an audio", None
|
33 |
sampling_rate, audio = input_audio
|
|
|
34 |
duration = audio.shape[0] / sampling_rate
|
35 |
+
if duration > 30 and limitation:
|
36 |
+
return "Please upload an audio file that is less than 30 seconds. If you need to generate a longer audio file, please use Colab.", None
|
37 |
audio = (audio / np.iinfo(audio.dtype).max).astype(np.float32)
|
38 |
if len(audio.shape) > 1:
|
39 |
audio = librosa.to_mono(audio.transpose(1, 0))
|
40 |
+
if sampling_rate != 44100:
|
41 |
+
audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=44100)
|
42 |
+
out_audio, out_sr = model.infer(sid, vc_transform, audio, auto_predict_f0=auto_f0)
|
|
|
|
|
|
|
|
|
43 |
return "Success", (44100, out_audio.cpu().numpy())
|
44 |
return vc_fn
|
45 |
|
|
|
59 |
models.append((name, cover, create_vc_fn(model, name)))
|
60 |
with gr.Blocks() as app:
|
61 |
gr.Markdown(
|
62 |
+
"# <center> Sovits Models\n"
|
63 |
"## <center> The input audio should be clean and pure voice without background music.\n"
|
64 |
"![visitor badge](https://visitor-badge.glitch.me/badge?page_id=sayashi.Sovits-Umamusume)\n\n"
|
65 |
"[Open In Colab](https://colab.research.google.com/drive/1wfsBbMzmtLflOJeqc5ZnJiLY7L239hJW?usp=share_link)"
|
66 |
+
" without queue and length limitation.\n\n"
|
67 |
"[Original Repo](https://github.com/innnky/so-vits-svc/tree/4.0)"
|
68 |
)
|
69 |
with gr.Tabs():
|
data_utils.py
DELETED
@@ -1,142 +0,0 @@
|
|
1 |
-
import time
|
2 |
-
import os
|
3 |
-
import random
|
4 |
-
import numpy as np
|
5 |
-
import torch
|
6 |
-
import torch.utils.data
|
7 |
-
|
8 |
-
import modules.commons as commons
|
9 |
-
import utils
|
10 |
-
from modules.mel_processing import spectrogram_torch, spec_to_mel_torch
|
11 |
-
from utils import load_wav_to_torch, load_filepaths_and_text
|
12 |
-
|
13 |
-
# import h5py
|
14 |
-
|
15 |
-
|
16 |
-
"""Multi speaker version"""
|
17 |
-
|
18 |
-
|
19 |
-
class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
20 |
-
"""
|
21 |
-
1) loads audio, speaker_id, text pairs
|
22 |
-
2) normalizes text and converts them to sequences of integers
|
23 |
-
3) computes spectrograms from audio files.
|
24 |
-
"""
|
25 |
-
|
26 |
-
def __init__(self, audiopaths, hparams):
|
27 |
-
self.audiopaths = load_filepaths_and_text(audiopaths)
|
28 |
-
self.max_wav_value = hparams.data.max_wav_value
|
29 |
-
self.sampling_rate = hparams.data.sampling_rate
|
30 |
-
self.filter_length = hparams.data.filter_length
|
31 |
-
self.hop_length = hparams.data.hop_length
|
32 |
-
self.win_length = hparams.data.win_length
|
33 |
-
self.sampling_rate = hparams.data.sampling_rate
|
34 |
-
self.use_sr = hparams.train.use_sr
|
35 |
-
self.spec_len = hparams.train.max_speclen
|
36 |
-
self.spk_map = hparams.spk
|
37 |
-
|
38 |
-
random.seed(1234)
|
39 |
-
random.shuffle(self.audiopaths)
|
40 |
-
|
41 |
-
def get_audio(self, filename):
|
42 |
-
filename = filename.replace("\\", "/")
|
43 |
-
audio, sampling_rate = load_wav_to_torch(filename)
|
44 |
-
if sampling_rate != self.sampling_rate:
|
45 |
-
raise ValueError("{} SR doesn't match target {} SR".format(
|
46 |
-
sampling_rate, self.sampling_rate))
|
47 |
-
audio_norm = audio / self.max_wav_value
|
48 |
-
audio_norm = audio_norm.unsqueeze(0)
|
49 |
-
spec_filename = filename.replace(".wav", ".spec.pt")
|
50 |
-
if os.path.exists(spec_filename):
|
51 |
-
spec = torch.load(spec_filename)
|
52 |
-
else:
|
53 |
-
spec = spectrogram_torch(audio_norm, self.filter_length,
|
54 |
-
self.sampling_rate, self.hop_length, self.win_length,
|
55 |
-
center=False)
|
56 |
-
spec = torch.squeeze(spec, 0)
|
57 |
-
torch.save(spec, spec_filename)
|
58 |
-
|
59 |
-
spk = filename.split("/")[-2]
|
60 |
-
spk = torch.LongTensor([self.spk_map[spk]])
|
61 |
-
|
62 |
-
f0 = np.load(filename + ".f0.npy")
|
63 |
-
f0, uv = utils.interpolate_f0(f0)
|
64 |
-
f0 = torch.FloatTensor(f0)
|
65 |
-
uv = torch.FloatTensor(uv)
|
66 |
-
|
67 |
-
c = torch.load(filename+ ".soft.pt")
|
68 |
-
c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[0])
|
69 |
-
|
70 |
-
|
71 |
-
lmin = min(c.size(-1), spec.size(-1))
|
72 |
-
assert abs(c.size(-1) - spec.size(-1)) < 3, (c.size(-1), spec.size(-1), f0.shape, filename)
|
73 |
-
assert abs(audio_norm.shape[1]-lmin * self.hop_length) < 3 * self.hop_length
|
74 |
-
spec, c, f0, uv = spec[:, :lmin], c[:, :lmin], f0[:lmin], uv[:lmin]
|
75 |
-
audio_norm = audio_norm[:, :lmin * self.hop_length]
|
76 |
-
if spec.shape[1] < 60:
|
77 |
-
print("skip too short audio:", filename)
|
78 |
-
return None
|
79 |
-
if spec.shape[1] > 800:
|
80 |
-
start = random.randint(0, spec.shape[1]-800)
|
81 |
-
end = start + 790
|
82 |
-
spec, c, f0, uv = spec[:, start:end], c[:, start:end], f0[start:end], uv[start:end]
|
83 |
-
audio_norm = audio_norm[:, start * self.hop_length : end * self.hop_length]
|
84 |
-
|
85 |
-
return c, f0, spec, audio_norm, spk, uv
|
86 |
-
|
87 |
-
def __getitem__(self, index):
|
88 |
-
return self.get_audio(self.audiopaths[index][0])
|
89 |
-
|
90 |
-
def __len__(self):
|
91 |
-
return len(self.audiopaths)
|
92 |
-
|
93 |
-
|
94 |
-
class TextAudioCollate:
|
95 |
-
|
96 |
-
def __call__(self, batch):
|
97 |
-
batch = [b for b in batch if b is not None]
|
98 |
-
|
99 |
-
input_lengths, ids_sorted_decreasing = torch.sort(
|
100 |
-
torch.LongTensor([x[0].shape[1] for x in batch]),
|
101 |
-
dim=0, descending=True)
|
102 |
-
|
103 |
-
max_c_len = max([x[0].size(1) for x in batch])
|
104 |
-
max_wav_len = max([x[3].size(1) for x in batch])
|
105 |
-
|
106 |
-
lengths = torch.LongTensor(len(batch))
|
107 |
-
|
108 |
-
c_padded = torch.FloatTensor(len(batch), batch[0][0].shape[0], max_c_len)
|
109 |
-
f0_padded = torch.FloatTensor(len(batch), max_c_len)
|
110 |
-
spec_padded = torch.FloatTensor(len(batch), batch[0][2].shape[0], max_c_len)
|
111 |
-
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
|
112 |
-
spkids = torch.LongTensor(len(batch), 1)
|
113 |
-
uv_padded = torch.FloatTensor(len(batch), max_c_len)
|
114 |
-
|
115 |
-
c_padded.zero_()
|
116 |
-
spec_padded.zero_()
|
117 |
-
f0_padded.zero_()
|
118 |
-
wav_padded.zero_()
|
119 |
-
uv_padded.zero_()
|
120 |
-
|
121 |
-
for i in range(len(ids_sorted_decreasing)):
|
122 |
-
row = batch[ids_sorted_decreasing[i]]
|
123 |
-
|
124 |
-
c = row[0]
|
125 |
-
c_padded[i, :, :c.size(1)] = c
|
126 |
-
lengths[i] = c.size(1)
|
127 |
-
|
128 |
-
f0 = row[1]
|
129 |
-
f0_padded[i, :f0.size(0)] = f0
|
130 |
-
|
131 |
-
spec = row[2]
|
132 |
-
spec_padded[i, :, :spec.size(1)] = spec
|
133 |
-
|
134 |
-
wav = row[3]
|
135 |
-
wav_padded[i, :, :wav.size(1)] = wav
|
136 |
-
|
137 |
-
spkids[i, 0] = row[4]
|
138 |
-
|
139 |
-
uv = row[5]
|
140 |
-
uv_padded[i, :uv.size(0)] = uv
|
141 |
-
|
142 |
-
return c_padded, f0_padded, spec_padded, wav_padded, spkids, lengths, uv_padded
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flask_api.py
DELETED
@@ -1,56 +0,0 @@
|
|
1 |
-
import io
|
2 |
-
import logging
|
3 |
-
|
4 |
-
import soundfile
|
5 |
-
import torch
|
6 |
-
import torchaudio
|
7 |
-
from flask import Flask, request, send_file
|
8 |
-
from flask_cors import CORS
|
9 |
-
|
10 |
-
from inference.infer_tool import Svc, RealTimeVC
|
11 |
-
|
12 |
-
app = Flask(__name__)
|
13 |
-
|
14 |
-
CORS(app)
|
15 |
-
|
16 |
-
logging.getLogger('numba').setLevel(logging.WARNING)
|
17 |
-
|
18 |
-
|
19 |
-
@app.route("/voiceChangeModel", methods=["POST"])
|
20 |
-
def voice_change_model():
|
21 |
-
request_form = request.form
|
22 |
-
wave_file = request.files.get("sample", None)
|
23 |
-
# 变调信息
|
24 |
-
f_pitch_change = float(request_form.get("fPitchChange", 0))
|
25 |
-
# DAW所需的采样率
|
26 |
-
daw_sample = int(float(request_form.get("sampleRate", 0)))
|
27 |
-
speaker_id = int(float(request_form.get("sSpeakId", 0)))
|
28 |
-
# http获得wav文件并转换
|
29 |
-
input_wav_path = io.BytesIO(wave_file.read())
|
30 |
-
|
31 |
-
# 模型推理
|
32 |
-
if raw_infer:
|
33 |
-
out_audio, out_sr = svc_model.infer(speaker_id, f_pitch_change, input_wav_path)
|
34 |
-
tar_audio = torchaudio.functional.resample(out_audio, svc_model.target_sample, daw_sample)
|
35 |
-
else:
|
36 |
-
out_audio = svc.process(svc_model, speaker_id, f_pitch_change, input_wav_path)
|
37 |
-
tar_audio = torchaudio.functional.resample(torch.from_numpy(out_audio), svc_model.target_sample, daw_sample)
|
38 |
-
# 返回音频
|
39 |
-
out_wav_path = io.BytesIO()
|
40 |
-
soundfile.write(out_wav_path, tar_audio.cpu().numpy(), daw_sample, format="wav")
|
41 |
-
out_wav_path.seek(0)
|
42 |
-
return send_file(out_wav_path, download_name="temp.wav", as_attachment=True)
|
43 |
-
|
44 |
-
|
45 |
-
if __name__ == '__main__':
|
46 |
-
# 启用则为直接切片合成,False为交叉淡化方式
|
47 |
-
# vst插件调整0.3-0.5s切片时间可以降低延迟,直接切片方法会有连接处爆音、交叉淡化会有轻微重叠声音
|
48 |
-
# 自行选择能接受的方法,或将vst最大切片时间调整为1s,此处设为Ture,延迟大音质稳定一些
|
49 |
-
raw_infer = True
|
50 |
-
# 每个模型和config是唯一对应的
|
51 |
-
model_name = "logs/32k/G_174000-Copy1.pth"
|
52 |
-
config_name = "configs/config.json"
|
53 |
-
svc_model = Svc(model_name, config_name)
|
54 |
-
svc = RealTimeVC()
|
55 |
-
# 此处与vst插件对应,不建议更改
|
56 |
-
app.run(port=6842, host="0.0.0.0", debug=False, threaded=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference/__pycache__/__init__.cpython-38.pyc
CHANGED
Binary files a/inference/__pycache__/__init__.cpython-38.pyc and b/inference/__pycache__/__init__.cpython-38.pyc differ
|
|
inference/__pycache__/infer_tool.cpython-38.pyc
CHANGED
Binary files a/inference/__pycache__/infer_tool.cpython-38.pyc and b/inference/__pycache__/infer_tool.cpython-38.pyc differ
|
|
inference/__pycache__/slicer.cpython-38.pyc
CHANGED
Binary files a/inference/__pycache__/slicer.cpython-38.pyc and b/inference/__pycache__/slicer.cpython-38.pyc differ
|
|
inference/infer_tool.py
CHANGED
@@ -92,6 +92,21 @@ def mkdir(paths: list):
|
|
92 |
if not os.path.exists(path):
|
93 |
os.mkdir(path)
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
class Svc(object):
|
97 |
def __init__(self, net_g_path, config_path,
|
@@ -127,10 +142,7 @@ class Svc(object):
|
|
127 |
|
128 |
|
129 |
|
130 |
-
def get_unit_f0(self,
|
131 |
-
|
132 |
-
wav, sr = librosa.load(in_path, sr=self.target_sample)
|
133 |
-
|
134 |
f0 = utils.compute_f0_parselmouth(wav, sampling_rate=self.target_sample, hop_length=self.hop_size)
|
135 |
f0, uv = utils.interpolate_f0(f0)
|
136 |
f0 = torch.FloatTensor(f0)
|
@@ -139,26 +151,29 @@ class Svc(object):
|
|
139 |
f0 = f0.unsqueeze(0).to(self.dev)
|
140 |
uv = uv.unsqueeze(0).to(self.dev)
|
141 |
|
142 |
-
wav16k = librosa.resample(wav, orig_sr=
|
143 |
wav16k = torch.from_numpy(wav16k).to(self.dev)
|
144 |
c = utils.get_hubert_content(self.hubert_model, wav_16k_tensor=wav16k)
|
145 |
c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1])
|
146 |
|
147 |
if cluster_infer_ratio !=0:
|
148 |
-
cluster_c = cluster.get_cluster_center_result(self.cluster_model, c.numpy().T, speaker).T
|
149 |
-
cluster_c = torch.FloatTensor(cluster_c)
|
150 |
c = cluster_infer_ratio * cluster_c + (1 - cluster_infer_ratio) * c
|
151 |
|
152 |
c = c.unsqueeze(0)
|
153 |
return c, f0, uv
|
154 |
|
155 |
-
def infer(self, speaker, tran,
|
156 |
cluster_infer_ratio=0,
|
157 |
auto_predict_f0=False,
|
158 |
noice_scale=0.4):
|
159 |
-
speaker_id = self.spk2id
|
|
|
|
|
|
|
160 |
sid = torch.LongTensor([int(speaker_id)]).to(self.dev).unsqueeze(0)
|
161 |
-
c, f0, uv = self.get_unit_f0(
|
162 |
if "half" in self.net_g_path and torch.cuda.is_available():
|
163 |
c = c.half()
|
164 |
with torch.no_grad():
|
@@ -167,39 +182,64 @@ class Svc(object):
|
|
167 |
use_time = time.time() - start
|
168 |
print("vits use time:{}".format(use_time))
|
169 |
return audio, audio.shape[-1]
|
|
|
|
|
|
|
|
|
170 |
|
171 |
-
def slice_inference(self,raw_audio_path, spk, tran, slice_db,cluster_infer_ratio, auto_predict_f0,noice_scale, pad_seconds=0.5):
|
172 |
wav_path = raw_audio_path
|
173 |
chunks = slicer.cut(wav_path, db_thresh=slice_db)
|
174 |
audio_data, audio_sr = slicer.chunks2audio(wav_path, chunks)
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
audio = []
|
177 |
for (slice_tag, data) in audio_data:
|
178 |
print(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======')
|
179 |
# padd
|
180 |
-
pad_len = int(audio_sr * pad_seconds)
|
181 |
-
data = np.concatenate([np.zeros([pad_len]), data, np.zeros([pad_len])])
|
182 |
length = int(np.ceil(len(data) / audio_sr * self.target_sample))
|
183 |
-
raw_path = io.BytesIO()
|
184 |
-
soundfile.write(raw_path, data, audio_sr, format="wav")
|
185 |
-
raw_path.seek(0)
|
186 |
if slice_tag:
|
187 |
print('jump empty segment')
|
188 |
_audio = np.zeros(length)
|
|
|
|
|
|
|
|
|
189 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
out_audio, out_sr = self.infer(spk, tran, raw_path,
|
191 |
cluster_infer_ratio=cluster_infer_ratio,
|
192 |
auto_predict_f0=auto_predict_f0,
|
193 |
noice_scale=noice_scale
|
194 |
)
|
195 |
_audio = out_audio.cpu().numpy()
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
return np.array(audio)
|
201 |
|
202 |
-
|
203 |
class RealTimeVC:
|
204 |
def __init__(self):
|
205 |
self.last_chunk = None
|
|
|
92 |
if not os.path.exists(path):
|
93 |
os.mkdir(path)
|
94 |
|
95 |
+
def pad_array(arr, target_length):
|
96 |
+
current_length = arr.shape[0]
|
97 |
+
if current_length >= target_length:
|
98 |
+
return arr
|
99 |
+
else:
|
100 |
+
pad_width = target_length - current_length
|
101 |
+
pad_left = pad_width // 2
|
102 |
+
pad_right = pad_width - pad_left
|
103 |
+
padded_arr = np.pad(arr, (pad_left, pad_right), 'constant', constant_values=(0, 0))
|
104 |
+
return padded_arr
|
105 |
+
|
106 |
+
def split_list_by_n(list_collection, n, pre=0):
|
107 |
+
for i in range(0, len(list_collection), n):
|
108 |
+
yield list_collection[i-pre if i-pre>=0 else i: i + n]
|
109 |
+
|
110 |
|
111 |
class Svc(object):
|
112 |
def __init__(self, net_g_path, config_path,
|
|
|
142 |
|
143 |
|
144 |
|
145 |
+
def get_unit_f0(self, wav, tran, cluster_infer_ratio, speaker):
|
|
|
|
|
|
|
146 |
f0 = utils.compute_f0_parselmouth(wav, sampling_rate=self.target_sample, hop_length=self.hop_size)
|
147 |
f0, uv = utils.interpolate_f0(f0)
|
148 |
f0 = torch.FloatTensor(f0)
|
|
|
151 |
f0 = f0.unsqueeze(0).to(self.dev)
|
152 |
uv = uv.unsqueeze(0).to(self.dev)
|
153 |
|
154 |
+
wav16k = librosa.resample(wav, orig_sr=44100, target_sr=16000)
|
155 |
wav16k = torch.from_numpy(wav16k).to(self.dev)
|
156 |
c = utils.get_hubert_content(self.hubert_model, wav_16k_tensor=wav16k)
|
157 |
c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1])
|
158 |
|
159 |
if cluster_infer_ratio !=0:
|
160 |
+
cluster_c = cluster.get_cluster_center_result(self.cluster_model, c.cpu().numpy().T, speaker).T
|
161 |
+
cluster_c = torch.FloatTensor(cluster_c).to(self.dev)
|
162 |
c = cluster_infer_ratio * cluster_c + (1 - cluster_infer_ratio) * c
|
163 |
|
164 |
c = c.unsqueeze(0)
|
165 |
return c, f0, uv
|
166 |
|
167 |
+
def infer(self, speaker, tran, raw_wav,
|
168 |
cluster_infer_ratio=0,
|
169 |
auto_predict_f0=False,
|
170 |
noice_scale=0.4):
|
171 |
+
speaker_id = self.spk2id.__dict__.get(speaker)
|
172 |
+
if not speaker_id and type(speaker) is int:
|
173 |
+
if len(self.spk2id.__dict__) >= speaker:
|
174 |
+
speaker_id = speaker
|
175 |
sid = torch.LongTensor([int(speaker_id)]).to(self.dev).unsqueeze(0)
|
176 |
+
c, f0, uv = self.get_unit_f0(raw_wav, tran, cluster_infer_ratio, speaker)
|
177 |
if "half" in self.net_g_path and torch.cuda.is_available():
|
178 |
c = c.half()
|
179 |
with torch.no_grad():
|
|
|
182 |
use_time = time.time() - start
|
183 |
print("vits use time:{}".format(use_time))
|
184 |
return audio, audio.shape[-1]
|
185 |
+
|
186 |
+
def clear_empty(self):
|
187 |
+
# 清理显存
|
188 |
+
torch.cuda.empty_cache()
|
189 |
|
190 |
+
def slice_inference(self,raw_audio_path, spk, tran, slice_db,cluster_infer_ratio, auto_predict_f0,noice_scale, pad_seconds=0.5, clip_seconds=0,lg_num=0,lgr_num =0.75):
|
191 |
wav_path = raw_audio_path
|
192 |
chunks = slicer.cut(wav_path, db_thresh=slice_db)
|
193 |
audio_data, audio_sr = slicer.chunks2audio(wav_path, chunks)
|
194 |
+
per_size = int(clip_seconds*audio_sr)
|
195 |
+
lg_size = int(lg_num*audio_sr)
|
196 |
+
lg_size_r = int(lg_size*lgr_num)
|
197 |
+
lg_size_c_l = (lg_size-lg_size_r)//2
|
198 |
+
lg_size_c_r = lg_size-lg_size_r-lg_size_c_l
|
199 |
+
lg = np.linspace(0,1,lg_size_r) if lg_size!=0 else 0
|
200 |
+
|
201 |
audio = []
|
202 |
for (slice_tag, data) in audio_data:
|
203 |
print(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======')
|
204 |
# padd
|
|
|
|
|
205 |
length = int(np.ceil(len(data) / audio_sr * self.target_sample))
|
|
|
|
|
|
|
206 |
if slice_tag:
|
207 |
print('jump empty segment')
|
208 |
_audio = np.zeros(length)
|
209 |
+
audio.extend(list(pad_array(_audio, length)))
|
210 |
+
continue
|
211 |
+
if per_size != 0:
|
212 |
+
datas = split_list_by_n(data, per_size,lg_size)
|
213 |
else:
|
214 |
+
datas = [data]
|
215 |
+
for k,dat in enumerate(datas):
|
216 |
+
per_length = int(np.ceil(len(dat) / audio_sr * self.target_sample)) if clip_seconds!=0 else length
|
217 |
+
if clip_seconds!=0: print(f'###=====segment clip start, {round(len(dat) / audio_sr, 3)}s======')
|
218 |
+
# padd
|
219 |
+
pad_len = int(audio_sr * pad_seconds)
|
220 |
+
dat = np.concatenate([np.zeros([pad_len]), dat, np.zeros([pad_len])])
|
221 |
+
raw_path = io.BytesIO()
|
222 |
+
soundfile.write(raw_path, dat, audio_sr, format="wav")
|
223 |
+
raw_path.seek(0)
|
224 |
out_audio, out_sr = self.infer(spk, tran, raw_path,
|
225 |
cluster_infer_ratio=cluster_infer_ratio,
|
226 |
auto_predict_f0=auto_predict_f0,
|
227 |
noice_scale=noice_scale
|
228 |
)
|
229 |
_audio = out_audio.cpu().numpy()
|
230 |
+
pad_len = int(self.target_sample * pad_seconds)
|
231 |
+
_audio = _audio[pad_len:-pad_len]
|
232 |
+
_audio = pad_array(_audio, per_length)
|
233 |
+
if lg_size!=0 and k!=0:
|
234 |
+
lg1 = audio[-(lg_size_r+lg_size_c_r):-lg_size_c_r] if lgr_num != 1 else audio[-lg_size:]
|
235 |
+
lg2 = _audio[lg_size_c_l:lg_size_c_l+lg_size_r] if lgr_num != 1 else _audio[0:lg_size]
|
236 |
+
lg_pre = lg1*(1-lg)+lg2*lg
|
237 |
+
audio = audio[0:-(lg_size_r+lg_size_c_r)] if lgr_num != 1 else audio[0:-lg_size]
|
238 |
+
audio.extend(lg_pre)
|
239 |
+
_audio = _audio[lg_size_c_l+lg_size_r:] if lgr_num != 1 else _audio[lg_size:]
|
240 |
+
audio.extend(list(_audio))
|
241 |
return np.array(audio)
|
242 |
|
|
|
243 |
class RealTimeVC:
|
244 |
def __init__(self):
|
245 |
self.last_chunk = None
|
modules/__pycache__/__init__.cpython-38.pyc
CHANGED
Binary files a/modules/__pycache__/__init__.cpython-38.pyc and b/modules/__pycache__/__init__.cpython-38.pyc differ
|
|
modules/__pycache__/attentions.cpython-38.pyc
CHANGED
Binary files a/modules/__pycache__/attentions.cpython-38.pyc and b/modules/__pycache__/attentions.cpython-38.pyc differ
|
|
modules/__pycache__/commons.cpython-38.pyc
CHANGED
Binary files a/modules/__pycache__/commons.cpython-38.pyc and b/modules/__pycache__/commons.cpython-38.pyc differ
|
|
modules/__pycache__/modules.cpython-38.pyc
CHANGED
Binary files a/modules/__pycache__/modules.cpython-38.pyc and b/modules/__pycache__/modules.cpython-38.pyc differ
|
|
preprocess_flist_config.py
DELETED
@@ -1,67 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import argparse
|
3 |
-
import re
|
4 |
-
|
5 |
-
from tqdm import tqdm
|
6 |
-
from random import shuffle
|
7 |
-
import json
|
8 |
-
|
9 |
-
config_template = json.load(open("configs/config.json"))
|
10 |
-
|
11 |
-
pattern = re.compile(r'^[\.a-zA-Z0-9_\/]+$')
|
12 |
-
|
13 |
-
if __name__ == "__main__":
|
14 |
-
parser = argparse.ArgumentParser()
|
15 |
-
parser.add_argument("--train_list", type=str, default="./filelists/train.txt", help="path to train list")
|
16 |
-
parser.add_argument("--val_list", type=str, default="./filelists/val.txt", help="path to val list")
|
17 |
-
parser.add_argument("--test_list", type=str, default="./filelists/test.txt", help="path to test list")
|
18 |
-
parser.add_argument("--source_dir", type=str, default="./dataset/44k", help="path to source dir")
|
19 |
-
args = parser.parse_args()
|
20 |
-
|
21 |
-
train = []
|
22 |
-
val = []
|
23 |
-
test = []
|
24 |
-
idx = 0
|
25 |
-
spk_dict = {}
|
26 |
-
spk_id = 0
|
27 |
-
for speaker in tqdm(os.listdir(args.source_dir)):
|
28 |
-
spk_dict[speaker] = spk_id
|
29 |
-
spk_id += 1
|
30 |
-
wavs = ["/".join([args.source_dir, speaker, i]) for i in os.listdir(os.path.join(args.source_dir, speaker))]
|
31 |
-
for wavpath in wavs:
|
32 |
-
if not pattern.match(wavpath):
|
33 |
-
print(f"warning:文件名{wavpath}中包含非字母数字下划线,可能会导致错误。(也可能不会)")
|
34 |
-
if len(wavs) < 10:
|
35 |
-
print(f"warning:{speaker}数据集数量小于10条,请补充数据")
|
36 |
-
wavs = [i for i in wavs if i.endswith("wav")]
|
37 |
-
shuffle(wavs)
|
38 |
-
train += wavs[2:-2]
|
39 |
-
val += wavs[:2]
|
40 |
-
test += wavs[-2:]
|
41 |
-
|
42 |
-
shuffle(train)
|
43 |
-
shuffle(val)
|
44 |
-
shuffle(test)
|
45 |
-
|
46 |
-
print("Writing", args.train_list)
|
47 |
-
with open(args.train_list, "w") as f:
|
48 |
-
for fname in tqdm(train):
|
49 |
-
wavpath = fname
|
50 |
-
f.write(wavpath + "\n")
|
51 |
-
|
52 |
-
print("Writing", args.val_list)
|
53 |
-
with open(args.val_list, "w") as f:
|
54 |
-
for fname in tqdm(val):
|
55 |
-
wavpath = fname
|
56 |
-
f.write(wavpath + "\n")
|
57 |
-
|
58 |
-
print("Writing", args.test_list)
|
59 |
-
with open(args.test_list, "w") as f:
|
60 |
-
for fname in tqdm(test):
|
61 |
-
wavpath = fname
|
62 |
-
f.write(wavpath + "\n")
|
63 |
-
|
64 |
-
config_template["spk"] = spk_dict
|
65 |
-
print("Writing configs/config.json")
|
66 |
-
with open("configs/config.json", "w") as f:
|
67 |
-
json.dump(config_template, f, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
preprocess_hubert_f0.py
DELETED
@@ -1,62 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
import multiprocessing
|
3 |
-
import os
|
4 |
-
import argparse
|
5 |
-
from random import shuffle
|
6 |
-
|
7 |
-
import torch
|
8 |
-
from glob import glob
|
9 |
-
from tqdm import tqdm
|
10 |
-
|
11 |
-
import utils
|
12 |
-
import logging
|
13 |
-
logging.getLogger('numba').setLevel(logging.WARNING)
|
14 |
-
import librosa
|
15 |
-
import numpy as np
|
16 |
-
|
17 |
-
hps = utils.get_hparams_from_file("configs/config.json")
|
18 |
-
sampling_rate = hps.data.sampling_rate
|
19 |
-
hop_length = hps.data.hop_length
|
20 |
-
|
21 |
-
|
22 |
-
def process_one(filename, hmodel):
|
23 |
-
# print(filename)
|
24 |
-
wav, sr = librosa.load(filename, sr=sampling_rate)
|
25 |
-
soft_path = filename + ".soft.pt"
|
26 |
-
if not os.path.exists(soft_path):
|
27 |
-
devive = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
-
wav16k = librosa.resample(wav, orig_sr=sampling_rate, target_sr=16000)
|
29 |
-
wav16k = torch.from_numpy(wav16k).to(devive)
|
30 |
-
c = utils.get_hubert_content(hmodel, wav_16k_tensor=wav16k)
|
31 |
-
torch.save(c.cpu(), soft_path)
|
32 |
-
f0_path = filename + ".f0.npy"
|
33 |
-
if not os.path.exists(f0_path):
|
34 |
-
f0 = utils.compute_f0_dio(wav, sampling_rate=sampling_rate, hop_length=hop_length)
|
35 |
-
np.save(f0_path, f0)
|
36 |
-
|
37 |
-
|
38 |
-
def process_batch(filenames):
|
39 |
-
print("Loading hubert for content...")
|
40 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
41 |
-
hmodel = utils.get_hubert_model().to(device)
|
42 |
-
print("Loaded hubert.")
|
43 |
-
for filename in tqdm(filenames):
|
44 |
-
process_one(filename, hmodel)
|
45 |
-
|
46 |
-
|
47 |
-
if __name__ == "__main__":
|
48 |
-
parser = argparse.ArgumentParser()
|
49 |
-
parser.add_argument("--in_dir", type=str, default="dataset/44k", help="path to input dir")
|
50 |
-
|
51 |
-
args = parser.parse_args()
|
52 |
-
filenames = glob(f'{args.in_dir}/*/*.wav', recursive=True) # [:10]
|
53 |
-
shuffle(filenames)
|
54 |
-
multiprocessing.set_start_method('spawn')
|
55 |
-
|
56 |
-
num_processes = 1
|
57 |
-
chunk_size = int(math.ceil(len(filenames) / num_processes))
|
58 |
-
chunks = [filenames[i:i + chunk_size] for i in range(0, len(filenames), chunk_size)]
|
59 |
-
print([len(c) for c in chunks])
|
60 |
-
processes = [multiprocessing.Process(target=process_batch, args=(chunk,)) for chunk in chunks]
|
61 |
-
for p in processes:
|
62 |
-
p.start()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
resample.py
DELETED
@@ -1,48 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import argparse
|
3 |
-
import librosa
|
4 |
-
import numpy as np
|
5 |
-
from multiprocessing import Pool, cpu_count
|
6 |
-
from scipy.io import wavfile
|
7 |
-
from tqdm import tqdm
|
8 |
-
|
9 |
-
|
10 |
-
def process(item):
|
11 |
-
spkdir, wav_name, args = item
|
12 |
-
# speaker 's5', 'p280', 'p315' are excluded,
|
13 |
-
speaker = spkdir.replace("\\", "/").split("/")[-1]
|
14 |
-
wav_path = os.path.join(args.in_dir, speaker, wav_name)
|
15 |
-
if os.path.exists(wav_path) and '.wav' in wav_path:
|
16 |
-
os.makedirs(os.path.join(args.out_dir2, speaker), exist_ok=True)
|
17 |
-
wav, sr = librosa.load(wav_path, None)
|
18 |
-
wav, _ = librosa.effects.trim(wav, top_db=20)
|
19 |
-
peak = np.abs(wav).max()
|
20 |
-
if peak > 1.0:
|
21 |
-
wav = 0.98 * wav / peak
|
22 |
-
wav2 = librosa.resample(wav, orig_sr=sr, target_sr=args.sr2)
|
23 |
-
wav2 /= max(wav2.max(), -wav2.min())
|
24 |
-
save_name = wav_name
|
25 |
-
save_path2 = os.path.join(args.out_dir2, speaker, save_name)
|
26 |
-
wavfile.write(
|
27 |
-
save_path2,
|
28 |
-
args.sr2,
|
29 |
-
(wav2 * np.iinfo(np.int16).max).astype(np.int16)
|
30 |
-
)
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
if __name__ == "__main__":
|
35 |
-
parser = argparse.ArgumentParser()
|
36 |
-
parser.add_argument("--sr2", type=int, default=44100, help="sampling rate")
|
37 |
-
parser.add_argument("--in_dir", type=str, default="./dataset_raw", help="path to source dir")
|
38 |
-
parser.add_argument("--out_dir2", type=str, default="./dataset/44k", help="path to target dir")
|
39 |
-
args = parser.parse_args()
|
40 |
-
processs = cpu_count()-2 if cpu_count() >4 else 1
|
41 |
-
pool = Pool(processes=processs)
|
42 |
-
|
43 |
-
for speaker in os.listdir(args.in_dir):
|
44 |
-
spk_dir = os.path.join(args.in_dir, speaker)
|
45 |
-
if os.path.isdir(spk_dir):
|
46 |
-
print(spk_dir)
|
47 |
-
for _ in tqdm(pool.imap_unordered(process, [(spk_dir, i, args) for i in os.listdir(spk_dir) if i.endswith("wav")])):
|
48 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spec_gen.py
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
from data_utils import TextAudioSpeakerLoader
|
2 |
-
import json
|
3 |
-
from tqdm import tqdm
|
4 |
-
|
5 |
-
from utils import HParams
|
6 |
-
|
7 |
-
config_path = 'configs/config.json'
|
8 |
-
with open(config_path, "r") as f:
|
9 |
-
data = f.read()
|
10 |
-
config = json.loads(data)
|
11 |
-
hps = HParams(**config)
|
12 |
-
|
13 |
-
train_dataset = TextAudioSpeakerLoader("filelists/train.txt", hps)
|
14 |
-
test_dataset = TextAudioSpeakerLoader("filelists/test.txt", hps)
|
15 |
-
eval_dataset = TextAudioSpeakerLoader("filelists/val.txt", hps)
|
16 |
-
|
17 |
-
for _ in tqdm(train_dataset):
|
18 |
-
pass
|
19 |
-
for _ in tqdm(eval_dataset):
|
20 |
-
pass
|
21 |
-
for _ in tqdm(test_dataset):
|
22 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train.py
DELETED
@@ -1,297 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
3 |
-
import os
|
4 |
-
import json
|
5 |
-
import argparse
|
6 |
-
import itertools
|
7 |
-
import math
|
8 |
-
import torch
|
9 |
-
from torch import nn, optim
|
10 |
-
from torch.nn import functional as F
|
11 |
-
from torch.utils.data import DataLoader
|
12 |
-
from torch.utils.tensorboard import SummaryWriter
|
13 |
-
import torch.multiprocessing as mp
|
14 |
-
import torch.distributed as dist
|
15 |
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
16 |
-
from torch.cuda.amp import autocast, GradScaler
|
17 |
-
|
18 |
-
import modules.commons as commons
|
19 |
-
import utils
|
20 |
-
from data_utils import TextAudioSpeakerLoader, TextAudioCollate
|
21 |
-
from models import (
|
22 |
-
SynthesizerTrn,
|
23 |
-
MultiPeriodDiscriminator,
|
24 |
-
)
|
25 |
-
from modules.losses import (
|
26 |
-
kl_loss,
|
27 |
-
generator_loss, discriminator_loss, feature_loss
|
28 |
-
)
|
29 |
-
|
30 |
-
from modules.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
31 |
-
|
32 |
-
torch.backends.cudnn.benchmark = True
|
33 |
-
global_step = 0
|
34 |
-
|
35 |
-
|
36 |
-
# os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'INFO'
|
37 |
-
|
38 |
-
|
39 |
-
def main():
|
40 |
-
"""Assume Single Node Multi GPUs Training Only"""
|
41 |
-
assert torch.cuda.is_available(), "CPU training is not allowed."
|
42 |
-
hps = utils.get_hparams()
|
43 |
-
|
44 |
-
n_gpus = torch.cuda.device_count()
|
45 |
-
os.environ['MASTER_ADDR'] = 'localhost'
|
46 |
-
os.environ['MASTER_PORT'] = hps.train.port
|
47 |
-
|
48 |
-
mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,))
|
49 |
-
|
50 |
-
|
51 |
-
def run(rank, n_gpus, hps):
|
52 |
-
global global_step
|
53 |
-
if rank == 0:
|
54 |
-
logger = utils.get_logger(hps.model_dir)
|
55 |
-
logger.info(hps)
|
56 |
-
utils.check_git_hash(hps.model_dir)
|
57 |
-
writer = SummaryWriter(log_dir=hps.model_dir)
|
58 |
-
writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
|
59 |
-
|
60 |
-
# for pytorch on win, backend use gloo
|
61 |
-
dist.init_process_group(backend= 'gloo' if os.name == 'nt' else 'nccl', init_method='env://', world_size=n_gpus, rank=rank)
|
62 |
-
torch.manual_seed(hps.train.seed)
|
63 |
-
torch.cuda.set_device(rank)
|
64 |
-
collate_fn = TextAudioCollate()
|
65 |
-
train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps)
|
66 |
-
train_loader = DataLoader(train_dataset, num_workers=8, shuffle=False, pin_memory=True,
|
67 |
-
batch_size=hps.train.batch_size,collate_fn=collate_fn)
|
68 |
-
if rank == 0:
|
69 |
-
eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps)
|
70 |
-
eval_loader = DataLoader(eval_dataset, num_workers=1, shuffle=False,
|
71 |
-
batch_size=1, pin_memory=False,
|
72 |
-
drop_last=False, collate_fn=collate_fn)
|
73 |
-
|
74 |
-
net_g = SynthesizerTrn(
|
75 |
-
hps.data.filter_length // 2 + 1,
|
76 |
-
hps.train.segment_size // hps.data.hop_length,
|
77 |
-
**hps.model).cuda(rank)
|
78 |
-
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
|
79 |
-
optim_g = torch.optim.AdamW(
|
80 |
-
net_g.parameters(),
|
81 |
-
hps.train.learning_rate,
|
82 |
-
betas=hps.train.betas,
|
83 |
-
eps=hps.train.eps)
|
84 |
-
optim_d = torch.optim.AdamW(
|
85 |
-
net_d.parameters(),
|
86 |
-
hps.train.learning_rate,
|
87 |
-
betas=hps.train.betas,
|
88 |
-
eps=hps.train.eps)
|
89 |
-
net_g = DDP(net_g, device_ids=[rank]) # , find_unused_parameters=True)
|
90 |
-
net_d = DDP(net_d, device_ids=[rank])
|
91 |
-
|
92 |
-
skip_optimizer = True
|
93 |
-
try:
|
94 |
-
_, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g,
|
95 |
-
optim_g, skip_optimizer)
|
96 |
-
_, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d,
|
97 |
-
optim_d, skip_optimizer)
|
98 |
-
global_step = (epoch_str - 1) * len(train_loader)
|
99 |
-
except:
|
100 |
-
print("load old checkpoint failed...")
|
101 |
-
epoch_str = 1
|
102 |
-
global_step = 0
|
103 |
-
if skip_optimizer:
|
104 |
-
epoch_str = 1
|
105 |
-
global_step = 0
|
106 |
-
|
107 |
-
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
108 |
-
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
109 |
-
|
110 |
-
scaler = GradScaler(enabled=hps.train.fp16_run)
|
111 |
-
|
112 |
-
for epoch in range(epoch_str, hps.train.epochs + 1):
|
113 |
-
if rank == 0:
|
114 |
-
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
|
115 |
-
[train_loader, eval_loader], logger, [writer, writer_eval])
|
116 |
-
else:
|
117 |
-
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
|
118 |
-
[train_loader, None], None, None)
|
119 |
-
scheduler_g.step()
|
120 |
-
scheduler_d.step()
|
121 |
-
|
122 |
-
|
123 |
-
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
|
124 |
-
net_g, net_d = nets
|
125 |
-
optim_g, optim_d = optims
|
126 |
-
scheduler_g, scheduler_d = schedulers
|
127 |
-
train_loader, eval_loader = loaders
|
128 |
-
if writers is not None:
|
129 |
-
writer, writer_eval = writers
|
130 |
-
|
131 |
-
# train_loader.batch_sampler.set_epoch(epoch)
|
132 |
-
global global_step
|
133 |
-
|
134 |
-
net_g.train()
|
135 |
-
net_d.train()
|
136 |
-
for batch_idx, items in enumerate(train_loader):
|
137 |
-
c, f0, spec, y, spk, lengths, uv = items
|
138 |
-
g = spk.cuda(rank, non_blocking=True)
|
139 |
-
spec, y = spec.cuda(rank, non_blocking=True), y.cuda(rank, non_blocking=True)
|
140 |
-
c = c.cuda(rank, non_blocking=True)
|
141 |
-
f0 = f0.cuda(rank, non_blocking=True)
|
142 |
-
uv = uv.cuda(rank, non_blocking=True)
|
143 |
-
lengths = lengths.cuda(rank, non_blocking=True)
|
144 |
-
mel = spec_to_mel_torch(
|
145 |
-
spec,
|
146 |
-
hps.data.filter_length,
|
147 |
-
hps.data.n_mel_channels,
|
148 |
-
hps.data.sampling_rate,
|
149 |
-
hps.data.mel_fmin,
|
150 |
-
hps.data.mel_fmax)
|
151 |
-
|
152 |
-
with autocast(enabled=hps.train.fp16_run):
|
153 |
-
y_hat, ids_slice, z_mask, \
|
154 |
-
(z, z_p, m_p, logs_p, m_q, logs_q), pred_lf0, norm_lf0, lf0 = net_g(c, f0, uv, spec, g=g, c_lengths=lengths,
|
155 |
-
spec_lengths=lengths)
|
156 |
-
|
157 |
-
y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
|
158 |
-
y_hat_mel = mel_spectrogram_torch(
|
159 |
-
y_hat.squeeze(1),
|
160 |
-
hps.data.filter_length,
|
161 |
-
hps.data.n_mel_channels,
|
162 |
-
hps.data.sampling_rate,
|
163 |
-
hps.data.hop_length,
|
164 |
-
hps.data.win_length,
|
165 |
-
hps.data.mel_fmin,
|
166 |
-
hps.data.mel_fmax
|
167 |
-
)
|
168 |
-
y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
|
169 |
-
|
170 |
-
# Discriminator
|
171 |
-
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
|
172 |
-
|
173 |
-
with autocast(enabled=False):
|
174 |
-
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
|
175 |
-
loss_disc_all = loss_disc
|
176 |
-
|
177 |
-
optim_d.zero_grad()
|
178 |
-
scaler.scale(loss_disc_all).backward()
|
179 |
-
scaler.unscale_(optim_d)
|
180 |
-
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
|
181 |
-
scaler.step(optim_d)
|
182 |
-
|
183 |
-
with autocast(enabled=hps.train.fp16_run):
|
184 |
-
# Generator
|
185 |
-
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
|
186 |
-
with autocast(enabled=False):
|
187 |
-
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
|
188 |
-
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
|
189 |
-
loss_fm = feature_loss(fmap_r, fmap_g)
|
190 |
-
loss_gen, losses_gen = generator_loss(y_d_hat_g)
|
191 |
-
loss_lf0 = F.mse_loss(pred_lf0, lf0)
|
192 |
-
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl + loss_lf0
|
193 |
-
optim_g.zero_grad()
|
194 |
-
scaler.scale(loss_gen_all).backward()
|
195 |
-
scaler.unscale_(optim_g)
|
196 |
-
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
|
197 |
-
scaler.step(optim_g)
|
198 |
-
scaler.update()
|
199 |
-
|
200 |
-
if rank == 0:
|
201 |
-
if global_step % hps.train.log_interval == 0:
|
202 |
-
lr = optim_g.param_groups[0]['lr']
|
203 |
-
losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_kl]
|
204 |
-
logger.info('Train Epoch: {} [{:.0f}%]'.format(
|
205 |
-
epoch,
|
206 |
-
100. * batch_idx / len(train_loader)))
|
207 |
-
logger.info([x.item() for x in losses] + [global_step, lr])
|
208 |
-
|
209 |
-
scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr,
|
210 |
-
"grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g}
|
211 |
-
scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl,
|
212 |
-
"loss/g/lf0": loss_lf0})
|
213 |
-
|
214 |
-
# scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
|
215 |
-
# scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
|
216 |
-
# scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
|
217 |
-
image_dict = {
|
218 |
-
"slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
|
219 |
-
"slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
|
220 |
-
"all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
|
221 |
-
"all/lf0": utils.plot_data_to_numpy(lf0[0, 0, :].cpu().numpy(),
|
222 |
-
pred_lf0[0, 0, :].detach().cpu().numpy()),
|
223 |
-
"all/norm_lf0": utils.plot_data_to_numpy(lf0[0, 0, :].cpu().numpy(),
|
224 |
-
norm_lf0[0, 0, :].detach().cpu().numpy())
|
225 |
-
}
|
226 |
-
|
227 |
-
utils.summarize(
|
228 |
-
writer=writer,
|
229 |
-
global_step=global_step,
|
230 |
-
images=image_dict,
|
231 |
-
scalars=scalar_dict
|
232 |
-
)
|
233 |
-
|
234 |
-
if global_step % hps.train.eval_interval == 0:
|
235 |
-
evaluate(hps, net_g, eval_loader, writer_eval)
|
236 |
-
utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch,
|
237 |
-
os.path.join(hps.model_dir, "G_{}.pth".format(global_step)), hps.train.eval_interval, global_step)
|
238 |
-
utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch,
|
239 |
-
os.path.join(hps.model_dir, "D_{}.pth".format(global_step)), hps.train.eval_interval, global_step)
|
240 |
-
global_step += 1
|
241 |
-
|
242 |
-
if rank == 0:
|
243 |
-
logger.info('====> Epoch: {}'.format(epoch))
|
244 |
-
|
245 |
-
|
246 |
-
def evaluate(hps, generator, eval_loader, writer_eval):
|
247 |
-
generator.eval()
|
248 |
-
image_dict = {}
|
249 |
-
audio_dict = {}
|
250 |
-
with torch.no_grad():
|
251 |
-
for batch_idx, items in enumerate(eval_loader):
|
252 |
-
c, f0, spec, y, spk, _, uv = items
|
253 |
-
g = spk[:1].cuda(0)
|
254 |
-
spec, y = spec[:1].cuda(0), y[:1].cuda(0)
|
255 |
-
c = c[:1].cuda(0)
|
256 |
-
f0 = f0[:1].cuda(0)
|
257 |
-
uv= uv[:1].cuda(0)
|
258 |
-
mel = spec_to_mel_torch(
|
259 |
-
spec,
|
260 |
-
hps.data.filter_length,
|
261 |
-
hps.data.n_mel_channels,
|
262 |
-
hps.data.sampling_rate,
|
263 |
-
hps.data.mel_fmin,
|
264 |
-
hps.data.mel_fmax)
|
265 |
-
y_hat = generator.module.infer(c, f0, uv, g=g)
|
266 |
-
|
267 |
-
y_hat_mel = mel_spectrogram_torch(
|
268 |
-
y_hat.squeeze(1).float(),
|
269 |
-
hps.data.filter_length,
|
270 |
-
hps.data.n_mel_channels,
|
271 |
-
hps.data.sampling_rate,
|
272 |
-
hps.data.hop_length,
|
273 |
-
hps.data.win_length,
|
274 |
-
hps.data.mel_fmin,
|
275 |
-
hps.data.mel_fmax
|
276 |
-
)
|
277 |
-
|
278 |
-
audio_dict.update({
|
279 |
-
f"gen/audio_{batch_idx}": y_hat[0],
|
280 |
-
f"gt/audio_{batch_idx}": y[0]
|
281 |
-
})
|
282 |
-
image_dict.update({
|
283 |
-
f"gen/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy()),
|
284 |
-
"gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())
|
285 |
-
})
|
286 |
-
utils.summarize(
|
287 |
-
writer=writer_eval,
|
288 |
-
global_step=global_step,
|
289 |
-
images=image_dict,
|
290 |
-
audios=audio_dict,
|
291 |
-
audio_sampling_rate=hps.data.sampling_rate
|
292 |
-
)
|
293 |
-
generator.train()
|
294 |
-
|
295 |
-
|
296 |
-
if __name__ == "__main__":
|
297 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils.py
CHANGED
@@ -222,7 +222,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
|
|
222 |
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
223 |
iteration = checkpoint_dict['iteration']
|
224 |
learning_rate = checkpoint_dict['learning_rate']
|
225 |
-
if optimizer is not None and not skip_optimizer:
|
226 |
optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
227 |
saved_state_dict = checkpoint_dict['model']
|
228 |
if hasattr(model, 'module'):
|
@@ -250,7 +250,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
|
|
250 |
return model, optimizer, learning_rate, iteration
|
251 |
|
252 |
|
253 |
-
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path
|
254 |
logger.info("Saving model and optimizer state at iteration {} to {}".format(
|
255 |
iteration, checkpoint_path))
|
256 |
if hasattr(model, 'module'):
|
@@ -261,14 +261,8 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path,
|
|
261 |
'iteration': iteration,
|
262 |
'optimizer': optimizer.state_dict(),
|
263 |
'learning_rate': learning_rate}, checkpoint_path)
|
264 |
-
if current_step >= val_steps * 3:
|
265 |
-
to_del_ckptname = checkpoint_path.replace(str(current_step), str(current_step - val_steps * 3))
|
266 |
-
if os.path.exists(to_del_ckptname):
|
267 |
-
os.remove(to_del_ckptname)
|
268 |
-
print("Removing ", to_del_ckptname)
|
269 |
|
270 |
-
|
271 |
-
def clean_checkpoints(path_to_models='logs/48k/', n_ckpts_to_keep=2, sort_by_time=True):
|
272 |
"""Freeing up space by deleting saved ckpts
|
273 |
|
274 |
Arguments:
|
|
|
222 |
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
223 |
iteration = checkpoint_dict['iteration']
|
224 |
learning_rate = checkpoint_dict['learning_rate']
|
225 |
+
if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None:
|
226 |
optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
227 |
saved_state_dict = checkpoint_dict['model']
|
228 |
if hasattr(model, 'module'):
|
|
|
250 |
return model, optimizer, learning_rate, iteration
|
251 |
|
252 |
|
253 |
+
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
|
254 |
logger.info("Saving model and optimizer state at iteration {} to {}".format(
|
255 |
iteration, checkpoint_path))
|
256 |
if hasattr(model, 'module'):
|
|
|
261 |
'iteration': iteration,
|
262 |
'optimizer': optimizer.state_dict(),
|
263 |
'learning_rate': learning_rate}, checkpoint_path)
|
|
|
|
|
|
|
|
|
|
|
264 |
|
265 |
+
def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True):
|
|
|
266 |
"""Freeing up space by deleting saved ckpts
|
267 |
|
268 |
Arguments:
|