Hector Lopez commited on
Commit
3fa54be
1 Parent(s): bc3d4e9

Multiple refactors

Browse files
Files changed (4) hide show
  1. app.py +1 -2
  2. model.py +60 -36
  3. requirements.txt +1 -0
  4. utils.py +10 -2
app.py CHANGED
@@ -3,8 +3,7 @@ import PIL
3
  import torch
4
 
5
  from utils import plot_img_no_mask, get_models
6
- from classifier import CustomEfficientNet, CustomViT
7
- from model import get_model, predict, prepare_prediction, predict_class
8
 
9
  DET_CKPT = 'efficientDet_icevision.ckpt'
10
  CLASS_CKPT = 'class_ViT_taco_7_class.pth'
 
3
  import torch
4
 
5
  from utils import plot_img_no_mask, get_models
6
+ from model import predict, prepare_prediction, predict_class
 
7
 
8
  DET_CKPT = 'efficientDet_icevision.ckpt'
9
  CLASS_CKPT = 'class_ViT_taco_7_class.pth'
model.py CHANGED
@@ -1,11 +1,10 @@
1
  from io import BytesIO
2
- from typing import Union
3
  from icevision import *
4
  from icevision.models.checkpoint import model_from_checkpoint
5
  from classifier import transform_image
6
  from icevision.models import ross
7
 
8
- import collections
9
  import PIL
10
  import torch
11
  import numpy as np
@@ -13,44 +12,34 @@ import torchvision
13
 
14
  MODEL_TYPE = ross.efficientdet
15
 
16
- def get_model(checkpoint_path : str):
17
- checkpoint_and_model = model_from_checkpoint(
18
- checkpoint_path,
19
- model_name='ross.efficientdet',
20
- backbone_name='d0',
21
- img_size=512,
22
- classes=['Waste'],
23
- revise_keys=[(r'^model\.', '')])
24
-
25
- model = checkpoint_and_model['model']
26
-
27
- return model
28
-
29
- def get_checkpoint(checkpoint_path : str):
30
- ckpt = torch.load(checkpoint_path, map_location=torch.device('cpu'))
31
-
32
- fixed_state_dict = collections.OrderedDict()
33
-
34
- for k, v in ckpt['state_dict'].items():
35
- new_k = k[6:]
36
- fixed_state_dict[new_k] = v
37
-
38
- return fixed_state_dict
39
-
40
- def predict(model : object, image : Union[str, BytesIO], detection_threshold : float):
41
  img = PIL.Image.open(image)
42
- #img = PIL.Image.open(BytesIO(image))
43
- img = np.array(img)
44
- img = PIL.Image.fromarray(img)
45
  class_map = ClassMap(classes=['Waste'])
46
  transforms = tfms.A.Adapter([
47
  *tfms.A.resize_and_pad(512),
48
  tfms.A.Normalize()
49
  ])
50
-
 
