ravi.naik commited on
Commit
4db4d66
1 Parent(s): d028330

Added source

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ lightning_logs
2
+ data
3
+ .ipynb_checkpoints
4
+ __pycache__/
app.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ import torchvision
7
+
8
+ from pytorch_grad_cam import GradCAM
9
+ from pytorch_grad_cam.utils.image import show_cam_on_image
10
+
11
+ from models.resnet_lightning import ResNet
12
+ from utils.data import CIFARDataModule
13
+ from utils.transforms import test_transform
14
+ from utils.common import get_misclassified_data
15
+
16
+ inv_normalize = torchvision.transforms.Normalize(
17
+ mean=[-0.50 / 0.23, -0.50 / 0.23, -0.50 / 0.23], std=[1 / 0.23, 1 / 0.23, 1 / 0.23]
18
+ )
19
+
20
+ datamodule = CIFARDataModule()
21
+ datamodule.setup()
22
+ classes = datamodule.train_dataset.classes
23
+
24
+ model = ResNet.load_from_checkpoint("model.ckpt")
25
+ model = model.to("cpu")
26
+
27
+ prediction_image = None
28
+
29
+
30
+ def upload_file(files):
31
+ file_paths = [file.name for file in files]
32
+ return file_paths
33
+
34
+
35
+ def read_image(path):
36
+ img = Image.open(path)
37
+ img.load()
38
+ data = np.asarray(img, dtype="uint8")
39
+ return data
40
+
41
+
42
+ def sample_images():
43
+ images = []
44
+ length = len(datamodule.test_dataset)
45
+ classes = datamodule.train_dataset.classes
46
+ for i in range(10):
47
+ idx = random.randint(0, length - 1)
48
+ image, label = datamodule.test_dataset[idx]
49
+ image = inv_normalize(image).permute(1, 2, 0).numpy()
50
+ images.append((image, classes[label]))
51
+ return images
52
+
53
+
54
+ def get_misclassified_images(misclassified_count):
55
+ misclassified_images = []
56
+ misclassified_data = get_misclassified_data(
57
+ model=model,
58
+ device="cpu",
59
+ test_loader=datamodule.test_dataloader(),
60
+ count=misclassified_count,
61
+ )
62
+ for i in range(misclassified_count):
63
+ img = misclassified_data[i][0].squeeze().to("cpu")
64
+ img = inv_normalize(img)
65
+ img = np.transpose(img.numpy(), (1, 2, 0))
66
+ label = f"Label: {classes[misclassified_data[i][1].item()]} | Prediction: {classes[misclassified_data[i][2].item()]}"
67
+ misclassified_images.append((img, label))
68
+ return misclassified_images
69
+
70
+
71
+ def get_gradcam_images(gradcam_layer, gradcam_count, gradcam_opacity):
72
+ gradcam_images = []
73
+ if gradcam_layer == "Layer1":
74
+ target_layers = [model.layer1[-1]]
75
+ elif gradcam_layer == "Layer2":
76
+ target_layers = [model.layer2[-1]]
77
+ else:
78
+ target_layers = [model.layer3[-1]]
79
+
80
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
81
+ data = get_misclassified_data(
82
+ model=model,
83
+ device="cpu",
84
+ test_loader=datamodule.test_dataloader(),
85
+ count=gradcam_count,
86
+ )
87
+ for i in range(gradcam_count):
88
+ input_tensor = data[i][0]
89
+
90
+ # Get the activations of the layer for the images
91
+ grayscale_cam = cam(input_tensor=input_tensor, targets=None)
92
+ grayscale_cam = grayscale_cam[0, :]
93
+
94
+ # Get back the original image
95
+ img = input_tensor.squeeze(0).to("cpu")
96
+ if inv_normalize is not None:
97
+ img = inv_normalize(img)
98
+ rgb_img = np.transpose(img, (1, 2, 0))
99
+ rgb_img = rgb_img.numpy()
100
+
101
+ # Mix the activations on the original image
102
+ visualization = show_cam_on_image(
103
+ rgb_img, grayscale_cam, use_rgb=True, image_weight=gradcam_opacity
104
+ )
105
+ label = f"Label: {classes[data[i][1].item()]} | Prediction: {classes[data[i][2].item()]}"
106
+ gradcam_images.append((visualization, label))
107
+ return gradcam_images
108
+
109
+
110
+ def show_hide_misclassified(status):
111
+ if not status:
112
+ return {misclassified_count: gr.update(visible=False)}
113
+ return {misclassified_count: gr.update(visible=True)}
114
+
115
+
116
+ def show_hide_gradcam(status):
117
+ if not status:
118
+ return [gr.update(visible=False) for i in range(3)]
119
+ return [gr.update(visible=True) for i in range(3)]
120
+
121
+
122
+ def set_prediction_image(evt: gr.SelectData, gallery):
123
+ global prediction_image
124
+ if isinstance(gallery[evt.index], dict):
125
+ prediction_image = gallery[evt.index]["name"]
126
+ else:
127
+ prediction_image = gallery[evt.index][0]["name"]
128
+
129
+
130
+ def predict(
131
+ is_misclassified,
132
+ misclassified_count,
133
+ is_gradcam,
134
+ gradcam_count,
135
+ gradcam_layer,
136
+ gradcam_opacity,
137
+ num_classes,
138
+ ):
139
+ misclassified_images = None
140
+ if is_misclassified:
141
+ misclassified_images = get_misclassified_images(int(misclassified_count))
142
+
143
+ gradcam_images = None
144
+ if is_gradcam:
145
+ gradcam_images = get_gradcam_images(
146
+ gradcam_layer, int(gradcam_count), gradcam_opacity
147
+ )
148
+
149
+ img = read_image(prediction_image)
150
+ image_transformed = test_transform(image=img)["image"]
151
+ output = model(image_transformed.unsqueeze(0))
152
+ preds = torch.softmax(output, dim=1).squeeze().detach().numpy()
153
+ indices = (
154
+ output.argsort(descending=True).squeeze().detach().numpy()[: int(num_classes)]
155
+ )
156
+ predictions = {classes[i]: round(float(preds[i]), 2) for i in indices}
157
+
158
+ return {
159
+ miscalssfied_output: gr.update(value=misclassified_images),
160
+ gradcam_output: gr.update(value=gradcam_images),
161
+ prediction_label: gr.update(value=predictions),
162
+ }
163
+
164
+
165
+ with gr.Blocks() as app:
166
+ gr.Markdown("## ERA Session12 - CIFAR10 Classification with ResNet")
167
+ with gr.Row():
168
+ with gr.Column():
169
+ with gr.Box():
170
+ is_misclassified = gr.Checkbox(
171
+ label="Misclassified Images", info="Display misclassified images?"
172
+ )
173
+ misclassified_count = gr.Dropdown(
174
+ choices=["10", "20"],
175
+ label="Select Number of Images",
176
+ info="Number of Misclassified images",
177
+ visible=False,
178
+ interactive=True,
179
+ )
180
+ is_misclassified.input(
181
+ show_hide_misclassified,
182
+ inputs=[is_misclassified],
183
+ outputs=[misclassified_count],
184
+ )
185
+ with gr.Box():
186
+ is_gradcam = gr.Checkbox(
187
+ label="GradCAM Images",
188
+ info="Display GradCAM images?",
189
+ )
190
+ gradcam_count = gr.Dropdown(
191
+ choices=["10", "20"],
192
+ label="Select Number of Images",
193
+ info="Number of GradCAM images",
194
+ interactive=True,
195
+ visible=False,
196
+ )
197
+ gradcam_layer = gr.Dropdown(
198
+ choices=["Layer1", "Layer2", "Layer3"],
199
+ label="Select the layer",
200
+ info="Please select the layer for which the GradCAM is required",
201
+ interactive=True,
202
+ visible=False,
203
+ )
204
+ gradcam_opacity = gr.Slider(
205
+ minimum=0,
206
+ maximum=1,
207
+ value=0.6,
208
+ label="Opacity",
209
+ info="Opacity of GradCAM output",
210
+ interactive=True,
211
+ visible=False,
212
+ )
213
+
214
+ is_gradcam.input(
215
+ show_hide_gradcam,
216
+ inputs=[is_gradcam],
217
+ outputs=[gradcam_count, gradcam_layer, gradcam_opacity],
218
+ )
219
+ with gr.Box():
220
+ # file_output = gr.File(file_types=["image"])
221
+ with gr.Group():
222
+ upload_gallery = gr.Gallery(
223
+ value=None,
224
+ label="Uploaded images",
225
+ show_label=False,
226
+ elem_id="gallery_upload",
227
+ columns=5,
228
+ rows=2,
229
+ height="auto",
230
+ object_fit="contain",
231
+ )
232
+ upload_button = gr.UploadButton(
233
+ "Click to Upload images",
234
+ file_types=["image"],
235
+ file_count="multiple",
236
+ )
237
+ upload_button.upload(upload_file, upload_button, upload_gallery)
238
+
239
+ with gr.Group():
240
+ sample_gallery = gr.Gallery(
241
+ value=sample_images,
242
+ label="Sample images",
243
+ show_label=True,
244
+ elem_id="gallery_sample",
245
+ columns=5,
246
+ rows=2,
247
+ height="auto",
248
+ object_fit="contain",
249
+ )
250
+
251
+ upload_gallery.select(set_prediction_image, inputs=[upload_gallery])
252
+ sample_gallery.select(set_prediction_image, inputs=[sample_gallery])
253
+
254
+ with gr.Box():
255
+ num_classes = gr.Dropdown(
256
+ choices=[str(i + 1) for i in range(10)],
257
+ label="Select Number of Top Classes",
258
+ info="Number of Top target classes to be shown",
259
+ )
260
+ run_btn = gr.Button()
261
+ with gr.Column():
262
+ with gr.Box():
263
+ miscalssfied_output = gr.Gallery(
264
+ value=None, label="Misclassified Images", show_label=True
265
+ )
266
+ with gr.Box():
267
+ gradcam_output = gr.Gallery(
268
+ value=None, label="GradCAM Images", show_label=True
269
+ )
270
+ with gr.Box():
271
+ prediction_label = gr.Label(value=None, label="Predictions")
272
+
273
+ run_btn.click(
274
+ predict,
275
+ inputs=[
276
+ is_misclassified,
277
+ misclassified_count,
278
+ is_gradcam,
279
+ gradcam_count,
280
+ gradcam_layer,
281
+ gradcam_opacity,
282
+ num_classes,
283
+ ],
284
+ outputs=[miscalssfied_output, gradcam_output, prediction_label],
285
+ )
286
+
287
+
288
+ app.launch(server_name="0.0.0.0", server_port=9998)
config.toml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [data]
2
+ batch_size = 512
3
+ shuffle = true
4
+ num_workers = 4
5
+
6
+ [training]
7
+ epochs = 20
8
+ batch_size = 512
9
+ optimizer = "adam"
10
+ criterion = "crossentropy"
11
+ lr = 0.003
12
+ weight_decay = 1e-4
13
+ lrfinder = { numiter = 600, endlr = 10, startlr = 1e-2 }
model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f3d4b6359778a6dd0c86e85afb1a522aae822ccfeeea9a6fb82aabb124f518d
3
+ size 78938183
models/custom_resnet.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class ResBlock(nn.Module):
6
+ def __init__(self, channels):
7
+ super(ResBlock, self).__init__()
8
+
9
+ self.resblock = nn.Sequential(
10
+ nn.Conv2d(
11
+ in_channels=channels,
12
+ out_channels=channels,
13
+ kernel_size=3,
14
+ stride=1,
15
+ padding=1,
16
+ bias=False,
17
+ ),
18
+ nn.BatchNorm2d(channels),
19
+ nn.ReLU(),
20
+ nn.Conv2d(
21
+ in_channels=channels,
22
+ out_channels=channels,
23
+ kernel_size=3,
24
+ stride=1,
25
+ padding=1,
26
+ bias=False,
27
+ ),
28
+ nn.BatchNorm2d(channels),
29
+ nn.ReLU(),
30
+ )
31
+
32
+ def forward(self, x):
33
+ return x + self.resblock(x)
34
+
35
+
36
+ class CustomResnet(nn.Module):
37
+ def __init__(self):
38
+ super(CustomResnet, self).__init__()
39
+
40
+ self.prep = nn.Sequential(
41
+ nn.Conv2d(
42
+ in_channels=3,
43
+ out_channels=64,
44
+ kernel_size=3,
45
+ stride=1,
46
+ padding=1,
47
+ bias=False,
48
+ ),
49
+ nn.BatchNorm2d(64),
50
+ nn.ReLU(),
51
+ )
52
+
53
+ self.layer1 = nn.Sequential(
54
+ nn.Conv2d(
55
+ in_channels=64,
56
+ out_channels=128,
57
+ kernel_size=3,
58
+ padding=1,
59
+ stride=1,
60
+ bias=False,
61
+ ),
62
+ nn.MaxPool2d(kernel_size=2),
63
+ nn.BatchNorm2d(128),
64
+ nn.ReLU(),
65
+ ResBlock(channels=128),
66
+ )
67
+
68
+ self.layer2 = nn.Sequential(
69
+ nn.Conv2d(
70
+ in_channels=128,
71
+ out_channels=256,
72
+ kernel_size=3,
73
+ padding=1,
74
+ stride=1,
75
+ bias=False,
76
+ ),
77
+ nn.MaxPool2d(kernel_size=2),
78
+ nn.BatchNorm2d(256),
79
+ nn.ReLU(),
80
+ )
81
+
82
+ self.layer3 = nn.Sequential(
83
+ nn.Conv2d(
84
+ in_channels=256,
85
+ out_channels=512,
86
+ kernel_size=3,
87
+ padding=1,
88
+ stride=1,
89
+ bias=False,
90
+ ),
91
+ nn.MaxPool2d(kernel_size=2),
92
+ nn.BatchNorm2d(512),
93
+ nn.ReLU(),
94
+ ResBlock(channels=512),
95
+ )
96
+
97
+ self.pool = nn.MaxPool2d(kernel_size=4)
98
+
99
+ self.fc = nn.Linear(in_features=512, out_features=10, bias=False)
100
+
101
+ self.softmax = nn.Softmax(dim=-1)
102
+
103
+ def forward(self, x):
104
+ x = self.prep(x)
105
+ x = self.layer1(x)
106
+ x = self.layer2(x)
107
+ x = self.layer3(x)
108
+ x = self.pool(x)
109
+ x = x.view(-1, 512)
110
+ x = self.fc(x)
111
+ # x = self.softmax(x)
112
+ return x
models/resnet_lightning.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import lightning as L
4
+ from torchmetrics import Accuracy
5
+ from typing import Any
6
+
7
+ from utils.common import one_cycle_lr
8
+
9
+ class ResidualBlock(L.LightningModule):
10
+ def __init__(self, channels):
11
+ super(ResidualBlock, self).__init__()
12
+
13
+ self.residual_block = nn.Sequential(
14
+ nn.Conv2d(
15
+ in_channels=channels,
16
+ out_channels=channels,
17
+ kernel_size=3,
18
+ stride=1,
19
+ padding=1,
20
+ bias=False,
21
+ ),
22
+ nn.BatchNorm2d(channels),
23
+ nn.ReLU(),
24
+ nn.Conv2d(
25
+ in_channels=channels,
26
+ out_channels=channels,
27
+ kernel_size=3,
28
+ stride=1,
29
+ padding=1,
30
+ bias=False,
31
+ ),
32
+ nn.BatchNorm2d(channels),
33
+ nn.ReLU(),
34
+ )
35
+
36
+ def forward(self, x):
37
+ return x + self.residual_block(x)
38
+
39
+ class ResNet(L.LightningModule):
40
+ def __init__(
41
+ self, batch_size=512, shuffle=True, num_workers=4, learning_rate=0.003, scheduler_steps=None, maxlr=None, epochs=None
42
+ ):
43
+ super(ResNet, self).__init__()
44
+ self.data_dir = "./data"
45
+ self.batch_size = batch_size
46
+ self.shuffle = shuffle
47
+ self.num_workers = num_workers
48
+ self.learning_rate = learning_rate
49
+ self.scheduler_steps = scheduler_steps
50
+ self.maxlr = maxlr if maxlr is not None else learning_rate
51
+ self.epochs = epochs
52
+
53
+ self.prep = nn.Sequential(
54
+ nn.Conv2d(
55
+ in_channels=3,
56
+ out_channels=64,
57
+ kernel_size=3,
58
+ stride=1,
59
+ padding=1,
60
+ bias=False,
61
+ ),
62
+ nn.BatchNorm2d(64),
63
+ nn.ReLU(),
64
+ )
65
+
66
+ self.layer1 = nn.Sequential(
67
+ nn.Conv2d(
68
+ in_channels=64,
69
+ out_channels=128,
70
+ kernel_size=3,
71
+ padding=1,
72
+ stride=1,
73
+ bias=False,
74
+ ),
75
+ nn.MaxPool2d(kernel_size=2),
76
+ nn.BatchNorm2d(128),
77
+ nn.ReLU(),
78
+ ResidualBlock(channels=128),
79
+ )
80
+
81
+ self.layer2 = nn.Sequential(
82
+ nn.Conv2d(
83
+ in_channels=128,
84
+ out_channels=256,
85
+ kernel_size=3,
86
+ padding=1,
87
+ stride=1,
88
+ bias=False,
89
+ ),
90
+ nn.MaxPool2d(kernel_size=2),
91
+ nn.BatchNorm2d(256),
92
+ nn.ReLU(),
93
+ )
94
+
95
+ self.layer3 = nn.Sequential(
96
+ nn.Conv2d(
97
+ in_channels=256,
98
+ out_channels=512,
99
+ kernel_size=3,
100
+ padding=1,
101
+ stride=1,
102
+ bias=False,
103
+ ),
104
+ nn.MaxPool2d(kernel_size=2),
105
+ nn.BatchNorm2d(512),
106
+ nn.ReLU(),
107
+ ResidualBlock(channels=512),
108
+ )
109
+
110
+ self.pool = nn.MaxPool2d(kernel_size=4)
111
+
112
+ self.fc = nn.Linear(in_features=512, out_features=10, bias=False)
113
+
114
+ self.softmax = nn.Softmax(dim=-1)
115
+
116
+ self.accuracy = Accuracy(task="multiclass", num_classes=10)
117
+
118
+ def forward(self, x):
119
+ x = self.prep(x)
120
+ x = self.layer1(x)
121
+ x = self.layer2(x)
122
+ x = self.layer3(x)
123
+ x = self.pool(x)
124
+ x = x.view(-1, 512)
125
+ x = self.fc(x)
126
+ # x = self.softmax(x)
127
+ return x
128
+
129
+ def configure_optimizers(self) -> Any:
130
+ optimizer = torch.optim.Adam(
131
+ self.parameters(), lr=self.learning_rate, weight_decay=1e-4
132
+ )
133
+ scheduler = one_cycle_lr(
134
+ optimizer=optimizer, maxlr=self.maxlr, steps=self.scheduler_steps, epochs=self.epochs
135
+ )
136
+ return {"optimizer": optimizer,
137
+ "lr_scheduler": {"scheduler": scheduler,
138
+ "interval": "step"}}
139
+
140
+ def training_step(self, batch, batch_idx):
141
+ X, y = batch
142
+ y_pred = self(X)
143
+ loss = nn.CrossEntropyLoss()(y_pred, y)
144
+
145
+ preds = torch.argmax(y_pred, dim=1)
146
+
147
+ accuracy = self.accuracy(preds, y)
148
+
149
+ self.log_dict({"train_loss": loss, "train_acc": accuracy}, prog_bar=True)
150
+ return loss
151
+
152
+ def validation_step(self, batch, batch_idx):
153
+ X, y = batch
154
+ y_pred = self(X)
155
+ loss = nn.CrossEntropyLoss(reduction="sum")(y_pred, y)
156
+
157
+ preds = torch.argmax(y_pred, dim=1)
158
+
159
+ accuracy = self.accuracy(preds, y)
160
+
161
+ self.log_dict({"val_loss": loss, "val_acc": accuracy}, prog_bar=True)
162
+
163
+ return loss
164
+
165
+ def test_step(self, batch, batch_idx):
166
+ X, y = batch
167
+ y_pred = self(X)
168
+ loss = nn.CrossEntropyLoss(reduction="sum")(y_pred, y)
169
+ preds = torch.argmax(y_pred, dim=1)
170
+
171
+ accuracy = self.accuracy(preds, y)
172
+
173
+ self.log_dict({"test_loss": loss, "test_acc": accuracy}, prog_bar=True)
requirements.txt ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.4.0
2
+ adbc-driver-manager==0.5.1
3
+ adbc-driver-sqlite==0.5.1
4
+ aiofiles==23.1.0
5
+ aiohttp==3.8.5
6
+ aiosignal==1.3.1
7
+ albumentations==1.3.1
8
+ altair==5.0.1
9
+ annotated-types==0.5.0
10
+ anyio==3.7.1
11
+ argon2-cffi==21.3.0
12
+ argon2-cffi-bindings==21.2.0
13
+ arrow==1.2.3
14
+ asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1670263926556/work
15
+ async-lru==2.0.4
16
+ async-timeout==4.0.2
17
+ attrs==23.1.0
18
+ Babel==2.12.1
19
+ backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work
20
+ backoff==2.2.1
21
+ backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1687772187254/work
22
+ beautifulsoup4==4.12.2
23
+ black==23.7.0
24
+ bleach==6.0.0
25
+ blessed==1.20.0
26
+ cachetools==5.3.1
27
+ certifi==2022.12.7
28
+ cffi==1.15.1
29
+ charset-normalizer==2.1.1
30
+ click==8.1.6
31
+ cloudpickle==2.2.1
32
+ cmake==3.25.0
33
+ connectorx==0.3.1
34
+ contourpy==1.1.0
35
+ croniter==1.4.1
36
+ cycler==0.11.0
37
+ dateutils==0.6.12
38
+ debugpy @ file:///home/builder/ci_310/debugpy_1640789504635/work
39
+ decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
40
+ deepdiff==6.3.1
41
+ defusedxml==0.7.1
42
+ deltalake==0.10.0
43
+ entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work
44
+ exceptiongroup==1.1.2
45
+ executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1667317341051/work
46
+ fastapi==0.100.1
47
+ fastjsonschema==2.18.0
48
+ ffmpy==0.3.1
49
+ filelock==3.12.2
50
+ fonttools==4.41.0
51
+ fqdn==1.5.1
52
+ frozenlist==1.4.0
53
+ fsspec==2023.6.0
54
+ google-auth==2.22.0
55
+ google-auth-oauthlib==1.0.0
56
+ grad-cam==1.4.8
57
+ gradio==3.39.0
58
+ gradio_client==0.3.0
59
+ greenlet==2.0.2
60
+ grpcio==1.56.2
61
+ h11==0.14.0
62
+ httpcore==0.17.3
63
+ httpx==0.24.1
64
+ huggingface-hub==0.16.4
65
+ idna==3.4
66
+ imageio==2.31.1
67
+ inquirer==3.1.3
68
+ ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1655369107642/work
69
+ ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1685727741709/work
70
+ ipywidgets==8.0.7
71
+ isoduration==20.11.0
72
+ itsdangerous==2.1.2
73
+ jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1669134318875/work
74
+ Jinja2==3.1.2
75
+ joblib==1.3.1
76
+ json5==0.9.14
77
+ jsonpointer==2.4
78
+ jsonschema==4.18.6
79
+ jsonschema-specifications==2023.7.1
80
+ jupyter-events==0.7.0
81
+ jupyter-lsp==2.2.0
82
+ jupyter_client==8.3.0
83
+ jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1686775611663/work
84
+ jupyter_server==2.7.0
85
+ jupyter_server_terminals==0.4.4
86
+ jupyterlab==4.0.4
87
+ jupyterlab-pygments==0.2.2
88
+ jupyterlab-widgets==3.0.8
89
+ jupyterlab_server==2.24.0
90
+ kiwisolver==1.4.4
91
+ lazy_loader==0.3
92
+ lightning==2.0.6
93
+ lightning-cloud==0.5.37
94
+ lightning-utilities==0.9.0
95
+ linkify-it-py==2.0.2
96
+ lit==15.0.7
97
+ Markdown==3.4.3
98
+ markdown-it-py==2.2.0
99
+ MarkupSafe==2.1.2
100
+ matplotlib==3.7.2
101
+ matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work
102
+ mdit-py-plugins==0.3.3
103
+ mdurl==0.1.2
104
+ mistune==3.0.1
105
+ mpmath==1.2.1
106
+ multidict==6.0.4
107
+ mypy-extensions==1.0.0
108
+ nbclient==0.8.0
109
+ nbconvert==7.7.3
110
+ nbformat==5.9.2
111
+ nest-asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1664684991461/work
112
+ netron==7.0.6
113
+ networkx==3.0
114
+ notebook_shim==0.2.3
115
+ numpy==1.24.1
116
+ nvidia-cublas-cu11==11.10.3.66
117
+ nvidia-cuda-cupti-cu11==11.7.101
118
+ nvidia-cuda-nvrtc-cu11==11.7.99
119
+ nvidia-cuda-runtime-cu11==11.7.99
120
+ nvidia-cudnn-cu11==8.5.0.96
121
+ nvidia-cufft-cu11==10.9.0.58
122
+ nvidia-curand-cu11==10.2.10.91
123
+ nvidia-cusolver-cu11==11.4.0.1
124
+ nvidia-cusparse-cu11==11.7.4.91
125
+ nvidia-nccl-cu11==2.14.3
126
+ nvidia-nvtx-cu11==11.7.91
127
+ oauthlib==3.2.2
128
+ opencv-python==4.8.0.74
129
+ opencv-python-headless==4.8.0.74
130
+ ordered-set==4.1.0
131
+ orjson==3.9.3
132
+ overrides==7.3.1
133
+ packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1681337016113/work
134
+ pandas==2.0.3
135
+ pandocfilters==1.5.0
136
+ parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work
137
+ pathspec==0.11.2
138
+ pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1667297516076/work
139
+ pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
140
+ Pillow==10.0.0
141
+ platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1689538620473/work
142
+ polars==0.18.8
143
+ prometheus-client==0.17.1
144
+ prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1688565951714/work
145
+ protobuf==4.23.4
146
+ psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work
147
+ ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
148
+ pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
149
+ pyarrow==12.0.1
150
+ pyasn1==0.5.0
151
+ pyasn1-modules==0.3.0
152
+ pycparser==2.21
153
+ pydantic==2.0.3
154
+ pydantic_core==2.3.0
155
+ pydub==0.25.1
156
+ Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1681904169130/work
157
+ PyJWT==2.8.0
158
+ pyparsing==3.0.9
159
+ python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work
160
+ python-editor==1.0.4
161
+ python-json-logger==2.0.7
162
+ python-multipart==0.0.6
163
+ pytorch-lightning==2.0.6
164
+ pytz==2023.3
165
+ PyWavelets==1.4.1
166
+ PyYAML==6.0.1
167
+ pyzmq @ file:///croot/pyzmq_1686601365461/work
168
+ qudida==0.0.4
169
+ readchar==4.0.5
170
+ referencing==0.30.2
171
+ requests==2.28.1
172
+ requests-oauthlib==1.3.1
173
+ rfc3339-validator==0.1.4
174
+ rfc3986-validator==0.1.1
175
+ rich==13.5.0
176
+ rpds-py==0.9.2
177
+ rsa==4.9
178
+ ruff==0.0.280
179
+ scikit-image==0.21.0
180
+ scikit-learn==1.3.0
181
+ scipy==1.11.1
182
+ semantic-version==2.10.0
183
+ Send2Trash==1.8.2
184
+ six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
185
+ sniffio==1.3.0
186
+ soupsieve==2.4.1
187
+ SQLAlchemy==2.0.19
188
+ stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
189
+ starlette==0.27.0
190
+ starsessions==1.3.0
191
+ sympy==1.11.1
192
+ tensorboard==2.13.0
193
+ tensorboard-data-server==0.7.1
194
+ terminado==0.17.1
195
+ threadpoolctl==3.2.0
196
+ tifffile==2023.7.18
197
+ tinycss2==1.2.1
198
+ toml==0.10.2
199
+ tomli==2.0.1
200
+ toolz==0.12.0
201
+ torch==2.0.1+cu118
202
+ torch-lr-finder==0.2.1
203
+ torch-tb-profiler==0.4.1
204
+ torchaudio==2.0.2+cu118
205
+ torchinfo==1.8.0
206
+ torchmetrics==1.0.1
207
+ torchvision==0.15.2+cu118
208
+ tornado==6.3.2
209
+ tqdm==4.65.0
210
+ traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1675110562325/work
211
+ triton==2.0.0
212
+ ttach==0.0.3
213
+ typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1688315532570/work
214
+ tzdata==2023.3
215
+ uc-micro-py==1.0.2
216
+ uri-template==1.3.0
217
+ urllib3==1.26.13
218
+ uvicorn==0.23.1
219
+ wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1673864653149/work
220
+ webcolors==1.13
221
+ webencodings==0.5.1
222
+ websocket-client==1.6.1
223
+ websockets==11.0.3
224
+ Werkzeug==2.3.6
225
+ widgetsnbextension==4.0.8
226
+ xlsx2csv==0.8.1
227
+ XlsxWriter==3.1.2
228
+ yarl==1.9.2
session12.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
utils/common.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import matplotlib.pyplot as plt
4
+
5
+ import torch
6
+ import torchvision
7
+ from torchinfo import summary
8
+ from torch_lr_finder import LRFinder
9
+
10
+
11
+ def find_lr(model, optimizer, criterion, device, trainloader, numiter, startlr, endlr):
12
+ lr_finder = LRFinder(
13
+ model=model, optimizer=optimizer, criterion=criterion, device=device
14
+ )
15
+
16
+ lr_finder.range_test(
17
+ train_loader=trainloader,
18
+ start_lr=startlr,
19
+ end_lr=endlr,
20
+ num_iter=numiter,
21
+ step_mode="exp",
22
+ )
23
+
24
+ lr_finder.plot()
25
+
26
+ lr_finder.reset()
27
+
28
+
29
+ def one_cycle_lr(optimizer, maxlr, steps, epochs):
30
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
31
+ optimizer=optimizer,
32
+ max_lr=maxlr,
33
+ steps_per_epoch=steps,
34
+ epochs=epochs,
35
+ pct_start=5 / epochs,
36
+ div_factor=100,
37
+ three_phase=False,
38
+ final_div_factor=100,
39
+ anneal_strategy="linear",
40
+ )
41
+ return scheduler
42
+
43
+
44
+ def show_random_images_for_each_class(train_data, num_images_per_class=16):
45
+ for c, cls in enumerate(train_data.classes):
46
+ rand_targets = random.sample(
47
+ [n for n, x in enumerate(train_data.targets) if x == c],
48
+ k=num_images_per_class,
49
+ )
50
+ show_img_grid(np.transpose(train_data.data[rand_targets], axes=(0, 3, 1, 2)))
51
+ plt.title(cls)
52
+
53
+
54
+ def show_img_grid(data):
55
+ try:
56
+ grid_img = torchvision.utils.make_grid(data.cpu().detach())
57
+ except:
58
+ data = torch.from_numpy(data)
59
+ grid_img = torchvision.utils.make_grid(data)
60
+
61
+ plt.figure(figsize=(10, 10))
62
+ plt.imshow(grid_img.permute(1, 2, 0))
63
+
64
+
65
+ def show_random_images(data_loader):
66
+ data, target = next(iter(data_loader))
67
+ show_img_grid(data)
68
+
69
+
70
+ def show_model_summary(model, batch_size):
71
+ summary(
72
+ model=model,
73
+ input_size=(batch_size, 3, 32, 32),
74
+ col_names=["input_size", "output_size", "num_params", "kernel_size"],
75
+ verbose=1,
76
+ )
77
+
78
+
79
+ def lossacc_plots(results):
80
+ plt.plot(results["epoch"], results["trainloss"])
81
+ plt.plot(results["epoch"], results["testloss"])
82
+ plt.legend(["Train Loss", "Validation Loss"])
83
+ plt.xlabel("Epochs")
84
+ plt.ylabel("Loss")
85
+ plt.title("Loss vs Epochs")
86
+ plt.show()
87
+
88
+ plt.plot(results["epoch"], results["trainacc"])
89
+ plt.plot(results["epoch"], results["testacc"])
90
+ plt.legend(["Train Acc", "Validation Acc"])
91
+ plt.xlabel("Epochs")
92
+ plt.ylabel("Accuracy")
93
+ plt.title("Accuracy vs Epochs")
94
+ plt.show()
95
+
96
+
97
+ def lr_plots(results, length):
98
+ plt.plot(range(length), results["lr"])
99
+ plt.xlabel("Epochs")
100
+ plt.ylabel("Learning Rate")
101
+ plt.title("Learning Rate vs Epochs")
102
+ plt.show()
103
+
104
+
105
+ def get_misclassified(model, testloader, device, mis_count=10):
106
+ misimgs, mistgts, mispreds = [], [], []
107
+ with torch.no_grad():
108
+ for data, target in testloader:
109
+ data, target = data.to(device), target.to(device)
110
+ output = model(data)
111
+ pred = output.argmax(dim=1, keepdim=True)
112
+ misclassified = torch.argwhere(pred.squeeze() != target).squeeze()
113
+ for idx in misclassified:
114
+ if len(misimgs) >= mis_count:
115
+ break
116
+ misimgs.append(data[idx])
117
+ mistgts.append(target[idx])
118
+ mispreds.append(pred[idx].squeeze())
119
+ return misimgs, mistgts, mispreds
120
+
121
+
122
+ # def plot_misclassified(misimgs, mistgts, mispreds, classes):
123
+ # fig, axes = plt.subplots(len(misimgs) // 2, 2)
124
+ # fig.tight_layout()
125
+ # for ax, img, tgt, pred in zip(axes.ravel(), misimgs, mistgts, mispreds):
126
+ # ax.imshow((img / img.max()).permute(1, 2, 0).cpu())
127
+ # ax.set_title(f"{classes[tgt]} | {classes[pred]}")
128
+ # ax.grid(False)
129
+ # ax.set_axis_off()
130
+ # plt.show()
131
+
132
+ def get_misclassified_data(model, device, test_loader, count):
133
+ """
134
+ Function to run the model on test set and return misclassified images
135
+ :param model: Network Architecture
136
+ :param device: CPU/GPU
137
+ :param test_loader: DataLoader for test set
138
+ """
139
+ # Prepare the model for evaluation i.e. drop the dropout layer
140
+ model.eval()
141
+
142
+ # List to store misclassified Images
143
+ misclassified_data = []
144
+
145
+ # Reset the gradients
146
+ with torch.no_grad():
147
+ # Extract images, labels in a batch
148
+ for data, target in test_loader:
149
+
150
+ # Migrate the data to the device
151
+ data, target = data.to(device), target.to(device)
152
+
153
+ # Extract single image, label from the batch
154
+ for image, label in zip(data, target):
155
+
156
+ # Add batch dimension to the image
157
+ image = image.unsqueeze(0)
158
+
159
+ # Get the model prediction on the image
160
+ output = model(image)
161
+
162
+ # Convert the output from one-hot encoding to a value
163
+ pred = output.argmax(dim=1, keepdim=True)
164
+
165
+ # If prediction is incorrect, append the data
166
+ if pred != label:
167
+ misclassified_data.append((image, label, pred))
168
+ if len(misclassified_data) >= count:
169
+ break
170
+
171
+ return misclassified_data[:count]
172
+
173
+ def plot_misclassified(data, classes, size=(10, 10), rows=2, cols=5, inv_normalize=None):
174
+ fig = plt.figure(figsize=size)
175
+ number_of_samples = len(data)
176
+ for i in range(number_of_samples):
177
+ plt.subplot(rows, cols, i + 1)
178
+ img = data[i][0].squeeze().to('cpu')
179
+ if inv_normalize is not None:
180
+ img = inv_normalize(img)
181
+ plt.imshow(np.transpose(img, (1, 2, 0)))
182
+ plt.title(f"Label: {classes[data[i][1].item()]} \n Prediction: {classes[data[i][2].item()]}")
183
+ plt.xticks([])
184
+ plt.yticks([])
185
+
utils/config.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import toml
2
+ from pydantic import BaseModel
3
+
4
+ TOML_PATH = "config.toml"
5
+
6
+
7
+ class Data(BaseModel):
8
+ batch_size: int = 512
9
+ shuffle: bool = True
10
+ num_workers: int = 4
11
+
12
+
13
+ class LRFinder(BaseModel):
14
+ numiter: int = 600
15
+ endlr: float = 10
16
+ startlr: float = 1e-2
17
+
18
+
19
+ class Training(BaseModel):
20
+ epochs: int = 20
21
+ optimizer: str = "adam"
22
+ criterion: str = "crossentropy"
23
+ lr: float = 0.003
24
+ weight_decay: float = 1e-4
25
+ lrfinder: LRFinder
26
+
27
+
28
+ class Config(BaseModel):
29
+ data: Data
30
+ training: Training
31
+
32
+
33
+ with open(TOML_PATH) as f:
34
+ toml_config = toml.load(f)
35
+
36
+ config = Config(**toml_config)
utils/data.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ import lightning as L
3
+ from torch.utils.data import DataLoader
4
+ from utils.transforms import train_transform, test_transform
5
+
6
+
7
+ class Cifar10SearchDataset(torchvision.datasets.CIFAR10):
8
+ def __init__(self, root="~/data", train=True, download=True, transform=None):
9
+ super().__init__(root=root, train=train, download=download, transform=transform)
10
+
11
+ def __getitem__(self, index):
12
+ image, label = self.data[index], self.targets[index]
13
+ if self.transform is not None:
14
+ transformed = self.transform(image=image)
15
+ image = transformed["image"]
16
+
17
+ return image, label
18
+
19
+
20
+ class CIFARDataModule(L.LightningDataModule):
21
+ def __init__(
22
+ self, data_dir="data", batch_size=512, shuffle=True, num_workers=4
23
+ ) -> None:
24
+ super().__init__()
25
+ self.data_dir = data_dir
26
+ self.batch_size = batch_size
27
+ self.shuffle = shuffle
28
+ self.num_workers = num_workers
29
+
30
+ def prepare_data(self) -> None:
31
+ pass
32
+
33
+ def setup(self, stage=None):
34
+ self.train_dataset = Cifar10SearchDataset(
35
+ root=self.data_dir, train=True, transform=train_transform
36
+ )
37
+
38
+ self.val_dataset = Cifar10SearchDataset(
39
+ root=self.data_dir, train=False, transform=test_transform
40
+ )
41
+
42
+ self.test_dataset = Cifar10SearchDataset(
43
+ root=self.data_dir, train=False, transform=test_transform
44
+ )
45
+
46
+ def train_dataloader(self):
47
+ return DataLoader(
48
+ dataset=self.train_dataset,
49
+ batch_size=self.batch_size,
50
+ shuffle=self.shuffle,
51
+ num_workers=self.num_workers,
52
+ )
53
+
54
+ def val_dataloader(self):
55
+ return DataLoader(
56
+ dataset=self.val_dataset,
57
+ batch_size=self.batch_size,
58
+ shuffle=self.shuffle,
59
+ num_workers=self.num_workers,
60
+ )
61
+
62
+ def test_dataloader(self):
63
+ return DataLoader(
64
+ dataset=self.test_dataset,
65
+ batch_size=self.batch_size,
66
+ shuffle=self.shuffle,
67
+ num_workers=self.num_workers,
68
+ )
utils/gradcam.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from pytorch_grad_cam import GradCAM
3
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
4
+ from pytorch_grad_cam.utils.image import show_cam_on_image
5
+
6
+ import matplotlib.pyplot as plt
7
+
8
+
9
+ def generate_gradcam(model, target_layers, images, labels, rgb_imgs):
10
+ results = []
11
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
12
+
13
+ for image, label, np_image in zip(images, labels, rgb_imgs):
14
+ targets = [ClassifierOutputTarget(label.item())]
15
+
16
+ # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
17
+ grayscale_cam = cam(
18
+ input_tensor=image.unsqueeze(0), targets=targets, aug_smooth=True
19
+ )
20
+
21
+ # In this example grayscale_cam has only one image in the batch:
22
+ grayscale_cam = grayscale_cam[0, :]
23
+ visualization = show_cam_on_image(
24
+ np_image / np_image.max(), grayscale_cam, use_rgb=True
25
+ )
26
+ results.append(visualization)
27
+ return results
28
+
29
+
30
+ def visualize_gradcam(misimgs, mistgts, mispreds, classes):
31
+ fig, axes = plt.subplots(len(misimgs) // 2, 2)
32
+ fig.tight_layout()
33
+ for ax, img, tgt, pred in zip(axes.ravel(), misimgs, mistgts, mispreds):
34
+ ax.imshow(img)
35
+ ax.set_title(f"{classes[tgt]} | {classes[pred]}")
36
+ ax.grid(False)
37
+ ax.set_axis_off()
38
+ plt.show()
39
+
40
+ def plot_gradcam(model, data, classes, target_layers, number_of_samples, inv_normalize=None, targets=None, transparency = 0.60, figsize=(10,10), rows=2, cols=5):
41
+
42
+ fig = plt.figure(figsize=figsize)
43
+
44
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
45
+ for i in range(number_of_samples):
46
+ plt.subplot(rows, cols, i + 1)
47
+ input_tensor = data[i][0]
48
+
49
+ # Get the activations of the layer for the images
50
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
51
+ grayscale_cam = grayscale_cam[0, :]
52
+
53
+ # Get back the original image
54
+ img = input_tensor.squeeze(0).to('cpu')
55
+ if inv_normalize is not None:
56
+ img = inv_normalize(img)
57
+ rgb_img = np.transpose(img, (1, 2, 0))
58
+ rgb_img = rgb_img.numpy()
59
+
60
+ # Mix the activations on the original image
61
+ visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
62
+
63
+ # Display the images on the plot
64
+ plt.imshow(visualization)
65
+ plt.title(f"Label: {classes[data[i][1].item()]} \n Prediction: {classes[data[i][2].item()]}")
66
+ plt.xticks([])
67
+ plt.yticks([])
utils/training.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def train(
7
+ model,
8
+ device,
9
+ train_loader,
10
+ optimizer,
11
+ criterion,
12
+ scheduler,
13
+ L1=False,
14
+ l1_lambda=0.01,
15
+ ):
16
+ model.train()
17
+ pbar = tqdm(train_loader)
18
+
19
+ train_losses = []
20
+ train_acc = []
21
+ lrs = []
22
+
23
+ correct = 0
24
+ processed = 0
25
+ train_loss = 0
26
+
27
+ for batch_idx, (data, target) in enumerate(pbar):
28
+ data, target = data.to(device), target.to(device)
29
+ optimizer.zero_grad()
30
+ y_pred = model(data)
31
+
32
+ # Calculate loss
33
+ loss = criterion(y_pred, target)
34
+ if L1:
35
+ l1_loss = 0
36
+ for p in model.parameters():
37
+ l1_loss = l1_loss + p.abs().sum()
38
+ loss = loss + l1_lambda * l1_loss
39
+ else:
40
+ loss = loss
41
+
42
+ train_loss += loss.item()
43
+ train_losses.append(loss.item())
44
+
45
+ # Backpropagation
46
+ loss.backward()
47
+ optimizer.step()
48
+ scheduler.step()
49
+
50
+ # Update pbar-tqdm
51
+ pred = y_pred.argmax(
52
+ dim=1, keepdim=True
53
+ ) # get the index of the max log-probability
54
+ correct += pred.eq(target.view_as(pred)).sum().item()
55
+ processed += len(data)
56
+
57
+ pbar.set_description(
58
+ desc=f"Loss={loss.item():0.2f} Accuracy={100*correct/processed:0.2f}"
59
+ )
60
+ train_acc.append(100 * correct / processed)
61
+ lrs.append(scheduler.get_last_lr())
62
+
63
+ return train_losses, train_acc, lrs
64
+
65
+
66
+ def test(model, device, criterion, test_loader):
67
+ model.eval()
68
+ test_loss = 0
69
+ correct = 0
70
+ with torch.no_grad():
71
+ for data, target in test_loader:
72
+ data, target = data.to(device), target.to(device)
73
+ output = model(data)
74
+ test_loss += F.cross_entropy(output, target, reduction="sum").item()
75
+ pred = output.argmax(dim=1, keepdim=True)
76
+ correct += pred.eq(target.view_as(pred)).sum().item()
77
+
78
+ test_loss /= len(test_loader.dataset)
79
+
80
+ print(
81
+ "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n".format(
82
+ test_loss,
83
+ correct,
84
+ len(test_loader.dataset),
85
+ 100.0 * correct / len(test_loader.dataset),
86
+ )
87
+ )
88
+ test_acc = 100.0 * correct / len(test_loader.dataset)
89
+
90
+ return test_loss, test_acc
utils/transforms.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ from albumentations.pytorch import ToTensorV2
3
+
4
+ train_transform = A.Compose(
5
+ [
6
+ A.PadIfNeeded(min_height=40, min_width=40, always_apply=True),
7
+ A.RandomCrop(height=32, width=32, always_apply=True),
8
+ A.HorizontalFlip(),
9
+ A.CoarseDropout(
10
+ min_holes=1,
11
+ max_holes=1,
12
+ min_height=8,
13
+ min_width=8,
14
+ max_height=8,
15
+ max_width=8,
16
+ fill_value=[0.49139968*255, 0.48215827*255 ,0.44653124*255], # type: ignore
17
+ p=0.5,
18
+ ),
19
+ A.Normalize((0.49139968, 0.48215827, 0.44653124),
20
+ (0.24703233, 0.24348505, 0.26158768)),
21
+ ToTensorV2(),
22
+ ]
23
+ )
24
+
25
+ test_transform = A.Compose(
26
+ [
27
+ A.Normalize((0.49139968, 0.48215827, 0.44653124),
28
+ (0.24703233, 0.24348505, 0.26158768)),
29
+ ToTensorV2(),
30
+ ]
31
+ )