Spaces:
Runtime error
Runtime error
File size: 6,045 Bytes
35e5468 9eae6e7 bd11a0f 9eae6e7 66475ee 5e41805 9eae6e7 4530503 9eae6e7 4530503 9eae6e7 4530503 9eae6e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import subprocess
subprocess.run('sh setup.sh', shell=True)
print("Installed the dependencies!")
from typing import Tuple
import dnnlib
from PIL import Image
import numpy as np
import torch
import legacy
import paddlehub as hub
import cv2
u2net = hub.Module(name='U2Net')
# gradio app imports
import gradio as gr
from torchvision.transforms import ToTensor, ToPILImage
image_to_tensor = ToTensor()
tensor_to_image = ToPILImage()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_idx = None
truncation_psi = 0.1
def create_model(network_pkl):
print('Loading networks from "%s"...' % network_pkl)
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'] # type: ignore
G = G.eval().to(device)
netG_params = sum(p.numel() for p in G.parameters())
print("Generator Params: {} M".format(netG_params/1e6))
return G
def fcf_inpaint(G, org_img, erased_img, mask):
label = torch.zeros([1, G.c_dim], device=device)
if G.c_dim != 0:
if class_idx is None:
ValueError("class_idx can't be None.")
label[:, class_idx] = 1
else:
if class_idx is not None:
print ('warn: --class=lbl ignored when running on an unconditional network')
pred_img = G(img=torch.cat([0.5 - mask, erased_img], dim=1), c=label, truncation_psi=truncation_psi, noise_mode='const')
comp_img = mask.to(device) * pred_img + (1 - mask).to(device) * org_img.to(device)
return comp_img
def show_images(img):
""" Display a batch of images inline. """
return Image.fromarray(img)
def denorm(img):
img = np.asarray(img[0].cpu(), dtype=np.float32).transpose(1, 2, 0)
img = (img +1) * 127.5
img = np.rint(img).clip(0, 255).astype(np.uint8)
return img
def pil_to_numpy(pil_img: Image) -> Tuple[torch.Tensor, torch.Tensor]:
img = np.array(pil_img)
return torch.from_numpy(img)[None].permute(0, 3, 1, 2).float() / 127.5 - 1
def inpaint(input_img, mask, option):
width, height = input_img.size
if option == "Automatic":
result = u2net.Segmentation(
images=[cv2.cvtColor(np.array(input_img), cv2.COLOR_RGB2BGR)],
paths=None,
batch_size=1,
input_size=320,
output_dir='output',
visualization=True)
mask = Image.fromarray(result[0]['mask'])
else:
mask = mask.resize((width,height))
if width is not 512 or height is not 512:
input_img = input_img.resize((512, 512))
mask = mask.resize((512, 512))
mask = mask.convert('L')
mask = np.array(mask) / 255.
mask_tensor = torch.from_numpy(mask).to(torch.float32)
mask_tensor = mask_tensor.unsqueeze(0)
mask_tensor = mask_tensor.unsqueeze(0).to(device)
rgb = input_img.convert('RGB')
rgb = np.array(rgb)
rgb = rgb.transpose(2,0,1)
rgb = torch.from_numpy(rgb.astype(np.float32)).unsqueeze(0)
rgb = (rgb.to(torch.float32) / 127.5 - 1).to(device)
rgb_erased = rgb.clone()
rgb_erased = rgb_erased * (1 - mask_tensor) # erase rgb
rgb_erased = rgb_erased.to(torch.float32)
model = create_model("models/places_512.pkl")
comp_img = fcf_inpaint(G=model, org_img=rgb.to(torch.float32), erased_img=rgb_erased.to(torch.float32), mask=mask_tensor.to(torch.float32))
rgb_erased = denorm(rgb_erased)
comp_img = denorm(comp_img)
return show_images(rgb_erased), show_images(comp_img)
gradio_inputs = [gr.inputs.Image(type='pil',
tool=None,
label="Input Image"),
gr.inputs.Image(type='pil',source="canvas", label="Mask", invert_colors=True),
gr.inputs.Radio(choices=["Automatic", "Manual"], type="value", default="Manual", label="Masking Choice")
]
gradio_outputs = [gr.outputs.Image(label='Image with Hole'),
gr.outputs.Image(label='Inpainted Image')]
examples = [['test_512/person512.png', 'test_512/mask_auto.png', 'Automatic'],
['test_512/a_org.png', 'test_512/a_mask.png', 'Manual'],
['test_512/c_org.png', 'test_512/b_mask.png', 'Manual'],
['test_512/b_org.png', 'test_512/c_mask.png', 'Manual'],
['test_512/d_org.png', 'test_512/d_mask.png', 'Manual'],
['test_512/e_org.png', 'test_512/e_mask.png', 'Manual'],
['test_512/f_org.png', 'test_512/f_mask.png', 'Manual'],
['test_512/g_org.png', 'test_512/g_mask.png', 'Manual'],
['test_512/h_org.png', 'test_512/h_mask.png', 'Manual'],
['test_512/i_org.png', 'test_512/i_mask.png', 'Manual']]
title = "FcF-Inpainting"
description = "[Note: Queue time may take upto 20 seconds! The image and mask are resized to 512x512 before inpainting.] To use FcF-Inpainting: \n \
(1) Upload an Image; \n \
(2) Draw (Manual) a Mask on the White Canvas or Generate a mask using U2Net by selecting the Automatic option; \n \
(3) Click on Submit and witness the MAGIC! 🪄 ✨ ✨"
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10741' target='_blank'> Keys to Better Image Inpainting: Structure and Texture Go Hand in Hand</a> | <a href='https://github.com/SHI-Labs/FcF-Inpainting' target='_blank'>Github Repo</a></p>"
css = ".image-preview {height: 32rem; width: auto;} .output-image {height: 32rem; width: auto;} .panel-buttons { display: flex; flex-direction: row;}"
iface = gr.Interface(fn=inpaint, inputs=gradio_inputs,
outputs=gradio_outputs,
css=css,
layout="vertical",
examples_per_page=5,
thumbnail="fcf_gan.png",
allow_flagging="never",
examples=examples, title=title,
description=description, article=article)
iface.launch(enable_queue=True,
share=True, server_name="0.0.0.0") |