Gosula commited on
Commit
be5cc68
1 Parent(s): 72eeb6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +230 -77
app.py CHANGED
@@ -1,98 +1,251 @@
1
- import torch
2
- from torchvision import transforms
3
  import numpy as np
4
  import gradio as gr
5
  from PIL import Image
6
  from pytorch_grad_cam import GradCAM
 
7
  from pytorch_grad_cam.utils.image import show_cam_on_image
8
- from custom_resnet import *
9
- #from resnet import ResNet18 # Assuming you have a custom ResNet18 implementation
10
-
11
- def load_custom_state_dict(model, state_dict):
12
- model_state_dict = model.state_dict()
13
- # Filter out unexpected keys
14
- filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
15
- # Update the model's state_dict
16
- model_state_dict.update(filtered_state_dict)
17
- # Load the updated state_dict to the model
18
- model.load_state_dict(model_state_dict)
19
 
20
-
21
- model = CustomResNet() # Replace this with your CustomResNet if necessary
22
- # Load the state_dict using the custom function
23
- state_dict = torch.load("model_pth.ckpt", map_location=torch.device('cpu'))
24
- load_custom_state_dict(model, state_dict['state_dict'])
25
-
26
- inv_normalize = transforms.Normalize(
27
- mean=[-0.494 / 0.2470, -0.4822 / 0.2435, -0.4465 / 0.2616],
28
- std=[1 / 0.2470, 1 / 0.2435, 1 / 0.2616]
29
- )
30
  classes = ('plane', 'car', 'bird', 'cat', 'deer',
31
- 'dog', 'frog', 'horse', 'ship', 'truck')
32
 
33
- def inference(input_img, transparency=0.5, target_layer_number=-1, num_images=1, num_top_classes=3):
34
- transform = transforms.ToTensor()
35
- org_img = input_img
 
 
 
 
 
 
36
  input_img = transform(input_img)
37
  input_img = input_img.unsqueeze(0)
38
  outputs = model(input_img)
39
- softmax = torch.nn.Softmax(dim=1)
40
- probabilities = softmax(outputs)
41
- confidences = {classes[i]: float(probabilities[0, i]) for i in range(10)}
42
-
 
43
  _, prediction = torch.max(outputs, 1)
44
-
45
- # Get GradCAM for the specified target_layer_number
46
- target_layers = [model.layer_2[target_layer_number]]
 
 
 
 
47
  cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
48
  grayscale_cam = cam(input_tensor=input_img, targets=None)
49
  grayscale_cam = grayscale_cam[0, :]
50
- img = input_img.squeeze(0)
51
- img = inv_normalize(img)
52
- rgb_img = np.transpose(img, (1, 2, 0))
53
- rgb_img = rgb_img.numpy()
54
-
55
- # Convert org_img (PIL image) to a NumPy array before performing arithmetic operations
56
- visualization = show_cam_on_image(org_img / 255, grayscale_cam, use_rgb=True, image_weight=transparency)
57
-
58
- # Create a list to store multiple visualizations
59
-
60
-
61
- # # Generate multiple GradCAM visualizations if num_images > 1
62
- # for _ in range(num_images - 1):
63
- # # Get GradCAM for different target_layer_number if provided by the user
64
- # if target_layer_number >= -1:
65
- # target_layers = [model.layer_2[target_layer_number]]
66
- # cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
67
- # grayscale_cam = cam(input_tensor=input_img, targets=None)
68
- # grayscale_cam = grayscale_cam[0, :]
69
-
70
- # visualization = show_cam_on_image(org_img / 255, grayscale_cam, use_rgb=True, image_weight=transparency)
71
- # visualizations.append(visualization)
72
-
73
- # Get top classes based on user input (up to a maximum of 10)
74
- top_classes = {k: v for k, v in sorted(confidences.items(), key=lambda item: item[1], reverse=True)[:min(num_top_classes, 10)]}
75
-
76
- return top_classes, visualization
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
 
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
 
 
80
 
