|
import sys |
|
import os |
|
import requests |
|
|
|
import torch |
|
import numpy as np |
|
|
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
import gradio as gr |
|
|
|
|
|
os.system("git clone https://github.com/facebookresearch/mae.git") |
|
sys.path.append('./mae') |
|
|
|
import models_mae |
|
|
|
|
|
|
|
imagenet_mean = np.array([0.485, 0.456, 0.406]) |
|
imagenet_std = np.array([0.229, 0.224, 0.225]) |
|
|
|
def show_image(image, title=''): |
|
|
|
assert image.shape[2] == 3 |
|
plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()) |
|
plt.title(title, fontsize=16) |
|
plt.axis('off') |
|
return |
|
|
|
def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'): |
|
|
|
model = getattr(models_mae, arch)() |
|
|
|
checkpoint = torch.load(chkpt_dir, map_location='cpu') |
|
msg = model.load_state_dict(checkpoint['model'], strict=False) |
|
print(msg) |
|
return model |
|
|
|
def run_one_image(img, model): |
|
x = torch.tensor(img) |
|
|
|
|
|
x = x.unsqueeze(dim=0) |
|
x = torch.einsum('nhwc->nchw', x) |
|
|
|
|
|
loss, y, mask = model(x.float(), mask_ratio=0.75) |
|
y = model.unpatchify(y) |
|
y = torch.einsum('nchw->nhwc', y).detach().cpu() |
|
|
|
|
|
mask = mask.detach() |
|
mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3) |
|
mask = model.unpatchify(mask) |
|
mask = torch.einsum('nchw->nhwc', mask).detach().cpu() |
|
|
|
x = torch.einsum('nchw->nhwc', x) |
|
|
|
|
|
im_masked = x * (1 - mask) |
|
|
|
|
|
im_paste = x * (1 - mask) + y * mask |
|
|
|
|
|
plt.rcParams['figure.figsize'] = [24, 6] |
|
|
|
plt.subplot(1, 4, 1) |
|
show_image(x[0], "original") |
|
|
|
plt.subplot(1, 4, 2) |
|
show_image(im_masked[0], "masked") |
|
|
|
plt.subplot(1, 4, 3) |
|
show_image(y[0], "reconstruction") |
|
|
|
plt.subplot(1, 4, 4) |
|
show_image(im_paste[0], "reconstruction + visible") |
|
|
|
plt.savefig("test.png",bbox_inches='tight') |
|
|
|
|
|
|
|
os.system("wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth") |
|
|
|
chkpt_dir = 'mae_visualize_vit_large.pth' |
|
model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16') |
|
print('Model loaded.') |
|
|
|
|
|
def inference(img): |
|
img = img.resize((224, 224)) |
|
img = np.array(img) / 255. |
|
|
|
assert img.shape == (224, 224, 3) |
|
|
|
|
|
img = img - imagenet_mean |
|
img = img / imagenet_std |
|
|
|
|
|
torch.manual_seed(2) |
|
|
|
run_one_image(img, model_mae) |
|
return "test.png" |
|
|
|
|
|
title = "MAE" |
|
description = "Gradio Demo for Masked Autoencoders Are Scalable Vision Learners. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below." |
|
|
|
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.06377' target='_blank'>Masked Autoencoders Are Scalable Vision Learners</a>| <a href='https://github.com/facebookresearch/mae' target='_blank'>Github Repo</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_mae' alt='visitor badge'></center>" |
|
|
|
examples=[['147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpeg']] |
|
gr.Interface(inference, [gr.inputs.Image(type="pil")], gr.outputs.Image(type="file"),title=title,description=description,article=article,allow_flagging="never",allow_screenshot=False,examples=examples).launch(enable_queue=True) |