from io import BytesIO from typing import Dict, Tuple, Union from icevision import * from icevision.models.checkpoint import model_from_checkpoint from classifier import transform_image from icevision.models import ross import PIL import torch import numpy as np import torchvision MODEL_TYPE = ross.efficientdet def predict(det_model : torch.nn.Module, image : Union[str, BytesIO], detection_threshold : float) -> Dict: """ Make a prediction with the detection model. Args: det_model (torch.nn.Module): Detection model image (Union[str, BytesIO]): Image filepath if the image is one of the example images and BytesIO if the image is a custom image uploaded by the user. detection_threshold (float): Detection threshold Returns: Dict: Prediction dictionary. """ img = PIL.Image.open(image) # Class map and transforms class_map = ClassMap(classes=['Waste']) transforms = tfms.A.Adapter([ *tfms.A.resize_and_pad(512), tfms.A.Normalize() ]) # Single prediction pred_dict = MODEL_TYPE.end2end_detect(img, transforms, det_model, class_map=class_map, detection_threshold=detection_threshold, return_as_pil_img=False, return_img=True, display_bbox=False, display_score=False, display_label=False) return pred_dict def prepare_prediction(pred_dict : Dict, nms_threshold : str) -> Tuple[torch.Tensor, np.ndarray]: """ Get the predictions in a right format. Args: pred_dict (Dict): Prediction dictionary. nms_threshold (float): Threshold for the NMS postprocess. Returns: Tuple: Tuple containing the following: - (torch.Tensor): Bounding boxes - (np.ndarray): Image data """ # Convert each box to a tensor and stack them into an unique tensor boxes = [box.to_tensor() for box in pred_dict['detection']['bboxes']] boxes = torch.stack(boxes) # Get the scores and labels as tensor scores = torch.as_tensor(pred_dict['detection']['scores']) labels = torch.as_tensor(pred_dict['detection']['label_ids']) image = np.array(pred_dict['img']) # Apply NMS to postprocess the bounding boxes fixed_boxes = torchvision.ops.batched_nms(boxes, scores, labels,nms_threshold) boxes = boxes[fixed_boxes, :] return boxes, image def predict_class(classifier : torch.nn.Module, image : np.ndarray, bboxes : torch.Tensor) -> np.ndarray: """ Predict the class of each detected object. Args: classifier (torch.nn.Module): Classifier model. image (np.ndarray): Image data. bboxes (torch.Tensor): Bounding boxes. Returns: np.ndarray: Array containing the predicted class for each object. """ preds = [] for bbox in bboxes: img = image.copy() bbox = np.array(bbox).astype(int) # Get the bounding box content cropped_img = PIL.Image.fromarray(img).crop(bbox) cropped_img = np.array(cropped_img) # Apply transformations to the cropped image tran_image = transform_image(cropped_img, 224) # Channels first tran_image = tran_image.transpose(2, 0, 1) tran_image = torch.as_tensor(tran_image, dtype=torch.float).unsqueeze(0) # Make prediction y_preds = classifier(tran_image) preds.append(y_preds.softmax(1).detach().numpy()) preds = np.concatenate(preds).argmax(1) return preds