51
  pred_dict = MODEL_TYPE.end2end_detect(img,
52
  transforms,
53
- model,
54
  class_map=class_map,
55
  detection_threshold=detection_threshold,
56
  return_as_pil_img=False,
@@ -61,32 +50,67 @@ def predict(model : object, image : Union[str, BytesIO], detection_threshold : f
61
 
62
  return pred_dict
63
 
64
- def prepare_prediction(pred_dict, threshold):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  boxes = [box.to_tensor() for box in pred_dict['detection']['bboxes']]
66
  boxes = torch.stack(boxes)
67
 
 
68
  scores = torch.as_tensor(pred_dict['detection']['scores'])
69
  labels = torch.as_tensor(pred_dict['detection']['label_ids'])
 
70
  image = np.array(pred_dict['img'])
71
 
72
- fixed_boxes = torchvision.ops.batched_nms(boxes, scores, labels, threshold)
 
 
73
  boxes = boxes[fixed_boxes, :]
74
 
75
  return boxes, image
76
 
77
- def predict_class(classifier, image, bboxes):
 
 
 
 
 
 
 
 
 
 
 
 
78
  preds = []
79
 
80
  for bbox in bboxes:
81
  img = image.copy()
82
  bbox = np.array(bbox).astype(int)
 
 
83
  cropped_img = PIL.Image.fromarray(img).crop(bbox)
84
  cropped_img = np.array(cropped_img)
85
 
 
86
  tran_image = transform_image(cropped_img, 224)
 
87
  tran_image = tran_image.transpose(2, 0, 1)
88
  tran_image = torch.as_tensor(tran_image, dtype=torch.float).unsqueeze(0)
89
- print(tran_image.shape)
 
90
  y_preds = classifier(tran_image)
91
  preds.append(y_preds.softmax(1).detach().numpy())
92
 
 
1
  from io import BytesIO
2
+ from typing import Dict, Tuple, Union
3
  from icevision import *
4
  from icevision.models.checkpoint import model_from_checkpoint
5
  from classifier import transform_image
6
  from icevision.models import ross
7
 
 
8
  import PIL
9
  import torch
10
  import numpy as np
 
12
 
13
  MODEL_TYPE = ross.efficientdet
14
 
15
+ def predict(det_model : torch.nn.Module, image : Union[str, BytesIO],
16
+ detection_threshold : float) -> Dict:
17
+ """
18
+ Make a prediction with the detection model.
19
+
20
+ Args:
21
+ det_model (torch.nn.Module): Detection model
22
+ image (Union[str, BytesIO]): Image filepath if the image is one of
23
+ the example images and BytesIO if the image is a custom image
24
+ uploaded by the user.
25
+ detection_threshold (float): Detection threshold
26
+
27
+ Returns:
28
+ Dict: Prediction dictionary.
29
+ """
 
 
 
 
 
 
 
 
 
 
30
  img = PIL.Image.open(image)
31
+
32
+ # Class map and transforms
 
33
  class_map = ClassMap(classes=['Waste'])
34
  transforms = tfms.A.Adapter([
35
  *tfms.A.resize_and_pad(512),
36
  tfms.A.Normalize()
37
  ])
38
+
39
+ # Single prediction
40
  pred_dict = MODEL_TYPE.end2end_detect(img,
41
  transforms,
42
+ det_model,
43
  class_map=class_map,
44
  detection_threshold=detection_threshold,
45
  return_as_pil_img=False,
 
50
 
51
  return pred_dict
52
 
53
+ def prepare_prediction(pred_dict : Dict,
54
+ nms_threshold : str) -> Tuple[torch.Tensor, np.ndarray]:
55
+ """
56
+ Get the predictions in a right format.
57
+
58
+ Args:
59
+ pred_dict (Dict): Prediction dictionary.
60
+ nms_threshold (float): Threshold for the NMS postprocess.
61
+
62
+ Returns:
63
+ Tuple: Tuple containing the following:
64
+ - (torch.Tensor): Bounding boxes
65
+ - (np.ndarray): Image data
66
+ """
67
+ # Convert each box to a tensor and stack them into an unique tensor
68
  boxes = [box.to_tensor() for box in pred_dict['detection']['bboxes']]
69
  boxes = torch.stack(boxes)
70
 
71
+ # Get the scores and labels as tensor
72
  scores = torch.as_tensor(pred_dict['detection']['scores'])
73
  labels = torch.as_tensor(pred_dict['detection']['label_ids'])
74
+
75
  image = np.array(pred_dict['img'])
76
 
77
+ # Apply NMS to postprocess the bounding boxes
78
+ fixed_boxes = torchvision.ops.batched_nms(boxes, scores,
79
+ labels,nms_threshold)
80
  boxes = boxes[fixed_boxes, :]
81
 
82
  return boxes, image
83
 
84
+ def predict_class(classifier : torch.nn.Module, image : np.ndarray,
85
+ bboxes : torch.Tensor) -> np.ndarray:
86
+ """
87
+ Predict the class of each detected object.
88
+
89
+ Args:
90
+ classifier (torch.nn.Module): Classifier model.
91
+ image (np.ndarray): Image data.
92
+ bboxes (torch.Tensor): Bounding boxes.
93
+
94
+ Returns:
95
+ np.ndarray: Array containing the predicted class for each object.
96
+ """
97
  preds = []
98
 
99
  for bbox in bboxes:
100
  img = image.copy()
101
  bbox = np.array(bbox).astype(int)
102
+
103
+ # Get the bounding box content
104
  cropped_img = PIL.Image.fromarray(img).crop(bbox)
105
  cropped_img = np.array(cropped_img)
106
 
107
+ # Apply transformations to the cropped image
108
  tran_image = transform_image(cropped_img, 224)
109
+ # Channels first
110
  tran_image = tran_image.transpose(2, 0, 1)
111
  tran_image = torch.as_tensor(tran_image, dtype=torch.float).unsqueeze(0)
112
+
113
+ # Make prediction
114
  y_preds = classifier(tran_image)
115
  preds.append(y_preds.softmax(1).detach().numpy())
116
 
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  icevision[all]
2
  matplotlib
3
  effdet
 
4
  Pillow==8.4.0
 
1
  icevision[all]
2
  matplotlib
3
  effdet
4
+ mmcv-full
5
  Pillow==8.4.0
utils.py CHANGED
@@ -4,8 +4,8 @@ import numpy as np
4
  import cv2
5
  import torch
6
 
 
7
  from classifier import CustomViT
8
- from model import get_model
9
 
10
  def plot_img_no_mask(image : np.ndarray, boxes : torch.Tensor, labels):
11
  colors = {
@@ -67,7 +67,15 @@ def get_models(
67
  - (torch.nn.Module): Classifier model
68
  """
69
  print('Loading the detection model')
70
- det_model = get_model(detection_ckpt)
 
 
 
 
 
 
 
 
71
  det_model.eval()
72
 
73
  print('Loading the classifier model')
 
4
  import cv2
5
  import torch
6
 
7
+ from icevision.models.checkpoint import model_from_checkpoint
8
  from classifier import CustomViT
 
9
 
10
  def plot_img_no_mask(image : np.ndarray, boxes : torch.Tensor, labels):
11
  colors = {
 
67
  - (torch.nn.Module): Classifier model
68
  """
69
  print('Loading the detection model')
70
+ checkpoint_and_model = model_from_checkpoint(
71
+ detection_ckpt,
72
+ model_name='ross.efficientdet',
73
+ backbone_name='d0',
74
+ img_size=512,
75
+ classes=['Waste'],
76
+ revise_keys=[(r'^model\.', '')])
77
+
78
+ det_model = checkpoint_and_model['model']
79
  det_model.eval()
80
 
81
  print('Loading the classifier model')