import gradio as gr from gradio_image_prompter import ImagePrompter import torch import numpy as np from sam2.sam2_image_predictor import SAM2ImagePredictor from uuid import uuid4 import os from huggingface_hub import upload_folder, login from PIL import Image as PILImage from datasets import Dataset, Features, Array2D, Image import shutil import time MODEL = "facebook/sam2-hiera-large" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") PREDICTOR = SAM2ImagePredictor.from_pretrained(MODEL, device=DEVICE) DESTINATION_DS = "amaye15/object-segmentation" # login(os.getenv("TOKEN")) IMAGE = None MASKS = None MASKED_IMAGES = None INDEX = None def prompter(prompts): image = np.array(prompts["image"]) # Convert the image to a numpy array points = prompts["points"] # Get the points from prompts # Perform inference with multimask_output=True with torch.inference_mode(): PREDICTOR.set_image(image) input_point = [[point[0], point[1]] for point in points] input_label = [1] * len(points) # Assuming all points are foreground masks, _, _ = PREDICTOR.predict( point_coords=input_point, point_labels=input_label, multimask_output=True ) # Prepare individual images with separate overlays overlay_images = [] for i, mask in enumerate(masks): print(f"Predicted Mask {i+1}:", mask.shape) red_mask = np.zeros_like(image) red_mask[:, :, 0] = mask.astype(np.uint8) * 255 # Apply the red channel red_mask = PILImage.fromarray(red_mask) # Convert the original image to a PIL image original_image = PILImage.fromarray(image) # Blend the original image with the red mask blended_image = PILImage.blend(original_image, red_mask, alpha=0.5) # Add the blended image to the list overlay_images.append(blended_image) global IMAGE, MASKS, MASKED_IMAGES IMAGE, MASKS = image, masks MASKED_IMAGES = [np.array(img) for img in overlay_images] return overlay_images[0], overlay_images[1], overlay_images[2], masks def select_mask( selected_mask_index, mask1, mask2, mask3, ): masks = [mask1, mask2, mask3] global INDEX INDEX = selected_mask_index return masks[selected_mask_index] def save_selected_mask(image, mask, output_dir="output"): output_dir = os.path.join(os.getcwd(), output_dir) os.makedirs(output_dir, exist_ok=True) folder_id = str(uuid4()) folder_path = os.path.join(output_dir, folder_id) os.makedirs(folder_path, exist_ok=True) data_path = os.path.join(folder_path, "data.parquet") data = { "image": IMAGE, "masked_image": MASKED_IMAGES[INDEX], "mask": MASKS[INDEX], } features = Features( { "image": Image(), "masked_image": Image(), "mask": Array2D( dtype="int64", shape=(MASKS[INDEX].shape[0], MASKS[INDEX].shape[1]) ), } ) ds = Dataset.from_list([data], features=features) ds.to_parquet(data_path) upload_folder( folder_path=output_dir, repo_id=DESTINATION_DS, repo_type="dataset", ) shutil.rmtree(folder_path) iframe_code = "Success - Check out the 'Results' tab." return iframe_code # time.sleep(5) # # Add a random query parameter to force reload # random_param = uuid4() # iframe_code = f""" # # """ # Define the Gradio Blocks app with gr.Blocks() as demo: with gr.Tab("Object Segmentation - Point Prompt"): gr.Markdown("# Image Point Collector with Multiple Separate Mask Overlays") gr.Markdown( "Upload an image, click on it, and get each predicted mask overlaid separately in red on individual images." ) with gr.Row(): with gr.Column(): # Input: ImagePrompter image_input = ImagePrompter(show_label=False) submit_button = gr.Button("Submit") with gr.Row(): with gr.Column(): # Outputs: Up to 3 overlay images image_output_1 = gr.Image(show_label=False) with gr.Column(): image_output_2 = gr.Image(show_label=False) with gr.Column(): image_output_3 = gr.Image(show_label=False) # Dropdown for selecting the correct mask with gr.Row(): mask_selector = gr.Radio( label="Select the correct mask", choices=["Mask 1", "Mask 2", "Mask 3"], type="index", ) # selected_mask_output = gr.Image(show_label=False) save_button = gr.Button("Save Selected Mask and Image") iframe_display = gr.Markdown() # Define the action triggered by the submit button submit_button.click( fn=prompter, inputs=image_input, outputs=[image_output_1, image_output_2, image_output_3, gr.State()], show_progress=True, ) # Define the action triggered by mask selection mask_selector.change( fn=select_mask, inputs=[mask_selector, image_output_1, image_output_2, image_output_3], outputs=gr.State(), ) # Define the action triggered by the save button save_button.click( fn=save_selected_mask, inputs=[gr.State(), gr.State()], outputs=iframe_display, show_progress=True, ) with gr.Tab("Results"): with gr.Row(): gr.HTML( f""" """ ) # with gr.Column(): # source = gr.Textbox(label="Source Dataset") # source_display = gr.Markdown() # iframe_display = gr.HTML() # source.change( # save_dataset_name, # inputs=(gr.State("source_dataset"), source), # outputs=(source_display, iframe_display), # ) # with gr.Column(): # destination = gr.Textbox(label="Destination Dataset") # destination_display = gr.Markdown() # destination.change( # save_dataset_name, # inputs=(gr.State("destination_dataset"), destination), # outputs=destination_display, # ) # Launch the Gradio app demo.launch()