amaye15 commited on
Commit
e34d5e8
β€’
1 Parent(s): d17cea3

App - V3 - Fully Complete

Browse files
Files changed (1) hide show
  1. app.py +139 -98
app.py CHANGED
@@ -5,11 +5,12 @@ import numpy as np
5
  from sam2.sam2_image_predictor import SAM2ImagePredictor
6
  from uuid import uuid4
7
  import os
8
- from huggingface_hub import upload_folder, login
9
  from PIL import Image as PILImage
10
  from datasets import Dataset, Features, Array2D, Image
11
  import shutil
12
- import time
 
13
 
14
  MODEL = "facebook/sam2-hiera-large"
15
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -17,7 +18,7 @@ PREDICTOR = SAM2ImagePredictor.from_pretrained(MODEL, device=DEVICE)
17
 
18
  DESTINATION_DS = "amaye15/object-segmentation"
19
 
20
- login(os.getenv("TOKEN"))
21
 
22
  IMAGE = None
23
  MASKS = None
@@ -25,6 +26,21 @@ MASKED_IMAGES = None
25
  INDEX = None
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def prompter(prompts):
29
 
30
  image = np.array(prompts["image"]) # Convert the image to a numpy array
@@ -116,114 +132,139 @@ def save_selected_mask(image, mask, output_dir="output"):
116
 
117
  shutil.rmtree(folder_path)
118
 
119
- iframe_code = "Success - Check out the 'Results' tab."
 
 
 
 
 
 
 
 
 
120
 
121
  return iframe_code
122
 
123
- # time.sleep(5)
124
 
125
- # # Add a random query parameter to force reload
126
- # random_param = uuid4()
127
- # iframe_code = f"""
128
- # <iframe
129
- # src="https://huggingface.co/datasets/{DESTINATION_DS}/embed/viewer/default/train"
130
- # frameborder="0"
131
- # width="100%"
132
- # height="560px"
133
- # ></iframe>
134
- # """
135
 
136
 
137
  # Define the Gradio Blocks app
138
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
- with gr.Tab("Object Segmentation - Point Prompt"):
141
- gr.Markdown("# Image Point Collector with Multiple Separate Mask Overlays")
142
- gr.Markdown(
143
- "Upload an image, click on it, and get each predicted mask overlaid separately in red on individual images."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  )
 
145
 
146
- with gr.Row():
147
- with gr.Column():
148
- # Input: ImagePrompter
149
- image_input = ImagePrompter(show_label=False)
150
- submit_button = gr.Button("Submit")
151
- with gr.Row():
152
- with gr.Column():
153
- # Outputs: Up to 3 overlay images
154
- image_output_1 = gr.Image(show_label=False)
155
- with gr.Column():
156
- image_output_2 = gr.Image(show_label=False)
157
- with gr.Column():
158
- image_output_3 = gr.Image(show_label=False)
159
-
160
- # Dropdown for selecting the correct mask
161
- with gr.Row():
162
- mask_selector = gr.Radio(
163
- label="Select the correct mask",
164
- choices=["Mask 1", "Mask 2", "Mask 3"],
165
- type="index",
166
- )
167
- # selected_mask_output = gr.Image(show_label=False)
168
-
169
- save_button = gr.Button("Save Selected Mask and Image")
170
- iframe_display = gr.Markdown()
171
-
172
- # Define the action triggered by the submit button
173
- submit_button.click(
174
- fn=prompter,
175
- inputs=image_input,
176
- outputs=[image_output_1, image_output_2, image_output_3, gr.State()],
177
- show_progress=True,
178
- )
179
 
180
- # Define the action triggered by mask selection
181
- mask_selector.change(
182
- fn=select_mask,
183
- inputs=[mask_selector, image_output_1, image_output_2, image_output_3],
184
- outputs=gr.State(),
185
- )
186
 
