Spaces:
Build error
Build error
import os | |
import torch | |
import torch.distributed as dist | |
from torch.utils.data import DistributedSampler | |
from tasks.base_task import BaseTask | |
from tasks.base_task import data_loader | |
from tasks.vocoder.dataset_utils import VocoderDataset, EndlessDistributedSampler | |
from utils.hparams import hparams | |
class VocoderBaseTask(BaseTask): | |
def __init__(self): | |
super(VocoderBaseTask, self).__init__() | |
self.max_sentences = hparams['max_sentences'] | |
self.max_valid_sentences = hparams['max_valid_sentences'] | |
if self.max_valid_sentences == -1: | |
hparams['max_valid_sentences'] = self.max_valid_sentences = self.max_sentences | |
self.dataset_cls = VocoderDataset | |
def train_dataloader(self): | |
train_dataset = self.dataset_cls('train', shuffle=True) | |
return self.build_dataloader(train_dataset, True, self.max_sentences, hparams['endless_ds']) | |
def val_dataloader(self): | |
valid_dataset = self.dataset_cls('valid', shuffle=False) | |
return self.build_dataloader(valid_dataset, False, self.max_valid_sentences) | |
def test_dataloader(self): | |
test_dataset = self.dataset_cls('test', shuffle=False) | |
return self.build_dataloader(test_dataset, False, self.max_valid_sentences) | |
def build_dataloader(self, dataset, shuffle, max_sentences, endless=False): | |
world_size = 1 | |
rank = 0 | |
if dist.is_initialized(): | |
world_size = dist.get_world_size() | |
rank = dist.get_rank() | |
sampler_cls = DistributedSampler if not endless else EndlessDistributedSampler | |
train_sampler = sampler_cls( | |
dataset=dataset, | |
num_replicas=world_size, | |
rank=rank, | |
shuffle=shuffle, | |
) | |
return torch.utils.data.DataLoader( | |
dataset=dataset, | |
shuffle=False, | |
collate_fn=dataset.collater, | |
batch_size=max_sentences, | |
num_workers=dataset.num_workers, | |
sampler=train_sampler, | |
pin_memory=True, | |
) | |
def test_start(self): | |
self.gen_dir = os.path.join(hparams['work_dir'], | |
f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') | |
os.makedirs(self.gen_dir, exist_ok=True) | |
def test_end(self, outputs): | |
return {} | |