import gradio as gr import random import pathlib import numpy as np from PIL import Image import torch import torchvision from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image from models.resnet_lightning import ResNet from utils.data import CIFARDataModule from utils.transforms import test_transform from utils.common import get_misclassified_data inv_normalize = torchvision.transforms.Normalize( mean=[-0.50 / 0.23, -0.50 / 0.23, -0.50 / 0.23], std=[1 / 0.23, 1 / 0.23, 1 / 0.23] ) datamodule = CIFARDataModule() datamodule.setup() classes = datamodule.train_dataset.classes model = ResNet.load_from_checkpoint("model.ckpt") model = model.to("cpu") prediction_image = None def upload_file(files): file_paths = [file.name for file in files] return file_paths def read_image(path): img = Image.open(path) img.load() data = np.asarray(img, dtype="uint8") return data # def sample_images(): # images = [] # length = len(datamodule.test_dataset) # classes = datamodule.train_dataset.classes # for i in range(10): # idx = random.randint(0, length - 1) # image, label = datamodule.test_dataset[idx] # image = inv_normalize(image).permute(1, 2, 0).numpy() # images.append((image, classes[label])) # return images def sample_images(): sample_imges_dir = pathlib.Path("./sample_images") sample_images = list(sample_imges_dir.iterdir()) sample_image_labels = [image.stem for image in sample_images] return list(zip(sample_images, sample_image_labels)) def get_misclassified_images(misclassified_count): misclassified_images = [] misclassified_data = get_misclassified_data( model=model, device="cpu", test_loader=datamodule.test_dataloader(), count=misclassified_count, ) for i in range(misclassified_count): img = misclassified_data[i][0].squeeze().to("cpu") img = inv_normalize(img) img = np.transpose(img.numpy(), (1, 2, 0)) label = f"Label: {classes[misclassified_data[i][1].item()]} | Prediction: {classes[misclassified_data[i][2].item()]}" misclassified_images.append((img, label)) return misclassified_images def get_gradcam_images(gradcam_layer, gradcam_count, gradcam_opacity): gradcam_images = [] if gradcam_layer == "Layer1": target_layers = [model.layer1[-1]] elif gradcam_layer == "Layer2": target_layers = [model.layer2[-1]] else: target_layers = [model.layer3[-1]] cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False) data = get_misclassified_data( model=model, device="cpu", test_loader=datamodule.test_dataloader(), count=gradcam_count, ) for i in range(gradcam_count): input_tensor = data[i][0] # Get the activations of the layer for the images grayscale_cam = cam(input_tensor=input_tensor, targets=None) grayscale_cam = grayscale_cam[0, :] # Get back the original image img = input_tensor.squeeze(0).to("cpu") if inv_normalize is not None: img = inv_normalize(img) rgb_img = np.transpose(img, (1, 2, 0)) rgb_img = rgb_img.numpy() # Mix the activations on the original image visualization = show_cam_on_image( rgb_img, grayscale_cam, use_rgb=True, image_weight=gradcam_opacity ) label = f"Label: {classes[data[i][1].item()]} | Prediction: {classes[data[i][2].item()]}" gradcam_images.append((visualization, label)) return gradcam_images def show_hide_misclassified(status): if not status: return {misclassified_count: gr.update(visible=False)} return {misclassified_count: gr.update(visible=True)} def show_hide_gradcam(status): if not status: return [gr.update(visible=False) for i in range(3)] return [gr.update(visible=True) for i in range(3)] def set_prediction_image(evt: gr.SelectData, gallery): global prediction_image if isinstance(gallery[evt.index], dict): prediction_image = gallery[evt.index]["name"] else: prediction_image = gallery[evt.index][0]["name"] def predict( is_misclassified, misclassified_count, is_gradcam, gradcam_count, gradcam_layer, gradcam_opacity, num_classes, ): if prediction_image is None: raise gr.Error( "Please select one of the sample image or upload an image for prediction!" ) misclassified_images = None if is_misclassified: misclassified_images = get_misclassified_images(int(misclassified_count)) gradcam_images = None if is_gradcam: gradcam_images = get_gradcam_images( gradcam_layer, int(gradcam_count), gradcam_opacity ) img = read_image(prediction_image) image_transformed = test_transform(image=img)["image"] output = model(image_transformed.unsqueeze(0)) preds = torch.softmax(output, dim=1).squeeze().detach().numpy() indices = ( output.argsort(descending=True).squeeze().detach().numpy()[: int(num_classes)] ) predictions = {classes[i]: round(float(preds[i]), 2) for i in indices} return { miscalssfied_output: gr.update(value=misclassified_images), gradcam_output: gr.update(value=gradcam_images), prediction_label: gr.update(value=predictions), } with gr.Blocks() as app: gr.Markdown("## CIFAR10 Classification with ResNet") with gr.Row(): with gr.Column(): with gr.Box(): is_misclassified = gr.Checkbox( label="Misclassified Images", info="Display misclassified images?" ) misclassified_count = gr.Dropdown( choices=[str(i + 1) for i in range(20)], value=5, label="Select Number of Images", info="Number of Misclassified images (Default:5)", visible=False, interactive=True, ) is_misclassified.input( show_hide_misclassified, inputs=[is_misclassified], outputs=[misclassified_count], ) with gr.Box(): is_gradcam = gr.Checkbox( label="GradCAM Images", info="Display GradCAM images?", ) gradcam_count = gr.Dropdown( choices=[str(i + 1) for i in range(20)], label="Select Number of Images", info="Number of GradCAM images (Default:5)", value=5, interactive=True, visible=False, ) gradcam_layer = gr.Dropdown( choices=["Layer1", "Layer2", "Layer3"], label="Select the layer", info="Please select the layer for which the GradCAM is required (Default:Layer3)", interactive=True, value="Layer3", visible=False, ) gradcam_opacity = gr.Slider( minimum=0, maximum=1, value=0.6, label="Opacity", info="Opacity of GradCAM output", interactive=True, visible=False, ) is_gradcam.input( show_hide_gradcam, inputs=[is_gradcam], outputs=[gradcam_count, gradcam_layer, gradcam_opacity], ) with gr.Box(): # file_output = gr.File(file_types=["image"]) with gr.Group(): upload_gallery = gr.Gallery( value=None, label="Uploaded images", show_label=False, elem_id="gallery_upload", columns=5, rows=2, height="auto", object_fit="contain", ) upload_button = gr.UploadButton( "Click to Upload images", file_types=["image"], file_count="multiple", ) upload_button.upload(upload_file, upload_button, upload_gallery) with gr.Group(): sample_gallery = gr.Gallery( value=sample_images, label="Sample images", show_label=True, elem_id="gallery_sample", columns=5, rows=2, height="auto", object_fit="contain", ) upload_gallery.select(set_prediction_image, inputs=[upload_gallery]) sample_gallery.select(set_prediction_image, inputs=[sample_gallery]) with gr.Box(): num_classes = gr.Dropdown( choices=[str(i + 1) for i in range(10)], label="Select Number of Top Classes", value=5, interactive=True, info="Number of Top target classes to be shown (Default:5)", ) run_btn = gr.Button() with gr.Column(): with gr.Box(): miscalssfied_output = gr.Gallery( value=None, label="Misclassified Images", show_label=True ) with gr.Box(): gradcam_output = gr.Gallery( value=None, label="GradCAM Images", show_label=True ) with gr.Box(): prediction_label = gr.Label(value=None, label="Predictions") run_btn.click( predict, inputs=[ is_misclassified, misclassified_count, is_gradcam, gradcam_count, gradcam_layer, gradcam_opacity, num_classes, ], outputs=[miscalssfied_output, gradcam_output, prediction_label], ) app.launch()