jhj0517 commited on
Commit
ed7d6af
2 Parent(s): 76d1f6d 84188d4

Merge branch 'master' into huggingface

Browse files
app.py CHANGED
@@ -35,7 +35,7 @@ class App:
35
  self.image_modes = [AUTOMATIC_MODE, BOX_PROMPT_MODE]
36
  self.default_mode = BOX_PROMPT_MODE
37
  self.filter_modes = [PIXELIZE_FILTER, COLOR_FILTER]
38
- self.default_filter = PIXELIZE_FILTER
39
  self.default_color = DEFAULT_COLOR
40
  self.default_pixel_size = DEFAULT_PIXEL_SIZE
41
  default_hparam_config_path = os.path.join(SAM2_CONFIGS_DIR, "default_hparams.yaml")
@@ -132,6 +132,7 @@ class App:
132
  nb_pixel_size = gr.Number(label="Pixel Size", interactive=True, minimum=1,
133
  visible=self.default_filter == PIXELIZE_FILTER,
134
  value=self.default_pixel_size)
 
135
  btn_generate_preview = gr.Button("GENERATE PREVIEW")
136
 
137
  with gr.Row():
@@ -157,7 +158,7 @@ class App:
157
  nb_pixel_size])
158
 
159
  preview_params = [vid_frame_prompter, dd_filter_mode, sld_frame_selector, nb_pixel_size,
160
- cp_color_picker]
161
  btn_generate_preview.click(fn=self.sam_inf.add_filter_to_preview,
162
  inputs=preview_params,
163
  outputs=[img_preview])
@@ -180,6 +181,7 @@ class App:
180
  choices=self.image_modes)
181
  dd_models = gr.Dropdown(label="Model", value=DEFAULT_MODEL_TYPE,
182
  choices=self.sam_inf.available_models)
 
183
 
184
  with gr.Accordion("Mask Parameters", open=False, visible=self.default_mode == AUTOMATIC_MODE) as acc_mask_hparams:
185
  mask_hparams_component = self.mask_generation_parameters(_mask_hparams)
@@ -194,10 +196,9 @@ class App:
194
  output_file = gr.File(label="Generated psd file", scale=9)
195
  btn_open_folder = gr.Button("📁\nOpen PSD folder", scale=1)
196
 
197
- sources = [img_input, img_input_prompter, dd_input_modes]
198
- model_params = [dd_models]
199
  mask_hparams = mask_hparams_component + [cb_multimask_output]
200
- input_params = sources + model_params + mask_hparams
201
 
202
  btn_generate.click(fn=self.sam_inf.divide_layer,
203
  inputs=input_params, outputs=[gallery_output, output_file])
 
35
  self.image_modes = [AUTOMATIC_MODE, BOX_PROMPT_MODE]
36
  self.default_mode = BOX_PROMPT_MODE
37
  self.filter_modes = [PIXELIZE_FILTER, COLOR_FILTER]
38
+ self.default_filter = COLOR_FILTER
39
  self.default_color = DEFAULT_COLOR
40
  self.default_pixel_size = DEFAULT_PIXEL_SIZE
41
  default_hparam_config_path = os.path.join(SAM2_CONFIGS_DIR, "default_hparams.yaml")
 
132
  nb_pixel_size = gr.Number(label="Pixel Size", interactive=True, minimum=1,
133
  visible=self.default_filter == PIXELIZE_FILTER,
134
  value=self.default_pixel_size)
135
+ cb_invert_mask = gr.Checkbox(label="invert mask", value=_mask_hparams["invert_mask"])
136
  btn_generate_preview = gr.Button("GENERATE PREVIEW")
137
 
138
  with gr.Row():
 
158
  nb_pixel_size])
159
 
160
  preview_params = [vid_frame_prompter, dd_filter_mode, sld_frame_selector, nb_pixel_size,
161
+ cp_color_picker, cb_invert_mask]
162
  btn_generate_preview.click(fn=self.sam_inf.add_filter_to_preview,
163
  inputs=preview_params,
164
  outputs=[img_preview])
 
181
  choices=self.image_modes)
182
  dd_models = gr.Dropdown(label="Model", value=DEFAULT_MODEL_TYPE,
183
  choices=self.sam_inf.available_models)
184
+ cb_invert_mask = gr.Checkbox(label="invert mask", value=_mask_hparams["invert_mask"])
185
 
186
  with gr.Accordion("Mask Parameters", open=False, visible=self.default_mode == AUTOMATIC_MODE) as acc_mask_hparams:
