Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Dataloaders.""" | |
import torch | |
class PretrainingSampler: | |
def __init__(self, total_samples, consumed_samples, micro_batch_size, | |
data_parallel_rank, data_parallel_size, drop_last=True): | |
# Keep a copy of input params for later use. | |
self.total_samples = total_samples | |
self.consumed_samples = consumed_samples | |
self.micro_batch_size = micro_batch_size | |
self.data_parallel_rank = data_parallel_rank | |
self.micro_batch_times_data_parallel_size = \ | |
self.micro_batch_size * data_parallel_size | |
self.drop_last = drop_last | |
# Sanity checks. | |
assert self.total_samples > 0, \ | |
'no sample to consume: {}'.format(self.total_samples) | |
assert self.consumed_samples < self.total_samples, \ | |
'no samples left to consume: {}, {}'.format(self.consumed_samples, | |
self.total_samples) | |
assert self.micro_batch_size > 0 | |
assert data_parallel_size > 0 | |
assert self.data_parallel_rank < data_parallel_size, \ | |
'data_parallel_rank should be smaller than data size: {}, ' \ | |
'{}'.format(self.data_parallel_rank, data_parallel_size) | |
def __len__(self): | |
return self.total_samples // self.micro_batch_times_data_parallel_size | |
def get_start_end_idx(self): | |
start_idx = self.data_parallel_rank * self.micro_batch_size | |
end_idx = start_idx + self.micro_batch_size | |
return start_idx, end_idx | |
def __iter__(self): | |
batch = [] | |
# Last batch will be dropped if drop_last is not set False | |
for idx in range(self.consumed_samples, self.total_samples): | |
batch.append(idx) | |
if len(batch) == self.micro_batch_times_data_parallel_size: | |
start_idx, end_idx = self.get_start_end_idx() | |
yield batch[start_idx:end_idx] | |
batch = [] | |
# Check the last partial batch and see drop_last is set | |
if len(batch) > 0 and not self.drop_last: | |
start_idx, end_idx = self.get_start_end_idx() | |
yield batch[start_idx:end_idx] | |
class PretrainingRandomSampler: | |
def __init__(self, total_samples, consumed_samples, micro_batch_size, | |
data_parallel_rank, data_parallel_size, epoch): | |
# Keep a copy of input params for later use. | |
self.total_samples = total_samples | |
self.consumed_samples = consumed_samples | |
self.micro_batch_size = micro_batch_size | |
self.data_parallel_rank = data_parallel_rank | |
self.data_parallel_size = data_parallel_size | |
self.micro_batch_times_data_parallel_size = \ | |
self.micro_batch_size * data_parallel_size | |
self.last_batch_size = \ | |
self.total_samples % self.micro_batch_times_data_parallel_size | |
self.epoch = epoch | |
# Sanity checks. | |
assert self.total_samples > 0, \ | |
'no sample to consume: {}'.format(self.total_samples) | |
assert self.micro_batch_size > 0 | |
assert data_parallel_size > 0 | |
assert self.data_parallel_rank < data_parallel_size, \ | |
'data_parallel_rank should be smaller than data size: {}, ' \ | |
'{}'.format(self.data_parallel_rank, data_parallel_size) | |
def __len__(self): | |
return self.total_samples // self.micro_batch_times_data_parallel_size | |
def __iter__(self): | |
active_total_samples = self.total_samples - self.last_batch_size | |
current_epoch_samples = self.consumed_samples % active_total_samples | |
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 | |
# data sharding and random sampling | |
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \ | |
* self.micro_batch_size | |
bucket_offset = current_epoch_samples // self.data_parallel_size | |
start_idx = self.data_parallel_rank * bucket_size | |
g = torch.Generator() | |
g.manual_seed(self.epoch) | |
random_idx = torch.randperm(bucket_size, generator=g).tolist() | |
idx_range = [start_idx + x for x in random_idx[bucket_offset:]] | |
batch = [] | |
# Last batch if not complete will be dropped. | |
for idx in idx_range: | |
batch.append(idx) | |
if len(batch) == self.micro_batch_size: | |
self.consumed_samples += self.micro_batch_times_data_parallel_size | |
yield batch | |
batch = [] | |
def set_epoch(self, epoch): | |
self.epoch = epoch | |