File size: 4,918 Bytes
2cd9d38 b4e8f1d 2cd9d38 b4e8f1d 2cd9d38 b4e8f1d 2cd9d38 b4e8f1d 2cd9d38 b4e8f1d 2cd9d38 b4e8f1d 2cd9d38 b4e8f1d 2cd9d38 b4e8f1d 2cd9d38 b4e8f1d 2cd9d38 b4e8f1d 2cd9d38 b4e8f1d 2cd9d38 b4e8f1d 2cd9d38 bb8da79 2cd9d38 ad09938 2cd9d38 bb8da79 2cd9d38 ad09938 2cd9d38 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import gradio as gr
import numpy as np
import torch
from PIL import Image
from transformers import SamModel, SamProcessor
from gradio_image_prompter import ImagePrompter
import spaces
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to("cuda")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
slimsam_model = SamModel.from_pretrained("nielsr/slimsam-50-uniform").to("cuda")
slimsam_processor = SamProcessor.from_pretrained("nielsr/slimsam-50-uniform")
def get_processor_and_model(slim: bool):
if slim:
return slimsam_processor, slimsam_model
return sam_processor, sam_model
@spaces.GPU
def sam_box_inference(image, x_min, y_min, x_max, y_max, *, slim=False):
processor, model = get_processor_and_model(slim)
inputs = processor(
Image.fromarray(image),
input_boxes=[[[[x_min, y_min, x_max, y_max]]]],
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model(**inputs)
mask = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0][0].numpy()
mask = mask[np.newaxis, ...]
print(mask)
print(mask.shape)
return [(mask, "mask")]
@spaces.GPU
def sam_point_inference(image, x, y, *, slim=False):
processor, model = get_processor_and_model(slim)
inputs = processor(
image,
input_points=[[[x, y]]],
return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
mask = processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0][0].numpy()
mask = mask[np.newaxis, ...]
print(type(mask))
print(mask.shape)
return [(mask, "mask")]
def infer_point(img):
if img is None:
gr.Error("Please upload an image and select a point.")
if img["background"] is None:
gr.Error("Please upload an image and select a point.")
# background (original image) layers[0] ( point prompt) composite (total image)
image = img["background"].convert("RGB")
point_prompt = img["layers"][0]
total_image = img["composite"]
img_arr = np.array(point_prompt)
if not np.any(img_arr):
gr.Error("Please select a point on top of the image.")
else:
nonzero_indices = np.nonzero(img_arr)
img_arr = np.array(point_prompt)
nonzero_indices = np.nonzero(img_arr)
center_x = int(np.mean(nonzero_indices[1]))
center_y = int(np.mean(nonzero_indices[0]))
print("Point inference returned.")
return ((image, sam_point_inference(image, center_x, center_y, slim=True)),
(image, sam_point_inference(image, center_x, center_y)))
def infer_box(prompts):
# background (original image) layers[0] ( point prompt) composite (total image)
image = prompts["image"]
if image is None:
gr.Error("Please upload an image and draw a box before submitting")
points = prompts["points"][0]
if points is None:
gr.Error("Please draw a box before submitting.")
print(points)
# x_min = points[0] x_max = points[3] y_min = points[1] y_max = points[4]
return ((image, sam_box_inference(image, points[0], points[1], points[3], points[4], slim=True)),
(image, sam_box_inference(image, points[0], points[1], points[3], points[4])))
with gr.Blocks(title="SlimSAM") as demo:
gr.Markdown("# SlimSAM")
gr.Markdown("SlimSAM is the pruned-distilled version of SAM that is smaller.")
gr.Markdown("In this demo, you can compare SlimSAM and SAM outputs in point and box prompts.")
with gr.Tab("Box Prompt"):
with gr.Row():
with gr.Column(scale=1):
# Title
gr.Markdown("To try box prompting, simply upload and image and draw a box on it.")
with gr.Row():
with gr.Column():
im = ImagePrompter()
btn = gr.Button("Submit")
with gr.Column():
output_box_slimsam = gr.AnnotatedImage(label="SlimSAM Output")
output_box_sam = gr.AnnotatedImage(label="SAM Output")
btn.click(infer_box, inputs=im, outputs=[output_box_slimsam, output_box_sam])
with gr.Tab("Point Prompt"):
with gr.Row():
with gr.Column(scale=1):
# Title
gr.Markdown("To try point prompting, simply upload and image and leave a dot on it.")
with gr.Row():
with gr.Column():
im = gr.ImageEditor(
type="pil",
)
with gr.Column():
output_slimsam = gr.AnnotatedImage(label="SlimSAM Output")
output_sam = gr.AnnotatedImage(label="SAM Output")
im.change(infer_point, inputs=im, outputs=[output_slimsam, output_sam])
demo.launch(debug=True) |