Nekshay commited on
Commit
477ab1d
1 Parent(s): 588994d

Create car_damage_detection.py

Browse files
Files changed (1) hide show
  1. car_damage_detection.py +100 -0
car_damage_detection.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from matplotlib.pyplot import axis
3
+ import gradio as gr
4
+ import requests
5
+ import numpy as np
6
+ from torch import nn
7
+ import requests
8
+
9
+ import torch
10
+ import detectron2
11
+ from detectron2 import model_zoo
12
+ from detectron2.engine import DefaultPredictor
13
+ from detectron2.config import get_cfg
14
+ from detectron2.utils.visualizer import Visualizer
15
+ from detectron2.data import MetadataCatalog
16
+ from detectron2.utils.visualizer import ColorMode
17
+
18
+ model_path = 'model_final.pth'
19
+
20
+ cfg = get_cfg()
21
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
22
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.6
23
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
24
+ cfg.MODEL.WEIGHTS = model_path
25
+
26
+ if not torch.cuda.is_available():
27
+ cfg.MODEL.DEVICE='cpu'
28
+
29
+ predictor = DefaultPredictor(cfg)
30
+ my_metadata = MetadataCatalog.get("car_dataset_val")
31
+ my_metadata.thing_classes = ["damage"]
32
+
33
+ def merge_segment(pred_segm):
34
+ merge_dict = {}
35
+ for i in range(len(pred_segm)):
36
+ merge_dict[i] = []
37
+ for j in range(i+1,len(pred_segm)):
38
+ if torch.sum(pred_segm[i]*pred_segm[j])>0:
39
+ merge_dict[i].append(j)
40
+
41
+ to_delete = []
42
+ for key in merge_dict:
43
+ for element in merge_dict[key]:
44
+ to_delete.append(element)
45
+
46
+ for element in to_delete:
47
+ merge_dict.pop(element,None)
48
+
49
+ empty_delete = []
50
+ for key in merge_dict:
51
+ if merge_dict[key] == []:
52
+ empty_delete.append(key)
53
+
54
+ for element in empty_delete:
55
+ merge_dict.pop(element,None)
56
+
57
+ for key in merge_dict:
58
+ for element in merge_dict[key]:
59
+ pred_segm[key]+=pred_segm[element]
60
+
61
+ except_elem = list(set(to_delete))
62
+
63
+ new_indexes = list(range(len(pred_segm)))
64
+ for elem in except_elem:
65
+ new_indexes.remove(elem)
66
+
67
+ return pred_segm[new_indexes]
68
+
69
+ def inference(image):
70
+ print(image.height)
71
+
72
+ height = image.height
73
+
74
+ # img = np.array(image.resize((500, height)))
75
+ img = np.array(image)
76
+ outputs = predictor(img)
77
+ out_dict = outputs["instances"].to("cpu").get_fields()
78
+ new_inst = detectron2.structures.Instances((1024,1024))
79
+ new_inst.set('pred_masks',merge_segment(out_dict['pred_masks']))
80
+ v = Visualizer(img[:, :, ::-1],
81
+ metadata=my_metadata,
82
+ scale=0.5,
83
+ instance_mode=ColorMode.SEGMENTATION # remove the colors of unsegmented pixels. This option is only available for segmentation models
84
+ )
85
+ # v = Visualizer(img,scale=1.2)
86
+ #print(outputs["instances"].to('cpu'))
87
+ out = v.draw_instance_predictions(new_inst)
88
+
89
+ return out.get_image()[:, :, ::-1]
90
+
91
+ title = "Detectron2 Car damage Detection"
92
+ description = "This demo introduces an interactive playground for our trained Detectron2 model."
93
+
94
+ gr.Interface(
95
+ inference,
96
+ [gr.inputs.Image(type="pil", label="Input")],
97
+ gr.outputs.Image(type="numpy", label="Output"),
98
+ title=title,
99
+ description=description,
100
+ examples=[]).launch()