EditGuard / data /__init__.py
Ricoooo's picture
'folder'
5d21dd2
raw
history blame
No virus
1.78 kB
'''create dataset and dataloader'''
import logging
import torch
import torch.utils.data
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
phase = dataset_opt['phase']
if phase == 'train':
if opt['dist']:
world_size = torch.distributed.get_world_size()
num_workers = dataset_opt['n_workers']
assert dataset_opt['batch_size'] % world_size == 0
batch_size = dataset_opt['batch_size'] // world_size
shuffle = False
else:
num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids'])
batch_size = dataset_opt['batch_size']
shuffle = True
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
num_workers=num_workers, sampler=sampler, drop_last=True,
pin_memory=False)
else:
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1,
pin_memory=True)
def create_dataset(dataset_opt):
mode = dataset_opt['mode']
if mode == 'test':
from data.coco_test_dataset import imageTestDataset as D
elif mode == 'train':
from data.coco_dataset import CoCoDataset as D
elif mode == 'td':
from data.test_dataset_td import imageTestDataset as D
else:
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
print(mode)
dataset = D(dataset_opt)
logger = logging.getLogger('base')
logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
dataset_opt['name']))
return dataset