187
- # Define the action triggered by the save button
188
- save_button.click(
189
- fn=save_selected_mask,
190
- inputs=[gr.State(), gr.State()],
191
- outputs=iframe_display,
192
- show_progress=True,
193
- )
194
- with gr.Tab("Results"):
195
- with gr.Row():
196
- gr.HTML(
197
- f"""
198
- <iframe
199
- src="https://huggingface.co/datasets/{DESTINATION_DS}/embed/viewer/default/train"
200
- frameborder="0"
201
- width="100%"
202
- height="560px"
203
- ></iframe>
204
- """
205
- )
206
- # with gr.Column():
207
- # source = gr.Textbox(label="Source Dataset")
208
- # source_display = gr.Markdown()
209
- # iframe_display = gr.HTML()
210
-
211
- # source.change(
212
- # save_dataset_name,
213
- # inputs=(gr.State("source_dataset"), source),
214
- # outputs=(source_display, iframe_display),
215
- # )
216
-
217
- # with gr.Column():
218
-
219
- # destination = gr.Textbox(label="Destination Dataset")
220
- # destination_display = gr.Markdown()
221
-
222
- # destination.change(
223
- # save_dataset_name,
224
- # inputs=(gr.State("destination_dataset"), destination),
225
- # outputs=destination_display,
226
- # )
227
 
228
  # Launch the Gradio app
229
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from sam2.sam2_image_predictor import SAM2ImagePredictor
6
  from uuid import uuid4
7
  import os
8
+ from huggingface_hub import upload_folder
9
  from PIL import Image as PILImage
10
  from datasets import Dataset, Features, Array2D, Image
11
  import shutil
12
+ import random
13
+ from datasets import load_dataset
14
 
15
  MODEL = "facebook/sam2-hiera-large"
16
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
18
 
19
  DESTINATION_DS = "amaye15/object-segmentation"
20
 
21
+ # login(os.getenv("TOKEN"))
22
 
23
  IMAGE = None
24
  MASKS = None
 
26
  INDEX = None
27
 
28
 
29
+ ds_name = ["amaye15/product_labels"] # "amaye15/Products-10k", "amaye15/receipts"
30
+ choices = ["test", "train"]
31
+ max_len = None
32
+
33
+ ds_stream = load_dataset(random.choice(ds_name), streaming=True)
34
+
35
+
36
+ ds_split = ds_stream[random.choice(choices)]
37
+
38
+ ds_iter = ds_split.iter(batch_size=1)
39
+
40
+ for idx, val in enumerate(ds_iter):
41
+ max_len = idx
42
+
43
+
44
  def prompter(prompts):
45
 
46
  image = np.array(prompts["image"]) # Convert the image to a numpy array
 
132
 
133
  shutil.rmtree(folder_path)
134
 
135
+ iframe_code = """## Success! πŸŽ‰πŸ€–βœ…
136
+
137
+ You've successfully contributed to the dataset.
138
+
139
+ Please note that because new data has been added to the dataset, it may take a couple of minutes to render.
140
+
141
+ Check it out here:
142
+
143
+ [Object Segmentation Dataset](https://huggingface.co/datasets/amaye15/object-segmentation)
144
+ """
145
 
146
  return iframe_code
147
 
 
148
 
149
+ def get_random_image():
150
+ """Get a random image from the dataset."""
151
+ global max_len
152
+ random_idx = random.choice(range(max_len))
153
+ image_data = list(ds_split.skip(random_idx).take(1))[0]["pixel_values"]
154
+ formatted_image = {
155
+ "image": np.array(image_data),
156
+ "points": [],
157
+ } # Create the correct format
158
+ return formatted_image
159
 
160
 
161
  # Define the Gradio Blocks app
162
  with gr.Blocks() as demo:
