AdalAbilbekov commited on
Commit
ed2380c
1 Parent(s): ae8e1dd
data_preparation.py DELETED
@@ -1,108 +0,0 @@
1
- import kaldiio
2
- import os
3
- import librosa
4
- from tqdm import tqdm
5
- import glob
6
- import json
7
- from shutil import copyfile
8
- import pandas as pd
9
- import argparse
10
- from text import _clean_text, symbols
11
- from num2words import num2words
12
- import re
13
- from melspec import mel_spectrogram
14
- import torchaudio
15
-
16
- if __name__ == '__main__':
17
- parser = argparse.ArgumentParser()
18
- parser.add_argument('-d', '--data', type=str, required=True, help='path to the emotional dataset')
19
- args = parser.parse_args()
20
- dataset_path = args.data
21
- filelists_path = 'filelists/all_spks/'
22
- feats_scp_file = filelists_path + 'feats.scp'
23
- feats_ark_file = filelists_path + 'feats.ark'
24
-
25
-
26
- spks = ['1263201035', '805570882', '399172782']
27
- train_files = []
28
- eval_files = []
29
- for spk in spks:
30
- train_files += glob.glob(dataset_path + spk + "/train/*.wav")
31
- eval_files += glob.glob(dataset_path + spk + "/eval/*.wav")
32
-
33
- os.makedirs(filelists_path, exist_ok=True)
34
-
35
- with open(filelists_path + 'train_utts.txt', 'w', encoding='utf-8') as f:
36
- for wav_path in train_files:
37
- wav_name = os.path.splitext(os.path.basename(wav_path))[0]
38
- f.write(wav_name + '\n')
39
- with open(filelists_path + 'eval_utts.txt', 'w', encoding='utf-8') as f:
40
- for wav_path in eval_files:
41
- wav_name = os.path.splitext(os.path.basename(wav_path))[0]
42
- f.write(wav_name + '\n')
43
-
44
- with open(feats_scp_file, 'w') as feats_scp, \
45
- kaldiio.WriteHelper(f'ark,scp:{feats_ark_file},{feats_scp_file}') as writer:
46
- for root, dirs, files in os.walk(dataset_path):
47
- for file in tqdm(files):
48
- if file.endswith('.wav'):
49
- # Get the file name and relative path to the root folder
50
- wav_path = os.path.join(root, file)
51
- rel_path = os.path.relpath(wav_path, dataset_path)
52
- wav_name = os.path.splitext(os.path.basename(wav_path))[0]
53
- signal, rate = torchaudio.load(wav_path)
54
- spec = mel_spectrogram(signal, 1024, 80, 22050, 256,
55
- 1024, 0, 8000, center=False).squeeze()
56
- # Write the features to feats.ark and feats.scp
57
- writer[wav_name] = spec
58
-
59
-
60
- emotions = [os.path.basename(x).split("_")[1] for x in glob.glob(dataset_path + '/**/**/*')]
61
- emotions = sorted(set(emotions))
62
-
63
- utt2spk = {}
64
- utt2emo = {}
65
- wavs = glob.glob(dataset_path + '**/**/*.wav')
66
- for wav_path in tqdm(wavs):
67
- wav_name = os.path.splitext(os.path.basename(wav_path))[0]
68
- emotion = emotions.index(wav_name.split("_")[1])
69
- if wav_path.split('/')[-3] == '1263201035':
70
- spk = 0 ## labels should start with 0
71
- elif wav_path.split('/')[-3] == '805570882':
72
- spk = 1
73
- else:
74
- spk = 2
75
- utt2spk[wav_name] = str(spk)
76
- utt2emo[wav_name] = str(emotion)
77
- utt2spk = dict(sorted(utt2spk.items()))
78
- utt2emo = dict(sorted(utt2emo.items()))
79
-
80
- with open(filelists_path + 'utt2emo.json', 'w') as fp:
81
- json.dump(utt2emo, fp, indent=4)
82
- with open(filelists_path + 'utt2spk.json', 'w') as fp:
83
- json.dump(utt2spk, fp, indent=4)
84
-
85
- txt_files = sorted(glob.glob(dataset_path + '/**/**/*.txt'))
86
- count = 0
87
- txt = []
88
- basenames = []
89
- utt2text = {}
90
- flag = False
91
- with open(filelists_path + 'text', 'w', encoding='utf-8') as write:
92
- for txt_path in txt_files:
93
- basename = os.path.basename(txt_path).replace('.txt', '')
94
- with open(txt_path, 'r', encoding='utf-8') as f:
95
- txt.append(_clean_text(f.read().strip("\n"), cleaner_names=["kazakh_cleaners"]).replace("'", ""))
96
- basenames.append(basename)
97
- output_string = [re.sub('(\d+)', lambda m: num2words(m.group(), lang='kz'), sentence) for sentence in txt]
98
- cleaned_txt = []
99
- for t in output_string:
100
- cleaned_txt.append(''.join([s for s in t if s in symbols]))
101
- utt2text = {basenames[i]: cleaned_txt[i] for i in range(len(cleaned_txt))}
102
- utt2text = dict(sorted(utt2text.items()))
103
-
104
- vocab = set()
105
- with open(filelists_path + '/text', 'w', encoding='utf-8') as f:
106
- for x, y in utt2text.items():
107
- for c in y: vocab.add(c)
108
- f.write(x + ' ' + y + '\n')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference_intensity_control.ipynb DELETED
File without changes
melspec.py DELETED
@@ -1,40 +0,0 @@
1
- import torch
2
- import torchaudio
3
- import librosa
4
-
5
- mel_basis = {}
6
- hann_window = {}
7
-
8
- def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
9
- return torch.log(torch.clamp(x, min=clip_val) * C)
10
-
11
- def spectral_normalize_torch(magnitudes):
12
- output = dynamic_range_compression_torch(magnitudes)
13
- return output
14
-
15
-
16
-
17
- def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
18
- if torch.min(y) < -1.:
19
- print('min value is ', torch.min(y))
20
- if torch.max(y) > 1.:
21
- print('max value is ', torch.max(y))
22
-
23
- global mel_basis, hann_window
24
- if fmax not in mel_basis:
25
- mel = librosa.filters.mel(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
26
- mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
27
- hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
28
-
29
- y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
30
- y = y.squeeze(1)
31
-
32
- spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
33
- center=center, pad_mode='reflect', normalized=False, onesided=True)
34
-
35
- spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
36
-
37
- spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
38
- spec = spectral_normalize_torch(spec)
39
-
40
- return spec.numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ scipy
2
+ torch
3
+ kaldiio
4
+ tqdm
5
+ einops
6
+ matplotlib
7
+ glob
8
+ pydub
train_EMA.py DELETED
@@ -1,207 +0,0 @@
1
- import numpy as np
2
- from tqdm import tqdm
3
-
4
- from copy import deepcopy
5
- import torch
6
- from torch.utils.data import DataLoader
7
- from torch.utils.tensorboard import SummaryWriter
8
- import data_collate
9
- import data_loader
10
- from utils_data import plot_tensor, save_plot
11
- from model.utils import fix_len_compatibility
12
- from text.symbols import symbols
13
- import utils_data as utils
14
-
15
-
16
- class ModelEmaV2(torch.nn.Module):
17
- def __init__(self, model, decay=0.9999, device=None):
18
- super(ModelEmaV2, self).__init__()
19
- self.model_state_dict = deepcopy(model.state_dict())
20
- self.decay = decay
21
- self.device = device # perform ema on different device from model if set
22
-
23
- def _update(self, model, update_fn):
24
- with torch.no_grad():
25
- for ema_v, model_v in zip(self.model_state_dict.values(), model.state_dict().values()):
26
- if self.device is not None:
27
- model_v = model_v.to(device=self.device)
28
- ema_v.copy_(update_fn(ema_v, model_v))
29
-
30
- def update(self, model):
31
- self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
32
-
33
- def set(self, model):
34
- self._update(model, update_fn=lambda e, m: m)
35
-
36
- def state_dict(self, destination=None, prefix='', keep_vars=False):
37
- return self.model_state_dict
38
-
39
-
40
- if __name__ == "__main__":
41
- hps = utils.get_hparams()
42
- logger_text = utils.get_logger(hps.model_dir)
43
- logger_text.info(hps)
44
-
45
- out_size = fix_len_compatibility(2 * hps.data.sampling_rate // hps.data.hop_length) # NOTE: 2-sec of mel-spec
46
-
47
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
- torch.manual_seed(hps.train.seed)
49
- np.random.seed(hps.train.seed)
50
-
51
- print('Initializing logger...')
52
- log_dir = hps.model_dir
53
- logger = SummaryWriter(log_dir=log_dir)
54
-
55
- train_dataset, collate, model = utils.get_correct_class(hps)
56
- test_dataset, _, _ = utils.get_correct_class(hps, train=False)
57
-
58
- print('Initializing data loaders...')
59
-
60
- batch_collate = collate
61
- loader = DataLoader(dataset=train_dataset, batch_size=hps.train.batch_size,
62
- collate_fn=batch_collate, drop_last=True,
63
- num_workers=4, shuffle=False) # NOTE: if on server, worker can be 4
64
-
65
- print('Initializing model...')
66
- model = model(**hps.model).to(device)
67
- print('Number of encoder + duration predictor parameters: %.2fm' % (model.encoder.nparams / 1e6))
68
- print('Number of decoder parameters: %.2fm' % (model.decoder.nparams / 1e6))
69
- print('Total parameters: %.2fm' % (model.nparams / 1e6))
70
-
71
- use_gt_dur = getattr(hps.train, "use_gt_dur", False)
72
- if use_gt_dur:
73
- print("++++++++++++++> Using ground truth duration for training")
74
-
75
- print('Initializing optimizer...')
76
- optimizer = torch.optim.Adam(params=model.parameters(), lr=hps.train.learning_rate)
77
-
78
- print('Logging test batch...')
79
- test_batch = test_dataset.sample_test_batch(size=hps.train.test_size)
80
- for i, item in enumerate(test_batch):
81
- mel = item['mel']
82
- logger.add_image(f'image_{i}/ground_truth', plot_tensor(mel.squeeze()),
83
- global_step=0, dataformats='HWC')
84
- save_plot(mel.squeeze(), f'{log_dir}/original_{i}.png')
85
-
86
- try:
87
- model, optimizer, learning_rate, epoch_logged = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "grad_*.pt"), model, optimizer)
88
- epoch_start = epoch_logged + 1
89
- print(f"Loaded checkpoint from {epoch_logged} epoch, resuming training.")
90
- global_step = epoch_logged * (len(train_dataset)/hps.train.batch_size)
91
- except:
92
- print(f"Cannot find trained checkpoint, begin to train from scratch")
93
- epoch_start = 1
94
- global_step = 0
95
- learning_rate = hps.train.learning_rate
96
-
97
- ema_model = ModelEmaV2(model, decay=0.9999) # It's necessary that we put this after loading model.
98
-
99
- print('Start training...')
100
- used_items = set()
101
- iteration = global_step
102
- for epoch in range(epoch_start, hps.train.n_epochs + 1):
103
- model.train()
104
- dur_losses = []
105
- prior_losses = []
106
- diff_losses = []
107
- with tqdm(loader, total=len(train_dataset) // hps.train.batch_size) as progress_bar:
108
- for batch_idx, batch in enumerate(progress_bar):
109
- model.zero_grad()
110
- x, x_lengths = batch['text_padded'].to(device), \
111
- batch['input_lengths'].to(device)
112
- y, y_lengths = batch['mel_padded'].to(device), \
113
- batch['output_lengths'].to(device)
114
- if hps.xvector:
115
- spk = batch['xvector'].to(device)
116
- else:
117
- spk = batch['spk_ids'].to(torch.long).to(device)
118
- emo = batch['emo_ids'].to(torch.long).to(device)
119
-
120
- dur_loss, prior_loss, diff_loss = model.compute_loss(x, x_lengths,
121
- y, y_lengths,
122
- spk=spk,
123
- emo=emo,
124
- out_size=out_size,
125
- use_gt_dur=use_gt_dur,
126
- durs=batch['dur_padded'].to(device) if use_gt_dur else None)
127
- loss = sum([dur_loss, prior_loss, diff_loss])
128
- loss.backward()
129
-
130
- enc_grad_norm = torch.nn.utils.clip_grad_norm_(model.encoder.parameters(),
131
- max_norm=1)
132
- dec_grad_norm = torch.nn.utils.clip_grad_norm_(model.decoder.parameters(),
133
- max_norm=1)
134
- optimizer.step()
135
- ema_model.update(model)
136
-
137
- logger.add_scalar('training/duration_loss', dur_loss.item(),
138
- global_step=iteration)
139
- logger.add_scalar('training/prior_loss', prior_loss.item(),
140
- global_step=iteration)
141
- logger.add_scalar('training/diffusion_loss', diff_loss.item(),
142
- global_step=iteration)
143
- logger.add_scalar('training/encoder_grad_norm', enc_grad_norm,
144
- global_step=iteration)
145
- logger.add_scalar('training/decoder_grad_norm', dec_grad_norm,
146
- global_step=iteration)
147
-
148
- dur_losses.append(dur_loss.item())
149
- prior_losses.append(prior_loss.item())
150
- diff_losses.append(diff_loss.item())
151
-
152
- if batch_idx % 5 == 0:
153
- msg = f'Epoch: {epoch}, iteration: {iteration} | dur_loss: {dur_loss.item()}, prior_loss: {prior_loss.item()}, diff_loss: {diff_loss.item()}'
154
- progress_bar.set_description(msg)
155
-
156
- iteration += 1
157
-
158
- log_msg = 'Epoch %d: duration loss = %.3f ' % (epoch, float(np.mean(dur_losses)))
159
- log_msg += '| prior loss = %.3f ' % np.mean(prior_losses)
160
- log_msg += '| diffusion loss = %.3f\n' % np.mean(diff_losses)
161
- with open(f'{log_dir}/train.log', 'a') as f:
162
- f.write(log_msg)
163
-
164
- if epoch % hps.train.save_every > 0:
165
- continue
166
-
167
- model.eval()
168
- print('Synthesis...')
169
-
170
- with torch.no_grad():
171
- for i, item in enumerate(test_batch):
172
- if item['utt'] + "/truth" not in used_items:
173
- used_items.add(item['utt'] + "/truth")
174
- x = item['text'].to(torch.long).unsqueeze(0).to(device)
175
- if not hps.xvector:
176
- spk = item['spk_ids']
177
- spk = torch.LongTensor([spk]).to(device)
178
- else:
179
- spk = item["xvector"]
180
- spk = spk.unsqueeze(0).to(device)
181
- emo = item['emo_ids']
182
- emo = torch.LongTensor([emo]).to(device)
183
-
184
- x_lengths = torch.LongTensor([x.shape[-1]]).to(device)
185
-
186
- y_enc, y_dec, attn = model(x, x_lengths, spk=spk, emo=emo, n_timesteps=10)
187
- logger.add_image(f'image_{i}/generated_enc',
188
- plot_tensor(y_enc.squeeze().cpu()),
189
- global_step=iteration, dataformats='HWC')
190
- logger.add_image(f'image_{i}/generated_dec',
191
- plot_tensor(y_dec.squeeze().cpu()),
192
- global_step=iteration, dataformats='HWC')
193
- logger.add_image(f'image_{i}/alignment',
194
- plot_tensor(attn.squeeze().cpu()),
195
- global_step=iteration, dataformats='HWC')
196
- save_plot(y_enc.squeeze().cpu(),
197
- f'{log_dir}/generated_enc_{i}.png')
198
- save_plot(y_dec.squeeze().cpu(),
199
- f'{log_dir}/generated_dec_{i}.png')
200
- save_plot(attn.squeeze().cpu(),
201
- f'{log_dir}/alignment_{i}.png')
202
-
203
- ckpt = model.state_dict()
204
-
205
- utils.save_checkpoint(ema_model, optimizer, learning_rate, epoch, checkpoint_path=f"{log_dir}/EMA_grad_{epoch}.pt")
206
- utils.save_checkpoint(model, optimizer, learning_rate, epoch, checkpoint_path=f"{log_dir}/grad_{epoch}.pt")
207
-