File size: 5,702 Bytes
1a11305 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import numpy as np
from io import BytesIO
import cv2
from captum.attr import IntegratedGradients
from captum.attr import visualization as viz
import io
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms,models
from torchvision.transforms import v2
import torch.nn.functional as F
torchvision.disable_beta_transforms_warning()
import json
import base64
index_to_target = {
0: 'अ', 1: 'अं', 2: 'अः', 3: 'आ', 4: 'इ', 5: 'ई', 6: 'उ', 7: 'ऊ', 8: 'ए', 9: 'ऐ', 10: 'ओ', 11: 'औ', 12: 'क', 13: 'क्ष', 14: 'ख', 15: 'ग', 16: 'घ', 17: 'ङ', 18: 'च', 19: 'छ', 20: 'ज', 21: 'ज्ञ', 22: 'झ', 23: 'ञ', 24: 'ट', 25: 'ठ', 26: 'ड', 27: 'ढ', 28: 'ण', 29: 'त', 30: 'त्र', 31: 'थ', 32: 'द', 33: 'ध', 34: 'न', 35: 'प', 36: 'फ', 37: 'ब', 38: 'भ', 39: 'म', 40: 'य', 41: 'र', 42: 'ल', 43: 'व', 44: 'श', 45: 'ष', 46: 'स', 47: 'ह', 48: '०', 49: '१', 50: '२', 51: '३', 52: '४', 53: '५', 54: '६', 55: '७', 56: '८', 57: '९'}
target_to_index = {value:key for key,value in index_to_target.items()}
device = 'cuda:0' if torch.cuda.is_available() else "cpu"
img_transforms = v2.Compose([
transforms.ToTensor(),
v2.ToDtype(torch.float32),
v2.Normalize((0.5,),(0.5,))
])
#
def crop_characters(img) -> np.array:
blur_img =cv2.GaussianBlur(img,(5,5),3)
gray = cv2.cvtColor(blur_img,cv2.COLOR_BGR2GRAY)
# bin_img= cv2.threshold(gray,0,255,cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
thres_value,thresh_img= cv2.threshold(gray,0,255,cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
contours, _ = cv2.findContours(thresh_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# cv2.drawContours(img,contours,-1,(0,255,0),2)
bounding_boxes = []
for contour in contours:
area = cv2.contourArea(contour)
#neglecting very small contours which are actually noise
if area<200:
continue
# Get the bounding box coordinates
x, y, w, h = cv2.boundingRect(contour)
# Draw rectangle around contour
# cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0),1)
# Store bounding box coordinates
bounding_boxes.append((x, y, x + w, y + h))
# Calculate the minimum bounding rectangle that encloses all the smaller bounding rectangles
x_min = min(box[0] for box in bounding_boxes)
y_min = min(box[1] for box in bounding_boxes)
x_max = max(box[2] for box in bounding_boxes)
y_max = max(box[3] for box in bounding_boxes)
padding_left=3
padding_right =3
padding_bottom =3
padding_top =3
x_min -= padding_left
y_min -= padding_bottom
x_max += padding_top
y_max += padding_right
cropped_img = img[y_min:y_max, x_min:x_max]
return cropped_img
def predict(image_buffer:BytesIO)->str:
device = 'cuda:0' if torch.cuda.is_available() else "cpu"
model_path ="res97_state.pth"
model = models.resnet101(weights=None).to(device)
num_classes = 58
model.fc = nn.Linear(model.fc.in_features, num_classes).to(device)
model.load_state_dict(torch.load(model_path,map_location=device))
image_buffer.seek(0)
img= cv2.imdecode(np.frombuffer(image_buffer.read(),np.uint8),-1)
if img is None:
raise RuntimeError("Failed to decode image")
img = crop_characters(img)
img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_AREA)
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # Convert to grayscale
img_bin = cv2.adaptiveThreshold(img_gray,255,cv2.ADAPTIVE_THRESH_MEAN_C,cv2.THRESH_BINARY,11,6)
thres_rgb= cv2.cvtColor(img_bin,cv2.COLOR_GRAY2BGR)
model.eval()
transformed_img = img_transforms(thres_rgb).unsqueeze(dim=0)
with torch.inference_mode():
transformed_img = transformed_img.to(device)
outputs = model(transformed_img)
_,predicted_index = torch.max(outputs.data,1) #returns max value and index
probabilities = F.softmax(outputs,1)
def attribute_image_features(algorithm,image , **kwargs):
model.zero_grad()
tensor_attributions = algorithm.attribute(image,
target=predicted_index,
**kwargs
)
return tensor_attributions
ig = IntegratedGradients(model)
attr_ig, delta = attribute_image_features(ig, transformed_img, baselines=transformed_img * 0, return_convergence_delta=True)
attr_ig = np.transpose(attr_ig.squeeze().cpu().detach().numpy(), (1, 2, 0))
original_image = img
# org_img,axes1 = viz.visualize_image_attr(None, original_image,
# method="original_image", title="Original Image",use_pyplot=False)
img_attr,axes2= viz.visualize_image_attr(attr_ig, original_image, method="blended_heat_map",sign="all",
title="Overlayed Integrated Gradients",use_pyplot=False,show_colorbar=True)
img_attr_bytes = io.BytesIO()
img_attr.savefig(img_attr_bytes,format="jpeg")
top3,top3index= torch.topk(probabilities,3)
# print(top3,top3index)
top3Value= top3.tolist()
top3Index= top3index.tolist()
# print(top3Value,top3Index)
json_data_dict={
"prob":top3Value[0],
"item":[ index_to_target[int(item)] for item in top3Index[0]],
"ig":base64.b64encode(img_attr_bytes.getvalue()).decode('utf-8')
}
return json.dumps( json_data_dict)
|