Ahsen Khaliq commited on
Commit
3c6d2ed
1 Parent(s): da3fa5d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import requests
4
+
5
+ import torch
6
+ import numpy as np
7
+
8
+ import matplotlib.pyplot as plt
9
+ from PIL import Image
10
+ import gradio as gr
11
+
12
+
13
+ os.system("pip install timm==0.4.5")
14
+ os.system("git clone https://github.com/facebookresearch/mae.git")
15
+ sys.path.append('./mae')
16
+
17
+ import models_mae
18
+
19
+ # define the utils
20
+
21
+ imagenet_mean = np.array([0.485, 0.456, 0.406])
22
+ imagenet_std = np.array([0.229, 0.224, 0.225])
23
+
24
+ def show_image(image, title=''):
25
+ # image is [H, W, 3]
26
+ assert image.shape[2] == 3
27
+ plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())
28
+ plt.title(title, fontsize=16)
29
+ plt.axis('off')
30
+ return
31
+
32
+ def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'):
33
+ # build model
34
+ model = getattr(models_mae, arch)()
35
+ # load model
36
+ checkpoint = torch.load(chkpt_dir, map_location='cpu')
37
+ msg = model.load_state_dict(checkpoint['model'], strict=False)
38
+ print(msg)
39
+ return model
40
+
41
+ def run_one_image(img, model):
42
+ x = torch.tensor(img)
43
+
44
+ # make it a batch-like
45
+ x = x.unsqueeze(dim=0)
46
+ x = torch.einsum('nhwc->nchw', x)
47
+
48
+ # run MAE
49
+ loss, y, mask = model(x.float(), mask_ratio=0.75)
50
+ y = model.unpatchify(y)
51
+ y = torch.einsum('nchw->nhwc', y).detach().cpu()
52
+
53
+ # visualize the mask
54
+ mask = mask.detach()
55
+ mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3) # (N, H*W, p*p*3)
56
+ mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
57
+ mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
58
+
59
+ x = torch.einsum('nchw->nhwc', x)
60
+
61
+ # masked image
62
+ im_masked = x * (1 - mask)
63
+
64
+ # MAE reconstruction pasted with visible patches
65
+ im_paste = x * (1 - mask) + y * mask
66
+
67
+ # make the plt figure larger
68
+ plt.rcParams['figure.figsize'] = [24, 24]
69
+
70
+ plt.subplot(1, 4, 1)
71
+ show_image(x[0], "original")
72
+
73
+ plt.subplot(1, 4, 2)
74
+ show_image(im_masked[0], "masked")
75
+
76
+ plt.subplot(1, 4, 3)
77
+ show_image(y[0], "reconstruction")
78
+
79
+ plt.subplot(1, 4, 4)
80
+ show_image(im_paste[0], "reconstruction + visible")
81
+
82
+ plt.show()
83
+
84
+
85
+ # download checkpoint if not exist
86
+ os.system("wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth")
87
+
88
+ chkpt_dir = 'mae_visualize_vit_large.pth'
89
+ model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16')
90
+ print('Model loaded.')
91
+
92
+
93
+ def inference(img):
94
+ img = img.resize((224, 224))
95
+ img = np.array(img) / 255.
96
+
97
+ assert img.shape == (224, 224, 3)
98
+
99
+ # normalize by ImageNet mean and std
100
+ img = img - imagenet_mean
101
+ img = img / imagenet_std
102
+
103
+
104
+ torch.manual_seed(2)
105
+ return run_one_image(img, model_mae)
106
+
107
+
108
+ title = "MAE"
109
+ description = "Gradio Demo for MAE. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
110
+
111
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center>"
112
+
113
+
114
+ gr.Interface(inference, [gr.inputs.Image(type="pil")], gr.outputs.Image(type="plot"),title=title,description=description,article=article,allow_flagging=False,examples=examples,allow_screenshot=False,enable_queue=True).launch()