81
- title = "CIFAR10 trained on ResNet18 Model with GradCAM"
82
- description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
83
- examples = [["car_1.jpg",0.5,-1],["car_2.jpg",0.5,-1],["cat_1.jpg",0.5,-1],["cat_2.jpg",0.5,-1],["dog_1.jpg",0.5,-1],["dog_2.jpg",0.5,-1],["frog_1.jpg",0.5,-1],["frog_2.jpg",0.5,-1],["horse_1.jpg",0.5,-1],["horse_2.jpg",0.5,-1]]
84
- demo = gr.Interface(
85
- inference,
86
- inputs = [gr.Image(shape=(32, 32), label="Input Image"),
87
- gr.Slider(0, 1, value=0.5, label="Opacity of GradCAM"),
88
- gr.Slider(-2, -1, value=-2, step=1, label="Which Layer?"),
89
- gr.Number(default=1, label="Number of GradCAM Images to Show"),
90
- gr.Slider(1, 10, value=3, step=1, label="Number of Top Classes to Show")],
91
- outputs = [gr.Label(num_top_classes=5), gr.Image(shape=(32, 32), label="Output").style(width=128, height=128)],
92
- title = title,
93
- description = description,
94
- examples = examples,
95
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
 
98
- demo.launch()
 
 
 
 
1
  import numpy as np
2
  import gradio as gr
3
  from PIL import Image
4
  from pytorch_grad_cam import GradCAM
5
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
6
  from pytorch_grad_cam.utils.image import show_cam_on_image
7
+ import torch
8
+ from torchvision import datasets, transforms
9
+ #from model import CustomResNet
10
+ import random
11
+
12
+
13
+ model = CustomResNet()
14
+ model.load_state_dict(torch.load('CustomResNet.pth', map_location=torch.device('cpu')), strict=False)
15
+ model.eval()
16
+
 
17
 
 
 
 
 
 
 
 
 
 
 
18
  classes = ('plane', 'car', 'bird', 'cat', 'deer',
19
+ 'dog', 'frog', 'horse', 'ship', 'truck')
20
 
21
+
22
+ def inference(input_img, input_slider_grad_or_not, transparency = 0.5, target_layer_number = 3, topk = 3):
23
+ mean=[0.49139968, 0.48215827, 0.44653124]
24
+ std=[0.24703233, 0.24348505, 0.26158768]
25
+ transform = transforms.Compose([
26
+ transforms.ToTensor(),
27
+ transforms.Normalize(mean, std)
28
+ ])
29
+ orginal_img = input_img
30
  input_img = transform(input_img)
31
  input_img = input_img.unsqueeze(0)
32
  outputs = model(input_img)
33
+ softmax = torch.nn.Softmax(dim=0)
34
+ o = softmax(outputs.flatten())
35
+ confidences = {classes[i]: float(o[i]) for i in range(10)}
36
+ if input_slider_grad_or_not == "No":
37
+ return confidences, orginal_img
38
  _, prediction = torch.max(outputs, 1)
39
+ target_layers = [model.layer_3[-1]]
40
+ if target_layer_number == 1:
41
+ target_layers = [model.layer_1[-1]]
42
+ if target_layer_number == 2:
43
+ target_layers = [model.layer_2[-1]]
44
+ if target_layer_number == 3:
45
+ target_layers = [model.layer_3[-1]]
46
  cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
47
  grayscale_cam = cam(input_tensor=input_img, targets=None)
48
  grayscale_cam = grayscale_cam[0, :]
49
+ visualization = show_cam_on_image(orginal_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
50
+
51
+ return confidences, visualization
52
+
53
+
54
+ def show_gradcam_images(n, a, b):
55
+ images = [
56
+ ('examples/car.jpg', 'car'),
57
+ ('examples/cat.jpg', 'cat'),
58
+ ('examples/dog.jpg', 'dog'),
59
+ ('examples/horse.jpg', 'horse'),
60
+ ('examples/ship.jpg', 'ship'),
61
+ ('examples/bird.jpg', 'bird'),
62
+ ('examples/frog.jpg', 'frog'),
63
+ ('examples/plane.jpg', 'plane'),
64
+ ('examples/truck.jpg', 'truck'),
65
+ ('examples/deer.jpg', 'deer'),
66
+ ]
67
+ images_with_gradcam = []
68
+ for image_path, label in images:
69
+ image = Image.open(image_path)
70
+ image_array = np.asarray(image)
71
+ visualization = inference(image_array, "Yes", a, b)[-1]
72
+ images_with_gradcam.append((visualization, label))
73
+
74
+ return {
75
+ grad1_block: gr.update(visible=True),
76
+ gallery3: images_with_gradcam[:n]
77
+ }
78
+
79
+
80
+ def change_grad_view(choice):
81
+ if choice == "Yes":
82
+ return grad_block.update(visible=True)
83
+ else:
84
+ return grad_block.update(visible=False)
85
+
86
+
87
+ def show_misclassified_images(n, grad_cam, a, b):
88
+ images = [
89
+ ('misclassified_images/misclassified_0_GT_bird_Pred_cat.jpg', 'bird/cat'),
90
+ ('misclassfied_images/misclassified_1_GT_car_Pred_truck.jpg', 'car/truck'),
91
+ ('misclassified_images/misclassified_2_GT_plane_Pred_truck.jpg', 'plane/truck'),
92
+ ('misclassified_images/misclassified_3_GT_deer_Pred_dog.jpg', 'deer/dog'),
93
+ ('misclassified_images/misclassified_4_GT_frog_Pred_cat.jpg', 'frog/cat'),
94
+ ('misclassified_images/misclassified_5_GT_cat_Pred_dog.jpg', 'cat/dog'),
95
+ ('misclassified_images/misclassified_6_GT_cat_Pred_dog.jpg', 'cat/dog'),
96
+ ('misclassified_images/misclassified_7_GT_dog_Pred_horse.jpg', 'dog/horse'),
97
+ ('misclassified_images/misclassified_8_GT_bird_Pred_dog.jpg', 'bird/dog'),
98
+ ('misclassified_images/misclassified_9_GT_ship_Pred_plane.jpg', 'ship/plane')
99
+ ]
100
+ images_with_gradcam = []
101
+ for image_path, label in images:
102
+ image = Image.open(image_path)
103
+ image_array = np.asarray(image)
104
+ visualization = inference(image_array, "Yes", a, b)[-1]
105
+ images_with_gradcam.append((visualization, label))
106
+ if grad_cam == "Yes":
107
+ return {
108
+ miscls1_block: gr.update(visible=True),
109
+ gallery: images_with_gradcam[:n]
110
+ }
111
+
112
+ return {
113
+ miscls1_block: gr.update(visible=True),
114
+ gallery: images[:n]
115
+ }
116
+
117
+
118
+ def change_miscls_view(choice):
119
+ if choice == "Yes":
120
+ return miscls_block.update(visible=True)
121
+ else:
122
+ return miscls_block.update(visible=False)
123
+
124
+
125
+ def change_textbox(choice):
126
+ if choice == "Yes":
127
+ return [gr.Slider.update(visible=True), gr.Slider.update(visible=True)]
128
+ else:
129
+ return [gr.Slider.update(visible=False), gr.Slider.update(visible=False)]
130
+
131
+
132
+ def update_num_top_classes(input_img, input_slider_grad_or_not, transparency, target_layer_number, topk):
133
+ output_classes.num_top_classes=topk
134
+ return inference(input_img, input_slider_grad_or_not, transparency, target_layer_number, topk)[0]
135
+
136
+
137
+ def change_mygrad_view(choice):
138
+ if choice == "Yes":
139
+ return grad_or_not.update(visible=True)
140
+ else:
141
+ return grad_or_not.update(visible=False)
142
+
143
+
144
+ with gr.Blocks(theme='abidlabs/dracula_revamped') as demo:
145
+ gr.Markdown("""
146
+
147
+ # CustomResNet with GradCAM - Interactive Interface
148
+
149
+ ### A simple Gradio interface to infer on CustomResNet model and get GradCAM results
150
+
151
+ """)
152
+ gr.Markdown("# Analyse the Model")
153
+ gr.Markdown("## Grad-CAM")
154
+ with gr.Row():
155
+ grad_yes_no = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to see GradCAM images")
156
+ with gr.Row(visible=False) as grad_block:
157
+ with gr.Column(scale=1):
158
+ input_grad = gr.Slider(1, 10, value = 3, step=1, label="Number of GradCAM images to view")
159
+ input_overlay = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to configure gradcam")
160
+ with gr.Row():
161
+ clear_btn3 = gr.ClearButton()
162
+ submit_btn3 = gr.Button("Submit")
163
+ with gr.Column(scale=1):
164
+ input_slider31 = gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM", interactive=True, visible=False)
165
+ input_slider32 = gr.Slider(1, 3, value = 3, step=1, label="Which Layer?", interactive=True, visible=False)
166
+ with gr.Row(visible=False) as grad1_block:
167
+ gallery3 = gr.Gallery(
168
+ label="GradCAM images", show_label=True, elem_id="gallery3"
169
+ ).style(columns=[4], rows=[3], object_fit="contain", height="auto")
170
+
171
+ submit_btn3.click(fn=show_gradcam_images, inputs=[input_grad, input_slider31, input_slider32], outputs = [grad1_block, gallery3])
172
+ clear_btn3.click(lambda: [None, None, None, None, None], outputs=[input_grad, input_grad, input_slider31, input_slider32, gallery3])
173
+ input_overlay.change(fn=change_textbox, inputs=input_overlay, outputs=[input_slider31, input_slider32])
174
+ grad_yes_no.change(fn=change_grad_view, inputs=grad_yes_no, outputs=[grad_block])
175
 
176
+
177
+ ###############################################
178
 
179
+
180
+ gr.Markdown("## Misclassification")
181
+ with gr.Row():
182
+ miscls_yes_no = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to see misclassified images")
183
+ with gr.Row(visible=False) as miscls_block:
184
+ with gr.Column(scale=1):
185
+ input_miscn = gr.Slider(1, 10, value = 3, step=1, label="Number of misclassified images to view")
186
+ with gr.Row():
187
+ clear_btn2 = gr.ClearButton()
188
+ submit_btn2 = gr.Button("Submit")
189
+ with gr.Column(scale=1):
190
+ input_grad2 = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to overlay gradcam")
191
+ input_slider21 = gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM", interactive=True, visible=False)
192
+ input_slider22 = gr.Slider(1, 3, value = 3, step=1, label="Which Layer?", interactive=True, visible=False)
193
+ with gr.Column(visible=False) as miscls1_block:
194
+ gallery = gr.Gallery(
195
+ label="Misclassified images", show_label=True, elem_id="gallery"
196
+ ).style(columns=[4], rows=[3], object_fit="contain", height="auto")
197
+
198
+
199
+ submit_btn2.click(fn=show_misclassified_images, inputs=[input_miscn, input_grad2, input_slider21, input_slider22], outputs = [miscls1_block, gallery])
200
+ clear_btn2.click(lambda: [None, None, None, None, None], outputs=[input_miscn, input_grad, input_slider21, input_slider22, gallery])
201
+ input_grad2.change(fn=change_textbox, inputs=input_grad2, outputs=[input_slider21, input_slider22])
202
+ miscls_yes_no.change(fn=change_miscls_view, inputs=miscls_yes_no, outputs=[miscls_block])
203
 
204
+
205
+ ###############################################
206
 
207
+
208
+ gr.Markdown("## Try it Out")
209
+ with gr.Row():
210
+ with gr.Column(scale=1):
211
+ input_image = gr.Image(shape=(32, 32), label="Input Image")
212
+ input_topk = gr.Slider(1, 10, value = 3, step=1, label="Top N Classes")
213
+ input_slider_grad_or_not = gr.Radio(choices = ["Yes", "No"], value="No", label="Do you want to overlay GradCAM output")
214
+ with gr.Column(visible=False) as grad_or_not:
215
+ input_slider1 = gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM")
216
+ input_slider2 = gr.Slider(1, 3, value = 3, step=1, label="Which Layer?")
217
+ with gr.Row():
218
+ clear_btn = gr.ClearButton()
219
+ submit_btn = gr.Button("Submit")
220
+ with gr.Column(scale=1):
221
+ output_classes = gr.Label(num_top_classes=3)
222
+ output_image = gr.Image(shape=(32, 32), label="Output").style(width=128, height=128)
223
+
224
+
225
+ gr.Markdown("## Examples")
226
+ gr.Examples(
227
+ examples=[["examples/car.jpg", "Yes", 0.5, 3, 3],
228
+ ["examples/cat.jpg", "Yes", 0.7, 2, 5],
229
+ ["examples/dog.jpg", "Yes", 0.9, 1, 4],
230
+ ["examples/truck.jpg", "Yes", 0.3, 1, 7],
231
+ ["examples/horse.jpg", "Yes", 0.7, 3, 4],
232
+ ["examples/frog.jpg", "Yes", 0.8, 3, 6],
233
+ ["examples/bird.jpg", "Yes", 0.9, 1, 7],
234
+ ["examples/deer.jpg", "Yes", 0.3, 1, 3],
235
+ ["examples/plane.jpg", "Yes", 0.4, 3, 4],
236
+ ["examples/ship.jpg", "Yes", 0.5, 2, 5]
237
+ ],
238
+ inputs=[input_image,input_slider_grad_or_not,input_slider1,input_slider2, input_topk],
239
+ outputs=[output_classes,output_image],
240
+ fn=inference,
241
+ cache_examples=True,
242
+ )
243
+
244
+ submit_btn.click(fn=inference, inputs=[input_image, input_slider_grad_or_not, input_slider1, input_slider2, input_topk], outputs=[output_classes, output_image])
245
+ clear_btn.click(lambda: [None, "No", 0.5, 3, None, None, 3], outputs=[input_image, input_slider_grad_or_not, input_slider1, input_slider2, output_classes, output_image])
246
+ input_topk.change(update_num_top_classes, inputs=[input_image, input_slider_grad_or_not, input_slider1, input_slider2, input_topk], outputs=[output_classes])
247
+ input_slider_grad_or_not.change(fn=change_mygrad_view, inputs=input_slider_grad_or_not, outputs=[grad_or_not])
248
 
249
 
250
+ if __name__ == "__main__":
251
+ demo.launch(debug=True)