import copy import glob import os from multiprocessing.dummy import Pool as ThreadPool from PIL import Image from torchvision.transforms.functional import to_tensor from ..Models import * class ImageSplitter: # key points: # Boarder padding and over-lapping img splitting to avoid the instability of edge value # Thanks Waifu2x's autorh nagadomi for suggestions (https://github.com/nagadomi/waifu2x/issues/238) def __init__(self, seg_size=48, scale_factor=2, boarder_pad_size=3): self.seg_size = seg_size self.scale_factor = scale_factor self.pad_size = boarder_pad_size self.height = 0 self.width = 0 self.upsampler = nn.Upsample(scale_factor=scale_factor, mode="bilinear") def split_img_tensor(self, pil_img, scale_method=Image.BILINEAR, img_pad=0): # resize image and convert them into tensor img_tensor = to_tensor(pil_img).unsqueeze(0) img_tensor = nn.ReplicationPad2d(self.pad_size)(img_tensor) batch, channel, height, width = img_tensor.size() self.height = height self.width = width if scale_method is not None: img_up = pil_img.resize( (2 * pil_img.size[0], 2 * pil_img.size[1]), scale_method ) img_up = to_tensor(img_up).unsqueeze(0) img_up = nn.ReplicationPad2d(self.pad_size * self.scale_factor)(img_up) patch_box = [] # avoid the residual part is smaller than the padded size if ( height % self.seg_size < self.pad_size or width % self.seg_size < self.pad_size ): self.seg_size += self.scale_factor * self.pad_size # split image into over-lapping pieces for i in range(self.pad_size, height, self.seg_size): for j in range(self.pad_size, width, self.seg_size): part = img_tensor[ :, :, (i - self.pad_size) : min( i + self.pad_size + self.seg_size, height ), (j - self.pad_size) : min(j + self.pad_size + self.seg_size, width), ] if img_pad > 0: part = nn.ZeroPad2d(img_pad)(part) if scale_method is not None: # part_up = self.upsampler(part) part_up = img_up[ :, :, self.scale_factor * (i - self.pad_size) : min( i + self.pad_size + self.seg_size, height ) * self.scale_factor, self.scale_factor * (j - self.pad_size) : min( j + self.pad_size + self.seg_size, width ) * self.scale_factor, ] patch_box.append((part, part_up)) else: patch_box.append(part) return patch_box def merge_img_tensor(self, list_img_tensor): out = torch.zeros( (1, 3, self.height * self.scale_factor, self.width * self.scale_factor) ) img_tensors = copy.copy(list_img_tensor) rem = self.pad_size * 2 pad_size = self.scale_factor * self.pad_size seg_size = self.scale_factor * self.seg_size height = self.scale_factor * self.height width = self.scale_factor * self.width for i in range(pad_size, height, seg_size): for j in range(pad_size, width, seg_size): part = img_tensors.pop(0) part = part[:, :, rem:-rem, rem:-rem] # might have error if len(part.size()) > 3: _, _, p_h, p_w = part.size() out[:, :, i : i + p_h, j : j + p_w] = part # out[:,:, # self.scale_factor*i:self.scale_factor*i+p_h, # self.scale_factor*j:self.scale_factor*j+p_w] = part out = out[:, :, rem:-rem, rem:-rem] return out def load_single_image( img_file, up_scale=False, up_scale_factor=2, up_scale_method=Image.BILINEAR, zero_padding=False, ): img = Image.open(img_file).convert("RGB") out = to_tensor(img).unsqueeze(0) if zero_padding: out = nn.ZeroPad2d(zero_padding)(out) if up_scale: size = tuple(map(lambda x: x * up_scale_factor, img.size)) img_up = img.resize(size, up_scale_method) img_up = to_tensor(img_up).unsqueeze(0) out = (out, img_up) return out def standardize_img_format(img_folder): def process(img_file): img_path = os.path.dirname(img_file) img_name, _ = os.path.basename(img_file).split(".") out = os.path.join(img_path, img_name + ".JPEG") os.rename(img_file, out) list_imgs = [] for i in ["png", "jpeg", "jpg"]: list_imgs.extend(glob.glob(img_folder + "**/*." + i, recursive=True)) print("Found {} images.".format(len(list_imgs))) pool = ThreadPool(4) pool.map(process, list_imgs) pool.close() pool.join()