from ast import Interactive from xml.sax.xmlreader import InputSource import gradio as gr import numpy as np import torch import torch.nn.functional as F import os import cv2 import pathlib import math device = 'cuda' if torch.cuda.is_available() else 'cpu' from predictor import Predictor display_height = 600 H = 256 W = 256 test_example_dir = pathlib.Path("./test_examples") test_examples = [str(test_example_dir / x) for x in sorted(os.listdir(test_example_dir))] default_example = test_examples[0] exp_dir = pathlib.Path('./checkpoints') default_model = 'ScribblePrompt-Unet' model_dict = { 'ScribblePrompt-Unet': 'ScribblePrompt_unet_v1_nf192_res128.pt' } # ----------------------------------------------------------------------------- # Model initialization functions # ----------------------------------------------------------------------------- def load_model(exp_key: str = default_model): fpath = exp_dir / model_dict.get(exp_key) exp = Predictor(fpath) return exp, None # ----------------------------------------------------------------------------- # Vizualization functions # ----------------------------------------------------------------------------- def _get_overlay(img, lay, const_color="l_blue"): """ Helper function for preparing overlay """ assert lay.ndim==2, "Overlay must be 2D, got shape: " + str(lay.shape) if img.ndim == 2: img = np.repeat(img[...,None], 3, axis=-1) assert img.ndim==3, "Image must be 3D, got shape: " + str(img.shape) if const_color == "blue": const_color = 255*np.array([0, 0, 1]) elif const_color == "green": const_color = 255*np.array([0, 1, 0]) elif const_color == "red": const_color = 255*np.array([1, 0, 0]) elif const_color == "l_blue": const_color = np.array([31, 119, 180]) elif const_color == "orange": const_color = np.array([255, 127, 14]) else: raise NotImplementedError x,y = np.nonzero(lay) for i in range(img.shape[-1]): img[x,y,i] = const_color[i] return img def image_overlay(img, mask=None, scribbles=None, contour=False, alpha=0.5): """ Overlay the ground truth mask and scribbles on the image if provided """ assert img.ndim == 2, "Image must be 2D, got shape: " + str(img.shape) output = np.repeat(img[...,None], 3, axis=-1) if mask is not None: assert mask.ndim == 2, "Mask must be 2D, got shape: " + str(mask.shape) if contour: contours = cv2.findContours((mask[...,None]>0.5).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(output, contours[0], -1, (0, 255, 0), 2) else: mask_overlay = _get_overlay(img, mask) mask2 = 0.5*np.repeat(mask[...,None], 3, axis=-1) output = cv2.convertScaleAbs(mask_overlay * mask2 + output * (1 - mask2)) if scribbles is not None: pos_scribble_overlay = _get_overlay(output, scribbles[0,...], const_color="green") cv2.addWeighted(pos_scribble_overlay, alpha, output, 1 - alpha, 0, output) neg_scribble_overlay = _get_overlay(output, scribbles[1,...], const_color="red") cv2.addWeighted(neg_scribble_overlay, alpha, output, 1 - alpha, 0, output) return output def viz_pred_mask(img, mask=None, point_coords=None, point_labels=None, bbox_coords=None, seperate_scribble_masks=None, binary=True): """ Visualize image with clicks, scribbles, predicted mask overlaid """ assert isinstance(img, np.ndarray), "Image must be numpy array, got type: " + str(type(img)) if mask is not None: if isinstance(mask, torch.Tensor): mask = mask.cpu().numpy() if binary and mask is not None: mask = 1*(mask > 0.5) out = image_overlay(img, mask=mask, scribbles=seperate_scribble_masks) H,W = img.shape[:2] marker_size = min(H,W)//100 if point_coords is not None: for i,(col,row) in enumerate(point_coords): if point_labels[i] == 1: cv2.circle(out,(col, row), marker_size, (0,255,0), -1) else: cv2.circle(out,(col, row), marker_size, (255,0,0), -1) if bbox_coords is not None: for i in range(len(bbox_coords)//2): cv2.rectangle(out, bbox_coords[2*i], bbox_coords[2*i+1], (255,165,0), marker_size) if len(bbox_coords) % 2 == 1: cv2.circle(out, tuple(bbox_coords[-1]), marker_size, (255,165,0), -1) return out.astype(np.uint8) # ----------------------------------------------------------------------------- # Collect scribbles # ----------------------------------------------------------------------------- def get_scribbles(seperate_scribble_masks, last_scribble_mask, scribble_img): """ Record scribbles """ assert isinstance(seperate_scribble_masks, np.ndarray), "seperate_scribble_masks must be numpy array, got type: " + str(type(seperate_scribble_masks)) if scribble_img is not None: # Only use first layer color_mask = scribble_img.get('layers')[0] positive_scribbles = 1.0*(color_mask[...,1] > 128) negative_scribbles = 1.0*(color_mask[...,0] > 128) seperate_scribble_masks = np.stack([positive_scribbles, negative_scribbles], axis=0) last_scribble_mask = None return seperate_scribble_masks, last_scribble_mask def get_predictions(predictor, input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, low_res_mask, img_features, multimask_mode): """ Make predictions """ box = None if len(bbox_coords) == 1: gr.Error("Please click a second time to define the bounding box") box = None elif len(bbox_coords) == 2: box = torch.Tensor(bbox_coords).flatten()[None,None,...].int().to(device) # B x n x 4 if seperate_scribble_masks is not None: scribble = torch.from_numpy(seperate_scribble_masks)[None,...].to(device) else: scribble = None prompts = dict( img=torch.from_numpy(input_img)[None,None,...].to(device)/255, point_coords=torch.Tensor([click_coords]).int().to(device) if len(click_coords)>0 else None, point_labels=torch.Tensor([click_labels]).int().to(device) if len(click_labels)>0 else None, scribble=scribble, mask_input=low_res_mask.to(device) if low_res_mask is not None else None, box=box, ) mask, img_features, low_res_mask = predictor.predict(prompts, img_features, multimask_mode=multimask_mode) return mask, img_features, low_res_mask def refresh_predictions(predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label, scribble_img, seperate_scribble_masks, last_scribble_mask, best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode): # Record any new scribbles seperate_scribble_masks, last_scribble_mask = get_scribbles( seperate_scribble_masks, last_scribble_mask, scribble_img ) # Make prediction best_mask, img_features, low_res_mask = get_predictions( predictor, input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, low_res_mask, img_features, multimask_mode ) # Update input visualizations mask_to_viz = best_mask.numpy() click_input_viz = viz_pred_mask(input_img, mask_to_viz, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox) empty_channel = np.zeros(input_img.shape[:2]).astype(np.uint8) full_channel = 255*np.ones(input_img.shape[:2]).astype(np.uint8) gray_mask = (255*mask_to_viz).astype(np.uint8) bg = viz_pred_mask(input_img, mask_to_viz, click_coords, click_labels, bbox_coords, None, binary_checkbox) old_scribbles = scribble_img.get('layers')[0] scribble_mask = 255*(old_scribbles > 0).any(-1) scribble_input_viz = { "background": np.stack([bg[...,i] for i in range(3)]+[full_channel], axis=-1), ["layers"][0]: [np.stack([ (255*seperate_scribble_masks[1]).astype(np.uint8), (255*seperate_scribble_masks[0]).astype(np.uint8), empty_channel, scribble_mask ], axis=-1)], "composite": np.stack([click_input_viz[...,i] for i in range(3)]+[empty_channel], axis=-1), } mask_img = 255*(mask_to_viz[...,None].repeat(axis=2, repeats=3)>0.5) if binary_checkbox else mask_to_viz[...,None].repeat(axis=2, repeats=3) out_viz = [ viz_pred_mask(input_img, mask_to_viz, point_coords=None, point_labels=None, bbox_coords=None, seperate_scribble_masks=None, binary=binary_checkbox), input_img, mask_img, ] return click_input_viz, scribble_input_viz, out_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask def get_select_coords(predictor, input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask, scribble_img, img_features, output_img, binary_checkbox, multimask_mode, autopredict_checkbox, evt: gr.SelectData): """ Record user click and update the prediction """ # Record click coordinates if bbox_label: bbox_coords.append(evt.index) elif brush_label in ['Positive (green)', 'Negative (red)']: click_coords.append(evt.index) click_labels.append(1 if brush_label=='Positive (green)' else 0) else: raise TypeError("Invalid brush label: {brush_label}") # Only make new prediction if not waiting for additional bounding box click if (len(bbox_coords) % 2 == 0) and autopredict_checkbox: click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask = refresh_predictions( predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label, scribble_img, seperate_scribble_masks, last_scribble_mask, best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode ) return click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask else: click_input_viz = viz_pred_mask( input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox ) scribble_input_viz = viz_pred_mask( input_img, best_mask, click_coords, click_labels, bbox_coords, None, binary_checkbox ) # Don't update output image if waiting for additional bounding box click return click_input_viz, scribble_input_viz, output_img, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask def undo_click(predictor, input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask, scribble_img, img_features, output_img, binary_checkbox, multimask_mode, autopredict_checkbox): """ Remove last click and then update the prediction """ if bbox_label: if len(bbox_coords) > 0: bbox_coords.pop() elif brush_label in ['Positive (green)', 'Negative (red)']: if len(click_coords) > 0: click_coords.pop() click_labels.pop() else: raise TypeError("Invalid brush label: {brush_label}") # Only make new prediction if not waiting for additional bounding box click if (len(bbox_coords)==0 or len(bbox_coords)==2) and autopredict_checkbox: click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask = refresh_predictions( predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label, scribble_img, seperate_scribble_masks, last_scribble_mask, best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode ) return click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask else: click_input_viz = viz_pred_mask( input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox ) scribble_input_viz = viz_pred_mask( input_img, best_mask, click_coords, click_labels, bbox_coords, None, binary_checkbox ) # Don't update output image if waiting for additional bounding box click return click_input_viz, scribble_input_viz, output_img, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask # -------------------------------------------------- with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as demo: # State variables seperate_scribble_masks = gr.State(np.zeros((2, H, W), dtype=np.float32)) last_scribble_mask = gr.State(np.zeros((H, W), dtype=np.float32)) click_coords = gr.State([]) click_labels = gr.State([]) bbox_coords = gr.State([]) # Load default model predictor = gr.State(load_model()[0]) img_features = gr.State(None) # For SAM models best_mask = gr.State(None) low_res_mask = gr.State(None) gr.HTML("""\

