Spaces:
Running
Running
File size: 2,302 Bytes
c968fc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch.nn.utils.rnn import pad_sequence
class VALLECollator:
def __init__(self, cfg=None):
self.cfg = cfg
def __call__(self, batch):
"""Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens')
speech: [B, T]
speech_len: [B]
phone_ids: [B, T]
phone_lens: [B]
"""
assert len(batch) != 0, "batch is empty before None checking"
batch = [b for b in batch if b is not None]
assert len(batch) != 0, "batch is empty after None checking"
packed_batch_features = {}
# Function to handle tensor copying
def process_tensor(data, dtype=torch.float32):
if isinstance(data, torch.Tensor):
return data.detach()
else:
return torch.tensor(data, dtype=dtype)
# Process 'speech' data
speeches = [process_tensor(b["speech"]) for b in batch]
packed_batch_features["speech_len"] = torch.tensor(
[len(s) for s in speeches], dtype=torch.long
)
packed_batch_features["speech"] = pad_sequence(
speeches, batch_first=True, padding_value=0
)
# right-padding 'phone' data
phones = [process_tensor(b["phone"], dtype=torch.long) for b in batch]
packed_batch_features["phone_lens"] = torch.tensor(
[len(phone) for phone in phones], dtype=torch.long
)
packed_batch_features["phone_ids"] = pad_sequence(
phones, batch_first=True, padding_value=0
)
# # Process 'phone' data, with left padding
# phones = [process_tensor(b['phone'], dtype=torch.long).flip(0) for b in batch] # first reverse the whole sequence
# packed_batch_features['phone_lens'] = torch.tensor([len(phone) for phone in phones], dtype=torch.long)
# packed_batch_features['phone_ids'] = pad_sequence(phones, batch_first=True, padding_value=0) # do the right padding
# packed_batch_features['phone_ids'] = packed_batch_features['phone_ids'].flip(1) # flip back to original order (left padding)
return packed_batch_features
|