jhj0517 commited on
Commit
17abb6a
1 Parent(s): a81c70a

Add docstring

Browse files
Files changed (1) hide show
  1. modules/sam_inference.py +121 -4
modules/sam_inference.py CHANGED
@@ -1,7 +1,7 @@
1
  from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
2
  from sam2.build_sam import build_sam2, build_sam2_video_predictor
3
  from sam2.sam2_image_predictor import SAM2ImagePredictor
4
- from typing import Dict, List, Optional
5
  import torch
6
  import os
7
  from datetime import datetime
@@ -52,6 +52,13 @@ class SamInference:
52
  def load_model(self,
53
  model_type: Optional[str] = None,
54
  load_video_predictor: bool = False):
 
 
 
 
 
 
 
55
  if model_type is None:
56
  model_type = DEFAULT_MODEL_TYPE
57
 
@@ -90,6 +97,13 @@ class SamInference:
90
  def init_video_inference_state(self,
91
  vid_input: str,
92
  model_type: Optional[str] = None):
 
 
 
 
 
 
 
93
  if model_type is None:
94
  model_type = self.current_model_type
95
 
@@ -113,7 +127,19 @@ class SamInference:
113
  def generate_mask(self,
114
  image: np.ndarray,
115
  model_type: str,
116
- **params):
 
 
 
 
 
 
 
 
 
 
 
 
117
  if self.model is None or self.current_model_type != model_type:
118
  self.current_model_type = model_type
119
  self.load_model(model_type=model_type)
@@ -134,7 +160,23 @@ class SamInference:
134
  box: Optional[np.ndarray] = None,
135
  point_coords: Optional[np.ndarray] = None,
136
  point_labels: Optional[np.ndarray] = None,
137
- **params):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  if self.model is None or self.current_model_type != model_type:
139
  self.current_model_type = model_type
140
  self.load_model(model_type=model_type)
@@ -159,7 +201,24 @@ class SamInference:
159
  inference_state: Optional[Dict] = None,
160
  points: Optional[np.ndarray] = None,
161
  labels: Optional[np.ndarray] = None,
162
- box: Optional[np.ndarray] = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  if (self.video_predictor is None or
164
  inference_state is None and self.video_inference_state is None):
165
  logger.exception("Error while predicting frame from video, load video predictor first")
@@ -184,6 +243,18 @@ class SamInference:
184
 
185
  def propagate_in_video(self,
186
  inference_state: Optional[Dict] = None,):
 
 
 
 
 
 
 
 
 
 
 
 
187
  if inference_state is None and self.video_inference_state is None:
188
  logger.exception("Error while propagating in video, load video predictor first")
189
 
@@ -219,6 +290,20 @@ class SamInference:
219
  pixel_size: Optional[int] = None,
220
  color_hex: Optional[str] = None,
221
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  if self.video_predictor is None or self.video_inference_state is None:
223
  logger.exception("Error while adding filter to preview, load video predictor first")
224
  raise f"Error while adding filter to preview"
@@ -262,6 +347,22 @@ class SamInference:
262
  pixel_size: Optional[int] = None,
263
  color_hex: Optional[str] = None
264
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  if self.video_predictor is None or self.video_inference_state is None:
266
  logger.exception("Error while adding filter to preview, load video predictor first")
267
  raise RuntimeError("Error while adding filter to preview")
@@ -321,6 +422,21 @@ class SamInference:
321
  input_mode: str,
322
  model_type: str,
323
  *params):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  timestamp = datetime.now().strftime("%m%d%H%M%S")
325
  output_file_name = f"result-{timestamp}.psd"
326
  output_path = os.path.join(self.output_dir, "psd", output_file_name)
@@ -378,6 +494,7 @@ class SamInference:
378
  def format_to_auto_result(
379
  masks: np.ndarray
380
  ):
 
381
  place_holder = 0
382
  if len(masks.shape) <= 3:
383
  masks = np.expand_dims(masks, axis=0)
 
