Spaces:
Runtime error
Runtime error
File size: 6,972 Bytes
35e5468 9eae6e7 bd11a0f 9eae6e7 9e43b47 5e41805 9e43b47 9eae6e7 4530503 9eae6e7 4530503 9eae6e7 4530503 9eae6e7 659310d 9e43b47 9eae6e7 9e43b47 9eae6e7 659310d 9eae6e7 659310d 9eae6e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import subprocess
subprocess.run('sh setup.sh', shell=True)
print("Installed the dependencies!")
from typing import Tuple
import dnnlib
from PIL import Image
import numpy as np
import torch
import legacy
import paddlehub as hub
import cv2
u2net = hub.Module(name='U2Net')
# gradio app imports
import gradio as gr
from torchvision.transforms import ToTensor, ToPILImage
image_to_tensor = ToTensor()
tensor_to_image = ToPILImage()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_idx = None
truncation_psi = 0.1
def create_model(network_pkl):
print('Loading networks from "%s"...' % network_pkl)
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'] # type: ignore
G = G.eval().to(device)
netG_params = sum(p.numel() for p in G.parameters())
print("Generator Params: {} M".format(netG_params/1e6))
return G
def fcf_inpaint(G, org_img, erased_img, mask):
label = torch.zeros([1, G.c_dim], device=device)
if G.c_dim != 0:
if class_idx is None:
ValueError("class_idx can't be None.")
label[:, class_idx] = 1
else:
if class_idx is not None:
print ('warn: --class=lbl ignored when running on an unconditional network')
pred_img = G(img=torch.cat([0.5 - mask, erased_img], dim=1), c=label, truncation_psi=truncation_psi, noise_mode='const')
comp_img = mask.to(device) * pred_img + (1 - mask).to(device) * org_img.to(device)
return comp_img
def show_images(img):
""" Display a batch of images inline. """
return Image.fromarray(img)
def denorm(img):
img = np.asarray(img[0].cpu(), dtype=np.float32).transpose(1, 2, 0)
img = (img +1) * 127.5
img = np.rint(img).clip(0, 255).astype(np.uint8)
return img
def pil_to_numpy(pil_img: Image) -> Tuple[torch.Tensor, torch.Tensor]:
img = np.array(pil_img)
return torch.from_numpy(img)[None].permute(0, 3, 1, 2).float() / 127.5 - 1
def inpaint(input_img, mask, option):
width, height = input_img.size
if option == "Automatic":
result = u2net.Segmentation(
images=[cv2.cvtColor(np.array(input_img), cv2.COLOR_RGB2BGR)],
paths=None,
batch_size=1,
input_size=320,
output_dir='output',
visualization=True)
mask = Image.fromarray(result[0]['mask'])
else:
mask = mask.resize((width,height))
mask = mask.convert('L')
if width != 512 or height != 512:
input_img = input_img.resize((512, 512))
mask = mask.resize((512, 512))
if option == 'Manual':
mask = (np.array(mask) - np.array(input_img.convert('L'))) > 0.
mask = mask * 1.
kernel = np.ones((5, 5), np.uint8)
mask = cv2.dilate(mask, kernel)
mask = mask * 255.
mask = np.array(mask) / 255.
mask_tensor = torch.from_numpy(mask).to(torch.float32)
mask_tensor = mask_tensor.unsqueeze(0)
mask_tensor = mask_tensor.unsqueeze(0).to(device)
rgb = input_img.convert('RGB')
rgb = np.array(rgb)
rgb = rgb.transpose(2,0,1)
rgb = torch.from_numpy(rgb.astype(np.float32)).unsqueeze(0)
rgb = (rgb.to(torch.float32) / 127.5 - 1).to(device)
rgb_erased = rgb.clone()
rgb_erased = rgb_erased * (1 - mask_tensor) # erase rgb
rgb_erased = rgb_erased.to(torch.float32)
model = create_model("models/places_512.pkl")
comp_img = fcf_inpaint(G=model, org_img=rgb.to(torch.float32), erased_img=rgb_erased.to(torch.float32), mask=mask_tensor.to(torch.float32))
rgb_erased = denorm(rgb_erased)
comp_img = denorm(comp_img)
return show_images(rgb_erased), show_images(comp_img)
gradio_inputs = [gr.inputs.Image(type='pil',
tool=None,
label="Image"),
# gr.inputs.Image(type='pil',source="canvas", label="Mask", invert_colors=True),
gr.inputs.Image(type='pil',
tool="editor",
label="Mask"),
gr.inputs.Radio(choices=["Automatic", "Manual"], type="value", default="Manual", label="Masking Choice")
]
gradio_outputs = [gr.outputs.Image(label='Image with Hole'),
gr.outputs.Image(label='Inpainted Image')]
examples = [['test_512/person512.png', 'test_512/person512.png', 'Automatic'],
['test_512/a_org.png', 'test_512/a_overlay.png', 'Manual'],
['test_512/b_org.png', 'test_512/b_overlay.png', 'Manual'],
['test_512/c_org.png', 'test_512/c_overlay.png', 'Manual'],
['test_512/d_org.png', 'test_512/d_overlay.png', 'Manual'],
['test_512/e_org.png', 'test_512/e_overlay.png', 'Manual'],
['test_512/f_org.png', 'test_512/f_overlay.png', 'Manual'],
['test_512/g_org.png', 'test_512/g_overlay.png', 'Manual'],
['test_512/h_org.png', 'test_512/h_overlay.png', 'Manual'],
['test_512/i_org.png', 'test_512/i_overlay.png', 'Manual']]
title = "FcF-Inpainting"
description = "<p style='color:royalblue; font-weight: w300;'> \
[Note: Queue time may take upto 20 seconds! The image and mask are resized to 512x512 before inpainting.] To use FcF-Inpainting: <br> \
(1) <span style='color:#E0B941;'>Upload </span> an Image to <span style='color:#E0B941;'>both</span> input boxes (Image and Mask) below. <br> \
(2a) <span style='color:#E0B941;'>Manual Option:</span> Draw a mask (hole) using the brush (click on the edit button in the top right of the Mask View and select draw option). <br> \
(2b) <span style='color:#E0B941;'>Automatic Option:</span> This option will generate a mask using a pretrained U2Net model. <br> \
(3) Click on Submit and witness the MAGIC! 🪄 ✨ ✨</p>"
article = "<p style='color: #E0B941; text-align: center'><a style='color: #E0B941;' href='https://github.com/SHI-Labs/FcF-Inpainting' target='_blank'> Keys to Better Image Inpainting: Structure and Texture Go Hand in Hand</a> | <a style='color: #E0B941;' href='https://github.com/SHI-Labs/FcF-Inpainting' target='_blank'>Github Repo</a></p>"
css = ".image-preview {height: 32rem; width: auto;} .output-image {height: 32rem; width: auto;} .panel-buttons { display: flex; flex-direction: row;}"
iface = gr.Interface(fn=inpaint, inputs=gradio_inputs,
outputs=gradio_outputs,
css=css,
layout="vertical",
theme="dark-huggingface",
examples_per_page=5,
thumbnail="fcf_gan.png",
allow_flagging="never",
examples=examples, title=title,
description=description, article=article)
iface.launch(enable_queue=True,
share=True, server_name="0.0.0.0") |