import gradio as gr
from gradio_image_prompter import ImagePrompter
import torch
import numpy as np
from sam2.sam2_image_predictor import SAM2ImagePredictor
from uuid import uuid4
import os
from huggingface_hub import upload_folder, login
from PIL import Image as PILImage
from datasets import Dataset, Features, Array2D, Image
import shutil
import time
MODEL = "facebook/sam2-hiera-large"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PREDICTOR = SAM2ImagePredictor.from_pretrained(MODEL, device=DEVICE)
DESTINATION_DS = "amaye15/object-segmentation"
# login(os.getenv("TOKEN"))
IMAGE = None
MASKS = None
MASKED_IMAGES = None
INDEX = None
def prompter(prompts):
image = np.array(prompts["image"]) # Convert the image to a numpy array
points = prompts["points"] # Get the points from prompts
# Perform inference with multimask_output=True
with torch.inference_mode():
PREDICTOR.set_image(image)
input_point = [[point[0], point[1]] for point in points]
input_label = [1] * len(points) # Assuming all points are foreground
masks, _, _ = PREDICTOR.predict(
point_coords=input_point, point_labels=input_label, multimask_output=True
)
# Prepare individual images with separate overlays
overlay_images = []
for i, mask in enumerate(masks):
print(f"Predicted Mask {i+1}:", mask.shape)
red_mask = np.zeros_like(image)
red_mask[:, :, 0] = mask.astype(np.uint8) * 255 # Apply the red channel
red_mask = PILImage.fromarray(red_mask)
# Convert the original image to a PIL image
original_image = PILImage.fromarray(image)
# Blend the original image with the red mask
blended_image = PILImage.blend(original_image, red_mask, alpha=0.5)
# Add the blended image to the list
overlay_images.append(blended_image)
global IMAGE, MASKS, MASKED_IMAGES
IMAGE, MASKS = image, masks
MASKED_IMAGES = [np.array(img) for img in overlay_images]
return overlay_images[0], overlay_images[1], overlay_images[2], masks
def select_mask(
selected_mask_index,
mask1,
mask2,
mask3,
):
masks = [mask1, mask2, mask3]
global INDEX
INDEX = selected_mask_index
return masks[selected_mask_index]
def save_selected_mask(image, mask, output_dir="output"):
output_dir = os.path.join(os.getcwd(), output_dir)
os.makedirs(output_dir, exist_ok=True)
folder_id = str(uuid4())
folder_path = os.path.join(output_dir, folder_id)
os.makedirs(folder_path, exist_ok=True)
data_path = os.path.join(folder_path, "data.parquet")
data = {
"image": IMAGE,
"masked_image": MASKED_IMAGES[INDEX],
"mask": MASKS[INDEX],
}
features = Features(
{
"image": Image(),
"masked_image": Image(),
"mask": Array2D(
dtype="int64", shape=(MASKS[INDEX].shape[0], MASKS[INDEX].shape[1])
),
}
)
ds = Dataset.from_list([data], features=features)
ds.to_parquet(data_path)
upload_folder(
folder_path=output_dir,
repo_id=DESTINATION_DS,
repo_type="dataset",
)
shutil.rmtree(folder_path)
iframe_code = "Success - Check out the 'Results' tab."
return iframe_code
# time.sleep(5)
# # Add a random query parameter to force reload
# random_param = uuid4()
# iframe_code = f"""
#
# """
# Define the Gradio Blocks app
with gr.Blocks() as demo:
with gr.Tab("Object Segmentation - Point Prompt"):
gr.Markdown("# Image Point Collector with Multiple Separate Mask Overlays")
gr.Markdown(
"Upload an image, click on it, and get each predicted mask overlaid separately in red on individual images."
)
with gr.Row():
with gr.Column():
# Input: ImagePrompter
image_input = ImagePrompter(show_label=False)
submit_button = gr.Button("Submit")
with gr.Row():
with gr.Column():
# Outputs: Up to 3 overlay images
image_output_1 = gr.Image(show_label=False)
with gr.Column():
image_output_2 = gr.Image(show_label=False)
with gr.Column():
image_output_3 = gr.Image(show_label=False)
# Dropdown for selecting the correct mask
with gr.Row():
mask_selector = gr.Radio(
label="Select the correct mask",
choices=["Mask 1", "Mask 2", "Mask 3"],
type="index",
)
# selected_mask_output = gr.Image(show_label=False)
save_button = gr.Button("Save Selected Mask and Image")
iframe_display = gr.Markdown()
# Define the action triggered by the submit button
submit_button.click(
fn=prompter,
inputs=image_input,
outputs=[image_output_1, image_output_2, image_output_3, gr.State()],
show_progress=True,
)
# Define the action triggered by mask selection
mask_selector.change(
fn=select_mask,
inputs=[mask_selector, image_output_1, image_output_2, image_output_3],
outputs=gr.State(),
)
# Define the action triggered by the save button
save_button.click(
fn=save_selected_mask,
inputs=[gr.State(), gr.State()],
outputs=iframe_display,
show_progress=True,
)
with gr.Tab("Results"):
with gr.Row():
gr.HTML(
f"""
"""
)
# with gr.Column():
# source = gr.Textbox(label="Source Dataset")
# source_display = gr.Markdown()
# iframe_display = gr.HTML()
# source.change(
# save_dataset_name,
# inputs=(gr.State("source_dataset"), source),
# outputs=(source_display, iframe_display),
# )
# with gr.Column():
# destination = gr.Textbox(label="Destination Dataset")
# destination_display = gr.Markdown()
# destination.change(
# save_dataset_name,
# inputs=(gr.State("destination_dataset"), destination),
# outputs=destination_display,
# )
# Launch the Gradio app
demo.launch()