1
  from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
2
  from sam2.build_sam import build_sam2, build_sam2_video_predictor
3
  from sam2.sam2_image_predictor import SAM2ImagePredictor
4
+ from typing import Dict, List, Optional, Tuple, Any
5
  import torch
6
  import os
7
  from datetime import datetime
 
52
  def load_model(self,
53
  model_type: Optional[str] = None,
54
  load_video_predictor: bool = False):
55
+ """
56
+ Load the model from the model directory. If the model is not found, download it from the URL.
57
+
58
+ Args:
59
+ model_type (str): The model type to load.
60
+ load_video_predictor (bool): Load the video predictor model.
61
+ """
62
  if model_type is None:
63
  model_type = DEFAULT_MODEL_TYPE
64
 
 
97
  def init_video_inference_state(self,
98
  vid_input: str,
99
  model_type: Optional[str] = None):
100
+ """
101
+ Initialize the video inference state for the video predictor.
102
+
103
+ Args:
104
+ vid_input (str): The video frames directory.
105
+ model_type (str): The model type to load.
106
+ """
107
  if model_type is None:
108
  model_type = self.current_model_type
109
 
 
127
  def generate_mask(self,
128
  image: np.ndarray,
129
  model_type: str,
130
+ **params) -> List[Dict[str, Any]]:
131
+ """
132
+ Generate masks with Automatic segmentation. Default hyperparameters are in './configs/default_hparams.yaml.'
133
+
134
+ Args:
135
+ image (np.ndarray): The input image.
136
+ model_type (str): The model type to load.
137
+ **params: The hyperparameters for the mask generator.
138
+
139
+ Returns:
140
+ List[Dict[str, Any]]: The auto-generated mask data.
141
+ """
142
+
143
  if self.model is None or self.current_model_type != model_type:
144
  self.current_model_type = model_type
145
  self.load_model(model_type=model_type)
 
160
  box: Optional[np.ndarray] = None,
161
  point_coords: Optional[np.ndarray] = None,
162
  point_labels: Optional[np.ndarray] = None,
163
+ **params) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
164
+ """
165
+ Predict image with prompt data.
166
+
167
+ Args:
168
+ image (np.ndarray): The input image.
169
+ model_type (str): The model type to load.
170
+ box (np.ndarray): The box prompt data.
171
+ point_coords (np.ndarray): The point coordinates prompt data.
172
+ point_labels (np.ndarray): The point labels prompt data.
173
+ **params: The hyperparameters for the mask generator.
174
+
175
+ Returns:
176
+ np.ndarray: The predicted masks output in CxHxW format.
177
+ np.ndarray: Array of scores for each mask.
178
+ np.ndarray: Array of logits in CxHxW format.
179
+ """
180
  if self.model is None or self.current_model_type != model_type:
181
  self.current_model_type = model_type
182
  self.load_model(model_type=model_type)
 
201
  inference_state: Optional[Dict] = None,
202
  points: Optional[np.ndarray] = None,
203
  labels: Optional[np.ndarray] = None,
204
+ box: Optional[np.ndarray] = None) -> Tuple[int, int, torch.Tensor]:
205
+ """
206
+ Add prediction to the current video inference state. inference state must be initialized before calling this method.
207
+
208
+ Args:
209
+ frame_idx (int): The frame index of the video.
210
+ obj_id (int): The object id for the frame.
211
+ inference_state (Dict): The inference state for the video predictor.
212
+ points (np.ndarray): The point coordinates prompt data.
213
+ labels (np.ndarray): The point labels prompt data.
214
+ box (np.ndarray): The box prompt data.
215
+
216
+ Returns:
217
+ int: The frame index of the corresponding prediction.
218
+ int: The object id of the corresponding prediction.
219
+ torch.Tensor: The mask logits output in CxHxW format.
220
+ """
221
+
222
  if (self.video_predictor is None or
223
  inference_state is None and self.video_inference_state is None):
