import numpy as np import torchvision.transforms as transforms from torch.utils.data import DataLoader, random_split from torch.utils.data.sampler import SubsetRandomSampler from utils import * from .atlas import Atlas from .brat import Brat from .ddti import DDTI from .isic import ISIC2016 from .kits import KITS from .lidc import LIDC from .lnq import LNQ from .pendal import Pendal from .refuge import REFUGE from .segrap import SegRap from .stare import STARE from .toothfairy import ToothFairy from .wbc import WBC def get_dataloader(args): transform_train = transforms.Compose([ transforms.Resize((args.image_size,args.image_size)), transforms.ToTensor(), ]) transform_train_seg = transforms.Compose([ transforms.Resize((args.out_size,args.out_size)), transforms.ToTensor(), ]) transform_test = transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), ]) transform_test_seg = transforms.Compose([ transforms.Resize((args.out_size,args.out_size)), transforms.ToTensor(), ]) if args.dataset == 'isic': '''isic data''' isic_train_dataset = ISIC2016(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'Training') isic_test_dataset = ISIC2016(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'Test') nice_train_loader = DataLoader(isic_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True) nice_test_loader = DataLoader(isic_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True) '''end''' elif args.dataset == 'decathlon': nice_train_loader, nice_test_loader, transform_train, transform_val, train_list, val_list = get_decath_loader(args) elif args.dataset == 'REFUGE': '''REFUGE data''' refuge_train_dataset = REFUGE(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'Training') refuge_test_dataset = REFUGE(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'Test') nice_train_loader = DataLoader(refuge_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True) nice_test_loader = DataLoader(refuge_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True) '''end''' elif args.dataset == 'LIDC': '''LIDC data''' # dataset = LIDC(data_path = args.data_path) dataset = MyLIDC(args, data_path = args.data_path,transform = transform_train, transform_msk= transform_train_seg) dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(np.floor(0.2 * dataset_size)) np.random.shuffle(indices) train_sampler = SubsetRandomSampler(indices[split:]) test_sampler = SubsetRandomSampler(indices[:split]) nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True) nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True) '''end''' elif args.dataset == 'DDTI': '''DDTI data''' refuge_train_dataset = DDTI(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'Training') refuge_test_dataset = DDTI(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'Test') nice_train_loader = DataLoader(refuge_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True) nice_test_loader = DataLoader(refuge_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True) '''end''' elif args.dataset == 'Brat': '''Brat data''' dataset = Brat(args, data_path = args.data_path,transform = transform_train, transform_msk= transform_train_seg) dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(np.floor(0.3 * dataset_size)) np.random.shuffle(indices) train_sampler = SubsetRandomSampler(indices[split:]) test_sampler = SubsetRandomSampler(indices[:split]) nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True) nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True) '''end''' elif args.dataset == 'STARE': '''STARE data''' # dataset = LIDC(data_path = args.data_path) dataset = STARE(args, data_path = args.data_path, transform = transform_train, transform_msk= transform_train_seg) dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(np.floor(0.2 * dataset_size)) np.random.shuffle(indices) train_sampler = SubsetRandomSampler(indices[split:]) test_sampler = SubsetRandomSampler(indices[:split]) nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True) nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True) '''end''' elif args.dataset == 'kits': '''kits data''' dataset = KITS(args, data_path = args.data_path,transform = transform_train, transform_msk= transform_train_seg) dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(np.floor(0.3 * dataset_size)) np.random.shuffle(indices) train_sampler = SubsetRandomSampler(indices[split:]) test_sampler = SubsetRandomSampler(indices[:split]) nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True) nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True) '''end''' elif args.dataset == 'WBC': '''WBC data''' dataset = WBC(args, data_path = args.data_path,transform = transform_train, transform_msk= transform_train_seg) dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(np.floor(0.3 * dataset_size)) np.random.shuffle(indices) train_sampler = SubsetRandomSampler(indices[split:]) test_sampler = SubsetRandomSampler(indices[:split]) nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True) nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True) '''end''' elif args.dataset == 'segrap': '''segrap data''' dataset = SegRap(args, data_path = args.data_path,transform = transform_train, transform_msk= transform_train_seg) dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(np.floor(0.3 * dataset_size)) np.random.shuffle(indices) train_sampler = SubsetRandomSampler(indices[split:]) test_sampler = SubsetRandomSampler(indices[:split]) nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True) nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True) '''end''' elif args.dataset == 'toothfairy': '''toothfairy data''' dataset = ToothFairy(args, data_path = args.data_path,transform = transform_train, transform_msk= transform_train_seg) dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(np.floor(0.3 * dataset_size)) np.random.shuffle(indices) train_sampler = SubsetRandomSampler(indices[split:]) test_sampler = SubsetRandomSampler(indices[:split]) nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True) nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True) '''end''' elif args.dataset == 'atlas': '''atlas data''' dataset = Atlas(args, data_path = args.data_path,transform = transform_train, transform_msk= transform_train_seg) dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(np.floor(0.3 * dataset_size)) np.random.shuffle(indices) train_sampler = SubsetRandomSampler(indices[split:]) test_sampler = SubsetRandomSampler(indices[:split]) nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True) nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True) '''end''' elif args.dataset == 'pendal': '''pendal data''' dataset = Pendal(args, data_path = args.data_path,transform = transform_train, transform_msk= transform_train_seg) dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(np.floor(0.3 * dataset_size)) np.random.shuffle(indices) train_sampler = SubsetRandomSampler(indices[split:]) test_sampler = SubsetRandomSampler(indices[:split]) nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True) nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True) '''end''' elif args.dataset == 'lnq': '''lnq data''' dataset = LNQ(args, data_path = args.data_path,transform = transform_train, transform_msk= transform_train_seg) dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(np.floor(0.3 * dataset_size)) np.random.shuffle(indices) train_sampler = SubsetRandomSampler(indices[split:]) test_sampler = SubsetRandomSampler(indices[:split]) nice_train_loader = DataLoader(dataset, batch_size=args.b, sampler=train_sampler, num_workers=8, pin_memory=True) nice_test_loader = DataLoader(dataset, batch_size=args.b, sampler=test_sampler, num_workers=8, pin_memory=True) '''end''' else: print("the dataset is not supported now!!!") return nice_train_loader, nice_test_loader