Spaces:
Sleeping
Sleeping
File size: 4,160 Bytes
fc055d6 2a96d71 fc055d6 2a96d71 fc055d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, datasets
from PIL import Image
from tqdm.auto import tqdm
import torch.nn.functional as F
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, FullGrad
from matplotlib import colormaps
import numpy as np
import gradio as gr
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
# Convolutional layers
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
# Pooling layer
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
# Fully connected layers
self.fc1 = nn.Linear(64 * (224 // 8) * (224 // 8), 64) # Adjusted based on pooling layers
self.fc2 = nn.Linear(64, 2) # 2 classes for binary classification
def forward(self, x):
# Convolutional layers with relu activation and pooling
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
# Flatten for fully connected layers
x = torch.flatten(x, 1)
# Fully connected layers with relu activation
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize to 224x224
transforms.ToTensor(), # Convert to tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize
])
model = CNN()
model.load_state_dict(torch.load("trained-cnn-concrete-crack.model", map_location=torch.device("cpu")))
magmaify = colormaps['magma']
def compute_gradcam(img_tensor, layer_idx, typeCAM):
allCAMs = {"GradCAM": GradCAM, "HiResCAM": HiResCAM, "ScoreCAM": ScoreCAM, "GradCAMPlusPlus": GradCAMPlusPlus, "AblationCAM": AblationCAM, "XGradCAM": XGradCAM, "FullGrad": FullGrad}
target_layers = [[model.conv1], [model.conv2], [model.conv3]]
cam = allCAMs[typeCAM](model=model, target_layers=target_layers[layer_idx-1])
grayscale_cam = cam(input_tensor=img_tensor, targets=None)
return magmaify(grayscale_cam.reshape(224, 224))
def predict_and_gradcam(model, img, layer_idx, typeCAM):
# Preprocess the image
img = Image.fromarray(img.astype('uint8'), 'RGB') if isinstance(img, np.ndarray) else img
img_tensor = transform(img).unsqueeze(0)
# Get predicted class index
with torch.no_grad():
output = model(img_tensor)
_, predicted = torch.max(output.data, 1)
predicted_label = str(predicted.item())
# Compute GradCAM
gradcam = compute_gradcam(img_tensor, layer_idx, typeCAM)
return predicted_label, gradcam
idx_to_lbl = {"0": "Cracked", "1":"Uncracked"}
# Define a function to be used in Gradio app
def classify_image(image, layer_idx, typeCAM):
# Predict label and get GradCAM
label, gradcam_img = predict_and_gradcam(model, image, layer_idx, typeCAM)
return idx_to_lbl[label], gradcam_img
description = """\
<hr><center>Upload an image of concrete and get the predicted label along with the GradCAM heatmap. <br><br>
<img src="https://www.huggingface.co/spaces/1rsh/concrete-crack-gradcam/resolve/main/header.jpeg" width=200px></img></center>
\
"""
typeCAMs = ["GradCAM", "HiResCAM", "ScoreCAM", "GradCAMPlusPlus", "AblationCAM", "XGradCAM", "FullGrad"]
# Define Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=[gr.Image(), gr.Slider(minimum=1, maximum=3, step=1, value=1), gr.Dropdown(choices=typeCAMs, value="GradCAM")],
outputs=[gr.Textbox(label="Predicted Label"), gr.Image(label="GradCAM Heatmap")],
title="Concrete Crack Detection with GradCAM",
description= description,
allow_flagging=False,
theme=gr.themes.Monochrome(font=gr.themes.GoogleFont("IBM Plex Mono"))
)
# Launch the interface
iface.launch() |