224
  logger.exception("Error while predicting frame from video, load video predictor first")
 
243
 
244
  def propagate_in_video(self,
245
  inference_state: Optional[Dict] = None,):
246
+ """
247
+ Propagate in the video with the tracked predictions for each frame. Currently only supports
248
+ single frame tracking.
249
+
250
+ Args:
251
+ inference_state (Dict): The inference state for the video predictor. Use self.video_inference_state if None.
252
+
253
+ Returns:
254
+ Dict: The video segments with the image and mask data. It has frame index as each key and each key has
255
+ "image" and "mask" data. "image" key contains the path of the original image file and "mask" key contains
256
+ the np.ndarray mask output.
257
+ """
258
  if inference_state is None and self.video_inference_state is None:
259
  logger.exception("Error while propagating in video, load video predictor first")
260
 
 
290
  pixel_size: Optional[int] = None,
291
  color_hex: Optional[str] = None,
292
  ):
293
+ """
294
+ Add filter to the preview image with the prompt data. Specially made for gradio app.
295
+ It adds prediction tracking to the self.video_inference_state and returns the filtered image.
296
+
297
+ Args:
298
+ image_prompt_input_data (Dict): The image prompt data.
299
+ filter_mode (str): The filter mode to apply. ["Solid Color", "Pixelize"]
300
+ frame_idx (int): The frame index of the video.
301
+ pixel_size (int): The pixel size for the pixelize filter.
302
+ color_hex (str): The color hex code for the solid color filter.
303
+
304
+ Returns:
305
+ np.ndarray: The filtered image output.
306
+ """
307
  if self.video_predictor is None or self.video_inference_state is None:
308
  logger.exception("Error while adding filter to preview, load video predictor first")
309
  raise f"Error while adding filter to preview"
 
347
  pixel_size: Optional[int] = None,
348
  color_hex: Optional[str] = None
349
  ):
350
+ """
351
+ Create a whole filtered video with video_inference_state. Currently only one frame tracking is supported.
352
+ This needs FFmpeg to run. Returns two output path because of the gradio app.
353
+
354
+ Args:
355
+ image_prompt_input_data (Dict): The image prompt data.
356
+ filter_mode (str): The filter mode to apply. ["Solid Color", "Pixelize"]
357
+ frame_idx (int): The frame index of the video.
358
+ pixel_size (int): The pixel size for the pixelize filter.
359
+ color_hex (str): The color hex code for the solid color filter.
360
+
361
+ Returns:
362
+ str: The output video path.
363
+ str: The output video path.
364
+ """
365
+
366
  if self.video_predictor is None or self.video_inference_state is None:
367
  logger.exception("Error while adding filter to preview, load video predictor first")
368
  raise RuntimeError("Error while adding filter to preview")
 
422
  input_mode: str,
423
  model_type: str,
424
  *params):
425
+ """
426
+ Divide the layer with the given prompt data and save psd file.
427
+
428
+ Args:
429
+ image_input (np.ndarray): The input image.
430
+ image_prompt_input_data (Dict): The image prompt data.
431
+ input_mode (str): The input mode for the image prompt data. ["Automatic", "Box Prompt"]
432
+ model_type (str): The model type to load.
433
+ *params: The hyperparameters for the mask generator.
434
+
435
+ Returns:
436
+ List[np.ndarray]: List of images by predicted masks.
437
+ str: The output path of the psd file.
438
+ """
439
+
440
  timestamp = datetime.now().strftime("%m%d%H%M%S")
441
  output_file_name = f"result-{timestamp}.psd"
442
  output_path = os.path.join(self.output_dir, "psd", output_file_name)
 
494
  def format_to_auto_result(
495
  masks: np.ndarray
496
  ):
497
+ """Format the masks to auto result format for convenience."""
498
  place_holder = 0
499
  if len(masks.shape) <= 3:
500
  masks = np.expand_dims(masks, axis=0)