zhengrongzhang's picture
init model
faac7d4
raw
history blame contribute delete
No virus
3.38 kB
import random
import numpy as np
from PIL import Image, ImageOps, ImageFilter
import torch
import torch.utils.data as data
__all__ = ['BaseDataset']
class BaseDataset(data.Dataset):
def __init__(self, root, split, mode=None, transform=None,
target_transform=None, base_size=1024, crop_size=512):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.split = split
self.mode = mode if mode is not None else split
self.base_size = base_size
self.crop_size = crop_size
if self.mode == 'train':
print('BaseDataset: base_size {}, crop_size {}'. \
format(base_size, crop_size))
@property
def num_class(self):
return self.NUM_CLASS
def _val_transform(self, img, mask):
outsize = self.crop_size
short_size = outsize
w, h = img.size
if w > h:
oh = short_size
ow = int(1.0 * w * oh / h)
else:
ow = short_size
oh = int(1.0 * h * ow / w)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# center crop
w, h = img.size
x1 = int(round((w - outsize) / 2.))
y1 = int(round((h - outsize) / 2.))
img = img.crop((x1, y1, x1+outsize, y1+outsize))
mask = mask.crop((x1, y1, x1+outsize, y1+outsize))
# final transform
return img, self._mask_transform(mask)
def _testval_transform(self, img, mask):
outsize = self.crop_size
short_size = outsize
w, h = img.size
if w > h:
oh = short_size
ow = int(1.0 * w * oh / h)
else:
ow = short_size
oh = int(1.0 * h * ow / w)
img = img.resize((ow, oh), Image.BILINEAR)
return img, self._mask_transform(mask)
def _train_transform(self, img, mask):
# random mirror
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
crop_size = self.crop_size
w, h = img.size
long_size = random.randint(int(self.base_size*0.5), int(self.base_size*2.0))
if h > w:
oh = long_size
ow = int(1.0 * w * long_size / h + 0.5)
short_size = ow
else:
ow = long_size
oh = int(1.0 * h * long_size / w + 0.5)
short_size = oh
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# pad crop
if short_size < crop_size:
padh = crop_size - oh if oh < crop_size else 0
padw = crop_size - ow if ow < crop_size else 0
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
# random crop crop_size
w, h = img.size
x1 = random.randint(0, w - crop_size)
y1 = random.randint(0, h - crop_size)
img = img.crop((x1, y1, x1+crop_size, y1+crop_size))
mask = mask.crop((x1, y1, x1+crop_size, y1+crop_size))
# final transform
return img, self._mask_transform(mask)
def _mask_transform(self, mask):
return torch.from_numpy(np.array(mask)).long()