File size: 1,677 Bytes
d25c63d
 
 
 
bde346e
d25c63d
 
 
 
 
bde346e
 
 
 
d25c63d
 
bde346e
 
d25c63d
 
 
 
 
 
 
 
 
bde346e
 
 
 
 
d25c63d
 
bde346e
d25c63d
 
 
bde346e
d25c63d
 
 
 
 
 
 
bde346e
d25c63d
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gradio as gr
import torch
from PIL import Image
import numpy as np
from sam2 import build_sam2, SamPredictor
from huggingface_hub import hf_hub_download

# Download the model weights
model_path = hf_hub_download(repo_id="facebook/sam2-hiera-large", filename="sam2_hiera_large.pth")

# Initialize the SAM2 model
device = "cpu"  # Use CPU
model = build_sam2(checkpoint=model_path).to(device)
predictor = SamPredictor(model)

def segment_image(input_image, x, y):
    # Convert gradio image to numpy array
    input_image = np.array(input_image)
    
    # Prepare the image for the model
    predictor.set_image(input_image)
    
    # Prepare the prompt (point)
    input_point = np.array([[x, y]])
    input_label = np.array([1])  # 1 for foreground

    # Generate the mask
    masks, _, _ = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=False,
    )

    # Convert the mask to an image
    mask = masks[0]
    mask_image = Image.fromarray((mask * 255).astype(np.uint8))
    
    # Apply the mask to the original image
    result = Image.composite(Image.fromarray(input_image), Image.new('RGB', mask_image.size, 'black'), mask_image)
    
    return result

# Create the Gradio interface
iface = gr.Interface(
    fn=segment_image,
    inputs=[
        gr.Image(type="pil"),
        gr.Slider(0, 1000, label="X coordinate"),
        gr.Slider(0, 1000, label="Y coordinate")
    ],
    outputs=gr.Image(type="pil"),
    title="SAM2 Image Segmentation",
    description="Upload an image and select a point to segment. Adjust X and Y coordinates to refine the selection."
)

# Launch the app
iface.launch()