import numpy as np |
from io import BytesIO |
import cv2 |
from captum.attr import IntegratedGradients |
from captum.attr import visualization as viz |
import io |
import torch |
import torchvision |
import torch.nn as nn |
from torchvision import transforms,models |
from torchvision.transforms import v2 |
import torch.nn.functional as F |
torchvision.disable_beta_transforms_warning() |
import json |
import base64 |
index_to_target = { |
0: 'अ', 1: 'अं', 2: 'अः', 3: 'आ', 4: 'इ', 5: 'ई', 6: 'उ', 7: 'ऊ', 8: 'ए', 9: 'ऐ', 10: 'ओ', 11: 'औ', 12: 'क', 13: 'क्ष', 14: 'ख', 15: 'ग', 16: 'घ', 17: 'ङ', 18: 'च', 19: 'छ', 20: 'ज', 21: 'ज्ञ', 22: 'झ', 23: 'ञ', 24: 'ट', 25: 'ठ', 26: 'ड', 27: 'ढ', 28: 'ण', 29: 'त', 30: 'त्र', 31: 'थ', 32: 'द', 33: 'ध', 34: 'न', 35: 'प', 36: 'फ', 37: 'ब', 38: 'भ', 39: 'म', 40: 'य', 41: 'र', 42: 'ल', 43: 'व', 44: 'श', 45: 'ष', 46: 'स', 47: 'ह', 48: '०', 49: '१', 50: '२', 51: '३', 52: '४', 53: '५', 54: '६', 55: '७', 56: '८', 57: '९'} |
target_to_index = {value:key for key,value in index_to_target.items()} |
device = 'cuda:0' if torch.cuda.is_available() else "cpu" |
img_transforms = v2.Compose([ |
transforms.ToTensor(), |
v2.ToDtype(torch.float32), |
v2.Normalize((0.5,),(0.5,)) |
]) |
def crop_characters(img) -> np.array: |
blur_img =cv2.GaussianBlur(img,(5,5),3) |
gray = cv2.cvtColor(blur_img,cv2.COLOR_BGR2GRAY) |
thres_value,thresh_img= cv2.threshold(gray,0,255,cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) |
contours, _ = cv2.findContours(thresh_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
bounding_boxes = [] |
for contour in contours: |
area = cv2.contourArea(contour) |
if area<200: |
continue |
x, y, w, h = cv2.boundingRect(contour) |
bounding_boxes.append((x, y, x + w, y + h)) |
x_min = min(box[0] for box in bounding_boxes) |
y_min = min(box[1] for box in bounding_boxes) |
x_max = max(box[2] for box in bounding_boxes) |
y_max = max(box[3] for box in bounding_boxes) |
padding_left=3 |
padding_right =3 |
padding_bottom =3 |
padding_top =3 |
x_min -= padding_left |
y_min -= padding_bottom |
x_max += padding_top |
y_max += padding_right |
cropped_img = img[y_min:y_max, x_min:x_max] |
return cropped_img |
def predict(image_buffer:BytesIO)->str: |
device = 'cuda:0' if torch.cuda.is_available() else "cpu" |
model_path ="res97_state.pth" |
model = models.resnet101(weights=None).to(device) |
num_classes = 58 |
model.fc = nn.Linear(model.fc.in_features, num_classes).to(device) |
model.load_state_dict(torch.load(model_path,map_location=device)) |
image_buffer.seek(0) |
img= cv2.imdecode(np.frombuffer(image_buffer.read(),np.uint8),-1) |
if img is None: |
raise RuntimeError("Failed to decode image") |
img = crop_characters(img) |
img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_AREA) |
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
img_bin = cv2.adaptiveThreshold(img_gray,255,cv2.ADAPTIVE_THRESH_MEAN_C,cv2.THRESH_BINARY,11,6) |
thres_rgb= cv2.cvtColor(img_bin,cv2.COLOR_GRAY2BGR) |
model.eval() |
transformed_img = img_transforms(thres_rgb).unsqueeze(dim=0) |
with torch.inference_mode(): |
transformed_img = transformed_img.to(device) |
outputs = model(transformed_img) |
_,predicted_index = torch.max(outputs.data,1) |
probabilities = F.softmax(outputs,1) |
def attribute_image_features(algorithm,image , **kwargs): |
model.zero_grad() |
tensor_attributions = algorithm.attribute(image, |
target=predicted_index, |
**kwargs |
) |
return tensor_attributions |
ig = IntegratedGradients(model) |
attr_ig, delta = attribute_image_features(ig, transformed_img, baselines=transformed_img * 0, return_convergence_delta=True) |
attr_ig = np.transpose(attr_ig.squeeze().cpu().detach().numpy(), (1, 2, 0)) |
original_image = img |
img_attr,axes2= viz.visualize_image_attr(attr_ig, original_image, method="blended_heat_map",sign="all", |
title="Overlayed Integrated Gradients",use_pyplot=False,show_colorbar=True) |
img_attr_bytes = io.BytesIO() |
img_attr.savefig(img_attr_bytes,format="jpeg") |
top3,top3index= torch.topk(probabilities,3) |
top3Value= top3.tolist() |
top3Index= top3index.tolist() |
json_data_dict={ |
"prob":top3Value[0], |
"item":[ index_to_target[int(item)] for item in top3Index[0]], |
"ig":base64.b64encode(img_attr_bytes.getvalue()).decode('utf-8') |
} |
return json.dumps( json_data_dict) |