Spaces:
Running
Running
Update MAE-ViT image reconstruction descriptions and add links to model card and GitHub repository
f1a7938
import gradio as gr | |
import torch | |
from PIL import Image | |
import numpy as np | |
import random | |
from einops import rearrange | |
import matplotlib.pyplot as plt | |
from torchvision.transforms import v2 | |
from model import MAE_ViT, MAE_Encoder, MAE_Decoder, MAE_Encoder_FeatureExtractor | |
path_1 = [['images/cat.jpg'], ['images/dog.jpg'], ['images/horse.jpg'], ['images/airplane.jpg'], ['images/truck.jpg']] | |
path_2 = [['images/cat.jpg'], ['images/dog.jpg'], ['images/horse.jpg'], ['images/airplane.jpg'], ['images/truck.jpg']] | |
path_3 = [['images/cat.jpg'], ['images/dog.jpg'], ['images/horse.jpg'], ['images/airplane.jpg'], ['images/truck.jpg']] | |
device = torch.device("cpu") | |
model_name = "model/no_mode/vit-t-mae-pretrain.pt" | |
model_no_mode = torch.load(model_name, map_location='cpu') | |
model_no_mode.eval() | |
model_no_mode.to(device) | |
model_name = "model/bottom_25/vit-t-mae-pretrain.pt" | |
model_pca_mode_bottom = torch.load(model_name, map_location='cpu') | |
model_pca_mode_bottom.eval() | |
model_pca_mode_bottom.to(device) | |
model_name = "model/top_75/vit-t-mae-pretrain.pt" | |
model_pca_mode_top = torch.load(model_name, map_location='cpu') | |
model_pca_mode_top.eval() | |
model_pca_mode_top.to(device) | |
transform = v2.Compose([ | |
v2.Resize((96, 96)), | |
v2.ToTensor(), | |
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
]) | |
# Load and Preprocess the Image | |
def load_image(image_path, transform): | |
img = Image.open(image_path).convert('RGB') | |
img = transform(img).unsqueeze(0) # Add batch dimension | |
return img | |
def show_image(img, title): | |
img = rearrange(img, "c h w -> h w c") | |
img = (img.cpu().detach().numpy() + 1) / 2 # Normalize to [0, 1] | |
plt.imshow(img) | |
plt.axis('off') | |
plt.title(title) | |
# Visualize a Single Image | |
def visualize_single_image_no_mode(image_path): | |
img = load_image(image_path, transform).to(device) | |
# Run inference | |
with torch.no_grad(): | |
predicted_img, mask = model_no_mode(img) | |
# Convert the tensor back to a displayable image | |
# masked image | |
im_masked = img * (1 - mask) | |
# MAE reconstruction pasted with visible patches | |
im_paste = img * (1 - mask) + predicted_img * mask | |
# remove the batch dimension | |
img = img[0] | |
im_masked = im_masked[0] | |
predicted_img = predicted_img[0] | |
im_paste = im_paste[0] | |
# make the plt figure larger | |
plt.figure(figsize=(18, 8)) | |
plt.subplot(1, 3, 1) | |
show_image(img, "original") | |
plt.subplot(1, 3, 2) | |
show_image(im_masked, "masked") | |
# plt.subplot(1, 4, 3) | |
# show_image(predicted_img, "reconstruction") | |
plt.subplot(1, 3, 3) | |
show_image(im_paste, "reconstruction") | |
plt.tight_layout() | |
# convert the plt figure to a numpy array | |
plt.savefig("output.png") | |
return np.array(plt.imread("output.png")) | |
def visualize_single_image_pca_mode_bottom(image_path): | |
img = load_image(image_path, transform).to(device) | |
# Run inference | |
with torch.no_grad(): | |
predicted_img, mask = model_pca_mode_bottom(img) | |
# Convert the tensor back to a displayable image | |
# masked image | |
im_masked = img * (1 - mask) | |
# MAE reconstruction pasted with visible patches | |
im_paste = img * (1 - mask) + predicted_img * mask | |
# remove the batch dimension | |
img = img[0] | |
im_masked = im_masked[0] | |
predicted_img = predicted_img[0] | |
im_paste = im_paste[0] | |
# make the plt figure larger | |
plt.figure(figsize=(18, 8)) | |
plt.subplot(1, 3, 1) | |
show_image(img, "original") | |
plt.subplot(1, 3, 2) | |
show_image(im_masked, "masked") | |
plt.subplot(1, 3, 3) | |
show_image(predicted_img, "reconstruction") | |
# plt.subplot(1, 4, 4) | |
# show_image(im_paste, "reconstruction + visible") | |
plt.tight_layout() | |
# convert the plt figure to a numpy array | |
plt.savefig("output.png") | |
return np.array(plt.imread("output.png")) | |
def visualize_single_image_pca_mode_top(image_path): | |
img = load_image(image_path, transform).to(device) | |
# Run inference | |
with torch.no_grad(): | |
predicted_img, mask = model_pca_mode_top(img) | |
# Convert the tensor back to a displayable image | |
# masked image | |
im_masked = img * (1 - mask) | |
# MAE reconstruction pasted with visible patches | |
im_paste = img * (1 - mask) + predicted_img * mask | |
# remove the batch dimension | |
img = img[0] | |
im_masked = im_masked[0] | |
predicted_img = predicted_img[0] | |
im_paste = im_paste[0] | |
# make the plt figure larger | |
plt.figure(figsize=(18, 8)) | |
plt.subplot(1, 3, 1) | |
show_image(img, "original") | |
plt.subplot(1, 3, 2) | |
show_image(im_masked, "masked") | |
plt.subplot(1, 3, 3) | |
show_image(predicted_img, "reconstruction") | |
# plt.subplot(1, 4, 4) | |
# show_image(im_paste, "reconstruction + visible") | |
plt.tight_layout() | |
# convert the plt figure to a numpy array | |
plt.savefig("output.png") | |
return np.array(plt.imread("output.png")) | |
inputs_image_1 = [ | |
gr.components.Image(type="filepath", label="Input Image"), | |
] | |
outputs_image_1 = [ | |
gr.components.Image(type="numpy", label="Output Image"), | |
] | |
inputs_image_2 = [ | |
gr.components.Image(type="filepath", label="Input Image"), | |
] | |
outputs_image_2 = [ | |
gr.components.Image(type="numpy", label="Output Image"), | |
] | |
inputs_image_3 = [ | |
gr.components.Image(type="filepath", label="Input Image"), | |
] | |
outputs_image_3 = [ | |
gr.components.Image(type="numpy", label="Output Image"), | |
] | |
inference_no_mode = gr.Interface( | |
fn=visualize_single_image_no_mode, | |
inputs=inputs_image_1, | |
outputs=outputs_image_1, | |
examples=path_1, | |
cache_examples = False, | |
title="MAE-ViT Image Reconstruction", | |
description="This is a demo of the MAE-ViT model for image reconstruction. The model is trained without PCA mode. It was trained on the STL-10 dataset. Check out the huggingface model card and the github repository for more information. https://huggingface.co/turhancan97/MAE-Models and https://github.com/turhancan97/Learning-by-Reconstruction-with-MAE", | |
) | |
inference_pca_mode_bottom = gr.Interface( | |
fn=visualize_single_image_pca_mode_bottom, | |
inputs=inputs_image_2, | |
outputs=outputs_image_2, | |
examples=path_2, | |
title="MAE-ViT Image Reconstruction", | |
description="This is a demo of the MAE-ViT model for image reconstruction. The model is trained with PCA mode (bottom 25%). It was trained on the STL-10 dataset.", | |
) | |
inference_pca_mode_top = gr.Interface( | |
fn=visualize_single_image_pca_mode_top, | |
inputs=inputs_image_3, | |
outputs=outputs_image_3, | |
examples=path_3, | |
title="MAE-ViT Image Reconstruction", | |
description="This is a demo of the MAE-ViT model for image reconstruction. The model is trained with PCA mode (top 75%). It was trained on the STL-10 dataset.", | |
) | |
gr.TabbedInterface( | |
[inference_no_mode, inference_pca_mode_bottom, inference_pca_mode_top], | |
tab_names=['Normal Mode', 'PCA Mode (Bottom 25%)', 'PCA Mode (Top 75%)'] | |
).queue().launch() |