# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import json import os import shutil import torch import time from pathlib import Path import torch from tqdm import tqdm import torch.nn as nn from .base_trainer import BaseTrainer def make_pad_mask( lengths: torch.Tensor, max_len: int = 0, left_pad=False ) -> torch.Tensor: """ Args: lengths: A 1-D tensor containing sentence lengths. max_len: The length of masks. left_pad: A boolean indicating whether to left pad the mask. Returns: Return a 2-D bool tensor, where masked positions are filled with `True` and non-masked positions are filled with `False`. >>> lengths = torch.tensor([1, 3, 2, 5]) >>> make_pad_mask(lengths) tensor([[False, True, True, True, True], [False, False, False, True, True], [False, False, True, True, True], [False, False, False, False, False]]) """ assert lengths.ndim == 1, lengths.ndim max_len = max(max_len, lengths.max()) n = lengths.size(0) seq_range = torch.arange(0, max_len, device=lengths.device) expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) mask = expaned_lengths >= lengths.unsqueeze(-1) if left_pad: mask = mask.flip(dims=[1]) return mask class ValleARTrainer(BaseTrainer): def __init__(self, args=None, cfg=None): super().__init__(args, cfg) if self.cfg.use_speechtokenizer: from models.codec.speechtokenizer.model import SpeechTokenizer config_path = "./ckpts/speechtokenizer_hubert_avg/config.json" ckpt_path = "./ckpts/speechtokenizer_hubert_avg/SpeechTokenizer.pt" assert os.path.isfile( config_path ), f"codec model {config_path} not found! Download with huggingface-cli download fnlp/SpeechTokenizer speechtokenizer_hubert_avg/SpeechTokenizer.pt speechtokenizer_hubert_avg/config.json --local-dir ckpts" assert os.path.isfile( ckpt_path ), f"codec model {ckpt_path} not found! Download with huggingface-cli download fnlp/SpeechTokenizer speechtokenizer_hubert_avg/SpeechTokenizer.pt speechtokenizer_hubert_avg/config.json --local-dir ckpts" self.codec_encoder = SpeechTokenizer.load_from_checkpoint( config_path, ckpt_path ) self.codec_encoder.eval() self.codec_encoder.to(self.accelerator.device) print(f"Loaded SpeechTokenizer from {config_path} and {ckpt_path}") else: from encodec import EncodecModel with self.accelerator.main_process_first(): self.codec_encoder = EncodecModel.encodec_model_24khz() self.codec_encoder.set_target_bandwidth(6.0) self.codec_encoder.to(self.accelerator.device) self.codec_decoder = None print("Loaded EncodecModel") self.top1_accuracies = [] self.top5_accuracies = [] self.top10_accuracies = [] if hasattr(self.cfg, "flatten_first_2_layers"): self.flatten_first_2_layers = self.cfg.flatten_first_2_layers print("flattened:", self.flatten_first_2_layers) else: self.flatten_first_2_layers = False if hasattr(self.cfg, "num_prediction_heads"): self.num_prediction_heads = self.cfg.num_prediction_heads print("num_prediction_heads:", self.num_prediction_heads) def _accelerator_prepare(self): # if self.accelerator.is_main_process: # breakpoint() # self.accelerator.wait_for_everyone() ( self.model, self.optimizer, ) = self.accelerator.prepare( self.model, self.optimizer, ) def _build_criterion(self): pass # loss is directly returned from model def _build_scheduler(self): from transformers import ( get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup, ) return get_cosine_schedule_with_warmup( self.optimizer, num_warmup_steps=self.cfg.train.scheduler.warmup_steps, num_training_steps=self.cfg.train.scheduler.total_steps, ) def _build_model(self): if hasattr(self.cfg.model, "num_prediction_heads"): from .valle_ar_multihead import ValleAR else: from .valle_ar import ValleAR return ValleAR(**self.cfg.model) def _train_step(self, batch): # inference codec """Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens') speech: [B, T] speech_len: [B] phone_ids: [B, T] phone_lens: [B] """ device = self.accelerator.device for k, v in batch.items(): if isinstance(v, torch.Tensor): batch[k] = v.to(device) with torch.no_grad(): if self.cfg.use_speechtokenizer: # Extract discrete codes from SpeechTokenizer vq_id = self.codec_encoder.encode( batch["speech"].unsqueeze(1) ) # [B,1,T] -> (n_q, B, T) else: vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1)) vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose( 0, 1 ) # recovered_audio = self.codec_decoder(vq_emb, vq=False) # torchaudio.save('a.wav', recovered_audio[0], 16000) # vq_id: [8, B, T//320] if self.flatten_first_2_layers: first_layer = vq_id[0] second_layer = vq_id[1] # flatten the first two layers batch["speech"] = torch.stack( [first_layer, second_layer], dim=-1 ).flatten(-2, -1) batch["speech_len"] = batch["speech_len"] // 160 elif hasattr(self.cfg.model, "num_prediction_heads"): batch["speech"] = vq_id[:2] # first two layers batch["speech_len"] = ( batch["speech_len"] // 320 ) # our codec downsamples 320x else: batch["speech"] = vq_id[0] # use first layer batch["speech_len"] = ( batch["speech_len"] // 320 ) # our codec downsamples 320x assert batch["speech_len"].max() <= batch["speech"].shape[-1] phone_mask = 1 - make_pad_mask( batch["phone_lens"], max_len=batch["phone_ids"].size(1), left_pad=False ).to(torch.long) speech_mask = 1 - make_pad_mask( batch["speech_len"], max_len=batch["speech"].size(1) ).to(torch.long) out = self.model( phone_ids=batch["phone_ids"], phone_mask=phone_mask, target_ids=batch["speech"], target_mask=speech_mask, ) loss = out.loss # if self.accelerator.is_main_process: # print(loss) # if hasattr(out, 'top1_acc'): # self.top1_accuracies.append(out.top1_acc) # self.top5_accuracies.append(out.top5_acc) # self.top10_accuracies.append(out.top10_acc) # print(f'avgs: top1: {sum(self.top1_accuracies)/len(self.top1_accuracies)}, top5: {sum(self.top5_accuracies)/len(self.top5_accuracies)}, top10: {sum(self.top10_accuracies)/len(self.top10_accuracies)}') # breakpoint() return loss ##########add your own dataloader to the trainer############# def _build_dataloader(self): from torch.utils.data import ConcatDataset, DataLoader if self.cfg.train.dataset.name == "emilia": from .emilia_dataset import EmiliaDataset as VALLEDataset train_dataset = VALLEDataset() elif self.cfg.train.dataset.name == "mls": from .mls_dataset import VALLEDataset as VALLEDataset train_dataset = VALLEDataset(self.cfg.dataset, resample_to_24k=False) elif self.cfg.train.dataset.name == "libritts": from .libritts_dataset import VALLEDataset as VALLEDataset train_dataset = VALLEDataset(self.cfg.dataset) from .valle_collator import VALLECollator import numpy as np print("length of train_dataset:", len(train_dataset)) collator = VALLECollator() if self.cfg.train.dataset.use_dynamic_batchsize: if self.accelerator.is_main_process: self.logger.info("Use Dynamic Batchsize......") from .mls_dataset import batch_by_size batch_sampler = batch_by_size( train_dataset.num_frame_indices, train_dataset.get_num_frames, max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes, max_sentences=self.cfg.train.max_sentences * self.accelerator.num_processes, required_batch_size_multiple=self.accelerator.num_processes, ) np.random.shuffle(batch_sampler) print(batch_sampler[0]) batches = [ x[ self.accelerator.local_process_index :: self.accelerator.num_processes ] for x in batch_sampler if len(x) % self.accelerator.num_processes == 0 ] from models.base.base_sampler import VariableSampler train_loader = DataLoader( train_dataset, collate_fn=collator, num_workers=self.cfg.train.dataloader.num_worker, batch_sampler=VariableSampler( batches, drop_last=True, use_random_sampler=True ), pin_memory=self.cfg.train.dataloader.pin_memory, persistent_workers=self.cfg.train.dataloader.persistent_workers, prefetch_factor=4, ) print( f"process {self.accelerator.local_process_index} has {len(batches)} batches" ) self.accelerator.wait_for_everyone() else: sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=self.accelerator.num_processes, rank=self.accelerator.local_process_index, shuffle=True, ) train_loader = DataLoader( train_dataset, batch_size=self.cfg.train.batch_size, num_workers=self.cfg.train.dataloader.num_worker, pin_memory=self.cfg.train.dataloader.pin_memory, collate_fn=collator, sampler=sampler, ) print( f"process {self.accelerator.local_process_index} has {len(train_loader)} batches" ) return train_loader, None def _test_step(self, batch): # inference codec """Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens') speech: [B, T] speech_len: [B] phone_ids: [B, T] phone_lens: [B] """ import torchaudio device = self.accelerator.device for k, v in batch.items(): if isinstance(v, torch.Tensor): batch[k] = v.to(device) with torch.no_grad(): if self.cfg.use_speechtokenizer: # Extract discrete codes from SpeechTokenizer vq_id = self.codec_encoder.encode( batch["speech"].unsqueeze(1) ) # [B,1,T] -> (n_q, B, T) else: vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1)) vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose( 0, 1 ) # recovered_audio = self.codec_decoder(vq_emb, vq=False) # torchaudio.save('a.wav', recovered_audio[0], 16000) # vq_id: [8, B, T//200] # vq_emb = self.codec_decoder.quantizer.vq2emb(vq=vq_id[:1], n_quantizers=1) # recovered_audio = self.codec_decoder(vq_emb, vq=False) # recovered_audio.shape: torch.Size([1, 1, 50200]) if self.flatten_first_2_layers: first_layer = vq_id[0] second_layer = vq_id[1] # flatten the first two layers batch["speech"] = torch.stack( [first_layer, second_layer], dim=-1 ).flatten(-2, -1) batch["speech_len"] = batch["speech_len"] // 160 elif hasattr(self.cfg.model, "num_prediction_heads"): batch["speech"] = vq_id[:2] # first two layers batch["speech_len"] = ( batch["speech_len"] // 320 ) # our codec downsamples 320x else: batch["speech"] = vq_id[0] # use first layer batch["speech_len"] = ( batch["speech_len"] // 320 ) # our codec downsamples 320x # save gt breakpoint() recovered_audio = self.codec_encoder.decode(vq_id[:1, :1]) # recovered_audio = self.codec_encoder.decode([(vq_id[:1].transpose(0,1), None)]) torchaudio.save("gt.wav", recovered_audio[0].cpu(), 16000) out_vq_ids = self.model.sample_hf( batch["phone_ids"][:1, ...], batch["speech"][:1, :225], temperature=0.9 ) # out_vq_ids = torch.cat([batch['speech'][:1, :225], out_vq_ids[:1, ...]], dim=1) # reconstruct form tokens recovered_audio = self.codec_encoder.decode(out_vq_ids.unsqueeze(0)) # recovered_audio = self.codec_encoder.decode([(out_vq_ids, None)]) torchaudio.save("a.wav", recovered_audio[0].cpu(), 16000) breakpoint() print() @torch.inference_mode() def _valid_epoch(self): r"""Testing epoch. Should return average loss of a batch (sample) over one epoch. See ``train_loop`` for usage. """ epoch_sum_loss = 0.0 return epoch_sum_loss def _inference(self): pass def test_loop(self): self.model.eval() for batch in self.train_dataloader: self._test_step(batch)