Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import Dataset, DataLoader, Subset | |
from torchvision import transforms, datasets | |
from PIL import Image | |
from tqdm.auto import tqdm | |
import torch.nn.functional as F | |
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, FullGrad | |
from matplotlib import colormaps | |
import numpy as np | |
import gradio as gr | |
class CNN(nn.Module): | |
def __init__(self): | |
super(CNN, self).__init__() | |
# Convolutional layers | |
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) | |
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) | |
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) | |
# Pooling layer | |
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) | |
# Fully connected layers | |
self.fc1 = nn.Linear(64 * (224 // 8) * (224 // 8), 64) # Adjusted based on pooling layers | |
self.fc2 = nn.Linear(64, 2) # 2 classes for binary classification | |
def forward(self, x): | |
# Convolutional layers with relu activation and pooling | |
x = self.pool(F.relu(self.conv1(x))) | |
x = self.pool(F.relu(self.conv2(x))) | |
x = self.pool(F.relu(self.conv3(x))) | |
# Flatten for fully connected layers | |
x = torch.flatten(x, 1) | |
# Fully connected layers with relu activation | |
x = F.relu(self.fc1(x)) | |
x = self.fc2(x) | |
return x | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), # Resize to 224x224 | |
transforms.ToTensor(), # Convert to tensor | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize | |
]) | |
model = CNN() | |
model.load_state_dict(torch.load("trained-cnn-concrete-crack.model", map_location=torch.device("cpu"))) | |
magmaify = colormaps['magma'] | |
def compute_gradcam(img_tensor, layer_idx, typeCAM): | |
allCAMs = {"GradCAM": GradCAM, "HiResCAM": HiResCAM, "ScoreCAM": ScoreCAM, "GradCAMPlusPlus": GradCAMPlusPlus, "AblationCAM": AblationCAM, "XGradCAM": XGradCAM, "FullGrad": FullGrad} | |
target_layers = [[model.conv1], [model.conv2], [model.conv3]] | |
cam = allCAMs[typeCAM](model=model, target_layers=target_layers[layer_idx-1]) | |
grayscale_cam = cam(input_tensor=img_tensor, targets=None) | |
return magmaify(grayscale_cam.reshape(224, 224)) | |
def predict_and_gradcam(model, img, layer_idx, typeCAM): | |
# Preprocess the image | |
img = Image.fromarray(img.astype('uint8'), 'RGB') if isinstance(img, np.ndarray) else img | |
img_tensor = transform(img).unsqueeze(0) | |
# Get predicted class index | |
with torch.no_grad(): | |
output = model(img_tensor) | |
_, predicted = torch.max(output.data, 1) | |
predicted_label = str(predicted.item()) | |
# Compute GradCAM | |
gradcam = compute_gradcam(img_tensor, layer_idx, typeCAM) | |
return predicted_label, gradcam | |
idx_to_lbl = {"0": "Cracked", "1":"Uncracked"} | |
# Define a function to be used in Gradio app | |
def classify_image(image, layer_idx, typeCAM): | |
# Predict label and get GradCAM | |
label, gradcam_img = predict_and_gradcam(model, image, layer_idx, typeCAM) | |
return idx_to_lbl[label], gradcam_img | |
description = """\ | |
<hr><center>Upload an image of concrete and get the predicted label along with the GradCAM heatmap. <br><br> | |
<img src="https://www.huggingface.co/spaces/1rsh/concrete-crack-gradcam/resolve/main/header.jpeg" width=200px></img></center> | |
\ | |
""" | |
typeCAMs = ["GradCAM", "HiResCAM", "ScoreCAM", "GradCAMPlusPlus", "AblationCAM", "XGradCAM", "FullGrad"] | |
# Define Gradio interface | |
iface = gr.Interface( | |
fn=classify_image, | |
inputs=[gr.Image(), gr.Slider(minimum=1, maximum=3, step=1, value=1), gr.Dropdown(choices=typeCAMs, value="GradCAM")], | |
outputs=[gr.Textbox(label="Predicted Label"), gr.Image(label="GradCAM Heatmap")], | |
title="Concrete Crack Detection with GradCAM", | |
description= description, | |
allow_flagging=False, | |
theme=gr.themes.Monochrome(font=gr.themes.GoogleFont("IBM Plex Mono")) | |
) | |
# Launch the interface | |
iface.launch() |