ScribblePrompt: Fast and Flexible Interactive Segmention for Any Biomedical Image

ScribblePrompt is an interactive segmentation tool designed to help users segment new structures in medical images using scribbles, clicks and bounding boxes. [paper | website | code]

""") with gr.Accordion("Open for instructions!", open=False): gr.Markdown( """ * Select an input image from the examples below or upload your own image through the 'Input Image' tab. * Use the 'Scribbles' tab to draw positive or negative scribbles. - Use the buttons in the top right hand corner of the canvas to undo or adjust the brush size - Note: the app cannot detect new scribbles drawn on top of previous scribbles in a different color. Please undo/erase the scribble before drawing on the same pixel in a different color. * Use the 'Clicks/Boxes' tab to draw positive or negative clicks and bounding boxes by placing two clicks. * The 'Output' tab will show the model's prediction based on your current inputs and the previous prediction. * The 'Clear Input Mask' button will clear the latest prediction (which is used as an input to the model). * The 'Clear All Inputs' button will clear all inputs (including scribbles, clicks, bounding boxes, and the last prediction). """ ) # Interface ------------------------------------ with gr.Row(): model_dropdown = gr.Dropdown( label="Model", choices = list(model_dict.keys()), value=default_model, multiselect=False, interactive=False, visible=False ) with gr.Row(): with gr.Column(scale=1): brush_label = gr.Radio(["Positive (green)", "Negative (red)"], value="Positive (green)", label="Scribble/Click Label") bbox_label = gr.Checkbox(value=False, label="Bounding Box (2 clicks)") with gr.Column(scale=1): binary_checkbox = gr.Checkbox(value=True, label="Show binary masks", visible=False) autopredict_checkbox = gr.Checkbox(value=True, label="Auto-update prediction on clicks") with gr.Accordion("Troubleshooting tips", open=False): gr.Markdown("If you encounter an error try clicking 'Clear All Inputs'.") multimask_mode = gr.Checkbox(value=True, label="Multi-mask mode", visible=False) with gr.Row(): green_brush = gr.Brush(colors=["#00FF00"], color_mode="fixed", default_size=3) red_brush = gr.Brush(colors=["#FF0000"], color_mode="fixed", default_size=3) with gr.Column(scale=1): with gr.Tab("Scribbles"): scribble_img = gr.ImageEditor( label="Input", image_mode="RGB", brush=green_brush, type='numpy', value=default_example, transforms=(), sources=(), container=True, show_download_button=True, height=display_height+60 ) with gr.Tab("Clicks/Boxes") as click_tab: click_img = gr.Image( label="Input", type='numpy', value=default_example, show_download_button=True, container=True, height=display_height ) with gr.Row(): undo_click_button = gr.Button("Undo Last Click") clear_click_button = gr.Button("Clear Clicks/Boxes", variant="stop") with gr.Tab("Input Image"): input_img = gr.Image( label="Input", image_mode="L", value=default_example, show_download_button=True, container=True, height=display_height ) gr.Markdown("To upload your own image: click the `x` in the top right corner to clear the current image, then drag & drop") with gr.Column(scale=1): with gr.Tab("Output"): output_img = gr.Gallery( label='Output', columns=1, elem_id="gallery", preview=True, object_fit="scale-down", height=display_height+60, container=True ) submit_button = gr.Button("Refresh Prediction", variant='primary') clear_all_button = gr.ClearButton([scribble_img], value="Clear All Inputs", variant="stop") clear_mask_button = gr.Button("Clear Input Mask") # ---------------------------------------------- # Loading Models # ---------------------------------------------- model_dropdown.change(fn=load_model, inputs=[model_dropdown], outputs=[predictor, img_features] ) # ---------------------------------------------- # Loading Examples # ---------------------------------------------- gr.Examples(examples=test_examples, inputs=[input_img], examples_per_page=12, label='Examples from datasets unseen during training' ) # with gr.Accordion(): # height_scribble = gr.Number(label="Scribble Panel Height", # value=display_height, Interactive=True) # height_scribble.change( # fn=lambda x: gr.update(height=x), # inputs=[scribble_img], # outputs=[scribble_img], # ) # When clear clicks button is clicked def clear_click_history(input_img): return input_img, input_img, [], [], [], None, None clear_click_button.click(clear_click_history, inputs=[input_img], outputs=[click_img, scribble_img, click_coords, click_labels, bbox_coords, best_mask, low_res_mask]) # When clear all button is clicked def clear_all_history(input_img): if input_img is not None: input_shape = input_img.shape[:2] else: input_shape = (H, W) return input_img, input_img, [], [], [], [], np.zeros((2,)+input_shape, dtype=np.float32), np.zeros(input_shape, dtype=np.float32), None, None, None # def clear_history_and_pad_input(input_img): # if input_img is not None: # h,w = input_img.shape[:2] # if h != w: # # Pad to square # pad = abs(h-w) # if h > w: # padding = [(0,0), (math.ceil(pad/2),math.floor(pad/2))] # else: # padding = [(math.ceil(pad/2),math.floor(pad/2)), (0,0)] # input_img = np.pad(input_img, padding, mode='constant', constant_values=0) # return clear_all_history(input_img) input_img.change(clear_all_history, inputs=[input_img], outputs=[click_img, scribble_img, output_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask, best_mask, low_res_mask, img_features ]) clear_all_button.click(clear_all_history, inputs=[input_img], outputs=[click_img, scribble_img, output_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask, best_mask, low_res_mask, img_features ]) # clear previous prediction mask def clear_best_mask(input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks): click_input_viz = viz_pred_mask( input_img, None, click_coords, click_labels, bbox_coords, seperate_scribble_masks ) scribble_input_viz = viz_pred_mask( input_img, None, click_coords, click_labels, bbox_coords, None ) return None, None, click_input_viz, scribble_input_viz clear_mask_button.click( clear_best_mask, inputs=[input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks], outputs=[best_mask, low_res_mask, click_img, scribble_img], ) # ---------------------------------------------- # Clicks # ---------------------------------------------- click_img.select(get_select_coords, inputs=[ predictor, input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask, scribble_img, img_features, output_img, binary_checkbox, multimask_mode, autopredict_checkbox ], outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask], api_name = "get_select_coords" ) submit_button.click(fn=refresh_predictions, inputs=[ predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label, scribble_img, seperate_scribble_masks, last_scribble_mask, best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode ], outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask], api_name="refresh_predictions" ) undo_click_button.click(fn=undo_click, inputs=[ predictor, input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask, scribble_img, img_features, output_img, binary_checkbox, multimask_mode, autopredict_checkbox ], outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask], api_name="undo_click" ) def update_click_img(input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox, last_scribble_mask, scribble_img, brush_label, best_mask): """ Draw scribbles in the click canvas """ seperate_scribble_masks, last_scribble_mask = get_scribbles( seperate_scribble_masks, last_scribble_mask, scribble_img ) click_input_viz = viz_pred_mask( input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox ) return click_input_viz, seperate_scribble_masks, last_scribble_mask click_tab.select(fn=update_click_img, inputs=[input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox, last_scribble_mask, scribble_img, brush_label, best_mask], outputs=[click_img, seperate_scribble_masks, last_scribble_mask], api_name="update_click_img" ) # ---------------------------------------------- # Scribbles # ---------------------------------------------- def change_brush_color(seperate_scribble_masks, last_scribble_mask, scribble_img, label): """ Recorn new scribbles when changing brush color """ if label == "Negative (red)": brush_update = gr.update(brush=red_brush) elif label == "Positive (green)": brush_update = gr.update(brush=green_brush) else: raise TypeError("Invalid brush color") return seperate_scribble_masks, last_scribble_mask, brush_update brush_label.change(fn=change_brush_color, inputs=[seperate_scribble_masks, last_scribble_mask, scribble_img, brush_label], outputs=[seperate_scribble_masks, last_scribble_mask, scribble_img], api_name="change_brush_color" ) if __name__ == "__main__": demo.queue(api_open=False).launch(show_api=False)