|
import numpy as np |
|
from io import BytesIO |
|
import cv2 |
|
from captum.attr import IntegratedGradients |
|
from captum.attr import visualization as viz |
|
import io |
|
import torch |
|
import torchvision |
|
import torch.nn as nn |
|
from torchvision import transforms,models |
|
from torchvision.transforms import v2 |
|
import torch.nn.functional as F |
|
torchvision.disable_beta_transforms_warning() |
|
import json |
|
import base64 |
|
|
|
|
|
index_to_target = { |
|
0: 'अ', 1: 'अं', 2: 'अः', 3: 'आ', 4: 'इ', 5: 'ई', 6: 'उ', 7: 'ऊ', 8: 'ए', 9: 'ऐ', 10: 'ओ', 11: 'औ', 12: 'क', 13: 'क्ष', 14: 'ख', 15: 'ग', 16: 'घ', 17: 'ङ', 18: 'च', 19: 'छ', 20: 'ज', 21: 'ज्ञ', 22: 'झ', 23: 'ञ', 24: 'ट', 25: 'ठ', 26: 'ड', 27: 'ढ', 28: 'ण', 29: 'त', 30: 'त्र', 31: 'थ', 32: 'द', 33: 'ध', 34: 'न', 35: 'प', 36: 'फ', 37: 'ब', 38: 'भ', 39: 'म', 40: 'य', 41: 'र', 42: 'ल', 43: 'व', 44: 'श', 45: 'ष', 46: 'स', 47: 'ह', 48: '०', 49: '१', 50: '२', 51: '३', 52: '४', 53: '५', 54: '६', 55: '७', 56: '८', 57: '९'} |
|
target_to_index = {value:key for key,value in index_to_target.items()} |
|
|
|
|
|
device = 'cuda:0' if torch.cuda.is_available() else "cpu" |
|
|
|
img_transforms = v2.Compose([ |
|
transforms.ToTensor(), |
|
v2.ToDtype(torch.float32), |
|
v2.Normalize((0.5,),(0.5,)) |
|
]) |
|
|
|
|
|
def crop_characters(img) -> np.array: |
|
|
|
|
|
blur_img =cv2.GaussianBlur(img,(5,5),3) |
|
gray = cv2.cvtColor(blur_img,cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
|
|
thres_value,thresh_img= cv2.threshold(gray,0,255,cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) |
|
|
|
contours, _ = cv2.findContours(thresh_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
|
|
|
|
|
|
bounding_boxes = [] |
|
for contour in contours: |
|
area = cv2.contourArea(contour) |
|
|
|
|
|
if area<200: |
|
continue |
|
|
|
|
|
x, y, w, h = cv2.boundingRect(contour) |
|
|
|
|
|
|
|
|
|
|
|
bounding_boxes.append((x, y, x + w, y + h)) |
|
|
|
|
|
x_min = min(box[0] for box in bounding_boxes) |
|
y_min = min(box[1] for box in bounding_boxes) |
|
x_max = max(box[2] for box in bounding_boxes) |
|
y_max = max(box[3] for box in bounding_boxes) |
|
|
|
padding_left=3 |
|
padding_right =3 |
|
padding_bottom =3 |
|
padding_top =3 |
|
x_min -= padding_left |
|
y_min -= padding_bottom |
|
x_max += padding_top |
|
y_max += padding_right |
|
|
|
|
|
cropped_img = img[y_min:y_max, x_min:x_max] |
|
return cropped_img |
|
|
|
|
|
|
|
|
|
def predict(image_buffer:BytesIO)->str: |
|
device = 'cuda:0' if torch.cuda.is_available() else "cpu" |
|
|
|
model_path ="res97_state.pth" |
|
|
|
model = models.resnet101(weights=None).to(device) |
|
num_classes = 58 |
|
model.fc = nn.Linear(model.fc.in_features, num_classes).to(device) |
|
model.load_state_dict(torch.load(model_path,map_location=device)) |
|
|
|
image_buffer.seek(0) |
|
|
|
img= cv2.imdecode(np.frombuffer(image_buffer.read(),np.uint8),-1) |
|
|
|
if img is None: |
|
raise RuntimeError("Failed to decode image") |
|
|
|
img = crop_characters(img) |
|
|
|
img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_AREA) |
|
|
|
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
img_bin = cv2.adaptiveThreshold(img_gray,255,cv2.ADAPTIVE_THRESH_MEAN_C,cv2.THRESH_BINARY,11,6) |
|
thres_rgb= cv2.cvtColor(img_bin,cv2.COLOR_GRAY2BGR) |
|
|
|
model.eval() |
|
|
|
transformed_img = img_transforms(thres_rgb).unsqueeze(dim=0) |
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
transformed_img = transformed_img.to(device) |
|
outputs = model(transformed_img) |
|
_,predicted_index = torch.max(outputs.data,1) |
|
probabilities = F.softmax(outputs,1) |
|
|
|
|
|
|
|
def attribute_image_features(algorithm,image , **kwargs): |
|
model.zero_grad() |
|
tensor_attributions = algorithm.attribute(image, |
|
target=predicted_index, |
|
**kwargs |
|
) |
|
|
|
return tensor_attributions |
|
|
|
ig = IntegratedGradients(model) |
|
attr_ig, delta = attribute_image_features(ig, transformed_img, baselines=transformed_img * 0, return_convergence_delta=True) |
|
attr_ig = np.transpose(attr_ig.squeeze().cpu().detach().numpy(), (1, 2, 0)) |
|
|
|
|
|
original_image = img |
|
|
|
|
|
|
|
|
|
img_attr,axes2= viz.visualize_image_attr(attr_ig, original_image, method="blended_heat_map",sign="all", |
|
title="Overlayed Integrated Gradients",use_pyplot=False,show_colorbar=True) |
|
|
|
img_attr_bytes = io.BytesIO() |
|
img_attr.savefig(img_attr_bytes,format="jpeg") |
|
|
|
|
|
|
|
top3,top3index= torch.topk(probabilities,3) |
|
|
|
top3Value= top3.tolist() |
|
top3Index= top3index.tolist() |
|
|
|
|
|
|
|
json_data_dict={ |
|
"prob":top3Value[0], |
|
"item":[ index_to_target[int(item)] for item in top3Index[0]], |
|
"ig":base64.b64encode(img_attr_bytes.getvalue()).decode('utf-8') |
|
|
|
} |
|
|
|
return json.dumps( json_data_dict) |
|
|