NepaliCNN / model.py
pujan paudel
first commit
1a11305
raw
history blame
5.7 kB
import numpy as np
from io import BytesIO
import cv2
from captum.attr import IntegratedGradients
from captum.attr import visualization as viz
import io
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms,models
from torchvision.transforms import v2
import torch.nn.functional as F
torchvision.disable_beta_transforms_warning()
import json
import base64
index_to_target = {
0: 'अ', 1: 'अं', 2: 'अः', 3: 'आ', 4: 'इ', 5: 'ई', 6: 'उ', 7: 'ऊ', 8: 'ए', 9: 'ऐ', 10: 'ओ', 11: 'औ', 12: 'क', 13: 'क्ष', 14: 'ख', 15: 'ग', 16: 'घ', 17: 'ङ', 18: 'च', 19: 'छ', 20: 'ज', 21: 'ज्ञ', 22: 'झ', 23: 'ञ', 24: 'ट', 25: 'ठ', 26: 'ड', 27: 'ढ', 28: 'ण', 29: 'त', 30: 'त्र', 31: 'थ', 32: 'द', 33: 'ध', 34: 'न', 35: 'प', 36: 'फ', 37: 'ब', 38: 'भ', 39: 'म', 40: 'य', 41: 'र', 42: 'ल', 43: 'व', 44: 'श', 45: 'ष', 46: 'स', 47: 'ह', 48: '०', 49: '१', 50: '२', 51: '३', 52: '४', 53: '५', 54: '६', 55: '७', 56: '८', 57: '९'}
target_to_index = {value:key for key,value in index_to_target.items()}
device = 'cuda:0' if torch.cuda.is_available() else "cpu"
img_transforms = v2.Compose([
transforms.ToTensor(),
v2.ToDtype(torch.float32),
v2.Normalize((0.5,),(0.5,))
])
#
def crop_characters(img) -> np.array:
blur_img =cv2.GaussianBlur(img,(5,5),3)
gray = cv2.cvtColor(blur_img,cv2.COLOR_BGR2GRAY)
# bin_img= cv2.threshold(gray,0,255,cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
thres_value,thresh_img= cv2.threshold(gray,0,255,cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
contours, _ = cv2.findContours(thresh_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# cv2.drawContours(img,contours,-1,(0,255,0),2)
bounding_boxes = []
for contour in contours:
area = cv2.contourArea(contour)
#neglecting very small contours which are actually noise
if area<200:
continue
# Get the bounding box coordinates
x, y, w, h = cv2.boundingRect(contour)
# Draw rectangle around contour
# cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0),1)
# Store bounding box coordinates
bounding_boxes.append((x, y, x + w, y + h))
# Calculate the minimum bounding rectangle that encloses all the smaller bounding rectangles
x_min = min(box[0] for box in bounding_boxes)
y_min = min(box[1] for box in bounding_boxes)
x_max = max(box[2] for box in bounding_boxes)
y_max = max(box[3] for box in bounding_boxes)
padding_left=3
padding_right =3
padding_bottom =3
padding_top =3
x_min -= padding_left
y_min -= padding_bottom
x_max += padding_top
y_max += padding_right
cropped_img = img[y_min:y_max, x_min:x_max]
return cropped_img
def predict(image_buffer:BytesIO)->str:
device = 'cuda:0' if torch.cuda.is_available() else "cpu"
model_path ="res97_state.pth"
model = models.resnet101(weights=None).to(device)
num_classes = 58
model.fc = nn.Linear(model.fc.in_features, num_classes).to(device)
model.load_state_dict(torch.load(model_path,map_location=device))
image_buffer.seek(0)
img= cv2.imdecode(np.frombuffer(image_buffer.read(),np.uint8),-1)
if img is None:
raise RuntimeError("Failed to decode image")
img = crop_characters(img)
img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_AREA)
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # Convert to grayscale
img_bin = cv2.adaptiveThreshold(img_gray,255,cv2.ADAPTIVE_THRESH_MEAN_C,cv2.THRESH_BINARY,11,6)
thres_rgb= cv2.cvtColor(img_bin,cv2.COLOR_GRAY2BGR)
model.eval()
transformed_img = img_transforms(thres_rgb).unsqueeze(dim=0)
with torch.inference_mode():
transformed_img = transformed_img.to(device)
outputs = model(transformed_img)
_,predicted_index = torch.max(outputs.data,1) #returns max value and index
probabilities = F.softmax(outputs,1)
def attribute_image_features(algorithm,image , **kwargs):
model.zero_grad()
tensor_attributions = algorithm.attribute(image,
target=predicted_index,
**kwargs
)
return tensor_attributions
ig = IntegratedGradients(model)
attr_ig, delta = attribute_image_features(ig, transformed_img, baselines=transformed_img * 0, return_convergence_delta=True)
attr_ig = np.transpose(attr_ig.squeeze().cpu().detach().numpy(), (1, 2, 0))
original_image = img
# org_img,axes1 = viz.visualize_image_attr(None, original_image,
# method="original_image", title="Original Image",use_pyplot=False)
img_attr,axes2= viz.visualize_image_attr(attr_ig, original_image, method="blended_heat_map",sign="all",
title="Overlayed Integrated Gradients",use_pyplot=False,show_colorbar=True)
img_attr_bytes = io.BytesIO()
img_attr.savefig(img_attr_bytes,format="jpeg")
top3,top3index= torch.topk(probabilities,3)
# print(top3,top3index)
top3Value= top3.tolist()
top3Index= top3index.tolist()
# print(top3Value,top3Index)
json_data_dict={
"prob":top3Value[0],
"item":[ index_to_target[int(item)] for item in top3Index[0]],
"ig":base64.b64encode(img_attr_bytes.getvalue()).decode('utf-8')
}
return json.dumps( json_data_dict)