|
import gradio as gr |
|
from gradio_image_prompter import ImagePrompter |
|
import torch |
|
import numpy as np |
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
from uuid import uuid4 |
|
import os |
|
from huggingface_hub import upload_folder, login |
|
from PIL import Image as PILImage |
|
from datasets import Dataset, Features, Array2D, Image |
|
import shutil |
|
import time |
|
|
|
MODEL = "facebook/sam2-hiera-large" |
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
PREDICTOR = SAM2ImagePredictor.from_pretrained(MODEL, device=DEVICE) |
|
|
|
DESTINATION_DS = "amaye15/object-segmentation" |
|
|
|
|
|
|
|
IMAGE = None |
|
MASKS = None |
|
MASKED_IMAGES = None |
|
INDEX = None |
|
|
|
|
|
def prompter(prompts): |
|
|
|
image = np.array(prompts["image"]) |
|
points = prompts["points"] |
|
|
|
|
|
with torch.inference_mode(): |
|
PREDICTOR.set_image(image) |
|
input_point = [[point[0], point[1]] for point in points] |
|
input_label = [1] * len(points) |
|
masks, _, _ = PREDICTOR.predict( |
|
point_coords=input_point, point_labels=input_label, multimask_output=True |
|
) |
|
|
|
|
|
overlay_images = [] |
|
for i, mask in enumerate(masks): |
|
print(f"Predicted Mask {i+1}:", mask.shape) |
|
red_mask = np.zeros_like(image) |
|
red_mask[:, :, 0] = mask.astype(np.uint8) * 255 |
|
red_mask = PILImage.fromarray(red_mask) |
|
|
|
|
|
original_image = PILImage.fromarray(image) |
|
|
|
|
|
blended_image = PILImage.blend(original_image, red_mask, alpha=0.5) |
|
|
|
|
|
overlay_images.append(blended_image) |
|
|
|
global IMAGE, MASKS, MASKED_IMAGES |
|
IMAGE, MASKS = image, masks |
|
MASKED_IMAGES = [np.array(img) for img in overlay_images] |
|
|
|
return overlay_images[0], overlay_images[1], overlay_images[2], masks |
|
|
|
|
|
def select_mask( |
|
selected_mask_index, |
|
mask1, |
|
mask2, |
|
mask3, |
|
): |
|
masks = [mask1, mask2, mask3] |
|
global INDEX |
|
INDEX = selected_mask_index |
|
return masks[selected_mask_index] |
|
|
|
|
|
def save_selected_mask(image, mask, output_dir="output"): |
|
|
|
output_dir = os.path.join(os.getcwd(), output_dir) |
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
folder_id = str(uuid4()) |
|
|
|
folder_path = os.path.join(output_dir, folder_id) |
|
|
|
os.makedirs(folder_path, exist_ok=True) |
|
|
|
data_path = os.path.join(folder_path, "data.parquet") |
|
|
|
data = { |
|
"image": IMAGE, |
|
"masked_image": MASKED_IMAGES[INDEX], |
|
"mask": MASKS[INDEX], |
|
} |
|
|
|
features = Features( |
|
{ |
|
"image": Image(), |
|
"masked_image": Image(), |
|
"mask": Array2D( |
|
dtype="int64", shape=(MASKS[INDEX].shape[0], MASKS[INDEX].shape[1]) |
|
), |
|
} |
|
) |
|
|
|
ds = Dataset.from_list([data], features=features) |
|
ds.to_parquet(data_path) |
|
|
|
upload_folder( |
|
folder_path=output_dir, |
|
repo_id=DESTINATION_DS, |
|
repo_type="dataset", |
|
) |
|
|
|
shutil.rmtree(folder_path) |
|
|
|
iframe_code = "Success - Check out the 'Results' tab." |
|
|
|
return iframe_code |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
with gr.Tab("Object Segmentation - Point Prompt"): |
|
gr.Markdown("# Image Point Collector with Multiple Separate Mask Overlays") |
|
gr.Markdown( |
|
"Upload an image, click on it, and get each predicted mask overlaid separately in red on individual images." |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
image_input = ImagePrompter(show_label=False) |
|
submit_button = gr.Button("Submit") |
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
image_output_1 = gr.Image(show_label=False) |
|
with gr.Column(): |
|
image_output_2 = gr.Image(show_label=False) |
|
with gr.Column(): |
|
image_output_3 = gr.Image(show_label=False) |
|
|
|
|
|
with gr.Row(): |
|
mask_selector = gr.Radio( |
|
label="Select the correct mask", |
|
choices=["Mask 1", "Mask 2", "Mask 3"], |
|
type="index", |
|
) |
|
|
|
|
|
save_button = gr.Button("Save Selected Mask and Image") |
|
iframe_display = gr.Markdown() |
|
|
|
|
|
submit_button.click( |
|
fn=prompter, |
|
inputs=image_input, |
|
outputs=[image_output_1, image_output_2, image_output_3, gr.State()], |
|
show_progress=True, |
|
) |
|
|
|
|
|
mask_selector.change( |
|
fn=select_mask, |
|
inputs=[mask_selector, image_output_1, image_output_2, image_output_3], |
|
outputs=gr.State(), |
|
) |
|
|
|
|
|
save_button.click( |
|
fn=save_selected_mask, |
|
inputs=[gr.State(), gr.State()], |
|
outputs=iframe_display, |
|
show_progress=True, |
|
) |
|
with gr.Tab("Results"): |
|
with gr.Row(): |
|
gr.HTML( |
|
f""" |
|
<iframe |
|
src="https://huggingface.co/datasets/{DESTINATION_DS}/embed/viewer/default/train" |
|
frameborder="0" |
|
width="100%" |
|
height="560px" |
|
></iframe> |
|
""" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.launch() |
|
|