|
import sys |
|
import os |
|
import requests |
|
|
|
import torch |
|
import numpy as np |
|
|
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
import gradio as gr |
|
|
|
|
|
os.system("pip install timm==0.4.5") |
|
os.system("git clone https://github.com/facebookresearch/mae.git") |
|
sys.path.append('./mae') |
|
|
|
import models_mae |
|
|
|
|
|
|
|
imagenet_mean = np.array([0.485, 0.456, 0.406]) |
|
imagenet_std = np.array([0.229, 0.224, 0.225]) |
|
|
|
def show_image(image, title=''): |
|
|
|
assert image.shape[2] == 3 |
|
plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()) |
|
plt.title(title, fontsize=16) |
|
plt.axis('off') |
|
return |
|
|
|
def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'): |
|
|
|
model = getattr(models_mae, arch)() |
|
|
|
checkpoint = torch.load(chkpt_dir, map_location='cpu') |
|
msg = model.load_state_dict(checkpoint['model'], strict=False) |
|
print(msg) |
|
return model |
|
|
|
def run_one_image(img, model): |
|
x = torch.tensor(img) |
|
|
|
|
|
x = x.unsqueeze(dim=0) |
|
x = torch.einsum('nhwc->nchw', x) |
|
|
|
|
|
loss, y, mask = model(x.float(), mask_ratio=0.75) |
|
y = model.unpatchify(y) |
|
y = torch.einsum('nchw->nhwc', y).detach().cpu() |
|
|
|
|
|
mask = mask.detach() |
|
mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3) |
|
mask = model.unpatchify(mask) |
|
mask = torch.einsum('nchw->nhwc', mask).detach().cpu() |
|
|
|
x = torch.einsum('nchw->nhwc', x) |
|
|
|
|
|
im_masked = x * (1 - mask) |
|
|
|
|
|
im_paste = x * (1 - mask) + y * mask |
|
|
|
|
|
plt.rcParams['figure.figsize'] = [24, 24] |
|
|
|
plt.subplot(1, 4, 1) |
|
show_image(x[0], "original") |
|
|
|
plt.subplot(1, 4, 2) |
|
show_image(im_masked[0], "masked") |
|
|
|
plt.subplot(1, 4, 3) |
|
show_image(y[0], "reconstruction") |
|
|
|
plt.subplot(1, 4, 4) |
|
show_image(im_paste[0], "reconstruction + visible") |
|
|
|
plt.show() |
|
|
|
|
|
|
|
os.system("wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth") |
|
|
|
chkpt_dir = 'mae_visualize_vit_large.pth' |
|
model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16') |
|
print('Model loaded.') |
|
|
|
|
|
def inference(img): |
|
img = img.resize((224, 224)) |
|
img = np.array(img) / 255. |
|
|
|
assert img.shape == (224, 224, 3) |
|
|
|
|
|
img = img - imagenet_mean |
|
img = img / imagenet_std |
|
|
|
|
|
torch.manual_seed(2) |
|
return run_one_image(img, model_mae) |
|
|
|
|
|
title = "MAE" |
|
description = "Gradio Demo for MAE. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below." |
|
|
|
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center>" |
|
|
|
|
|
gr.Interface(inference, [gr.inputs.Image(type="pil")], gr.outputs.Image(type="plot"),title=title,description=description,article=article,allow_flagging=False,examples=examples,allow_screenshot=False,enable_queue=True).launch() |