import numpy as np import random import matplotlib.pyplot as plt import torch import torchvision from torchinfo import summary from torch_lr_finder import LRFinder def find_lr(model, optimizer, criterion, device, trainloader, numiter, startlr, endlr): lr_finder = LRFinder( model=model, optimizer=optimizer, criterion=criterion, device=device ) lr_finder.range_test( train_loader=trainloader, start_lr=startlr, end_lr=endlr, num_iter=numiter, step_mode="exp", ) lr_finder.plot() lr_finder.reset() def one_cycle_lr(optimizer, maxlr, steps, epochs): scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer=optimizer, max_lr=maxlr, steps_per_epoch=steps, epochs=epochs, pct_start=5 / epochs, div_factor=100, three_phase=False, final_div_factor=100, anneal_strategy="linear", ) return scheduler def show_random_images_for_each_class(train_data, num_images_per_class=16): for c, cls in enumerate(train_data.classes): rand_targets = random.sample( [n for n, x in enumerate(train_data.targets) if x == c], k=num_images_per_class, ) show_img_grid(np.transpose(train_data.data[rand_targets], axes=(0, 3, 1, 2))) plt.title(cls) def show_img_grid(data): try: grid_img = torchvision.utils.make_grid(data.cpu().detach()) except: data = torch.from_numpy(data) grid_img = torchvision.utils.make_grid(data) plt.figure(figsize=(10, 10)) plt.imshow(grid_img.permute(1, 2, 0)) def show_random_images(data_loader): data, target = next(iter(data_loader)) show_img_grid(data) def show_model_summary(model, batch_size): summary( model=model, input_size=(batch_size, 3, 32, 32), col_names=["input_size", "output_size", "num_params", "kernel_size"], verbose=1, ) def lossacc_plots(results): plt.plot(results["epoch"], results["trainloss"]) plt.plot(results["epoch"], results["testloss"]) plt.legend(["Train Loss", "Validation Loss"]) plt.xlabel("Epochs") plt.ylabel("Loss") plt.title("Loss vs Epochs") plt.show() plt.plot(results["epoch"], results["trainacc"]) plt.plot(results["epoch"], results["testacc"]) plt.legend(["Train Acc", "Validation Acc"]) plt.xlabel("Epochs") plt.ylabel("Accuracy") plt.title("Accuracy vs Epochs") plt.show() def lr_plots(results, length): plt.plot(range(length), results["lr"]) plt.xlabel("Epochs") plt.ylabel("Learning Rate") plt.title("Learning Rate vs Epochs") plt.show() def get_misclassified(model, testloader, device, mis_count=10): misimgs, mistgts, mispreds = [], [], [] with torch.no_grad(): for data, target in testloader: data, target = data.to(device), target.to(device) output = model(data) pred = output.argmax(dim=1, keepdim=True) misclassified = torch.argwhere(pred.squeeze() != target).squeeze() for idx in misclassified: if len(misimgs) >= mis_count: break misimgs.append(data[idx]) mistgts.append(target[idx]) mispreds.append(pred[idx].squeeze()) return misimgs, mistgts, mispreds # def plot_misclassified(misimgs, mistgts, mispreds, classes): # fig, axes = plt.subplots(len(misimgs) // 2, 2) # fig.tight_layout() # for ax, img, tgt, pred in zip(axes.ravel(), misimgs, mistgts, mispreds): # ax.imshow((img / img.max()).permute(1, 2, 0).cpu()) # ax.set_title(f"{classes[tgt]} | {classes[pred]}") # ax.grid(False) # ax.set_axis_off() # plt.show() def get_misclassified_data(model, device, test_loader, count): """ Function to run the model on test set and return misclassified images :param model: Network Architecture :param device: CPU/GPU :param test_loader: DataLoader for test set """ # Prepare the model for evaluation i.e. drop the dropout layer model.eval() # List to store misclassified Images misclassified_data = [] # Reset the gradients with torch.no_grad(): # Extract images, labels in a batch for data, target in test_loader: # Migrate the data to the device data, target = data.to(device), target.to(device) # Extract single image, label from the batch for image, label in zip(data, target): # Add batch dimension to the image image = image.unsqueeze(0) # Get the model prediction on the image output = model(image) # Convert the output from one-hot encoding to a value pred = output.argmax(dim=1, keepdim=True) # If prediction is incorrect, append the data if pred != label: misclassified_data.append((image, label, pred)) if len(misclassified_data) >= count: break return misclassified_data[:count] def plot_misclassified(data, classes, size=(10, 10), rows=2, cols=5, inv_normalize=None): fig = plt.figure(figsize=size) number_of_samples = len(data) for i in range(number_of_samples): plt.subplot(rows, cols, i + 1) img = data[i][0].squeeze().to('cpu') if inv_normalize is not None: img = inv_normalize(img) plt.imshow(np.transpose(img, (1, 2, 0))) plt.title(f"Label: {classes[data[i][1].item()]} \n Prediction: {classes[data[i][2].item()]}") plt.xticks([]) plt.yticks([])