187
  mask_hparams_component = self.mask_generation_parameters(_mask_hparams)
 
196
  output_file = gr.File(label="Generated psd file", scale=9)
197
  btn_open_folder = gr.Button("📁\nOpen PSD folder", scale=1)
198
 
199
+ input_params = [img_input, img_input_prompter, dd_input_modes, dd_models, cb_invert_mask]
 
200
  mask_hparams = mask_hparams_component + [cb_multimask_output]
201
+ input_params += mask_hparams
202
 
203
  btn_generate.click(fn=self.sam_inf.divide_layer,
204
  inputs=input_params, outputs=[gallery_output, output_file])
configs/default_hparams.yaml CHANGED
@@ -10,3 +10,4 @@ mask_hparams:
10
  min_mask_region_area: 25.0
11
  use_m2m: true
12
  multimask_output: true
 
 
10
  min_mask_region_area: 25.0
11
  use_m2m: true
12
  multimask_output: true
13
+ invert_mask: false
modules/mask_utils.py CHANGED
@@ -17,6 +17,12 @@ def decode_to_mask(seg: np.ndarray[np.bool_] | np.ndarray[np.uint8]) -> np.ndarr
17
  return seg.astype(np.uint8)
18
 
19
 
 
 
 
 
 
 
20
  def generate_random_color() -> Tuple[int, int, int]:
21
  """Generate random color in RGB format"""
22
  h = np.random.randint(0, 360)
