Spaces:
Runtime error
Runtime error
File size: 6,390 Bytes
0dd537b c4af616 0dd537b c4af616 0dd537b c4af616 0dd537b 38277a1 0dd537b 38277a1 c4af616 0dd537b 38277a1 0dd537b c4af616 38277a1 c4af616 7ac4295 c4af616 0dd537b 38277a1 7ac4295 38277a1 0dd537b 38277a1 0dd537b 50a5f00 0dd537b 50a5f00 0dd537b 50a5f00 c4af616 50a5f00 cba1a87 0dd537b c4af616 38277a1 c4af616 0dd537b c4af616 7ac4295 c4af616 7ac4295 c4af616 7ac4295 c4af616 7ac4295 0dd537b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import os
import app_configs as configs
from feedback import Feedback
import service
import gradio as gr
import numpy as np
import cv2
from PIL import Image
import logging
from huggingface_hub import hf_hub_download
import torch
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
sam = None #service.get_sam(configs.model_type, configs.model_ckpt_path, configs.device)
red = (255,0,0)
blue = (0,0,255)
def load_sam_instance():
global sam
if sam is None:
gr.Info('Initialising SAM, hang in there...')
if not os.path.exists(configs.model_ckpt_path):
chkpt_path = hf_hub_download("ybelkada/segment-anything", configs.model_ckpt_path)
else:
chkpt_path = configs.model_ckpt_path
device = configs.device
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
sam = service.get_sam(configs.model_type, chkpt_path, device)
return sam
block = gr.Blocks()
with block:
# states
def point_coords_empty():
return []
def point_labels_empty():
return []
raw_image = gr.Image(type='pil', visible=False)
point_coords = gr.State(point_coords_empty)
point_labels = gr.State(point_labels_empty)
masks = gr.State()
cutout_idx = gr.State(set())
feedback = gr.State(lambda : Feedback())
# UI
with gr.Column():
with gr.Row():
input_image = gr.Image(label='Input', height=512, type='pil')
masks_annotated_image = gr.AnnotatedImage(label='Segments', height=512)
cutout_galary = gr.Gallery(label='Cutouts', object_fit='contain', height=512)
with gr.Row():
with gr.Column(scale=1):
point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1)
reset_btn = gr.Button('Reset')
run_btn = gr.Button('Run', variant = 'primary')
with gr.Column(scale=2):
with gr.Accordion('Provide Feedback'):
feedback_textbox = gr.Textbox(lines=3, show_label=False, info="Comments (Leave blank to vote without any comments)")
with gr.Row():
upvote_button = gr.Button('Upvote')
downvote_button = gr.Button('Downvote')
# components
components = {
point_coords, point_labels, raw_image, masks, cutout_idx,
feedback, upvote_button, downvote_button, feedback_textbox,
input_image, point_label_radio, reset_btn, run_btn, masks_annotated_image}
# event - init coords
def on_reset_btn_click(raw_image):
return raw_image, point_coords_empty(), point_labels_empty(), None, []
reset_btn.click(on_reset_btn_click, [raw_image], [input_image, point_coords, point_labels], queue=False)
def on_input_image_upload(input_image):
return input_image, point_coords_empty(), point_labels_empty(), None
input_image.upload(on_input_image_upload, [input_image], [raw_image, point_coords, point_labels], queue=False)
# event - set coords
def on_input_image_select(input_image, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
x, y = evt.index
color = red if point_label_radio == 0 else blue
img = np.array(input_image)
cv2.circle(img, (x, y), 5, color, -1)
img = Image.fromarray(img)
point_coords.append([x,y])
point_labels.append(point_label_radio)
return img, point_coords, point_labels
input_image.select(on_input_image_select, [input_image, point_coords, point_labels, point_label_radio], [input_image, point_coords, point_labels], queue=False)
# event - inference
def on_run_btn_click(inputs):
sam = load_sam_instance()
image = inputs[raw_image]
if len(inputs[point_coords]) == 0:
if configs.enable_segment_all:
generated_masks, _ = service.predict_all(sam, image)
else:
raise gr.Error('Segment-all disabled, set point label(s) before running')
else:
generated_masks, _ = service.predict_conditioned(sam,
image,
point_coords=np.array(inputs[point_coords]),
point_labels=np.array(inputs[point_labels]))
annotated = (image, [(generated_masks[i], f'Mask {i}') for i in range(len(generated_masks))])
inputs[feedback].save_inference(
pt_coords=inputs[point_coords],
pt_labels=inputs[point_labels],
image=inputs[raw_image],
mask=generated_masks,
)
return {
masks_annotated_image:annotated,
masks: generated_masks,
cutout_idx: set(),
feedback: inputs[feedback],
}
run_btn.click(on_run_btn_click, components, [masks_annotated_image, masks, cutout_idx, feedback], queue=True)
# event - get cutout
def on_masks_annotated_image_select(inputs, evt:gr.SelectData):
inputs[cutout_idx].add(evt.index)
cutouts = [service.cutout(inputs[raw_image], inputs[masks][idx]) for idx in list(inputs[cutout_idx])]
tight_cutouts = [service.crop_empty(cutout) for cutout in cutouts]
inputs[feedback].save_feedback(cutout_idx=evt.index)
return inputs[cutout_idx], tight_cutouts, inputs[feedback]
masks_annotated_image.select(on_masks_annotated_image_select, components, [cutout_idx, cutout_galary, feedback], queue=False)
# event - feedback
def on_upvote_button_click(inputs):
inputs[feedback].save_feedback(like=1, feedback_str=inputs[feedback_textbox])
gr.Info('Thanks for your feedback')
return {feedback:inputs[feedback],feedback_textbox:None}
upvote_button.click(on_upvote_button_click,components,[feedback, feedback_textbox], queue=False)
def on_downvote_button_click(inputs):
inputs[feedback].save_feedback(like=-1, feedback_str=inputs[feedback_textbox])
gr.Info('Thanks for your feedback')
return {feedback:inputs[feedback],feedback_textbox:None}
downvote_button.click(on_downvote_button_click,components,[feedback, feedback_textbox], queue=False)
if __name__ == '__main__':
block.queue()
block.launch() |