Spaces:
Build error
Build error
File size: 7,922 Bytes
bb3ea39 dd504c0 880da41 dd504c0 5dfb000 aa5e40b 402b433 880da41 402b433 880da41 f4b82b2 9651aac e7c9542 402b433 f4b82b2 c9dadbf f4b82b2 351ead9 22a02d6 f4b82b2 b3806d2 cb6bc7f f4b82b2 b3806d2 cb6bc7f 1d018e4 ce415a7 1d018e4 5dfb000 092e211 0f68393 092e211 9595159 5dfb000 ea2b2ef f1b5b7f ea2b2ef f1b5b7f a15c1b9 ea2b2ef 9651aac b2bf1e7 c9dadbf 9595159 c9dadbf 411e4cb f1b5b7f c1911e8 b2bf1e7 9595159 b2bf1e7 b3806d2 0cb19e3 b3806d2 5dfb000 b489890 f4b82b2 f981819 f4b82b2 c9dadbf 092e211 ea2b2ef c1911e8 ea2b2ef 402b433 c1911e8 9595159 cb347cf 9fec9a2 c1911e8 9fec9a2 c1911e8 9595159 cb3d625 c1911e8 a809352 402b433 7c408ba 402b433 dd504c0 402b433 722b0aa 402b433 7022373 402b433 3ea46fa 402b433 dca7dd8 402b433 7022373 402b433 8222092 402b433 f4b82b2 0f77bb9 3d3c7f5 a0c42c4 0f77bb9 3d3c7f5 a0c42c4 0f77bb9 cc9878d 8222092 8ff5253 cb347cf 9180d7e 8ff5253 2566e5b 8ff5253 1a2db09 274f0f4 1a2db09 2566e5b 43f0015 411e4cb 7f2fdd5 411e4cb 0cb19e3 36eb9c8 0f77bb9 06125bf cc9878d 1a2db09 0f77bb9 2134c75 1a2db09 0f77bb9 707ee9f 2639fe5 53874e7 c91b02f 700c051 1a2db09 2134c75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
import gradio as gr
import cv2
import torch
import torch.utils.data as data
from torchvision import transforms
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib import colors
from mpl_toolkits.axes_grid1 import ImageGrid
import fire_network
import numpy as np
from PIL import Image
# Possible Scales for multiscale inference
scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25]
device = 'cpu'
# Load nets
state = torch.load('fire.pth', map_location='cpu')
state['net_params']['pretrained'] = None # no need for imagenet pretrained model
net_sfm = fire_network.init_network(**state['net_params']).to(device)
net_sfm.load_state_dict(state['state_dict'])
dim_red_params_dict = {}
for name, param in net_sfm.named_parameters():
if 'dim_reduction' in name:
dim_red_params_dict[name] = param
state2 = torch.load('fire_imagenet.pth', map_location='cpu')
state2['net_params'] = state['net_params']
state2['state_dict'] = dict(state2['state_dict'], **dim_red_params_dict);
net_imagenet = fire_network.init_network(**state['net_params']).to(device)
net_imagenet.load_state_dict(state2['state_dict'], strict=False)
transform = transforms.Compose([
transforms.Resize(1024),
transforms.ToTensor(),
transforms.Normalize(**dict(zip(["mean", "std"], net_sfm.runtime['mean_std'])))
])
def match(query_feat, pos_feat, LoweRatioTh=0.9):
# first perform reciprocal nn
dist = torch.cdist(query_feat, pos_feat)
# print('dist.size',dist.size())
best1 = torch.argmin(dist, dim=1)
best2 = torch.argmin(dist, dim=0)
# print('best2.size',best2.size())
arange = torch.arange(best2.size(0))
reciprocal = best1[best2]==arange
# check Lowe ratio test
dist2 = dist.clone()
dist2[best2,arange] = float('Inf')
dist2_second2 = torch.argmin(dist2, dim=0)
ratio1to2 = dist[best2,arange] / dist2_second2
valid = torch.logical_and(reciprocal, ratio1to2<=LoweRatioTh)
pindices = torch.where(valid)[0]
qindices = best2[pindices]
# keep only the ones with same indices
valid = pindices==qindices
return pindices[valid]
def clear_figures():
plt.figure().clear()
plt.close()
plt.cla()
plt.clf()
def generate_matching_superfeatures(
im1, im2,
Imagenet_model=False,
scale_id=6, threshold=50,
random_mode=False, sf_ids=''): #, only_matching=True):
# print('im1:', im1.size)
# print('im2:', im2.size)
clear_figures()
col = plt.get_cmap('tab10')
net = net_sfm
if Imagenet_model:
net = net_imagenet
im1_tensor = transform(im1).unsqueeze(0)
im2_tensor = transform(im2).unsqueeze(0)
im1_cv = np.array(im1)[:, :, ::-1].copy()
im2_cv = np.array(im2)[:, :, ::-1].copy()
# extract features
with torch.no_grad():
output1 = net.get_superfeatures(im1_tensor.to(device), scales=[scales[scale_id]])
feats1 = output1[0][0]
attns1 = output1[1][0]
strenghts1 = output1[2][0]
output2 = net.get_superfeatures(im2_tensor.to(device), scales=[scales[scale_id]])
feats2 = output2[0][0]
attns2 = output2[1][0]
strenghts2 = output2[2][0]
feats1n = F.normalize(torch.t(torch.squeeze(feats1)), dim=1)
feats2n = F.normalize(torch.t(torch.squeeze(feats2)), dim=1)
ind_match = match(feats1n, feats2n)
# which sf
sf_idx_ = []
n_sf_ids = 10
if random_mode or sf_ids == '':
sf_idx_ = np.random.randint(256, size=n_sf_ids)
else:
sf_idx_ = map(int, sf_ids.strip().split(','))
# only_matching:
if random_mode:
sf_idx_ = [int(jj) for jj in ind_match[np.random.randint(len(list(ind_match)), size=n_sf_ids)].numpy()]
sf_idx_ = list( dict.fromkeys(sf_idx_) )
else:
sf_idx_ = [i for i in sf_idx_ if i in list(ind_match)]
n_sf_ids = len(sf_idx_)
# Store all binary SF att maps to show them all at once in the end
all_att_bin1 = []
all_att_bin2 = []
for n, i in enumerate(sf_idx_):
att_heat = np.array(attns1[0,i,:,:].numpy(), dtype=np.float32)
att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0)
att_heat_bin = np.where(att_heat>threshold, 255, 0)
all_att_bin1.append(att_heat_bin)
att_heat = np.array(attns2[0,i,:,:].numpy(), dtype=np.float32)
att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0)
att_heat_bin = np.where(att_heat>threshold, 255, 0)
all_att_bin2.append(att_heat_bin)
fin_img = []
img1rsz = np.copy(im1_cv)
for j, att in enumerate(all_att_bin1):
att = cv2.resize(att, im1.size, interpolation=cv2.INTER_NEAREST)
mask2d = zip(*np.where(att==255))
for m,n in mask2d:
col_ = col.colors[j]
col_ = 255*np.array(colors.to_rgba(col_))[:3]
img1rsz[m,n, :] = col_[::-1]
img2rsz = np.copy(im2_cv)
for j, att in enumerate(all_att_bin2):
att = cv2.resize(att, im2.size, interpolation=cv2.INTER_NEAREST)
mask2d = zip(*np.where(att==255))
for m,n in mask2d:
col_ = col.colors[j]
col_ = 255*np.array(colors.to_rgba(col_))[:3]
img2rsz[m,n, :] = col_[::-1]
fig1 = plt.figure(1)
plt.imshow(cv2.cvtColor(img1rsz, cv2.COLOR_BGR2RGB))
ax1 = plt.gca()
ax1.axis('off')
plt.tight_layout()
fig2 = plt.figure(2)
plt.imshow(cv2.cvtColor(img2rsz, cv2.COLOR_BGR2RGB))
ax2 = plt.gca()
ax2.axis('off')
plt.tight_layout()
f = lambda m,c: plt.plot([],[],marker=m, color=c, ls="none")[0]
handles = [f("s", col.colors[i]) for i in range(n_sf_ids)]
fig_leg = plt.figure(3)
legend = plt.legend(handles, sf_idx_, framealpha=1, frameon=False, facecolor='w',fontsize=25, loc="center")
ax3 = plt.gca()
ax3.axis('off')
plt.tight_layout()
im1 = None
im2 = None
return fig1, fig2, fig_leg
# GRADIO APP
title = "Visualizing Super-features"
description = "This is a visualization demo for the ICLR 2022 paper <b><a href='https://github.com/naver/fire' target='_blank'>Learning Super-Features for Image Retrieval</a></p></b>"
article = "<p style='text-align: center'><a href='https://github.com/naver/fire' target='_blank'>Original Github Repo</a></p>"
iface = gr.Interface(
fn=generate_matching_superfeatures,
inputs=[
gr.inputs.Image(shape=(1024, 1024), type="pil", label="First Image"),
gr.inputs.Image(shape=(1024, 1024), type="pil", label="Second Image"),
gr.inputs.Checkbox(default=False, label="ImageNet Model (Default: SfM-120k)"),
gr.inputs.Slider(minimum=0, maximum=6, step=1, default=4, label="Scale"),
gr.inputs.Slider(minimum=0, maximum=255, step=25, default=150, label="Binarization Threshold"),
gr.inputs.Checkbox(default=True, label="Show random (matching) SFs"),
gr.inputs.Textbox(lines=1, default="", label="...or show specific SF IDs:", optional=True),
],
outputs=[
gr.outputs.Image(type="plot", label="First Image SFs"),
gr.outputs.Image(type="plot", label="Second Image SFs"),
gr.outputs.Image(type="plot", label="SF legend")],
title=title,
theme='peach',
layout="horizontal",
description=description,
article=article,
examples=[
["chateau_1.png", "chateau_2.png", False, 3, 150, False, '170,15,25,63,193,125,92,214,107'],
["areopoli1.jpeg", "areopoli2.jpeg", False, 4, 150, False, '205,2,163,130'],
["jaipur1.jpeg", "jaipur2.jpeg", False, 4, 50, False, '51,206,216,49,27'],
["basil1.jpeg", "basil2.jpeg", True, 4, 100, False, '75,152,19,36,156'],
["mill1.jpeg", "mill2.jpeg", False, 4, 100, False, '177,88,170,190,151,155'],
]
)
iface.launch(enable_queue=True) |