@@ -47,7 +53,6 @@ def create_mask_layers(
47
  List of RGBA images
48
  """
49
  layer_list = []
50
-
51
  sorted_masks = sorted(masks, key=lambda x: x['area'], reverse=True)
52
 
53
  for info in sorted_masks:
 
17
  return seg.astype(np.uint8)
18
 
19
 
20
+ def invert_masks(masks: List[Dict]) -> List[Dict]:
21
+ """Invert the masks. Used for background masking"""
22
+ inverted = 1 - masks
23
+ return inverted
24
+
25
+
26
  def generate_random_color() -> Tuple[int, int, int]:
27
  """Generate random color in RGB format"""
28
  h = np.random.randint(0, 360)
 
53
  List of RGBA images
54
  """
55
  layer_list = []
 
56
  sorted_masks = sorted(masks, key=lambda x: x['area'], reverse=True)
57
 
58
  for info in sorted_masks:
modules/sam_inference.py CHANGED
@@ -16,6 +16,7 @@ from modules.model_downloader import (
16
  from modules.paths import (MODELS_DIR, TEMP_OUT_DIR, TEMP_DIR, MODEL_CONFIGS, OUTPUT_DIR)
17
  from modules.constants import (BOX_PROMPT_MODE, AUTOMATIC_MODE, COLOR_FILTER, PIXELIZE_FILTER, IMAGE_FILE_EXT)
18
  from modules.mask_utils import (
 
19
  save_psd_with_masks,
20
  create_mask_combined_images,
21
  create_mask_gallery,
@@ -133,6 +134,7 @@ class SamInference:
133
  def generate_mask(self,
134
  image: np.ndarray,
135
  model_type: str,
 
136
  **params) -> List[Dict[str, Any]]:
137
  """
138
  Generate masks with Automatic segmentation. Default hyperparameters are in './configs/default_hparams.yaml.'
@@ -140,6 +142,7 @@ class SamInference:
140
  Args:
141
  image (np.ndarray): The input image.
142
  model_type (str): The model type to load.
 
143
  **params: The hyperparameters for the mask generator.
144
 
145
  Returns:
@@ -158,6 +161,11 @@ class SamInference:
158
  except Exception as e:
159
  logger.exception(f"Error while auto generating masks : {e}")
160
  raise RuntimeError(f"Failed to generate masks") from e
 
 
 
 
 
161
  return generated_masks
162
 
163
  def predict_image(self,
@@ -166,6 +174,7 @@ class SamInference:
166
  box: Optional[np.ndarray] = None,
167
  point_coords: Optional[np.ndarray] = None,
168
  point_labels: Optional[np.ndarray] = None,
 
169
  **params) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
170
  """
171
  Predict image with prompt data.
@@ -176,6 +185,7 @@ class SamInference:
176
  box (np.ndarray): The box prompt data.
177
  point_coords (np.ndarray): The point coordinates prompt data.
178
  point_labels (np.ndarray): The point labels prompt data.
 
179
  **params: The hyperparameters for the mask generator.
180
 
181
  Returns:
@@ -199,6 +209,10 @@ class SamInference:
199
  except Exception as e:
200
  logger.exception(f"Error while predicting image with prompt: {str(e)}")
201
  raise RuntimeError(f"Failed to predict image with prompt") from e
 
 
 
 
202
  return masks, scores, logits
203
 
204
  def add_prediction_to_frame(self,
@@ -295,6 +309,7 @@ class SamInference:
295
  frame_idx: int,
296
  pixel_size: Optional[int] = None,
297
  color_hex: Optional[str] = None,
 
298
  ):
299
  """
300
  Add filter to the preview image with the prompt data. Specially made for gradio app.
@@ -306,6 +321,7 @@ class SamInference:
306
  frame_idx (int): The frame index of the video.
307
  pixel_size (int): The pixel size for the pixelize filter.
308
  color_hex (str): The color hex code for the solid color filter.
 
309
 
310
  Returns:
311
  np.ndarray: The filtered image output.
@@ -336,6 +352,9 @@ class SamInference:
336
  box=box
337
  )
338
  masks = (logits[0] > 0.0).cpu().numpy()
 
 
 
339
  generated_masks = self.format_to_auto_result(masks)
340
 
341
  if filter_mode == COLOR_FILTER:
@@ -351,7 +370,8 @@ class SamInference:
351
  filter_mode: str,
352
  frame_idx: int,
353
  pixel_size: Optional[int] = None,
354
- color_hex: Optional[str] = None
 
355
  ):
356
  """
357
  Create a whole filtered video with video_inference_state. Currently only one frame tracking is supported.
@@ -363,6 +383,7 @@ class SamInference:
363
  frame_idx (int): The frame index of the video.
364
  pixel_size (int): The pixel size for the pixelize filter.
365
  color_hex (str): The color hex code for the solid color filter.
 
366
 
367
  Returns:
368
  str: The output video path.
@@ -394,12 +415,14 @@ class SamInference:
394
  inference_state=self.video_inference_state,
395
  points=point_coords,
396
  labels=point_labels,
397
- box=box
398
  )
399
 
400
  video_segments = self.propagate_in_video(inference_state=self.video_inference_state)
401
  for frame_index, info in video_segments.items():
402
  orig_image, masks = info["image"], info["mask"]
 
 
403
  masks = self.format_to_auto_result(masks)
404
 
405
  if filter_mode == COLOR_FILTER:
@@ -427,6 +450,7 @@ class SamInference:
427
  image_prompt_input_data: Dict,
428
  input_mode: str,
429
  model_type: str,
 
430
  *params):
431
  """
432
  Divide the layer with the given prompt data and save psd file.
@@ -436,6 +460,7 @@ class SamInference:
436
  image_prompt_input_data (Dict): The image prompt data.
437
  input_mode (str): The input mode for the image prompt data. ["Automatic", "Box Prompt"]
438
  model_type (str): The model type to load.
 
439
  *params: The hyperparameters for the mask generator.
440
 
441
  Returns:
@@ -467,6 +492,7 @@ class SamInference:
467
  generated_masks = self.generate_mask(
468
  image=image,
469
  model_type=model_type,
 
470
  **hparams
471
  )
472
 
@@ -485,7 +511,8 @@ class SamInference:
485
  box=box,
486
  point_coords=point_coords,
487
  point_labels=point_labels,
488
- multimask_output=hparams["multimask_output"]
 
489
  )
490
  generated_masks = self.format_to_auto_result(predicted_masks)
491
 
 
16
  from modules.paths import (MODELS_DIR, TEMP_OUT_DIR, TEMP_DIR, MODEL_CONFIGS, OUTPUT_DIR)
17
  from modules.constants import (BOX_PROMPT_MODE, AUTOMATIC_MODE, COLOR_FILTER, PIXELIZE_FILTER, IMAGE_FILE_EXT)
18
  from modules.mask_utils import (
19
+ invert_masks,
20
  save_psd_with_masks,
21
  create_mask_combined_images,
22
  create_mask_gallery,
 
134
  def generate_mask(self,
135
  image: np.ndarray,
136
  model_type: str,
137
+ invert_mask: bool = False,
138
  **params) -> List[Dict[str, Any]]:
139
  """
140
  Generate masks with Automatic segmentation. Default hyperparameters are in './configs/default_hparams.yaml.'
 
142
  Args:
143
  image (np.ndarray): The input image.
144
  model_type (str): The model type to load.
145
+ invert_mask (bool): Invert the mask output - used for background masking.
146
  **params: The hyperparameters for the mask generator.
147
 
148
  Returns:
 
161
  except Exception as e:
162
  logger.exception(f"Error while auto generating masks : {e}")
163
  raise RuntimeError(f"Failed to generate masks") from e
164
+
165
+ if invert_mask:
166
+ generated_masks = [{'segmentation': invert_masks(mask['segmentation']),
167
+ 'area': mask['area']} for mask in generated_masks]
168
+
169
  return generated_masks
170
 
171
  def predict_image(self,
 
174
  box: Optional[np.ndarray] = None,
175
  point_coords: Optional[np.ndarray] = None,
176
  point_labels: Optional[np.ndarray] = None,
177
+ invert_mask: bool = False,
178
  **params) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
179
  """
180
  Predict image with prompt data.
 
185
  box (np.ndarray): The box prompt data.
186
  point_coords (np.ndarray): The point coordinates prompt data.
187
  point_labels (np.ndarray): The point labels prompt data.
188
+ invert_mask (bool): Invert the mask output - used for background masking.
189
  **params: The hyperparameters for the mask generator.
190
 
191
  Returns:
 
209
  except Exception as e:
210
  logger.exception(f"Error while predicting image with prompt: {str(e)}")
211
  raise RuntimeError(f"Failed to predict image with prompt") from e
212
+
213
+ if invert_mask:
214
+ masks = invert_masks(masks)
215
+
216
  return masks, scores, logits
217
 
218
  def add_prediction_to_frame(self,
 
309
  frame_idx: int,
310
  pixel_size: Optional[int] = None,
311
  color_hex: Optional[str] = None,
312
+ invert_mask: bool = False
313
  ):
314
  """
315
  Add filter to the preview image with the prompt data. Specially made for gradio app.
 
321
  frame_idx (int): The frame index of the video.
322
  pixel_size (int): The pixel size for the pixelize filter.
323
  color_hex (str): The color hex code for the solid color filter.
324
+ invert_mask (bool): Invert the mask output - used for background masking.
325
 
326
  Returns:
327
  np.ndarray: The filtered image output.
 
352
  box=box
353
  )
