amaye15
commited on
Commit
•
933c40c
1
Parent(s):
7e76bff
Sam 2 point prompt
Browse files- app.py +170 -9
- requirements.txt +3 -1
app.py
CHANGED
@@ -1,21 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
from gradio_image_prompter import ImagePrompter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
# Define the Gradio interface
|
5 |
demo = gr.Interface(
|
6 |
-
fn=
|
7 |
-
prompts["image"],
|
8 |
-
prompts["points"],
|
9 |
-
), # Extract image and points from the ImagePrompter
|
10 |
inputs=ImagePrompter(
|
11 |
show_label=False
|
12 |
), # ImagePrompter for image input and point selection
|
13 |
outputs=[
|
14 |
-
gr.Image(show_label=False)
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
description="Upload an image, click on it, and get the coordinates of the clicked points.",
|
19 |
)
|
20 |
|
21 |
# Launch the Gradio app
|
|
|
1 |
+
# import gradio as gr
|
2 |
+
# from gradio_image_prompter import ImagePrompter
|
3 |
+
|
4 |
+
# import os
|
5 |
+
# import torch
|
6 |
+
|
7 |
+
|
8 |
+
# def prompter(prompts):
|
9 |
+
# image = prompts["image"] # Get the image from prompts
|
10 |
+
# points = prompts["points"] # Get the points from prompts
|
11 |
+
|
12 |
+
# # Print the collected inputs for debugging or logging
|
13 |
+
# print("Image received:", image)
|
14 |
+
# print("Points received:", points)
|
15 |
+
|
16 |
+
# import torch
|
17 |
+
# from sam2.sam2_image_predictor import SAM2ImagePredictor
|
18 |
+
|
19 |
+
# device = torch.device("cpu")
|
20 |
+
|
21 |
+
# predictor = SAM2ImagePredictor.from_pretrained(
|
22 |
+
# "facebook/sam2-hiera-base-plus", device=device
|
23 |
+
# )
|
24 |
+
|
25 |
+
# with torch.inference_mode():
|
26 |
+
# predictor.set_image(image)
|
27 |
+
# # masks, _, _ = predictor.predict([[point[0], point[1]] for point in points])
|
28 |
+
# input_point = [[point[0], point[1]] for point in points]
|
29 |
+
# input_label = [1]
|
30 |
+
# masks, _, _ = predictor.predict(
|
31 |
+
# point_coords=input_point, point_labels=input_label
|
32 |
+
# )
|
33 |
+
# print("Predicted Mask:", masks)
|
34 |
+
|
35 |
+
# return image, points
|
36 |
+
|
37 |
+
|
38 |
+
# # Define the Gradio interface
|
39 |
+
# demo = gr.Interface(
|
40 |
+
# fn=prompter, # Use the custom prompter function
|
41 |
+
# inputs=ImagePrompter(
|
42 |
+
# show_label=False
|
43 |
+
# ), # ImagePrompter for image input and point selection
|
44 |
+
# outputs=[
|
45 |
+
# gr.Image(show_label=False), # Display the image
|
46 |
+
# gr.Dataframe(label="Points"), # Display the points in a DataFrame
|
47 |
+
# ],
|
48 |
+
# title="Image Point Collector",
|
49 |
+
# description="Upload an image, click on it, and get the coordinates of the clicked points.",
|
50 |
+
# )
|
51 |
+
|
52 |
+
# # Launch the Gradio app
|
53 |
+
# demo.launch()
|
54 |
+
|
55 |
+
|
56 |
+
# import gradio as gr
|
57 |
+
# from gradio_image_prompter import ImagePrompter
|
58 |
+
# import torch
|
59 |
+
# from sam2.sam2_image_predictor import SAM2ImagePredictor
|
60 |
+
|
61 |
+
|
62 |
+
# def prompter(prompts):
|
63 |
+
# image = prompts["image"] # Get the image from prompts
|
64 |
+
# points = prompts["points"] # Get the points from prompts
|
65 |
+
|
66 |
+
# # Print the collected inputs for debugging or logging
|
67 |
+
# print("Image received:", image)
|
68 |
+
# print("Points received:", points)
|
69 |
+
|
70 |
+
# device = torch.device("cpu")
|
71 |
+
|
72 |
+
# # Load the SAM2ImagePredictor model
|
73 |
+
# predictor = SAM2ImagePredictor.from_pretrained(
|
74 |
+
# "facebook/sam2-hiera-base-plus", device=device
|
75 |
+
# )
|
76 |
+
|
77 |
+
# # Perform inference
|
78 |
+
# with torch.inference_mode():
|
79 |
+
# predictor.set_image(image)
|
80 |
+
# input_point = [[point[0], point[1]] for point in points]
|
81 |
+
# input_label = [1] * len(points) # Assuming all points are foreground
|
82 |
+
# masks, _, _ = predictor.predict(
|
83 |
+
# point_coords=input_point, point_labels=input_label
|
84 |
+
# )
|
85 |
+
|
86 |
+
# # The masks are returned as a list of numpy arrays
|
87 |
+
# print("Predicted Mask:", masks)
|
88 |
+
|
89 |
+
# # Assuming there's only one mask returned, you can adjust if there are multiple
|
90 |
+
# predicted_mask = masks[0]
|
91 |
+
|
92 |
+
# print(len(image))
|
93 |
+
|
94 |
+
# print(len(predicted_mask))
|
95 |
+
|
96 |
+
# # Create annotations for AnnotatedImage
|
97 |
+
# annotations = [(predicted_mask, "Predicted Mask")]
|
98 |
+
|
99 |
+
# return image, annotations
|
100 |
+
|
101 |
+
|
102 |
+
# # Define the Gradio interface
|
103 |
+
# demo = gr.Interface(
|
104 |
+
# fn=prompter, # Use the custom prompter function
|
105 |
+
# inputs=ImagePrompter(
|
106 |
+
# show_label=False
|
107 |
+
# ), # ImagePrompter for image input and point selection
|
108 |
+
# outputs=gr.AnnotatedImage(), # Display the image with the predicted mask
|
109 |
+
# title="Image Point Collector with Mask Overlay",
|
110 |
+
# description="Upload an image, click on it, and get the predicted mask overlayed on the image.",
|
111 |
+
# )
|
112 |
+
|
113 |
+
# # Launch the Gradio app
|
114 |
+
# demo.launch()
|
115 |
+
|
116 |
+
|
117 |
import gradio as gr
|
118 |
from gradio_image_prompter import ImagePrompter
|
119 |
+
import torch
|
120 |
+
import numpy as np
|
121 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
122 |
+
from PIL import Image
|
123 |
+
|
124 |
+
|
125 |
+
def prompter(prompts):
|
126 |
+
image = np.array(prompts["image"]) # Convert the image to a numpy array
|
127 |
+
points = prompts["points"] # Get the points from prompts
|
128 |
+
|
129 |
+
# Print the collected inputs for debugging or logging
|
130 |
+
print("Image received:", image)
|
131 |
+
print("Points received:", points)
|
132 |
+
|
133 |
+
device = torch.device("cpu")
|
134 |
+
|
135 |
+
# Load the SAM2ImagePredictor model
|
136 |
+
predictor = SAM2ImagePredictor.from_pretrained(
|
137 |
+
"facebook/sam2-hiera-base-plus", device=device
|
138 |
+
)
|
139 |
+
|
140 |
+
# Perform inference with multimask_output=True
|
141 |
+
with torch.inference_mode():
|
142 |
+
predictor.set_image(image)
|
143 |
+
input_point = [[point[0], point[1]] for point in points]
|
144 |
+
input_label = [1] * len(points) # Assuming all points are foreground
|
145 |
+
masks, _, _ = predictor.predict(
|
146 |
+
point_coords=input_point, point_labels=input_label, multimask_output=True
|
147 |
+
)
|
148 |
+
|
149 |
+
# Prepare individual images with separate overlays
|
150 |
+
overlay_images = []
|
151 |
+
for i, mask in enumerate(masks):
|
152 |
+
print(f"Predicted Mask {i+1}:", mask)
|
153 |
+
red_mask = np.zeros_like(image)
|
154 |
+
red_mask[:, :, 0] = mask.astype(np.uint8) * 255 # Apply the red channel
|
155 |
+
red_mask = Image.fromarray(red_mask)
|
156 |
+
|
157 |
+
# Convert the original image to a PIL image
|
158 |
+
original_image = Image.fromarray(image)
|
159 |
+
|
160 |
+
# Blend the original image with the red mask
|
161 |
+
blended_image = Image.blend(original_image, red_mask, alpha=0.5)
|
162 |
+
|
163 |
+
# Add the blended image to the list
|
164 |
+
overlay_images.append(blended_image)
|
165 |
+
|
166 |
+
return overlay_images
|
167 |
+
|
168 |
|
169 |
# Define the Gradio interface
|
170 |
demo = gr.Interface(
|
171 |
+
fn=prompter, # Use the custom prompter function
|
|
|
|
|
|
|
172 |
inputs=ImagePrompter(
|
173 |
show_label=False
|
174 |
), # ImagePrompter for image input and point selection
|
175 |
outputs=[
|
176 |
+
gr.Image(show_label=False) for _ in range(3)
|
177 |
+
], # Display up to 3 overlay images
|
178 |
+
title="Image Point Collector with Multiple Separate Mask Overlays",
|
179 |
+
description="Upload an image, click on it, and get each predicted mask overlaid separately in red on individual images.",
|
|
|
180 |
)
|
181 |
|
182 |
# Launch the Gradio app
|
requirements.txt
CHANGED
@@ -1,3 +1,5 @@
|
|
1 |
gradio
|
2 |
gradio-image-prompter
|
3 |
-
Pillow
|
|
|
|
|
|
1 |
gradio
|
2 |
gradio-image-prompter
|
3 |
+
Pillow
|
4 |
+
opencv-python
|
5 |
+
git+https://github.com/facebookresearch/segment-anything-2.git
|