import gradio as gr import torch from PIL import Image import numpy as np from sam2 import build_sam2, SamPredictor from huggingface_hub import hf_hub_download # Download the model weights model_path = hf_hub_download(repo_id="facebook/sam2-hiera-large", filename="sam2_hiera_large.pth") # Initialize the SAM2 model device = "cpu" # Use CPU model = build_sam2(checkpoint=model_path).to(device) predictor = SamPredictor(model) def segment_image(input_image, x, y): # Convert gradio image to numpy array input_image = np.array(input_image) # Prepare the image for the model predictor.set_image(input_image) # Prepare the prompt (point) input_point = np.array([[x, y]]) input_label = np.array([1]) # 1 for foreground # Generate the mask masks, _, _ = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=False, ) # Convert the mask to an image mask = masks[0] mask_image = Image.fromarray((mask * 255).astype(np.uint8)) # Apply the mask to the original image result = Image.composite(Image.fromarray(input_image), Image.new('RGB', mask_image.size, 'black'), mask_image) return result # Create the Gradio interface iface = gr.Interface( fn=segment_image, inputs=[ gr.Image(type="pil"), gr.Slider(0, 1000, label="X coordinate"), gr.Slider(0, 1000, label="Y coordinate") ], outputs=gr.Image(type="pil"), title="SAM2 Image Segmentation", description="Upload an image and select a point to segment. Adjust X and Y coordinates to refine the selection." ) # Launch the app iface.launch()