|
import os |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from skimage import io |
|
from skimage.transform import resize |
|
from torch.utils.data import Dataset |
|
|
|
from saicinpainting.evaluation.evaluator import InpaintingEvaluator |
|
from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore |
|
|
|
|
|
class SimpleImageDataset(Dataset): |
|
def __init__(self, root_dir, image_size=(400, 600)): |
|
self.root_dir = root_dir |
|
self.files = sorted(os.listdir(root_dir)) |
|
self.image_size = image_size |
|
|
|
def __getitem__(self, index): |
|
img_name = os.path.join(self.root_dir, self.files[index]) |
|
image = io.imread(img_name) |
|
image = resize(image, self.image_size, anti_aliasing=True) |
|
image = torch.FloatTensor(image).permute(2, 0, 1) |
|
return image |
|
|
|
def __len__(self): |
|
return len(self.files) |
|
|
|
|
|
def create_rectangle_mask(height, width): |
|
mask = np.ones((height, width)) |
|
up_left_corner = width // 4, height // 4 |
|
down_right_corner = (width - up_left_corner[0] - 1, height - up_left_corner[1] - 1) |
|
cv2.rectangle(mask, up_left_corner, down_right_corner, (0, 0, 0), thickness=cv2.FILLED) |
|
return mask |
|
|
|
|
|
class Model(): |
|
def __call__(self, img_batch, mask_batch): |
|
mean = (img_batch * mask_batch[:, None, :, :]).sum(dim=(2, 3)) / mask_batch.sum(dim=(1, 2))[:, None] |
|
inpainted = mean[:, :, None, None] * (1 - mask_batch[:, None, :, :]) + img_batch * mask_batch[:, None, :, :] |
|
return inpainted |
|
|
|
|
|
class SimpleImageSquareMaskDataset(Dataset): |
|
def __init__(self, dataset): |
|
self.dataset = dataset |
|
self.mask = torch.FloatTensor(create_rectangle_mask(*self.dataset.image_size)) |
|
self.model = Model() |
|
|
|
def __getitem__(self, index): |
|
img = self.dataset[index] |
|
mask = self.mask.clone() |
|
inpainted = self.model(img[None, ...], mask[None, ...]) |
|
return dict(image=img, mask=mask, inpainted=inpainted) |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
|
|
dataset = SimpleImageDataset('imgs') |
|
mask_dataset = SimpleImageSquareMaskDataset(dataset) |
|
model = Model() |
|
metrics = { |
|
'ssim': SSIMScore(), |
|
'lpips': LPIPSScore(), |
|
'fid': FIDScore() |
|
} |
|
|
|
evaluator = InpaintingEvaluator( |
|
mask_dataset, scores=metrics, batch_size=3, area_grouping=True |
|
) |
|
|
|
results = evaluator.evaluate(model) |
|
print(results) |
|
|