File size: 5,470 Bytes
c968fc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
import random

from torch.utils.data import ConcatDataset, Dataset
from torch.utils.data.sampler import (
    BatchSampler,
    RandomSampler,
    Sampler,
    SequentialSampler,
)


class ScheduledSampler(Sampler):
    """A sampler that samples data from a given concat-dataset.

    Args:
        concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
        batch_size (int): batch size
        holistic_shuffle (bool): whether to shuffle the whole dataset or not
        logger (logging.Logger): logger to print warning message

    Usage:
        For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
        >>> list(ScheduledSampler(ConcatDataset([[0, 1, 2], [3, 4, 5], [6, 7, 8]])))
        [3, 4, 5, 0, 1, 2, 6, 7, 8]
    """

    def __init__(
        self,
        concat_dataset,
        batch_size,
        holistic_shuffle,
        logger=None,
        loader_type="train",
    ):
        if not isinstance(concat_dataset, ConcatDataset):
            raise ValueError(
                "concat_dataset must be an instance of ConcatDataset, but got {}".format(
                    type(concat_dataset)
                )
            )
        if not isinstance(batch_size, int):
            raise ValueError(
                "batch_size must be an integer, but got {}".format(type(batch_size))
            )
        if not isinstance(holistic_shuffle, bool):
            raise ValueError(
                "holistic_shuffle must be a boolean, but got {}".format(
                    type(holistic_shuffle)
                )
            )

        self.concat_dataset = concat_dataset
        self.batch_size = batch_size
        self.holistic_shuffle = holistic_shuffle

        affected_dataset_name = []
        affected_dataset_len = []
        for dataset in concat_dataset.datasets:
            dataset_len = len(dataset)
            dataset_name = dataset.get_dataset_name()
            if dataset_len < batch_size:
                affected_dataset_name.append(dataset_name)
                affected_dataset_len.append(dataset_len)

        self.type = loader_type
        for dataset_name, dataset_len in zip(
            affected_dataset_name, affected_dataset_len
        ):
            if not loader_type == "valid":
                logger.warning(
                    "The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format(
                        loader_type, dataset_name, dataset_len, batch_size
                    )
                )

    def __len__(self):
        # the number of batches with drop last
        num_of_batches = sum(
            [
                math.floor(len(dataset) / self.batch_size)
                for dataset in self.concat_dataset.datasets
            ]
        )
        # if samples are not enough for one batch, we don't drop last
        if self.type == "valid" and num_of_batches < 1:
            return len(self.concat_dataset)
        return num_of_batches * self.batch_size

    def __iter__(self):
        iters = []
        for dataset in self.concat_dataset.datasets:
            iters.append(
                SequentialSampler(dataset).__iter__()
                if not self.holistic_shuffle
                else RandomSampler(dataset).__iter__()
            )
        # e.g. [0, 200, 400]
        init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1]
        output_batches = []
        for dataset_idx in range(len(self.concat_dataset.datasets)):
            cur_batch = []
            for idx in iters[dataset_idx]:
                cur_batch.append(idx + init_indices[dataset_idx])
                if len(cur_batch) == self.batch_size:
                    output_batches.append(cur_batch)
                    cur_batch = []
            # if loader_type is valid, we don't need to drop last
            if self.type == "valid" and len(cur_batch) > 0:
                output_batches.append(cur_batch)

        # force drop last in training
        random.shuffle(output_batches)
        output_indices = [item for sublist in output_batches for item in sublist]
        return iter(output_indices)


def build_samplers(concat_dataset: Dataset, cfg, logger, loader_type):
    sampler = ScheduledSampler(
        concat_dataset,
        cfg.train.batch_size,
        cfg.train.sampler.holistic_shuffle,
        logger,
        loader_type,
    )
    batch_sampler = BatchSampler(
        sampler,
        cfg.train.batch_size,
        cfg.train.sampler.drop_last if not loader_type == "valid" else False,
    )
    return sampler, batch_sampler


class VariableSampler(BatchSampler):
    def __init__(self, sampler, drop_last: bool, use_random_sampler=False):
        self.data_list = sampler
        if use_random_sampler:
            self.sampler = RandomSampler(sampler)
        else:
            self.sampler = SequentialSampler(sampler)

        super().__init__(self.sampler, 1, drop_last)

    def __iter__(self):
        for batch_ids in self.data_list:
            yield batch_ids

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size