File size: 4,613 Bytes
850cda3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# GSL

import os
import torch
import numpy as np
from PIL import Image, ImageChops, ImageEnhance
import cv2
from simple_lama_inpainting import SimpleLama
from segment_anything import build_sam, SamPredictor
from GroundingDINO.groundingdino.util import box_ops
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict
from GroundingDINO.groundingdino.util.inference import annotate, load_image, predict
from huggingface_hub import hf_hub_download

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
    cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)
    args = SLConfig.fromfile(cache_config_file)
    args.device = device
    model = build_model(args)
    cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
    checkpoint = torch.load(cache_file, map_location=device)
    model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
    model.eval()
    return model

groundingdino_model = load_model_hf(
    repo_id="ShilongLiu/GroundingDINO",
    filename="groundingdino_swinb_cogcoor.pth",
    ckpt_config_filename="GroundingDINO_SwinB.cfg.py",
    device=device
)

sam_predictor = SamPredictor(build_sam(checkpoint='sam_vit_h_4b8939.pth').to(device))
simple_lama = SimpleLama()

def detect(image, model, text_prompt='insect . flower . cloud', box_threshold=0.15, text_threshold=0.15):
    boxes, logits, phrases = predict(
        image=image,
        model=model,
        caption=text_prompt,
        box_threshold=box_threshold,
        text_threshold=text_threshold
    )
    annotated_frame = annotate(image_source=image, boxes=boxes, logits=logits, phrases=phrases)
    annotated_frame = annotated_frame[..., ::-1]  # BGR to RGB
    return annotated_frame, boxes, phrases

def segment(image, sam_model, boxes):
    sam_model.set_image(image)
    H, W, _ = image.shape
    boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
    transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_xyxy.to(device), image.shape[:2])
    masks, _, _ = sam_model.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes,
        multimask_output=True,
    )
    return masks.cpu()

def draw_mask(mask, image, random_color=True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    annotated_frame_pil = Image.fromarray(image).convert("RGBA")
    mask_image_pil = Image.fromarray((mask_image.numpy() * 255).astype(np.uint8)).convert("RGBA")
    return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))

def dilate_mask(mask, dilate_factor=15):
    mask = mask.astype(np.uint8)
    mask = cv2.dilate(
        mask,
        np.ones((dilate_factor, dilate_factor), np.uint8),
        iterations=1
    )
    return mask

def gsl_process_image(local_image_path):
    # Load image
    image_source, image = load_image(local_image_path)

    # Detect insects
    annotated_frame, detected_boxes, phrases = detect(image, model=groundingdino_model)
    indices = [i for i, s in enumerate(phrases) if 'insect' in s]

    # Segment insects
    segmented_frame_masks = segment(image_source, sam_predictor, detected_boxes[indices])

    # Combine masks
    final_mask = None
    for i in range(len(segmented_frame_masks) - 1):
        if final_mask is None:
            final_mask = np.bitwise_or(segmented_frame_masks[i][0].cpu(), segmented_frame_masks[i + 1][0].cpu())
        else:
            final_mask = np.bitwise_or(final_mask, segmented_frame_masks[i + 1][0].cpu())

    # Draw mask
    annotated_frame_with_mask = draw_mask(final_mask, image_source)

    # Dilate mask
    mask = final_mask.numpy()
    mask = mask.astype(np.uint8) * 255
    mask = dilate_mask(mask)
    dilated_image_mask_pil = Image.fromarray(mask)

    # Inpainting
    result = simple_lama(image_source, dilated_image_mask_pil)

    # Difference and composite
    diff = ImageChops.difference(result, Image.fromarray(image_source))
    threshold = 7
    diff2 = diff.convert('L').point(lambda p: 255 if p > threshold else 0).convert('1')
    img3 = Image.new('RGB', Image.fromarray(image_source).size, (255, 236, 10))
    diff3 = Image.composite(Image.fromarray(image_source), img3, diff2)

    return diff3