File size: 2,302 Bytes
c968fc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch.nn.utils.rnn import pad_sequence


class VALLECollator:
    def __init__(self, cfg=None):
        self.cfg = cfg

    def __call__(self, batch):
        """Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens')
        speech: [B, T]
        speech_len: [B]
        phone_ids: [B, T]
        phone_lens: [B]
        """
        assert len(batch) != 0, "batch is empty before None checking"
        batch = [b for b in batch if b is not None]
        assert len(batch) != 0, "batch is empty after None checking"
        packed_batch_features = {}

        # Function to handle tensor copying
        def process_tensor(data, dtype=torch.float32):
            if isinstance(data, torch.Tensor):
                return data.detach()
            else:
                return torch.tensor(data, dtype=dtype)

        # Process 'speech' data
        speeches = [process_tensor(b["speech"]) for b in batch]
        packed_batch_features["speech_len"] = torch.tensor(
            [len(s) for s in speeches], dtype=torch.long
        )
        packed_batch_features["speech"] = pad_sequence(
            speeches, batch_first=True, padding_value=0
        )

        # right-padding 'phone' data
        phones = [process_tensor(b["phone"], dtype=torch.long) for b in batch]
        packed_batch_features["phone_lens"] = torch.tensor(
            [len(phone) for phone in phones], dtype=torch.long
        )
        packed_batch_features["phone_ids"] = pad_sequence(
            phones, batch_first=True, padding_value=0
        )

        # # Process 'phone' data, with left padding
        # phones = [process_tensor(b['phone'], dtype=torch.long).flip(0) for b in batch] # first reverse the whole sequence
        # packed_batch_features['phone_lens'] = torch.tensor([len(phone) for phone in phones], dtype=torch.long)
        # packed_batch_features['phone_ids'] = pad_sequence(phones, batch_first=True, padding_value=0) # do the right padding
        # packed_batch_features['phone_ids'] = packed_batch_features['phone_ids'].flip(1) # flip back to original order (left padding)

        return packed_batch_features