Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from PIL import Image | |
import cv2 | |
import torch | |
from torch.utils import data | |
from torchvision import transforms | |
from torchvision.transforms import functional as F | |
import numbers | |
import numpy as np | |
import random | |
#re_size = (256, 256) | |
#cr_size = (224, 224) | |
class ImageDataTrain(data.Dataset): | |
def __init__(self): | |
self.sal_root = '/home/liuj/dataset/DUTS/DUTS-TR' | |
self.sal_source = '/home/liuj/dataset/DUTS/DUTS-TR/train_pair_edge.lst' | |
with open(self.sal_source, 'r') as f: | |
self.sal_list = [x.strip() for x in f.readlines()] | |
self.sal_num = len(self.sal_list) | |
def __getitem__(self, item): | |
sal_image = load_image(os.path.join(self.sal_root, self.sal_list[item%self.sal_num].split()[0])) | |
sal_label = load_sal_label(os.path.join(self.sal_root, self.sal_list[item%self.sal_num].split()[1])) | |
sal_edge = load_edge_label(os.path.join(self.sal_root, self.sal_list[item%self.sal_num].split()[2])) | |
sal_image, sal_label, sal_edge = cv_random_flip(sal_image, sal_label, sal_edge) | |
sal_image = torch.Tensor(sal_image) | |
sal_label = torch.Tensor(sal_label) | |
sal_edge = torch.Tensor(sal_edge) | |
sample = {'sal_image': sal_image, 'sal_label': sal_label, 'sal_edge': sal_edge} | |
return sample | |
def __len__(self): | |
# return max(max(self.edge_num, self.sal_num), self.skel_num) | |
return self.sal_num | |
class ImageDataTest(data.Dataset): | |
def __init__(self, test_mode=1, sal_mode='e'): | |
if test_mode == 0: | |
# self.image_root = '/home/liuj/dataset/saliency_test/ECSSD/Imgs/' | |
# self.image_source = '/home/liuj/dataset/saliency_test/ECSSD/test.lst' | |
self.image_root = '/home/liuj/dataset/HED-BSDS_PASCAL/HED-BSDS/test/' | |
self.image_source = '/home/liuj/dataset/HED-BSDS_PASCAL/HED-BSDS/test.lst' | |
elif test_mode == 1: | |
if sal_mode == 'e': | |
self.image_root = '/home/liuj/dataset/saliency_test/ECSSD/Imgs/' | |
self.image_source = '/home/liuj/dataset/saliency_test/ECSSD/test.lst' | |
self.test_fold = '/media/ubuntu/disk/Result/saliency/ECSSD/' | |
elif sal_mode == 'p': | |
self.image_root = '/home/liuj/dataset/saliency_test/PASCALS/Imgs/' | |
self.image_source = '/home/liuj/dataset/saliency_test/PASCALS/test.lst' | |
self.test_fold = '/media/ubuntu/disk/Result/saliency/PASCALS/' | |
elif sal_mode == 'd': | |
self.image_root = '/home/liuj/dataset/saliency_test/DUTOMRON/Imgs/' | |
self.image_source = '/home/liuj/dataset/saliency_test/DUTOMRON/test.lst' | |
self.test_fold = '/media/ubuntu/disk/Result/saliency/DUTOMRON/' | |
elif sal_mode == 'h': | |
self.image_root = '/home/liuj/dataset/saliency_test/HKU-IS/Imgs/' | |
self.image_source = '/home/liuj/dataset/saliency_test/HKU-IS/test.lst' | |
self.test_fold = '/media/ubuntu/disk/Result/saliency/HKU-IS/' | |
elif sal_mode == 's': | |
self.image_root = '/home/liuj/dataset/saliency_test/SOD/Imgs/' | |
self.image_source = '/home/liuj/dataset/saliency_test/SOD/test.lst' | |
self.test_fold = '/media/ubuntu/disk/Result/saliency/SOD/' | |
elif sal_mode == 'm': | |
self.image_root = '/home/liuj/dataset/saliency_test/MSRA/Imgs/' | |
self.image_source = '/home/liuj/dataset/saliency_test/MSRA/test.lst' | |
elif sal_mode == 'o': | |
self.image_root = '/home/liuj/dataset/saliency_test/SOC/TestSet/Imgs/' | |
self.image_source = '/home/liuj/dataset/saliency_test/SOC/TestSet/test.lst' | |
self.test_fold = '/media/ubuntu/disk/Result/saliency/SOC/' | |
elif sal_mode == 't': | |
self.image_root = '/home/liuj/dataset/DUTS/DUTS-TE/DUTS-TE-Image/' | |
self.image_source = '/home/liuj/dataset/DUTS/DUTS-TE/test.lst' | |
self.test_fold = '/media/ubuntu/disk/Result/saliency/DUTS/' | |
elif test_mode == 2: | |
self.image_root = '/home/liuj/dataset/SK-LARGE/images/test/' | |
self.image_source = '/home/liuj/dataset/SK-LARGE/test.lst' | |
with open(self.image_source, 'r') as f: | |
self.image_list = [x.strip() for x in f.readlines()] | |
self.image_num = len(self.image_list) | |
def __getitem__(self, item): | |
image, im_size = load_image_test(os.path.join(self.image_root, self.image_list[item])) | |
image = torch.Tensor(image) | |
return {'image': image, 'name': self.image_list[item%self.image_num], 'size': im_size} | |
def save_folder(self): | |
return self.test_fold | |
def __len__(self): | |
# return max(max(self.edge_num, self.skel_num), self.sal_num) | |
return self.image_num | |
# get the dataloader (Note: without data augmentation, except saliency with random flip) | |
def get_loader(batch_size, mode='train', num_thread=1, test_mode=0, sal_mode='e'): | |
shuffle = False | |
if mode == 'train': | |
shuffle = True | |
dataset = ImageDataTrain() | |
else: | |
dataset = ImageDataTest(test_mode=test_mode, sal_mode=sal_mode) | |
data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_thread) | |
return data_loader, dataset | |
def load_image(pah): | |
if not os.path.exists(pah): | |
print('File Not Exists') | |
im = cv2.imread(pah) | |
in_ = np.array(im, dtype=np.float32) | |
# in_ = cv2.resize(in_, im_sz, interpolation=cv2.INTER_CUBIC) | |
# in_ = in_[:,:,::-1] # only if use PIL to load image | |
in_ -= np.array((104.00699, 116.66877, 122.67892)) | |
in_ = in_.transpose((2,0,1)) | |
return in_ | |
def load_image_test(pah): | |
if not os.path.exists(pah): | |
print('File Not Exists') | |
im = cv2.imread(pah) | |
in_ = np.array(im, dtype=np.float32) | |
im_size = tuple(in_.shape[:2]) | |
# in_ = cv2.resize(in_, (cr_size[1], cr_size[0]), interpolation=cv2.INTER_LINEAR) | |
# in_ = in_[:,:,::-1] # only if use PIL to load image | |
in_ -= np.array((104.00699, 116.66877, 122.67892)) | |
in_ = in_.transpose((2,0,1)) | |
return in_, im_size | |
def load_edge_label(pah): | |
""" | |
pixels > 0.5 -> 1 | |
Load label image as 1 x height x width integer array of label indices. | |
The leading singleton dimension is required by the loss. | |
""" | |
if not os.path.exists(pah): | |
print('File Not Exists') | |
im = Image.open(pah) | |
label = np.array(im, dtype=np.float32) | |
if len(label.shape) == 3: | |
label = label[:,:,0] | |
# label = cv2.resize(label, im_sz, interpolation=cv2.INTER_NEAREST) | |
label = label / 255. | |
label[np.where(label > 0.5)] = 1. | |
label = label[np.newaxis, ...] | |
return label | |
def load_skel_label(pah): | |
""" | |
pixels > 0 -> 1 | |
Load label image as 1 x height x width integer array of label indices. | |
The leading singleton dimension is required by the loss. | |
""" | |
if not os.path.exists(pah): | |
print('File Not Exists') | |
im = Image.open(pah) | |
label = np.array(im, dtype=np.float32) | |
if len(label.shape) == 3: | |
label = label[:,:,0] | |
# label = cv2.resize(label, im_sz, interpolation=cv2.INTER_NEAREST) | |
label = label / 255. | |
label[np.where(label > 0.)] = 1. | |
label = label[np.newaxis, ...] | |
return label | |
def load_sal_label(pah): | |
""" | |
Load label image as 1 x height x width integer array of label indices. | |
The leading singleton dimension is required by the loss. | |
""" | |
if not os.path.exists(pah): | |
print('File Not Exists') | |
im = Image.open(pah) | |
label = np.array(im, dtype=np.float32) | |
if len(label.shape) == 3: | |
label = label[:,:,0] | |
# label = cv2.resize(label, im_sz, interpolation=cv2.INTER_NEAREST) | |
label = label / 255. | |
label = label[np.newaxis, ...] | |
return label | |
def load_sem_label(pah): | |
""" | |
Load label image as 1 x height x width integer array of label indices. | |
The leading singleton dimension is required by the loss. | |
""" | |
if not os.path.exists(pah): | |
print('File Not Exists') | |
im = Image.open(pah) | |
label = np.array(im, dtype=np.float32) | |
if len(label.shape) == 3: | |
label = label[:,:,0] | |
# label = cv2.resize(label, im_sz, interpolation=cv2.INTER_NEAREST) | |
# label = label / 255. | |
label = label[np.newaxis, ...] | |
return label | |
def edge_thres_transform(x, thres): | |
# y0 = torch.zeros(x.size()) | |
y1 = torch.ones(x.size()) | |
x = torch.where(x >= thres, y1, x) | |
return x | |
def skel_thres_transform(x, thres): | |
y0 = torch.zeros(x.size()) | |
y1 = torch.ones(x.size()) | |
x = torch.where(x > thres, y1, y0) | |
return x | |
def cv_random_flip(img, label, edge): | |
flip_flag = random.randint(0, 1) | |
if flip_flag == 1: | |
img = img[:,:,::-1].copy() | |
label = label[:,:,::-1].copy() | |
edge = edge[:,:,::-1].copy() | |
return img, label, edge | |
def cv_random_crop_flip(img, label, resize_size, crop_size, random_flip=True): | |
def get_params(img_size, output_size): | |
h, w = img_size | |
th, tw = output_size | |
if w == tw and h == th: | |
return 0, 0, h, w | |
i = random.randint(0, h - th) | |
j = random.randint(0, w - tw) | |
return i, j, th, tw | |
if random_flip: | |
flip_flag = random.randint(0, 1) | |
img = img.transpose((1,2,0)) # H, W, C | |
label = label[0,:,:] # H, W | |
img = cv2.resize(img, (resize_size[1], resize_size[0]), interpolation=cv2.INTER_LINEAR) | |
label = cv2.resize(label, (resize_size[1], resize_size[0]), interpolation=cv2.INTER_NEAREST) | |
i, j, h, w = get_params(resize_size, crop_size) | |
img = img[i:i+h, j:j+w, :].transpose((2,0,1)) # C, H, W | |
label = label[i:i+h, j:j+w][np.newaxis, ...] # 1, H, W | |
if flip_flag == 1: | |
img = img[:,:,::-1].copy() | |
label = label[:,:,::-1].copy() | |
return img, label | |
def random_crop(img, label, size, padding=None, pad_if_needed=True, fill_img=(123, 116, 103), fill_label=0, padding_mode='constant'): | |
def get_params(img, output_size): | |
w, h = img.size | |
th, tw = output_size | |
if w == tw and h == th: | |
return 0, 0, h, w | |
i = random.randint(0, h - th) | |
j = random.randint(0, w - tw) | |
return i, j, th, tw | |
if isinstance(size, numbers.Number): | |
size = (int(size), int(size)) | |
if padding is not None: | |
img = F.pad(img, padding, fill_img, padding_mode) | |
label = F.pad(label, padding, fill_label, padding_mode) | |
# pad the width if needed | |
if pad_if_needed and img.size[0] < size[1]: | |
img = F.pad(img, (int((1 + size[1] - img.size[0]) / 2), 0), fill_img, padding_mode) | |
label = F.pad(label, (int((1 + size[1] - label.size[0]) / 2), 0), fill_label, padding_mode) | |
# pad the height if needed | |
if pad_if_needed and img.size[1] < size[0]: | |
img = F.pad(img, (0, int((1 + size[0] - img.size[1]) / 2)), fill_img, padding_mode) | |
label = F.pad(label, (0, int((1 + size[0] - label.size[1]) / 2)), fill_label, padding_mode) | |
i, j, h, w = get_params(img, size) | |
return [F.crop(img, i, j, h, w), F.crop(label, i, j, h, w)] | |