File size: 6,390 Bytes
0dd537b
 
c4af616
0dd537b
 
 
 
 
 
 
c4af616
0dd537b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4af616
 
 
 
0dd537b
 
 
 
 
 
 
 
 
38277a1
0dd537b
 
38277a1
 
c4af616
0dd537b
 
38277a1
 
0dd537b
c4af616
 
38277a1
c4af616
 
 
 
 
 
7ac4295
c4af616
 
 
0dd537b
38277a1
 
7ac4295
38277a1
0dd537b
 
38277a1
0dd537b
 
 
 
 
 
 
 
 
 
 
50a5f00
0dd537b
 
 
 
 
 
 
50a5f00
0dd537b
50a5f00
 
 
c4af616
50a5f00
cba1a87
0dd537b
c4af616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38277a1
 
 
 
 
 
c4af616
 
 
0dd537b
c4af616
 
7ac4295
c4af616
7ac4295
 
c4af616
7ac4295
c4af616
7ac4295
 
0dd537b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import app_configs as configs
from feedback import Feedback
import service
import gradio as gr
import numpy as np
import cv2 
from PIL import Image
import logging 
from huggingface_hub import hf_hub_download
import torch

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()

sam = None #service.get_sam(configs.model_type, configs.model_ckpt_path, configs.device)
red = (255,0,0)
blue = (0,0,255)

def load_sam_instance():
    global sam
    if sam is None:
        gr.Info('Initialising SAM, hang in there...')
        if not os.path.exists(configs.model_ckpt_path):
            chkpt_path = hf_hub_download("ybelkada/segment-anything", configs.model_ckpt_path)
        else:
            chkpt_path = configs.model_ckpt_path
        device = configs.device
        if device is None:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
        sam = service.get_sam(configs.model_type, chkpt_path, device)
    return sam

block = gr.Blocks()
with block:
    # states
    def point_coords_empty():
        return []
    def point_labels_empty():
        return []
    raw_image = gr.Image(type='pil', visible=False)
    point_coords = gr.State(point_coords_empty)
    point_labels = gr.State(point_labels_empty)
    masks = gr.State()
    cutout_idx = gr.State(set())
    feedback = gr.State(lambda : Feedback())

    # UI
    with gr.Column():
        with gr.Row():
            input_image = gr.Image(label='Input', height=512, type='pil')
            masks_annotated_image = gr.AnnotatedImage(label='Segments', height=512)
            cutout_galary = gr.Gallery(label='Cutouts', object_fit='contain', height=512)
        with gr.Row():
            with gr.Column(scale=1):
                point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1)
                reset_btn = gr.Button('Reset')
                run_btn = gr.Button('Run', variant = 'primary')
            with gr.Column(scale=2):
                with gr.Accordion('Provide Feedback'):
                    feedback_textbox = gr.Textbox(lines=3, show_label=False, info="Comments (Leave blank to vote without any comments)")
                    with gr.Row():
                        upvote_button = gr.Button('Upvote')
                        downvote_button = gr.Button('Downvote')
    # components
    components = {
        point_coords, point_labels, raw_image, masks, cutout_idx,
        feedback, upvote_button, downvote_button, feedback_textbox,
        input_image, point_label_radio, reset_btn, run_btn, masks_annotated_image}
    # event - init coords
    def on_reset_btn_click(raw_image):
        return raw_image, point_coords_empty(), point_labels_empty(), None, []
    reset_btn.click(on_reset_btn_click, [raw_image], [input_image, point_coords, point_labels], queue=False)

    def on_input_image_upload(input_image):
        return input_image, point_coords_empty(), point_labels_empty(), None
    input_image.upload(on_input_image_upload, [input_image], [raw_image, point_coords, point_labels], queue=False)

    # event - set coords
    def on_input_image_select(input_image, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
        x, y = evt.index
        color = red if point_label_radio == 0 else blue
        img = np.array(input_image)
        cv2.circle(img, (x, y), 5, color, -1)
        img = Image.fromarray(img)
        point_coords.append([x,y])
        point_labels.append(point_label_radio)
        return img, point_coords, point_labels
    input_image.select(on_input_image_select, [input_image, point_coords, point_labels, point_label_radio], [input_image, point_coords, point_labels], queue=False)

    # event - inference
    def on_run_btn_click(inputs):
        sam = load_sam_instance()
        image = inputs[raw_image]
        if len(inputs[point_coords]) == 0:
            if configs.enable_segment_all:
                generated_masks, _ = service.predict_all(sam, image)
            else:
                raise gr.Error('Segment-all disabled, set point label(s) before running')
        else:
            generated_masks, _ = service.predict_conditioned(sam,
                                                            image, 
                                                            point_coords=np.array(inputs[point_coords]), 
                                                            point_labels=np.array(inputs[point_labels]))
        annotated = (image, [(generated_masks[i], f'Mask {i}') for i in range(len(generated_masks))])
        inputs[feedback].save_inference(
            pt_coords=inputs[point_coords], 
            pt_labels=inputs[point_labels],
            image=inputs[raw_image],
            mask=generated_masks,
            )
        return {
            masks_annotated_image:annotated, 
            masks: generated_masks, 
            cutout_idx: set(),
            feedback: inputs[feedback],
            }
    run_btn.click(on_run_btn_click, components, [masks_annotated_image, masks, cutout_idx, feedback], queue=True)

    # event - get cutout 
    def on_masks_annotated_image_select(inputs, evt:gr.SelectData):
        inputs[cutout_idx].add(evt.index)
        cutouts = [service.cutout(inputs[raw_image], inputs[masks][idx]) for idx in list(inputs[cutout_idx])]
        tight_cutouts = [service.crop_empty(cutout) for cutout in cutouts]
        inputs[feedback].save_feedback(cutout_idx=evt.index)
        return inputs[cutout_idx], tight_cutouts, inputs[feedback]
    masks_annotated_image.select(on_masks_annotated_image_select, components, [cutout_idx, cutout_galary, feedback], queue=False)

    # event - feedback
    def on_upvote_button_click(inputs):
        inputs[feedback].save_feedback(like=1, feedback_str=inputs[feedback_textbox])
        gr.Info('Thanks for your feedback')
        return {feedback:inputs[feedback],feedback_textbox:None}
    upvote_button.click(on_upvote_button_click,components,[feedback, feedback_textbox], queue=False)
    def on_downvote_button_click(inputs):
        inputs[feedback].save_feedback(like=-1, feedback_str=inputs[feedback_textbox])
        gr.Info('Thanks for your feedback')
        return {feedback:inputs[feedback],feedback_textbox:None}
    downvote_button.click(on_downvote_button_click,components,[feedback, feedback_textbox], queue=False)
if __name__ == '__main__':
    block.queue()
    block.launch()