File size: 3,266 Bytes
5ebeb73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417b347
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import cv2
import numpy as np
import torch
from mmdet.registry import VISUALIZERS


class SegMaskHelper:
    def __init__(self):
        pass

    # Pad the masks to image size (bug in RTMDet config?)
    # @timer_func
    def align_masks_with_image(self, result, img):
        masks = list()

        img = img[..., ::-1].copy()

        for j, mask in enumerate(result.pred_instances.masks):
            numpy_mask = mask.cpu().numpy()
            mask = cv2.resize(
                numpy_mask.astype(np.uint8),
                (img.shape[1], img.shape[0]),
                interpolation=cv2.INTER_NEAREST,
            )

            # Pad the mask to match the size of the image
            padded_mask = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
            padded_mask[: mask.shape[0], : mask.shape[1]] = mask
            mask = padded_mask
            mask = torch.from_numpy(mask)
            masks.append(mask)

        stacked_masks = torch.stack(masks)
        result.pred_instances.masks = stacked_masks

        return result

    # Crops the images using masks and put the cropped images on a white background
    # @timer_func
    def crop_masks(self, result, img):
        cropped_imgs = list()
        polygons = list()

        for j, mask in enumerate(result.pred_instances.masks):
            np_array = mask.cpu().numpy()
            contours, _ = cv2.findContours(
                np_array.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
            )  # fix so only one contour (the largest one) is extracted
            largest_contour = max(contours, key=cv2.contourArea)

            epsilon = 0.003 * cv2.arcLength(largest_contour, True)
            approx_poly = cv2.approxPolyDP(largest_contour, epsilon, True)
            approx_poly = np.squeeze(approx_poly)
            approx_poly = approx_poly.tolist()
            polygons.append(approx_poly)

            x, y, w, h = cv2.boundingRect(largest_contour)

            # Crop masked region and put on white background
            masked_region = img[y : y + h, x : x + w]
            white_background = np.ones_like(masked_region)
            white_background.fill(255)
            masked_region_on_white = cv2.bitwise_and(
                white_background, masked_region, mask=np_array.astype(np.uint8)[y : y + h, x : x + w]
            )

            cv2.bitwise_not(white_background, white_background, mask=np_array.astype(np.uint8)[y : y + h, x : x + w])
            res = white_background + masked_region_on_white

            cropped_imgs.append(res)

        return cropped_imgs, polygons

    def visualize_result(self, result, img, model_visualizer):
        visualizer = VISUALIZERS.build(model_visualizer)
        visualizer.add_datasample("result", img, data_sample=result, draw_gt=False)

        return visualizer.get_image()

    def _translate_line_coords(self, region_mask, line_polygons):
        region_mask = region_mask.cpu().numpy()
        region_masks_binary = (region_mask > 0).astype(np.uint8)

        box = cv2.boundingRect(region_masks_binary)
        translated_line_polygons = [[[a + box[0], b + box[1]] for [a, b] in poly] for poly in line_polygons]

        return translated_line_polygons


if __name__ == "__main__":
    pass