SuperFeatures / app.py
YannisK's picture
temp
f1b5b7f
import gradio as gr
import cv2
import torch
import torch.utils.data as data
from torchvision import transforms
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib import colors
from mpl_toolkits.axes_grid1 import ImageGrid
import fire_network
import numpy as np
from PIL import Image
# Possible Scales for multiscale inference
scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25]
device = 'cpu'
# Load nets
state = torch.load('fire.pth', map_location='cpu')
state['net_params']['pretrained'] = None # no need for imagenet pretrained model
net_sfm = fire_network.init_network(**state['net_params']).to(device)
net_sfm.load_state_dict(state['state_dict'])
dim_red_params_dict = {}
for name, param in net_sfm.named_parameters():
if 'dim_reduction' in name:
dim_red_params_dict[name] = param
state2 = torch.load('fire_imagenet.pth', map_location='cpu')
state2['net_params'] = state['net_params']
state2['state_dict'] = dict(state2['state_dict'], **dim_red_params_dict);
net_imagenet = fire_network.init_network(**state['net_params']).to(device)
net_imagenet.load_state_dict(state2['state_dict'], strict=False)
transform = transforms.Compose([
transforms.Resize(1024),
transforms.ToTensor(),
transforms.Normalize(**dict(zip(["mean", "std"], net_sfm.runtime['mean_std'])))
])
def match(query_feat, pos_feat, LoweRatioTh=0.9):
# first perform reciprocal nn
dist = torch.cdist(query_feat, pos_feat)
# print('dist.size',dist.size())
best1 = torch.argmin(dist, dim=1)
best2 = torch.argmin(dist, dim=0)
# print('best2.size',best2.size())
arange = torch.arange(best2.size(0))
reciprocal = best1[best2]==arange
# check Lowe ratio test
dist2 = dist.clone()
dist2[best2,arange] = float('Inf')
dist2_second2 = torch.argmin(dist2, dim=0)
ratio1to2 = dist[best2,arange] / dist2_second2
valid = torch.logical_and(reciprocal, ratio1to2<=LoweRatioTh)
pindices = torch.where(valid)[0]
qindices = best2[pindices]
# keep only the ones with same indices
valid = pindices==qindices
return pindices[valid]
def clear_figures():
plt.figure().clear()
plt.close()
plt.cla()
plt.clf()
def generate_matching_superfeatures(
im1, im2,
Imagenet_model=False,
scale_id=6, threshold=50,
random_mode=False, sf_ids=''): #, only_matching=True):
# print('im1:', im1.size)
# print('im2:', im2.size)
clear_figures()
col = plt.get_cmap('tab10')
net = net_sfm
if Imagenet_model:
net = net_imagenet
im1_tensor = transform(im1).unsqueeze(0)
im2_tensor = transform(im2).unsqueeze(0)
im1_cv = np.array(im1)[:, :, ::-1].copy()
im2_cv = np.array(im2)[:, :, ::-1].copy()
# extract features
with torch.no_grad():
output1 = net.get_superfeatures(im1_tensor.to(device), scales=[scales[scale_id]])
feats1 = output1[0][0]
attns1 = output1[1][0]
strenghts1 = output1[2][0]
output2 = net.get_superfeatures(im2_tensor.to(device), scales=[scales[scale_id]])
feats2 = output2[0][0]
attns2 = output2[1][0]
strenghts2 = output2[2][0]
feats1n = F.normalize(torch.t(torch.squeeze(feats1)), dim=1)
feats2n = F.normalize(torch.t(torch.squeeze(feats2)), dim=1)
ind_match = match(feats1n, feats2n)
# which sf
sf_idx_ = []
n_sf_ids = 10
if random_mode or sf_ids == '':
sf_idx_ = np.random.randint(256, size=n_sf_ids)
else:
sf_idx_ = map(int, sf_ids.strip().split(','))
# only_matching:
if random_mode:
sf_idx_ = [int(jj) for jj in ind_match[np.random.randint(len(list(ind_match)), size=n_sf_ids)].numpy()]
sf_idx_ = list( dict.fromkeys(sf_idx_) )
else:
sf_idx_ = [i for i in sf_idx_ if i in list(ind_match)]
n_sf_ids = len(sf_idx_)
# Store all binary SF att maps to show them all at once in the end
all_att_bin1 = []
all_att_bin2 = []
for n, i in enumerate(sf_idx_):
att_heat = np.array(attns1[0,i,:,:].numpy(), dtype=np.float32)
att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0)
att_heat_bin = np.where(att_heat>threshold, 255, 0)
all_att_bin1.append(att_heat_bin)
att_heat = np.array(attns2[0,i,:,:].numpy(), dtype=np.float32)
att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0)
att_heat_bin = np.where(att_heat>threshold, 255, 0)
all_att_bin2.append(att_heat_bin)
fin_img = []
img1rsz = np.copy(im1_cv)
for j, att in enumerate(all_att_bin1):
att = cv2.resize(att, im1.size, interpolation=cv2.INTER_NEAREST)
mask2d = zip(*np.where(att==255))
for m,n in mask2d:
col_ = col.colors[j]
col_ = 255*np.array(colors.to_rgba(col_))[:3]
img1rsz[m,n, :] = col_[::-1]
img2rsz = np.copy(im2_cv)
for j, att in enumerate(all_att_bin2):
att = cv2.resize(att, im2.size, interpolation=cv2.INTER_NEAREST)
mask2d = zip(*np.where(att==255))
for m,n in mask2d:
col_ = col.colors[j]
col_ = 255*np.array(colors.to_rgba(col_))[:3]
img2rsz[m,n, :] = col_[::-1]
fig1 = plt.figure(1)
plt.imshow(cv2.cvtColor(img1rsz, cv2.COLOR_BGR2RGB))
ax1 = plt.gca()
ax1.axis('off')
plt.tight_layout()
fig2 = plt.figure(2)
plt.imshow(cv2.cvtColor(img2rsz, cv2.COLOR_BGR2RGB))
ax2 = plt.gca()
ax2.axis('off')
plt.tight_layout()
f = lambda m,c: plt.plot([],[],marker=m, color=c, ls="none")[0]
handles = [f("s", col.colors[i]) for i in range(n_sf_ids)]
fig_leg = plt.figure(3)
legend = plt.legend(handles, sf_idx_, framealpha=1, frameon=False, facecolor='w',fontsize=25, loc="center")
ax3 = plt.gca()
ax3.axis('off')
plt.tight_layout()
im1 = None
im2 = None
return fig1, fig2, fig_leg
# GRADIO APP
title = "Visualizing Super-features"
description = "This is a visualization demo for the ICLR 2022 paper <b><a href='https://github.com/naver/fire' target='_blank'>Learning Super-Features for Image Retrieval</a></p></b>"
article = "<p style='text-align: center'><a href='https://github.com/naver/fire' target='_blank'>Original Github Repo</a></p>"
iface = gr.Interface(
fn=generate_matching_superfeatures,
inputs=[
gr.inputs.Image(shape=(1024, 1024), type="pil", label="First Image"),
gr.inputs.Image(shape=(1024, 1024), type="pil", label="Second Image"),
gr.inputs.Checkbox(default=False, label="ImageNet Model (Default: SfM-120k)"),
gr.inputs.Slider(minimum=0, maximum=6, step=1, default=4, label="Scale"),
gr.inputs.Slider(minimum=0, maximum=255, step=25, default=150, label="Binarization Threshold"),
gr.inputs.Checkbox(default=True, label="Show random (matching) SFs"),
gr.inputs.Textbox(lines=1, default="", label="...or show specific SF IDs:", optional=True),
],
outputs=[
gr.outputs.Image(type="plot", label="First Image SFs"),
gr.outputs.Image(type="plot", label="Second Image SFs"),
gr.outputs.Image(type="plot", label="SF legend")],
title=title,
theme='peach',
layout="horizontal",
description=description,
article=article,
examples=[
["chateau_1.png", "chateau_2.png", False, 3, 150, False, '170,15,25,63,193,125,92,214,107'],
["areopoli1.jpeg", "areopoli2.jpeg", False, 4, 150, False, '205,2,163,130'],
["jaipur1.jpeg", "jaipur2.jpeg", False, 4, 50, False, '51,206,216,49,27'],
["basil1.jpeg", "basil2.jpeg", True, 4, 100, False, '75,152,19,36,156'],
["mill1.jpeg", "mill2.jpeg", False, 4, 100, False, '177,88,170,190,151,155'],
]
)
iface.launch(enable_queue=True)