Spaces:
Running
Running
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import json | |
import os | |
import shutil | |
import torch | |
import time | |
from pathlib import Path | |
import torch | |
from tqdm import tqdm | |
import torch.nn as nn | |
from .base_trainer import BaseTrainer | |
def make_pad_mask( | |
lengths: torch.Tensor, max_len: int = 0, left_pad=False | |
) -> torch.Tensor: | |
""" | |
Args: | |
lengths: | |
A 1-D tensor containing sentence lengths. | |
max_len: | |
The length of masks. | |
left_pad: | |
A boolean indicating whether to left pad the mask. | |
Returns: | |
Return a 2-D bool tensor, where masked positions | |
are filled with `True` and non-masked positions are | |
filled with `False`. | |
>>> lengths = torch.tensor([1, 3, 2, 5]) | |
>>> make_pad_mask(lengths) | |
tensor([[False, True, True, True, True], | |
[False, False, False, True, True], | |
[False, False, True, True, True], | |
[False, False, False, False, False]]) | |
""" | |
assert lengths.ndim == 1, lengths.ndim | |
max_len = max(max_len, lengths.max()) | |
n = lengths.size(0) | |
seq_range = torch.arange(0, max_len, device=lengths.device) | |
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) | |
mask = expaned_lengths >= lengths.unsqueeze(-1) | |
if left_pad: | |
mask = mask.flip(dims=[1]) | |
return mask | |
class ValleARTrainer(BaseTrainer): | |
def __init__(self, args=None, cfg=None): | |
super().__init__(args, cfg) | |
if self.cfg.use_speechtokenizer: | |
from models.codec.speechtokenizer.model import SpeechTokenizer | |
config_path = "./ckpts/speechtokenizer_hubert_avg/config.json" | |
ckpt_path = "./ckpts/speechtokenizer_hubert_avg/SpeechTokenizer.pt" | |
assert os.path.isfile( | |
config_path | |
), f"codec model {config_path} not found! Download with huggingface-cli download fnlp/SpeechTokenizer speechtokenizer_hubert_avg/SpeechTokenizer.pt speechtokenizer_hubert_avg/config.json --local-dir ckpts" | |
assert os.path.isfile( | |
ckpt_path | |
), f"codec model {ckpt_path} not found! Download with huggingface-cli download fnlp/SpeechTokenizer speechtokenizer_hubert_avg/SpeechTokenizer.pt speechtokenizer_hubert_avg/config.json --local-dir ckpts" | |
self.codec_encoder = SpeechTokenizer.load_from_checkpoint( | |
config_path, ckpt_path | |
) | |
self.codec_encoder.eval() | |
self.codec_encoder.to(self.accelerator.device) | |
print(f"Loaded SpeechTokenizer from {config_path} and {ckpt_path}") | |
else: | |
from encodec import EncodecModel | |
with self.accelerator.main_process_first(): | |
self.codec_encoder = EncodecModel.encodec_model_24khz() | |
self.codec_encoder.set_target_bandwidth(6.0) | |
self.codec_encoder.to(self.accelerator.device) | |
self.codec_decoder = None | |
print("Loaded EncodecModel") | |
self.top1_accuracies = [] | |
self.top5_accuracies = [] | |
self.top10_accuracies = [] | |
if hasattr(self.cfg, "flatten_first_2_layers"): | |
self.flatten_first_2_layers = self.cfg.flatten_first_2_layers | |
print("flattened:", self.flatten_first_2_layers) | |
else: | |
self.flatten_first_2_layers = False | |
if hasattr(self.cfg, "num_prediction_heads"): | |
self.num_prediction_heads = self.cfg.num_prediction_heads | |
print("num_prediction_heads:", self.num_prediction_heads) | |
def _accelerator_prepare(self): | |
# if self.accelerator.is_main_process: | |
# breakpoint() | |
# self.accelerator.wait_for_everyone() | |
( | |
self.model, | |
self.optimizer, | |
) = self.accelerator.prepare( | |
self.model, | |
self.optimizer, | |
) | |
def _build_criterion(self): | |
pass # loss is directly returned from model | |
def _build_scheduler(self): | |
from transformers import ( | |
get_cosine_schedule_with_warmup, | |
get_constant_schedule_with_warmup, | |
) | |
return get_cosine_schedule_with_warmup( | |
self.optimizer, | |
num_warmup_steps=self.cfg.train.scheduler.warmup_steps, | |
num_training_steps=self.cfg.train.scheduler.total_steps, | |
) | |
def _build_model(self): | |
if hasattr(self.cfg.model, "num_prediction_heads"): | |
from .valle_ar_multihead import ValleAR | |
else: | |
from .valle_ar import ValleAR | |
return ValleAR(**self.cfg.model) | |
def _train_step(self, batch): | |
# inference codec | |
"""Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens') | |
speech: [B, T] | |
speech_len: [B] | |
phone_ids: [B, T] | |
phone_lens: [B] | |
""" | |
device = self.accelerator.device | |
for k, v in batch.items(): | |
if isinstance(v, torch.Tensor): | |
batch[k] = v.to(device) | |
with torch.no_grad(): | |
if self.cfg.use_speechtokenizer: | |
# Extract discrete codes from SpeechTokenizer | |
vq_id = self.codec_encoder.encode( | |
batch["speech"].unsqueeze(1) | |
) # [B,1,T] -> (n_q, B, T) | |
else: | |
vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1)) | |
vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose( | |
0, 1 | |
) | |
# recovered_audio = self.codec_decoder(vq_emb, vq=False) | |
# torchaudio.save('a.wav', recovered_audio[0], 16000) | |
# vq_id: [8, B, T//320] | |
if self.flatten_first_2_layers: | |
first_layer = vq_id[0] | |
second_layer = vq_id[1] | |
# flatten the first two layers | |
batch["speech"] = torch.stack( | |
[first_layer, second_layer], dim=-1 | |
).flatten(-2, -1) | |
batch["speech_len"] = batch["speech_len"] // 160 | |
elif hasattr(self.cfg.model, "num_prediction_heads"): | |
batch["speech"] = vq_id[:2] # first two layers | |
batch["speech_len"] = ( | |
batch["speech_len"] // 320 | |
) # our codec downsamples 320x | |
else: | |
batch["speech"] = vq_id[0] # use first layer | |
batch["speech_len"] = ( | |
batch["speech_len"] // 320 | |
) # our codec downsamples 320x | |
assert batch["speech_len"].max() <= batch["speech"].shape[-1] | |
phone_mask = 1 - make_pad_mask( | |
batch["phone_lens"], max_len=batch["phone_ids"].size(1), left_pad=False | |
).to(torch.long) | |
speech_mask = 1 - make_pad_mask( | |
batch["speech_len"], max_len=batch["speech"].size(1) | |
).to(torch.long) | |
out = self.model( | |
phone_ids=batch["phone_ids"], | |
phone_mask=phone_mask, | |
target_ids=batch["speech"], | |
target_mask=speech_mask, | |
) | |
loss = out.loss | |
# if self.accelerator.is_main_process: | |
# print(loss) | |
# if hasattr(out, 'top1_acc'): | |
# self.top1_accuracies.append(out.top1_acc) | |
# self.top5_accuracies.append(out.top5_acc) | |
# self.top10_accuracies.append(out.top10_acc) | |
# print(f'avgs: top1: {sum(self.top1_accuracies)/len(self.top1_accuracies)}, top5: {sum(self.top5_accuracies)/len(self.top5_accuracies)}, top10: {sum(self.top10_accuracies)/len(self.top10_accuracies)}') | |
# breakpoint() | |
return loss | |
##########add your own dataloader to the trainer############# | |
def _build_dataloader(self): | |
from torch.utils.data import ConcatDataset, DataLoader | |
if self.cfg.train.dataset.name == "emilia": | |
from .emilia_dataset import EmiliaDataset as VALLEDataset | |
train_dataset = VALLEDataset() | |
elif self.cfg.train.dataset.name == "mls": | |
from .mls_dataset import VALLEDataset as VALLEDataset | |
train_dataset = VALLEDataset(self.cfg.dataset, resample_to_24k=False) | |
elif self.cfg.train.dataset.name == "libritts": | |
from .libritts_dataset import VALLEDataset as VALLEDataset | |
train_dataset = VALLEDataset(self.cfg.dataset) | |
from .valle_collator import VALLECollator | |
import numpy as np | |
print("length of train_dataset:", len(train_dataset)) | |
collator = VALLECollator() | |
if self.cfg.train.dataset.use_dynamic_batchsize: | |
if self.accelerator.is_main_process: | |
self.logger.info("Use Dynamic Batchsize......") | |
from .mls_dataset import batch_by_size | |
batch_sampler = batch_by_size( | |
train_dataset.num_frame_indices, | |
train_dataset.get_num_frames, | |
max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes, | |
max_sentences=self.cfg.train.max_sentences | |
* self.accelerator.num_processes, | |
required_batch_size_multiple=self.accelerator.num_processes, | |
) | |
np.random.shuffle(batch_sampler) | |
print(batch_sampler[0]) | |
batches = [ | |
x[ | |
self.accelerator.local_process_index :: self.accelerator.num_processes | |
] | |
for x in batch_sampler | |
if len(x) % self.accelerator.num_processes == 0 | |
] | |
from models.base.base_sampler import VariableSampler | |
train_loader = DataLoader( | |
train_dataset, | |
collate_fn=collator, | |
num_workers=self.cfg.train.dataloader.num_worker, | |
batch_sampler=VariableSampler( | |
batches, drop_last=True, use_random_sampler=True | |
), | |
pin_memory=self.cfg.train.dataloader.pin_memory, | |
persistent_workers=self.cfg.train.dataloader.persistent_workers, | |
prefetch_factor=4, | |
) | |
print( | |
f"process {self.accelerator.local_process_index} has {len(batches)} batches" | |
) | |
self.accelerator.wait_for_everyone() | |
else: | |
sampler = torch.utils.data.distributed.DistributedSampler( | |
train_dataset, | |
num_replicas=self.accelerator.num_processes, | |
rank=self.accelerator.local_process_index, | |
shuffle=True, | |
) | |
train_loader = DataLoader( | |
train_dataset, | |
batch_size=self.cfg.train.batch_size, | |
num_workers=self.cfg.train.dataloader.num_worker, | |
pin_memory=self.cfg.train.dataloader.pin_memory, | |
collate_fn=collator, | |
sampler=sampler, | |
) | |
print( | |
f"process {self.accelerator.local_process_index} has {len(train_loader)} batches" | |
) | |
return train_loader, None | |
def _test_step(self, batch): | |
# inference codec | |
"""Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens') | |
speech: [B, T] | |
speech_len: [B] | |
phone_ids: [B, T] | |
phone_lens: [B] | |
""" | |
import torchaudio | |
device = self.accelerator.device | |
for k, v in batch.items(): | |
if isinstance(v, torch.Tensor): | |
batch[k] = v.to(device) | |
with torch.no_grad(): | |
if self.cfg.use_speechtokenizer: | |
# Extract discrete codes from SpeechTokenizer | |
vq_id = self.codec_encoder.encode( | |
batch["speech"].unsqueeze(1) | |
) # [B,1,T] -> (n_q, B, T) | |
else: | |
vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1)) | |
vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose( | |
0, 1 | |
) | |
# recovered_audio = self.codec_decoder(vq_emb, vq=False) | |
# torchaudio.save('a.wav', recovered_audio[0], 16000) | |
# vq_id: [8, B, T//200] | |
# vq_emb = self.codec_decoder.quantizer.vq2emb(vq=vq_id[:1], n_quantizers=1) | |
# recovered_audio = self.codec_decoder(vq_emb, vq=False) | |
# recovered_audio.shape: torch.Size([1, 1, 50200]) | |
if self.flatten_first_2_layers: | |
first_layer = vq_id[0] | |
second_layer = vq_id[1] | |
# flatten the first two layers | |
batch["speech"] = torch.stack( | |
[first_layer, second_layer], dim=-1 | |
).flatten(-2, -1) | |
batch["speech_len"] = batch["speech_len"] // 160 | |
elif hasattr(self.cfg.model, "num_prediction_heads"): | |
batch["speech"] = vq_id[:2] # first two layers | |
batch["speech_len"] = ( | |
batch["speech_len"] // 320 | |
) # our codec downsamples 320x | |
else: | |
batch["speech"] = vq_id[0] # use first layer | |
batch["speech_len"] = ( | |
batch["speech_len"] // 320 | |
) # our codec downsamples 320x | |
# save gt | |
breakpoint() | |
recovered_audio = self.codec_encoder.decode(vq_id[:1, :1]) | |
# recovered_audio = self.codec_encoder.decode([(vq_id[:1].transpose(0,1), None)]) | |
torchaudio.save("gt.wav", recovered_audio[0].cpu(), 16000) | |
out_vq_ids = self.model.sample_hf( | |
batch["phone_ids"][:1, ...], batch["speech"][:1, :225], temperature=0.9 | |
) | |
# out_vq_ids = torch.cat([batch['speech'][:1, :225], out_vq_ids[:1, ...]], dim=1) | |
# reconstruct form tokens | |
recovered_audio = self.codec_encoder.decode(out_vq_ids.unsqueeze(0)) | |
# recovered_audio = self.codec_encoder.decode([(out_vq_ids, None)]) | |
torchaudio.save("a.wav", recovered_audio[0].cpu(), 16000) | |
breakpoint() | |
print() | |
def _valid_epoch(self): | |
r"""Testing epoch. Should return average loss of a batch (sample) over | |
one epoch. See ``train_loop`` for usage. | |
""" | |
epoch_sum_loss = 0.0 | |
return epoch_sum_loss | |
def _inference(self): | |
pass | |
def test_loop(self): | |
self.model.eval() | |
for batch in self.train_dataloader: | |
self._test_step(batch) | |