Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
from tqdm import tqdm | |
import torch | |
import numpy as np | |
from torch.utils.data import DataLoader | |
from torch.nn.parallel import DistributedDataParallel | |
from optimizer.optimizers import Eve, ScaledAdam | |
from schedulers.scheduler import NoamScheduler, Eden | |
from models.tts.valle.valle_dataset import ( | |
VALLEDataset, | |
VALLECollator, | |
batch_by_size, | |
) | |
from models.base.base_sampler import VariableSampler | |
from models.tts.base import TTSTrainer | |
from models.tts.valle.valle import VALLE | |
import diffusers | |
class VALLETrainer(TTSTrainer): | |
def __init__(self, args, cfg): | |
TTSTrainer.__init__(self, args, cfg) | |
def _build_model(self): | |
model = VALLE(self.cfg.model) | |
return model | |
def _build_dataset(self): | |
return VALLEDataset, VALLECollator | |
def _build_optimizer(self): | |
if self.args.train_stage: | |
if isinstance(self.model, DistributedDataParallel): | |
model = self.model.module | |
else: | |
model = self.model | |
model_parameters = model.stage_parameters(self.args.train_stage) | |
else: | |
model_parameters = self.model.parameters() | |
if self.cfg.train.optimizer == "ScaledAdam": | |
parameters_names = [] | |
if self.args.train_stage != 0: | |
parameters_names.append( | |
[ | |
name_param_pair[0] | |
for name_param_pair in model.stage_named_parameters( | |
self.args.train_stage | |
) | |
] | |
) | |
else: | |
parameters_names.append( | |
[name_param_pair[0] for name_param_pair in model.named_parameters()] | |
) | |
optimizer = ScaledAdam( | |
model_parameters, | |
lr=self.cfg.train.base_lr, | |
betas=(0.9, 0.95), | |
clipping_scale=2.0, | |
parameters_names=parameters_names, | |
show_dominant_parameters=False, | |
clipping_update_period=1000, | |
) | |
elif self.cfg.train.optimizer == "Eve": | |
optimizer = Eve( | |
model_parameters, | |
lr=self.cfg.train.base_lr, | |
betas=(0.9, 0.98), | |
target_rms=0.1, | |
) | |
elif self.cfg.train.optimizer == "AdamW": | |
optimizer = torch.optim.AdamW( | |
model_parameters, | |
lr=self.cfg.train.base_lr, | |
betas=(0.9, 0.95), | |
weight_decay=1e-2, | |
eps=1e-8, | |
) | |
elif self.cfg.train.optimizer == "Adam": | |
optimizer = torch.optim.Adam( | |
model_parameters, | |
lr=self.cfg.train.base_lr, | |
betas=(0.9, 0.95), | |
eps=1e-8, | |
) | |
else: | |
raise NotImplementedError() | |
return optimizer | |
def _build_scheduler(self): | |
if self.cfg.train.scheduler.lower() == "eden": | |
scheduler = Eden( | |
self.optimizer, 5000, 4, warmup_batches=self.cfg.train.warmup_steps | |
) | |
elif self.cfg.train.scheduler.lower() == "noam": | |
scheduler = NoamScheduler( | |
self.cfg.train.base_lr, | |
self.optimizer, | |
self.cfg.model.decoder_dim, | |
warmup_steps=self.cfg.train.warmup_steps, | |
) | |
elif self.cfg.train.scheduler.lower() == "cosine": | |
from diffusers.optimization import get_cosine_schedule_with_warmup | |
scheduler = get_cosine_schedule_with_warmup( | |
self.optimizer, | |
num_warmup_steps=self.cfg.train.warmup_steps | |
* self.accelerator.num_processes, | |
num_training_steps=self.cfg.train.total_training_steps | |
* self.accelerator.num_processes, | |
) | |
else: | |
raise NotImplementedError(f"{self.cfg.train.scheduler}") | |
return scheduler | |
def _train_epoch(self): | |
r"""Training epoch. Should return average loss of a batch (sample) over | |
one epoch. See ``train_loop`` for usage. | |
""" | |
if isinstance(self.model, dict): | |
for key in self.model.keys(): | |
self.model[key].train() | |
else: | |
self.model.train() | |
epoch_sum_loss: float = 0.0 | |
epoch_losses: dict = {} | |
epoch_step: int = 0 | |
for batch in tqdm( | |
self.train_dataloader, | |
desc=f"Training Epoch {self.epoch}", | |
unit="batch", | |
colour="GREEN", | |
leave=False, | |
dynamic_ncols=True, | |
smoothing=0.04, | |
disable=not self.accelerator.is_main_process, | |
): | |
# Do training step and BP | |
with self.accelerator.accumulate(self.model): | |
total_loss, train_losses = self._train_step(batch) | |
self.accelerator.backward(total_loss) | |
self.optimizer.step() | |
self.optimizer.zero_grad() | |
self.batch_count += 1 | |
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0: | |
if self.cfg.train.optimizer not in ["ScaledAdam", "Eve"]: | |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) | |
for k in range(self.cfg.train.gradient_accumulation_step): | |
if isinstance(self.scheduler, Eden): | |
self.scheduler.step_batch(self.step) | |
else: | |
self.scheduler.step() | |
epoch_sum_loss += total_loss.detach().cpu().item() | |
if isinstance(train_losses, dict): | |
for key, value in train_losses.items(): | |
if key not in epoch_losses.keys(): | |
epoch_losses[key] = value | |
else: | |
epoch_losses[key] += value | |
if isinstance(train_losses, dict): | |
for key, loss in train_losses.items(): | |
self.accelerator.log( | |
{"Step/Train {}".format(key): "{:.6f}".format(loss)}, | |
step=self.step, | |
) | |
else: | |
self.accelerator.log( | |
{"Step/Train Loss": loss}, | |
step=self.step, | |
) | |
self.accelerator.log( | |
{"Step/lr": self.scheduler.get_last_lr()[0]}, | |
step=self.step, | |
) | |
# print loss every log_epoch_step steps | |
# if epoch_step % self.cfg.train.log_epoch_step == 0: | |
# for key, loss in train_losses.items(): | |
# self.logger.info("Step/Train {}: {:.6f}".format(key, loss)) | |
# print("Step/Train {}: {:.6f}".format(key, loss)) | |
self.step += 1 | |
epoch_step += 1 | |
self.accelerator.wait_for_everyone() | |
epoch_sum_loss = ( | |
epoch_sum_loss | |
/ len(self.train_dataloader) | |
* self.cfg.train.gradient_accumulation_step | |
) | |
for key in epoch_losses.keys(): | |
epoch_losses[key] = ( | |
epoch_losses[key] | |
/ len(self.train_dataloader) | |
* self.cfg.train.gradient_accumulation_step | |
) | |
return epoch_sum_loss, epoch_losses | |
def _train_step(self, batch, is_training=True): | |
text_tokens = batch["phone_seq"].to(self.device) | |
text_tokens_lens = batch["phone_len"].to(self.device) | |
assert text_tokens.ndim == 2 | |
audio_features = batch["acoustic_token"].to(self.device) | |
audio_features_lens = batch["target_len"].to(self.device) | |
assert audio_features.ndim == 3 | |
with torch.set_grad_enabled(is_training): | |
loss, losses = self.model( | |
x=text_tokens, | |
x_lens=text_tokens_lens, | |
y=audio_features, | |
y_lens=audio_features_lens, | |
train_stage=self.args.train_stage, | |
) | |
assert loss.requires_grad == is_training | |
loss_dict = {} | |
frames_sum = (audio_features_lens).sum() | |
avg_loss = loss / frames_sum | |
loss_dict["loss"] = avg_loss.detach().cpu().item() | |
for l in losses: | |
loss_dict[l] = losses[l].detach().cpu().item() / frames_sum.item() | |
return avg_loss, loss_dict | |
def _valid_step(self, batch): | |
valid_losses = {} | |
total_loss = 0 | |
valid_stats = {} | |
total_loss, valid_losses = self._train_step( | |
batch=batch, | |
is_training=False, | |
) | |
assert total_loss.requires_grad is False | |
total_loss = total_loss.detach().cpu().item() | |
return total_loss, valid_losses, valid_stats | |
def _build_dataloader(self): | |
if not self.cfg.train.use_dynamic_batchsize: | |
return super()._build_dataloader() | |
if len(self.cfg.dataset) > 1: | |
raise Exception("use_dynamic_batchsize only supports single dataset now.") | |
Dataset, Collator = self._build_dataset() | |
train_dataset = Dataset( | |
self.cfg, self.cfg.dataset[0], is_valid=False | |
) # TODO: support use_dynamic_batchsize for more than one datasets. | |
train_collate = Collator(self.cfg) | |
batch_sampler = batch_by_size( | |
train_dataset.num_frame_indices, | |
train_dataset.get_num_frames, | |
max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes, | |
max_sentences=self.cfg.train.max_sentences * self.accelerator.num_processes, | |
required_batch_size_multiple=self.accelerator.num_processes, | |
) | |
np.random.seed(1234) | |
np.random.shuffle(batch_sampler) | |
print(batch_sampler[:1]) | |
batches = [ | |
x[self.accelerator.local_process_index :: self.accelerator.num_processes] | |
for x in batch_sampler | |
if len(x) % self.accelerator.num_processes == 0 | |
] | |
train_loader = DataLoader( | |
train_dataset, | |
collate_fn=train_collate, | |
num_workers=self.cfg.train.dataloader.num_worker, | |
batch_sampler=VariableSampler( | |
batches, drop_last=False, use_random_sampler=True | |
), | |
pin_memory=False, | |
) | |
self.accelerator.wait_for_everyone() | |
valid_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=True) | |
valid_collate = Collator(self.cfg) | |
batch_sampler = batch_by_size( | |
valid_dataset.num_frame_indices, | |
valid_dataset.get_num_frames, | |
max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes, | |
max_sentences=self.cfg.train.max_sentences * self.accelerator.num_processes, | |
required_batch_size_multiple=self.accelerator.num_processes, | |
) | |
batches = [ | |
x[self.accelerator.local_process_index :: self.accelerator.num_processes] | |
for x in batch_sampler | |
if len(x) % self.accelerator.num_processes == 0 | |
] | |
valid_loader = DataLoader( | |
valid_dataset, | |
collate_fn=valid_collate, | |
num_workers=self.cfg.train.dataloader.num_worker, | |
batch_sampler=VariableSampler(batches, drop_last=False), | |
pin_memory=False, | |
) | |
self.accelerator.wait_for_everyone() | |
return train_loader, valid_loader | |
def _accelerator_prepare(self): | |
if not self.cfg.train.use_dynamic_batchsize: | |
( | |
self.train_dataloader, | |
self.valid_dataloader, | |
) = self.accelerator.prepare( | |
self.train_dataloader, | |
self.valid_dataloader, | |
) | |
if isinstance(self.model, dict): | |
for key in self.model.keys(): | |
self.model[key] = self.accelerator.prepare(self.model[key]) | |
else: | |
self.model = self.accelerator.prepare(self.model) | |
if isinstance(self.optimizer, dict): | |
for key in self.optimizer.keys(): | |
self.optimizer[key] = self.accelerator.prepare(self.optimizer[key]) | |
else: | |
self.optimizer = self.accelerator.prepare(self.optimizer) | |
if isinstance(self.scheduler, dict): | |
for key in self.scheduler.keys(): | |
self.scheduler[key] = self.accelerator.prepare(self.scheduler[key]) | |
else: | |
self.scheduler = self.accelerator.prepare(self.scheduler) | |
def add_arguments(parser: argparse.ArgumentParser): | |
parser.add_argument( | |
"--train_stage", | |
type=int, | |
default="1", | |
help="0: train all modules, 1: AR Decoder, 2: NAR Decoder", | |
) | |
parser.add_argument( | |
"--ar_model_ckpt_dir", | |
type=str, | |
default=None, | |
help="Checkpoint for ar model ckeckpoint in the first training stage.", | |
) | |