163
+ gr.Markdown("# Object Segmentation- Image Point Collector and Mask Overlay Tool")
164
+ gr.Markdown(
165
+ """
166
+ This application utilizes **Segment Anything V2 (SAM2)** to allow you to upload an image or select a random image from a dataset and interactively generate segmentation masks based on multiple points you select on the image.
167
+
168
+ ### How It Works:
169
+ 1. **Upload or Select an Image**: You can either upload your own image or use a random image from the dataset.
170
+ 2. **Point Selection**: Click on the image to indicate points of interest. You can add multiple points, and these will be used collectively to generate segmentation masks using SAM2.
171
+ 3. **Mask Generation**: The app will generate up to three different segmentation masks for the selected points, each displayed separately with a red overlay.
172
+ 4. **Mask Selection**: Carefully review the generated masks and select the one that best fits your needs. **It's important to choose the correct mask, as your selection will be saved and used for further processing.**
173
+ 5. **Save and Contribute**: Save the selected mask along with the image to a dataset, contributing to a shared dataset on Hugging Face.
174
+
175
+ **Disclaimer**: All images and masks you work with will be collected and stored in a public dataset. Please ensure that you are comfortable with your selections and the data you provide before saving.
176
+
177
+ This tool is particularly useful for creating precise object segmentation masks for computer vision tasks, such as training models or generating labeled datasets.
178
+ """
179
+ )
180
 
181
+ with gr.Row():
182
+ with gr.Column():
183
+ image_input = gr.State()
184
+ # Input: ImagePrompter for uploaded image
185
+ upload_image_input = ImagePrompter(show_label=False)
186
+
187
+ random_image_button = gr.Button("Use Random Image")
188
+
189
+ submit_button = gr.Button("Submit")
190
+
191
+ with gr.Row():
192
+ with gr.Column():
193
+ # Outputs: Up to 3 overlay images
194
+ image_output_1 = gr.Image(show_label=False)
195
+ with gr.Column():
196
+ image_output_2 = gr.Image(show_label=False)
197
+ with gr.Column():
198
+ image_output_3 = gr.Image(show_label=False)
199
+
200
+ # Dropdown for selecting the correct mask
201
+ with gr.Row():
202
+ mask_selector = gr.Radio(
203
+ label="Select the correct mask",
204
+ choices=["Mask 1", "Mask 2", "Mask 3"],
205
+ type="index",
206
  )
207
+ # selected_mask_output = gr.Image(show_label=False)
208
 
209
+ save_button = gr.Button("Save Selected Mask and Image")
210
+ iframe_display = gr.Markdown()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
+ # Logic for the random image button
213
+ random_image_button.click(
214
+ fn=get_random_image,
215
+ inputs=None,
216
+ outputs=upload_image_input, # Pass the formatted random image to ImagePrompter
217
+ )
218
 
219
+ # Logic to use uploaded image
220
+ upload_image_input.change(
221
+ fn=lambda img: img, inputs=upload_image_input, outputs=image_input
222
+ )
223
+ # Define the action triggered by the submit button
224
+ submit_button.click(
225
+ fn=prompter,
226
+ inputs=upload_image_input, # The final image input (whether uploaded or random)
227
+ outputs=[image_output_1, image_output_2, image_output_3, gr.State()],
228
+ show_progress=True,
229
+ )
230
+
231
+ # Define the action triggered by mask selection
232
+ mask_selector.change(
233
+ fn=select_mask,
234
+ inputs=[mask_selector, image_output_1, image_output_2, image_output_3],
235
+ outputs=gr.State(),
236
+ )
237
+
238
+ # Define the action triggered by the save button
239
+ save_button.click(
240
+ fn=save_selected_mask,
241
+ inputs=[gr.State(), gr.State()],
242
+ outputs=iframe_display,
243
+ show_progress=True,
244
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  # Launch the Gradio app
247
  demo.launch()
248
+
249
+
250
+ # with gr.Column():
251
+ # source = gr.Textbox(label="Source Dataset")
252
+ # source_display = gr.Markdown()
253
+ # iframe_display = gr.HTML()
254
+
255
+ # source.change(
256
+ # save_dataset_name,
257
+ # inputs=(gr.State("source_dataset"), source),
258
+ # outputs=(source_display, iframe_display),
259
+ # )
260
+
261
+ # with gr.Column():
262
+
263
+ # destination = gr.Textbox(label="Destination Dataset")
264
+ # destination_display = gr.Markdown()
265
+
266
+ # destination.change(
267
+ # save_dataset_name,
268
+ # inputs=(gr.State("destination_dataset"), destination),
269
+ # outputs=destination_display,
270
+ # )