File size: 5,702 Bytes
1a11305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import numpy as np 
from io import BytesIO
import cv2
from captum.attr import IntegratedGradients
from captum.attr import visualization as viz
import io
import torch 
import torchvision 
import torch.nn as nn
from torchvision import transforms,models
from torchvision.transforms import v2
import torch.nn.functional as F
torchvision.disable_beta_transforms_warning()
import json
import base64
    

index_to_target = {
0: 'अ', 1: 'अं', 2: 'अः', 3: 'आ', 4: 'इ', 5: 'ई', 6: 'उ', 7: 'ऊ', 8: 'ए', 9: 'ऐ', 10: 'ओ', 11: 'औ', 12: 'क', 13: 'क्ष', 14: 'ख', 15: 'ग', 16: 'घ', 17: 'ङ', 18: 'च', 19: 'छ', 20: 'ज', 21: 'ज्ञ', 22: 'झ', 23: 'ञ', 24: 'ट', 25: 'ठ', 26: 'ड', 27: 'ढ', 28: 'ण', 29: 'त', 30: 'त्र', 31: 'थ', 32: 'द', 33: 'ध', 34: 'न', 35: 'प', 36: 'फ', 37: 'ब', 38: 'भ', 39: 'म', 40: 'य', 41: 'र', 42: 'ल', 43: 'व', 44: 'श', 45: 'ष', 46: 'स', 47: 'ह', 48: '०', 49: '१', 50: '२', 51: '३', 52: '४', 53: '५', 54: '६', 55: '७', 56: '८', 57: '९'}
target_to_index = {value:key for key,value in index_to_target.items()}


device = 'cuda:0' if torch.cuda.is_available() else "cpu"

img_transforms = v2.Compose([
    transforms.ToTensor(),
    v2.ToDtype(torch.float32),
    v2.Normalize((0.5,),(0.5,))
])

#
def crop_characters(img) -> np.array:


    blur_img =cv2.GaussianBlur(img,(5,5),3)
    gray = cv2.cvtColor(blur_img,cv2.COLOR_BGR2GRAY)

    # bin_img= cv2.threshold(gray,0,255,cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]

    thres_value,thresh_img= cv2.threshold(gray,0,255,cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)

    contours, _ = cv2.findContours(thresh_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # cv2.drawContours(img,contours,-1,(0,255,0),2)


    bounding_boxes = []
    for contour in contours:
        area = cv2.contourArea(contour)

        #neglecting very small contours which are actually noise 
        if area<200:
            continue

        # Get the bounding box coordinates
        x, y, w, h = cv2.boundingRect(contour)
        
        # Draw rectangle around contour
        # cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0),1)
        
        # Store bounding box coordinates
        bounding_boxes.append((x, y, x + w, y + h))

    # Calculate the minimum bounding rectangle that encloses all the smaller bounding rectangles
    x_min = min(box[0] for box in bounding_boxes) 
    y_min = min(box[1] for box in bounding_boxes)
    x_max = max(box[2] for box in bounding_boxes)
    y_max = max(box[3] for box in bounding_boxes)

    padding_left=3
    padding_right =3
    padding_bottom =3
    padding_top =3
    x_min -= padding_left
    y_min -= padding_bottom
    x_max += padding_top
    y_max += padding_right


    cropped_img = img[y_min:y_max, x_min:x_max]
    return cropped_img




def predict(image_buffer:BytesIO)->str:
    device = 'cuda:0' if torch.cuda.is_available() else "cpu"

    model_path ="res97_state.pth"
    
    model = models.resnet101(weights=None).to(device)
    num_classes = 58
    model.fc = nn.Linear(model.fc.in_features, num_classes).to(device)
    model.load_state_dict(torch.load(model_path,map_location=device))

    image_buffer.seek(0)
    
    img= cv2.imdecode(np.frombuffer(image_buffer.read(),np.uint8),-1)
    
    if img is None:
        raise RuntimeError("Failed to decode image")
    
    img = crop_characters(img)
    
    img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_AREA)
    
    img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)  # Convert to grayscale
    

    img_bin = cv2.adaptiveThreshold(img_gray,255,cv2.ADAPTIVE_THRESH_MEAN_C,cv2.THRESH_BINARY,11,6)
    thres_rgb= cv2.cvtColor(img_bin,cv2.COLOR_GRAY2BGR)
    
    model.eval()
    
    transformed_img =  img_transforms(thres_rgb).unsqueeze(dim=0)
  
    
    
    with torch.inference_mode():
        transformed_img = transformed_img.to(device)
        outputs = model(transformed_img)
        _,predicted_index = torch.max(outputs.data,1) #returns max value and index
        probabilities = F.softmax(outputs,1)

    

    def attribute_image_features(algorithm,image , **kwargs):
        model.zero_grad()
        tensor_attributions = algorithm.attribute(image,
                                                target=predicted_index,
                                                **kwargs
                                                )
        
        return tensor_attributions

    ig = IntegratedGradients(model)
    attr_ig, delta = attribute_image_features(ig, transformed_img, baselines=transformed_img * 0, return_convergence_delta=True)
    attr_ig = np.transpose(attr_ig.squeeze().cpu().detach().numpy(), (1, 2, 0))
    

    original_image = img

    # org_img,axes1 = viz.visualize_image_attr(None, original_image, 
    #                     method="original_image", title="Original Image",use_pyplot=False)
        
    img_attr,axes2= viz.visualize_image_attr(attr_ig, original_image, method="blended_heat_map",sign="all",
                                        title="Overlayed Integrated Gradients",use_pyplot=False,show_colorbar=True)
    
    img_attr_bytes = io.BytesIO()
    img_attr.savefig(img_attr_bytes,format="jpeg")
    

    
    top3,top3index= torch.topk(probabilities,3)
    # print(top3,top3index)
    top3Value= top3.tolist()
    top3Index= top3index.tolist()
    # print(top3Value,top3Index)
    
    
    json_data_dict={
        "prob":top3Value[0],
        "item":[ index_to_target[int(item)] for item in top3Index[0]],
        "ig":base64.b64encode(img_attr_bytes.getvalue()).decode('utf-8')
       
    }
    
    return  json.dumps( json_data_dict)