# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch from torch.nn.utils.rnn import pad_sequence class VALLECollator: def __init__(self, cfg=None): self.cfg = cfg def __call__(self, batch): """Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens') speech: [B, T] speech_len: [B] phone_ids: [B, T] phone_lens: [B] """ assert len(batch) != 0, "batch is empty before None checking" batch = [b for b in batch if b is not None] assert len(batch) != 0, "batch is empty after None checking" packed_batch_features = {} # Function to handle tensor copying def process_tensor(data, dtype=torch.float32): if isinstance(data, torch.Tensor): return data.detach() else: return torch.tensor(data, dtype=dtype) # Process 'speech' data speeches = [process_tensor(b["speech"]) for b in batch] packed_batch_features["speech_len"] = torch.tensor( [len(s) for s in speeches], dtype=torch.long ) packed_batch_features["speech"] = pad_sequence( speeches, batch_first=True, padding_value=0 ) # right-padding 'phone' data phones = [process_tensor(b["phone"], dtype=torch.long) for b in batch] packed_batch_features["phone_lens"] = torch.tensor( [len(phone) for phone in phones], dtype=torch.long ) packed_batch_features["phone_ids"] = pad_sequence( phones, batch_first=True, padding_value=0 ) # # Process 'phone' data, with left padding # phones = [process_tensor(b['phone'], dtype=torch.long).flip(0) for b in batch] # first reverse the whole sequence # packed_batch_features['phone_lens'] = torch.tensor([len(phone) for phone in phones], dtype=torch.long) # packed_batch_features['phone_ids'] = pad_sequence(phones, batch_first=True, padding_value=0) # do the right padding # packed_batch_features['phone_ids'] = packed_batch_features['phone_ids'].flip(1) # flip back to original order (left padding) return packed_batch_features