|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
import numpy as np |
|
from sam2 import build_sam2, SamPredictor |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
model_path = hf_hub_download(repo_id="facebook/sam2-hiera-large", filename="sam2_hiera_large.pth") |
|
|
|
|
|
device = "cpu" |
|
model = build_sam2(checkpoint=model_path).to(device) |
|
predictor = SamPredictor(model) |
|
|
|
def segment_image(input_image, x, y): |
|
|
|
input_image = np.array(input_image) |
|
|
|
|
|
predictor.set_image(input_image) |
|
|
|
|
|
input_point = np.array([[x, y]]) |
|
input_label = np.array([1]) |
|
|
|
|
|
masks, _, _ = predictor.predict( |
|
point_coords=input_point, |
|
point_labels=input_label, |
|
multimask_output=False, |
|
) |
|
|
|
|
|
mask = masks[0] |
|
mask_image = Image.fromarray((mask * 255).astype(np.uint8)) |
|
|
|
|
|
result = Image.composite(Image.fromarray(input_image), Image.new('RGB', mask_image.size, 'black'), mask_image) |
|
|
|
return result |
|
|
|
|
|
iface = gr.Interface( |
|
fn=segment_image, |
|
inputs=[ |
|
gr.Image(type="pil"), |
|
gr.Slider(0, 1000, label="X coordinate"), |
|
gr.Slider(0, 1000, label="Y coordinate") |
|
], |
|
outputs=gr.Image(type="pil"), |
|
title="SAM2 Image Segmentation", |
|
description="Upload an image and select a point to segment. Adjust X and Y coordinates to refine the selection." |
|
) |
|
|
|
|
|
iface.launch() |