mazpie's picture
Initial commit
2d9a728
raw
history blame
4.7 kB
import torch
import torch.distributed as dist
from utils.distributed import get_rank, is_dist_avail_and_initialized, is_main_process
import random
import logging
logger = logging.getLogger(__name__)
class MetaLoader(object):
""" wraps multiple data loader """
def __init__(self, name2loader):
"""Iterates over multiple dataloaders, it ensures all processes
work on data from the same dataloader. This loader will end when
the shorter dataloader raises StopIteration exception.
loaders: Dict, {name: dataloader}
"""
self.name2loader = name2loader
self.name2iter = {name: iter(l) for name, l in name2loader.items()}
name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())}
index2name = {v: k for k, v in name2index.items()}
iter_order = []
for n, l in name2loader.items():
iter_order.extend([name2index[n]]*len(l))
random.shuffle(iter_order)
iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8)
# sync
if is_dist_avail_and_initialized():
# make sure all processes have the same order so that
# each step they will have data from the same loader
dist.broadcast(iter_order, src=0)
self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()]
logger.info(str(self))
def __str__(self):
output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"]
for idx, (name, loader) in enumerate(self.name2loader.items()):
output.append(
f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={len(loader)} "
)
return "\n".join(output)
def __len__(self):
return len(self.iter_order)
def __iter__(self):
""" this iterator will run indefinitely """
for name in self.iter_order:
_iter = self.name2iter[name]
batch = next(_iter)
yield name, batch
class MetaLoader_rs(object):
""" wraps multiple data loader """
def __init__(self, name2loader, skip_num=0):
"""Iterates over multiple dataloaders, it ensures all processes
work on data from the same dataloader. This loader will end when
the shorter dataloader raises StopIteration exception.
loaders: Dict, {name: dataloader}
"""
self.name2loader = name2loader
name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())}
index2name = {v: k for k, v in name2index.items()}
iter_order = []
for n, l in name2loader.items():
iter_order.extend([name2index[n]]*len(l))
random.shuffle(iter_order)
iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8)
# sync
if is_dist_avail_and_initialized():
# make sure all processes have the same order so that
# each step they will have data from the same loader
dist.broadcast(iter_order, src=0)
if skip_num > 0:
iter_order_skip = iter_order[:skip_num]
for k, v in index2name.items():
media_step = (iter_order_skip == k).sum().item()
name2loader[v].sampler.set_start_iter(media_step)
logger.info(f"{v} dataloder skip steps: {media_step}")
iter_order = iter_order[skip_num:]
self.name2loader = name2loader
else:
logger.info("Do not skip steps for any dataloader!")
for k, v in index2name.items():
name2loader[v].sampler.set_start_iter(0)
self.name2iter = {name: iter(l) for name, l in name2loader.items()}
self.iter_idx = iter_order
self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()]
logger.info(str(self))
def __str__(self):
output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"]
for idx, (name, loader) in enumerate(self.name2loader.items()):
length = (self.iter_idx == idx).sum()
output.append(
f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={length} "
)
return "\n".join(output)
def __len__(self):
return len(self.iter_order)
def __iter__(self):
""" this iterator will run indefinitely """
for name in self.iter_order:
_iter = self.name2iter[name]
batch = next(_iter)
yield name, batch