354
  masks = (logits[0] > 0.0).cpu().numpy()
355
+ if invert_mask:
356
+ masks = invert_masks(masks)
357
+
358
  generated_masks = self.format_to_auto_result(masks)
359
 
360
  if filter_mode == COLOR_FILTER:
 
370
  filter_mode: str,
371
  frame_idx: int,
372
  pixel_size: Optional[int] = None,
373
+ color_hex: Optional[str] = None,
374
+ invert_mask: bool = False
375
  ):
376
  """
377
  Create a whole filtered video with video_inference_state. Currently only one frame tracking is supported.
 
383
  frame_idx (int): The frame index of the video.
384
  pixel_size (int): The pixel size for the pixelize filter.
385
  color_hex (str): The color hex code for the solid color filter.
386
+ invert_mask (bool): Invert the mask output - used for background masking.
387
 
388
  Returns:
389
  str: The output video path.
 
415
  inference_state=self.video_inference_state,
416
  points=point_coords,
417
  labels=point_labels,
418
+ box=box,
419
  )
420
 
421
  video_segments = self.propagate_in_video(inference_state=self.video_inference_state)
422
  for frame_index, info in video_segments.items():
423
  orig_image, masks = info["image"], info["mask"]
424
+ if invert_mask:
425
+ masks = invert_masks(masks)
426
  masks = self.format_to_auto_result(masks)
427
 
428
  if filter_mode == COLOR_FILTER:
 
450
  image_prompt_input_data: Dict,
451
  input_mode: str,
452
  model_type: str,
453
+ invert_mask: bool = False,
454
  *params):
455
  """
456
  Divide the layer with the given prompt data and save psd file.
 
460
  image_prompt_input_data (Dict): The image prompt data.
461
  input_mode (str): The input mode for the image prompt data. ["Automatic", "Box Prompt"]
462
  model_type (str): The model type to load.
463
+ invert_mask (bool): Invert the mask output.
464
  *params: The hyperparameters for the mask generator.
465
 
466
  Returns:
 
492
  generated_masks = self.generate_mask(
493
  image=image,
494
  model_type=model_type,
495
+ invert_mask=invert_mask,
496
  **hparams
497
  )
498
 
 
511
  box=box,
512
  point_coords=point_coords,
513
  point_labels=point_labels,
514
+ multimask_output=hparams["multimask_output"],
515
+ invert_mask=invert_mask
516
  )
517
  generated_masks = self.format_to_auto_result(predicted_masks)
518