Spaces:
Runtime error
Runtime error
import numpy as np | |
from tqdm import tqdm | |
from copy import deepcopy | |
import torch | |
from torch.utils.data import DataLoader | |
from torch.utils.tensorboard import SummaryWriter | |
import data_collate | |
import data_loader | |
from utils_data import plot_tensor, save_plot | |
from model.utils import fix_len_compatibility | |
from text.symbols import symbols | |
import utils_data as utils | |
class ModelEmaV2(torch.nn.Module): | |
def __init__(self, model, decay=0.9999, device=None): | |
super(ModelEmaV2, self).__init__() | |
self.model_state_dict = deepcopy(model.state_dict()) | |
self.decay = decay | |
self.device = device # perform ema on different device from model if set | |
def _update(self, model, update_fn): | |
with torch.no_grad(): | |
for ema_v, model_v in zip(self.model_state_dict.values(), model.state_dict().values()): | |
if self.device is not None: | |
model_v = model_v.to(device=self.device) | |
ema_v.copy_(update_fn(ema_v, model_v)) | |
def update(self, model): | |
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) | |
def set(self, model): | |
self._update(model, update_fn=lambda e, m: m) | |
def state_dict(self, destination=None, prefix='', keep_vars=False): | |
return self.model_state_dict | |
if __name__ == "__main__": | |
hps = utils.get_hparams() | |
logger_text = utils.get_logger(hps.model_dir) | |
logger_text.info(hps) | |
out_size = fix_len_compatibility(2 * hps.data.sampling_rate // hps.data.hop_length) # NOTE: 2-sec of mel-spec | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
torch.manual_seed(hps.train.seed) | |
np.random.seed(hps.train.seed) | |
print('Initializing logger...') | |
log_dir = hps.model_dir | |
logger = SummaryWriter(log_dir=log_dir) | |
train_dataset, collate, model = utils.get_correct_class(hps) | |
test_dataset, _, _ = utils.get_correct_class(hps, train=False) | |
print('Initializing data loaders...') | |
batch_collate = collate | |
loader = DataLoader(dataset=train_dataset, batch_size=hps.train.batch_size, | |
collate_fn=batch_collate, drop_last=True, | |
num_workers=4, shuffle=False) # NOTE: if on server, worker can be 4 | |
print('Initializing model...') | |
model = model(**hps.model).to(device) | |
print('Number of encoder + duration predictor parameters: %.2fm' % (model.encoder.nparams / 1e6)) | |
print('Number of decoder parameters: %.2fm' % (model.decoder.nparams / 1e6)) | |
print('Total parameters: %.2fm' % (model.nparams / 1e6)) | |
use_gt_dur = getattr(hps.train, "use_gt_dur", False) | |
if use_gt_dur: | |
print("++++++++++++++> Using ground truth duration for training") | |
print('Initializing optimizer...') | |
optimizer = torch.optim.Adam(params=model.parameters(), lr=hps.train.learning_rate) | |
print('Logging test batch...') | |
test_batch = test_dataset.sample_test_batch(size=hps.train.test_size) | |
for i, item in enumerate(test_batch): | |
mel = item['mel'] | |
logger.add_image(f'image_{i}/ground_truth', plot_tensor(mel.squeeze()), | |
global_step=0, dataformats='HWC') | |
save_plot(mel.squeeze(), f'{log_dir}/original_{i}.png') | |
try: | |
model, optimizer, learning_rate, epoch_logged = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "grad_*.pt"), model, optimizer) | |
epoch_start = epoch_logged + 1 | |
print(f"Loaded checkpoint from {epoch_logged} epoch, resuming training.") | |
global_step = epoch_logged * (len(train_dataset)/hps.train.batch_size) | |
except: | |
print(f"Cannot find trained checkpoint, begin to train from scratch") | |
epoch_start = 1 | |
global_step = 0 | |
learning_rate = hps.train.learning_rate | |
ema_model = ModelEmaV2(model, decay=0.9999) # It's necessary that we put this after loading model. | |
print('Start training...') | |
used_items = set() | |
iteration = global_step | |
for epoch in range(epoch_start, hps.train.n_epochs + 1): | |
model.train() | |
dur_losses = [] | |
prior_losses = [] | |
diff_losses = [] | |
with tqdm(loader, total=len(train_dataset) // hps.train.batch_size) as progress_bar: | |
for batch_idx, batch in enumerate(progress_bar): | |
model.zero_grad() | |
x, x_lengths = batch['text_padded'].to(device), \ | |
batch['input_lengths'].to(device) | |
y, y_lengths = batch['mel_padded'].to(device), \ | |
batch['output_lengths'].to(device) | |
if hps.xvector: | |
spk = batch['xvector'].to(device) | |
else: | |
spk = batch['spk_ids'].to(torch.long).to(device) | |
emo = batch['emo_ids'].to(torch.long).to(device) | |
dur_loss, prior_loss, diff_loss = model.compute_loss(x, x_lengths, | |
y, y_lengths, | |
spk=spk, | |
emo=emo, | |
out_size=out_size, | |
use_gt_dur=use_gt_dur, | |
durs=batch['dur_padded'].to(device) if use_gt_dur else None) | |
loss = sum([dur_loss, prior_loss, diff_loss]) | |
loss.backward() | |
enc_grad_norm = torch.nn.utils.clip_grad_norm_(model.encoder.parameters(), | |
max_norm=1) | |
dec_grad_norm = torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), | |
max_norm=1) | |
optimizer.step() | |
ema_model.update(model) | |
logger.add_scalar('training/duration_loss', dur_loss.item(), | |
global_step=iteration) | |
logger.add_scalar('training/prior_loss', prior_loss.item(), | |
global_step=iteration) | |
logger.add_scalar('training/diffusion_loss', diff_loss.item(), | |
global_step=iteration) | |
logger.add_scalar('training/encoder_grad_norm', enc_grad_norm, | |
global_step=iteration) | |
logger.add_scalar('training/decoder_grad_norm', dec_grad_norm, | |
global_step=iteration) | |
dur_losses.append(dur_loss.item()) | |
prior_losses.append(prior_loss.item()) | |
diff_losses.append(diff_loss.item()) | |
if batch_idx % 5 == 0: | |
msg = f'Epoch: {epoch}, iteration: {iteration} | dur_loss: {dur_loss.item()}, prior_loss: {prior_loss.item()}, diff_loss: {diff_loss.item()}' | |
progress_bar.set_description(msg) | |
iteration += 1 | |
log_msg = 'Epoch %d: duration loss = %.3f ' % (epoch, float(np.mean(dur_losses))) | |
log_msg += '| prior loss = %.3f ' % np.mean(prior_losses) | |
log_msg += '| diffusion loss = %.3f\n' % np.mean(diff_losses) | |
with open(f'{log_dir}/train.log', 'a') as f: | |
f.write(log_msg) | |
if epoch % hps.train.save_every > 0: | |
continue | |
model.eval() | |
print('Synthesis...') | |
with torch.no_grad(): | |
for i, item in enumerate(test_batch): | |
if item['utt'] + "/truth" not in used_items: | |
used_items.add(item['utt'] + "/truth") | |
x = item['text'].to(torch.long).unsqueeze(0).to(device) | |
if not hps.xvector: | |
spk = item['spk_ids'] | |
spk = torch.LongTensor([spk]).to(device) | |
else: | |
spk = item["xvector"] | |
spk = spk.unsqueeze(0).to(device) | |
emo = item['emo_ids'] | |
emo = torch.LongTensor([emo]).to(device) | |
x_lengths = torch.LongTensor([x.shape[-1]]).to(device) | |
y_enc, y_dec, attn = model(x, x_lengths, spk=spk, emo=emo, n_timesteps=10) | |
logger.add_image(f'image_{i}/generated_enc', | |
plot_tensor(y_enc.squeeze().cpu()), | |
global_step=iteration, dataformats='HWC') | |
logger.add_image(f'image_{i}/generated_dec', | |
plot_tensor(y_dec.squeeze().cpu()), | |
global_step=iteration, dataformats='HWC') | |
logger.add_image(f'image_{i}/alignment', | |
plot_tensor(attn.squeeze().cpu()), | |
global_step=iteration, dataformats='HWC') | |
save_plot(y_enc.squeeze().cpu(), | |
f'{log_dir}/generated_enc_{i}.png') | |
save_plot(y_dec.squeeze().cpu(), | |
f'{log_dir}/generated_dec_{i}.png') | |
save_plot(attn.squeeze().cpu(), | |
f'{log_dir}/alignment_{i}.png') | |
ckpt = model.state_dict() | |
utils.save_checkpoint(ema_model, optimizer, learning_rate, epoch, checkpoint_path=f"{log_dir}/EMA_grad_{epoch}.pt") | |
utils.save_checkpoint(model, optimizer, learning_rate, epoch, checkpoint_path=f"{log_dir}/grad_{epoch}.pt") | |