""" reference: https://github.com/xuebinqin/DIS """ import PIL.Image import numpy as np import torch import torch.nn.functional as F from PIL import Image from torch import nn from torch.autograd import Variable from torchvision import transforms from torchvision.transforms.functional import normalize from .models import ISNetDIS # Helpers device = 'cuda' if torch.cuda.is_available() else 'cpu' class GOSNormalize(object): """ Normalize the Image using torch.transforms """ def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): self.mean = mean self.std = std def __call__(self, image): image = normalize(image, self.mean, self.std) return image def im_preprocess(im, size): if len(im.shape) < 3: im = im[:, :, np.newaxis] if im.shape[2] == 1: im = np.repeat(im, 3, axis=2) im_tensor = torch.tensor(im.copy(), dtype=torch.float32) im_tensor = torch.transpose(torch.transpose(im_tensor, 1, 2), 0, 1) if len(size) < 2: return im_tensor, im.shape[0:2] else: im_tensor = torch.unsqueeze(im_tensor, 0) im_tensor = F.upsample(im_tensor, size, mode="bilinear") im_tensor = torch.squeeze(im_tensor, 0) return im_tensor.type(torch.uint8), im.shape[0:2] class IsNetPipeLine: def __init__(self, model_path=None, model_digit="full"): self.model_digit = model_digit self.model = ISNetDIS() self.cache_size = [1024, 1024] self.transform = transforms.Compose([ GOSNormalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) ]) # Build Model self.build_model(model_path) def load_image(self, image: PIL.Image.Image): im = np.array(image.convert("RGB")) im, im_shp = im_preprocess(im, self.cache_size) im = torch.divide(im, 255.0) shape = torch.from_numpy(np.array(im_shp)) return self.transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape def build_model(self, model_path=None): if model_path is not None: self.model.load_state_dict(torch.load(model_path, map_location=device)) # convert to half precision if self.model_digit == "half": self.model.half() for layer in self.model.modules(): if isinstance(layer, nn.BatchNorm2d): layer.float() self.model.to(device) self.model.eval() def __call__(self, image: PIL.Image.Image): image_tensor, orig_size = self.load_image(image) mask = self.predict(image_tensor, orig_size) pil_mask = Image.fromarray(mask).convert('L') im_rgb = image.convert("RGB") im_rgba = im_rgb.copy() im_rgba.putalpha(pil_mask) return [im_rgba, pil_mask] def predict(self, inputs_val: torch.Tensor, shapes_val): """ Given an Image, predict the mask """ if self.model_digit == "full": inputs_val = inputs_val.type(torch.FloatTensor) else: inputs_val = inputs_val.type(torch.HalfTensor) inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable ds_val = self.model(inputs_val_v)[0] # list of 6 results # B x 1 x H x W # we want the first one which is the most accurate prediction pred_val = ds_val[0][0, :, :, :] # recover the prediction spatial size to the orignal image size pred_val = torch.squeeze( F.upsample(torch.unsqueeze(pred_val, 0), (shapes_val[0][0], shapes_val[0][1]), mode='bilinear')) ma = torch.max(pred_val) mi = torch.min(pred_val) pred_val = (pred_val - mi) / (ma - mi) # max = 1 if device == 'cuda': torch.cuda.empty_cache() return (pred_val.detach().cpu().numpy() * 255).astype(np.uint8) # it is the mask we need # a = IsNetPipeLine(model_path="save_models/isnet.pth") # input_image = Image.open("image_0mx.png") # rgb, mask = a(input_image) # # rgb.save("rgb.png") # mask.save("mask.png")