Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,98 +1,251 @@
|
|
1 |
-
import torch
|
2 |
-
from torchvision import transforms
|
3 |
import numpy as np
|
4 |
import gradio as gr
|
5 |
from PIL import Image
|
6 |
from pytorch_grad_cam import GradCAM
|
|
|
7 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
model.load_state_dict(model_state_dict)
|
19 |
|
20 |
-
|
21 |
-
model = CustomResNet() # Replace this with your CustomResNet if necessary
|
22 |
-
# Load the state_dict using the custom function
|
23 |
-
state_dict = torch.load("model_pth.ckpt", map_location=torch.device('cpu'))
|
24 |
-
load_custom_state_dict(model, state_dict['state_dict'])
|
25 |
-
|
26 |
-
inv_normalize = transforms.Normalize(
|
27 |
-
mean=[-0.494 / 0.2470, -0.4822 / 0.2435, -0.4465 / 0.2616],
|
28 |
-
std=[1 / 0.2470, 1 / 0.2435, 1 / 0.2616]
|
29 |
-
)
|
30 |
classes = ('plane', 'car', 'bird', 'cat', 'deer',
|
31 |
-
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
input_img = transform(input_img)
|
37 |
input_img = input_img.unsqueeze(0)
|
38 |
outputs = model(input_img)
|
39 |
-
softmax = torch.nn.Softmax(dim=
|
40 |
-
|
41 |
-
confidences = {classes[i]: float(
|
42 |
-
|
|
|
43 |
_, prediction = torch.max(outputs, 1)
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
47 |
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
|
48 |
grayscale_cam = cam(input_tensor=input_img, targets=None)
|
49 |
grayscale_cam = grayscale_cam[0, :]
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
|
|
|
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
|
|
|
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
gr.
|
90 |
-
gr.Slider(1,
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
|
98 |
-
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
import gradio as gr
|
3 |
from PIL import Image
|
4 |
from pytorch_grad_cam import GradCAM
|
5 |
+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
6 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
7 |
+
import torch
|
8 |
+
from torchvision import datasets, transforms
|
9 |
+
#from model import CustomResNet
|
10 |
+
import random
|
11 |
+
|
12 |
+
|
13 |
+
model = CustomResNet()
|
14 |
+
model.load_state_dict(torch.load('CustomResNet.pth', map_location=torch.device('cpu')), strict=False)
|
15 |
+
model.eval()
|
16 |
+
|
|
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
classes = ('plane', 'car', 'bird', 'cat', 'deer',
|
19 |
+
'dog', 'frog', 'horse', 'ship', 'truck')
|
20 |
|
21 |
+
|
22 |
+
def inference(input_img, input_slider_grad_or_not, transparency = 0.5, target_layer_number = 3, topk = 3):
|
23 |
+
mean=[0.49139968, 0.48215827, 0.44653124]
|
24 |
+
std=[0.24703233, 0.24348505, 0.26158768]
|
25 |
+
transform = transforms.Compose([
|
26 |
+
transforms.ToTensor(),
|
27 |
+
transforms.Normalize(mean, std)
|
28 |
+
])
|
29 |
+
orginal_img = input_img
|
30 |
input_img = transform(input_img)
|
31 |
input_img = input_img.unsqueeze(0)
|
32 |
outputs = model(input_img)
|
33 |
+
softmax = torch.nn.Softmax(dim=0)
|
34 |
+
o = softmax(outputs.flatten())
|
35 |
+
confidences = {classes[i]: float(o[i]) for i in range(10)}
|
36 |
+
if input_slider_grad_or_not == "No":
|
37 |
+
return confidences, orginal_img
|
38 |
_, prediction = torch.max(outputs, 1)
|
39 |
+
target_layers = [model.layer_3[-1]]
|
40 |
+
if target_layer_number == 1:
|
41 |
+
target_layers = [model.layer_1[-1]]
|
42 |
+
if target_layer_number == 2:
|
43 |
+
target_layers = [model.layer_2[-1]]
|
44 |
+
if target_layer_number == 3:
|
45 |
+
target_layers = [model.layer_3[-1]]
|
46 |
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
|
47 |
grayscale_cam = cam(input_tensor=input_img, targets=None)
|
48 |
grayscale_cam = grayscale_cam[0, :]
|
49 |
+
visualization = show_cam_on_image(orginal_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
|
50 |
+
|
51 |
+
return confidences, visualization
|
52 |
+
|
53 |
+
|
54 |
+
def show_gradcam_images(n, a, b):
|
55 |
+
images = [
|
56 |
+
('examples/car.jpg', 'car'),
|
57 |
+
('examples/cat.jpg', 'cat'),
|
58 |
+
('examples/dog.jpg', 'dog'),
|
59 |
+
('examples/horse.jpg', 'horse'),
|
60 |
+
('examples/ship.jpg', 'ship'),
|
61 |
+
('examples/bird.jpg', 'bird'),
|
62 |
+
('examples/frog.jpg', 'frog'),
|
63 |
+
('examples/plane.jpg', 'plane'),
|
64 |
+
('examples/truck.jpg', 'truck'),
|
65 |
+
('examples/deer.jpg', 'deer'),
|
66 |
+
]
|
67 |
+
images_with_gradcam = []
|
68 |
+
for image_path, label in images:
|
69 |
+
image = Image.open(image_path)
|
70 |
+
image_array = np.asarray(image)
|
71 |
+
visualization = inference(image_array, "Yes", a, b)[-1]
|
72 |
+
images_with_gradcam.append((visualization, label))
|
73 |
+
|
74 |
+
return {
|
75 |
+
grad1_block: gr.update(visible=True),
|
76 |
+
gallery3: images_with_gradcam[:n]
|
77 |
+
}
|
78 |
+
|
79 |
+
|
80 |
+
def change_grad_view(choice):
|
81 |
+
if choice == "Yes":
|
82 |
+
return grad_block.update(visible=True)
|
83 |
+
else:
|
84 |
+
return grad_block.update(visible=False)
|
85 |
+
|
86 |
+
|
87 |
+
def show_misclassified_images(n, grad_cam, a, b):
|
88 |
+
images = [
|
89 |
+
('misclassified_images/misclassified_0_GT_bird_Pred_cat.jpg', 'bird/cat'),
|
90 |
+
('misclassfied_images/misclassified_1_GT_car_Pred_truck.jpg', 'car/truck'),
|
91 |
+
('misclassified_images/misclassified_2_GT_plane_Pred_truck.jpg', 'plane/truck'),
|
92 |
+
('misclassified_images/misclassified_3_GT_deer_Pred_dog.jpg', 'deer/dog'),
|
93 |
+
('misclassified_images/misclassified_4_GT_frog_Pred_cat.jpg', 'frog/cat'),
|
94 |
+
('misclassified_images/misclassified_5_GT_cat_Pred_dog.jpg', 'cat/dog'),
|
95 |
+
('misclassified_images/misclassified_6_GT_cat_Pred_dog.jpg', 'cat/dog'),
|
96 |
+
('misclassified_images/misclassified_7_GT_dog_Pred_horse.jpg', 'dog/horse'),
|
97 |
+
('misclassified_images/misclassified_8_GT_bird_Pred_dog.jpg', 'bird/dog'),
|
98 |
+
('misclassified_images/misclassified_9_GT_ship_Pred_plane.jpg', 'ship/plane')
|
99 |
+
]
|
100 |
+
images_with_gradcam = []
|
101 |
+
for image_path, label in images:
|
102 |
+
image = Image.open(image_path)
|
103 |
+
image_array = np.asarray(image)
|
104 |
+
visualization = inference(image_array, "Yes", a, b)[-1]
|
105 |
+
images_with_gradcam.append((visualization, label))
|
106 |
+
if grad_cam == "Yes":
|
107 |
+
return {
|
108 |
+
miscls1_block: gr.update(visible=True),
|
109 |
+
gallery: images_with_gradcam[:n]
|
110 |
+
}
|
111 |
+
|
112 |
+
return {
|
113 |
+
miscls1_block: gr.update(visible=True),
|
114 |
+
gallery: images[:n]
|
115 |
+
}
|
116 |
+
|
117 |
+
|
118 |
+
def change_miscls_view(choice):
|
119 |
+
if choice == "Yes":
|
120 |
+
return miscls_block.update(visible=True)
|
121 |
+
else:
|
122 |
+
return miscls_block.update(visible=False)
|
123 |
+
|
124 |
+
|
125 |
+
def change_textbox(choice):
|
126 |
+
if choice == "Yes":
|
127 |
+
return [gr.Slider.update(visible=True), gr.Slider.update(visible=True)]
|
128 |
+
else:
|
129 |
+
return [gr.Slider.update(visible=False), gr.Slider.update(visible=False)]
|
130 |
+
|
131 |
+
|
132 |
+
def update_num_top_classes(input_img, input_slider_grad_or_not, transparency, target_layer_number, topk):
|
133 |
+
output_classes.num_top_classes=topk
|
134 |
+
return inference(input_img, input_slider_grad_or_not, transparency, target_layer_number, topk)[0]
|
135 |
+
|
136 |
+
|
137 |
+
def change_mygrad_view(choice):
|
138 |
+
if choice == "Yes":
|
139 |
+
return grad_or_not.update(visible=True)
|
140 |
+
else:
|
141 |
+
return grad_or_not.update(visible=False)
|
142 |
+
|
143 |
+
|
144 |
+
with gr.Blocks(theme='abidlabs/dracula_revamped') as demo:
|
145 |
+
gr.Markdown("""
|
146 |
+
|
147 |
+
# CustomResNet with GradCAM - Interactive Interface
|
148 |
+
|
149 |
+
### A simple Gradio interface to infer on CustomResNet model and get GradCAM results
|
150 |
+
|
151 |
+
""")
|
152 |
+
gr.Markdown("# Analyse the Model")
|
153 |
+
gr.Markdown("## Grad-CAM")
|
154 |
+
with gr.Row():
|
155 |
+
grad_yes_no = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to see GradCAM images")
|
156 |
+
with gr.Row(visible=False) as grad_block:
|
157 |
+
with gr.Column(scale=1):
|
158 |
+
input_grad = gr.Slider(1, 10, value = 3, step=1, label="Number of GradCAM images to view")
|
159 |
+
input_overlay = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to configure gradcam")
|
160 |
+
with gr.Row():
|
161 |
+
clear_btn3 = gr.ClearButton()
|
162 |
+
submit_btn3 = gr.Button("Submit")
|
163 |
+
with gr.Column(scale=1):
|
164 |
+
input_slider31 = gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM", interactive=True, visible=False)
|
165 |
+
input_slider32 = gr.Slider(1, 3, value = 3, step=1, label="Which Layer?", interactive=True, visible=False)
|
166 |
+
with gr.Row(visible=False) as grad1_block:
|
167 |
+
gallery3 = gr.Gallery(
|
168 |
+
label="GradCAM images", show_label=True, elem_id="gallery3"
|
169 |
+
).style(columns=[4], rows=[3], object_fit="contain", height="auto")
|
170 |
+
|
171 |
+
submit_btn3.click(fn=show_gradcam_images, inputs=[input_grad, input_slider31, input_slider32], outputs = [grad1_block, gallery3])
|
172 |
+
clear_btn3.click(lambda: [None, None, None, None, None], outputs=[input_grad, input_grad, input_slider31, input_slider32, gallery3])
|
173 |
+
input_overlay.change(fn=change_textbox, inputs=input_overlay, outputs=[input_slider31, input_slider32])
|
174 |
+
grad_yes_no.change(fn=change_grad_view, inputs=grad_yes_no, outputs=[grad_block])
|
175 |
|
176 |
+
|
177 |
+
###############################################
|
178 |
|
179 |
+
|
180 |
+
gr.Markdown("## Misclassification")
|
181 |
+
with gr.Row():
|
182 |
+
miscls_yes_no = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to see misclassified images")
|
183 |
+
with gr.Row(visible=False) as miscls_block:
|
184 |
+
with gr.Column(scale=1):
|
185 |
+
input_miscn = gr.Slider(1, 10, value = 3, step=1, label="Number of misclassified images to view")
|
186 |
+
with gr.Row():
|
187 |
+
clear_btn2 = gr.ClearButton()
|
188 |
+
submit_btn2 = gr.Button("Submit")
|
189 |
+
with gr.Column(scale=1):
|
190 |
+
input_grad2 = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to overlay gradcam")
|
191 |
+
input_slider21 = gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM", interactive=True, visible=False)
|
192 |
+
input_slider22 = gr.Slider(1, 3, value = 3, step=1, label="Which Layer?", interactive=True, visible=False)
|
193 |
+
with gr.Column(visible=False) as miscls1_block:
|
194 |
+
gallery = gr.Gallery(
|
195 |
+
label="Misclassified images", show_label=True, elem_id="gallery"
|
196 |
+
).style(columns=[4], rows=[3], object_fit="contain", height="auto")
|
197 |
+
|
198 |
+
|
199 |
+
submit_btn2.click(fn=show_misclassified_images, inputs=[input_miscn, input_grad2, input_slider21, input_slider22], outputs = [miscls1_block, gallery])
|
200 |
+
clear_btn2.click(lambda: [None, None, None, None, None], outputs=[input_miscn, input_grad, input_slider21, input_slider22, gallery])
|
201 |
+
input_grad2.change(fn=change_textbox, inputs=input_grad2, outputs=[input_slider21, input_slider22])
|
202 |
+
miscls_yes_no.change(fn=change_miscls_view, inputs=miscls_yes_no, outputs=[miscls_block])
|
203 |
|
204 |
+
|
205 |
+
###############################################
|
206 |
|
207 |
+
|
208 |
+
gr.Markdown("## Try it Out")
|
209 |
+
with gr.Row():
|
210 |
+
with gr.Column(scale=1):
|
211 |
+
input_image = gr.Image(shape=(32, 32), label="Input Image")
|
212 |
+
input_topk = gr.Slider(1, 10, value = 3, step=1, label="Top N Classes")
|
213 |
+
input_slider_grad_or_not = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to overlay GradCAM output")
|
214 |
+
with gr.Column(visible=False) as grad_or_not:
|
215 |
+
input_slider1 = gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM")
|
216 |
+
input_slider2 = gr.Slider(1, 3, value = 3, step=1, label="Which Layer?")
|
217 |
+
with gr.Row():
|
218 |
+
clear_btn = gr.ClearButton()
|
219 |
+
submit_btn = gr.Button("Submit")
|
220 |
+
with gr.Column(scale=1):
|
221 |
+
output_classes = gr.Label(num_top_classes=3)
|
222 |
+
output_image = gr.Image(shape=(32, 32), label="Output").style(width=128, height=128)
|
223 |
+
|
224 |
+
|
225 |
+
gr.Markdown("## Examples")
|
226 |
+
gr.Examples(
|
227 |
+
examples=[["examples/car.jpg", "Yes", 0.5, 3, 3],
|
228 |
+
["examples/cat.jpg", "Yes", 0.7, 2, 5],
|
229 |
+
["examples/dog.jpg", "Yes", 0.9, 1, 4],
|
230 |
+
["examples/truck.jpg", "Yes", 0.3, 1, 7],
|
231 |
+
["examples/horse.jpg", "Yes", 0.7, 3, 4],
|
232 |
+
["examples/frog.jpg", "Yes", 0.8, 3, 6],
|
233 |
+
["examples/bird.jpg", "Yes", 0.9, 1, 7],
|
234 |
+
["examples/deer.jpg", "Yes", 0.3, 1, 3],
|
235 |
+
["examples/plane.jpg", "Yes", 0.4, 3, 4],
|
236 |
+
["examples/ship.jpg", "Yes", 0.5, 2, 5]
|
237 |
+
],
|
238 |
+
inputs=[input_image,input_slider_grad_or_not,input_slider1,input_slider2, input_topk],
|
239 |
+
outputs=[output_classes,output_image],
|
240 |
+
fn=inference,
|
241 |
+
cache_examples=True,
|
242 |
+
)
|
243 |
+
|
244 |
+
submit_btn.click(fn=inference, inputs=[input_image, input_slider_grad_or_not, input_slider1, input_slider2, input_topk], outputs=[output_classes, output_image])
|
245 |
+
clear_btn.click(lambda: [None, "No", 0.5, 3, None, None, 3], outputs=[input_image, input_slider_grad_or_not, input_slider1, input_slider2, output_classes, output_image])
|
246 |
+
input_topk.change(update_num_top_classes, inputs=[input_image, input_slider_grad_or_not, input_slider1, input_slider2, input_topk], outputs=[output_classes])
|
247 |
+
input_slider_grad_or_not.change(fn=change_mygrad_view, inputs=input_slider_grad_or_not, outputs=[grad_or_not])
|
248 |
|
249 |
|
250 |
+
if __name__ == "__main__":
|
251 |
+
demo.launch(debug=True)
|