Spaces:
Sleeping
Sleeping
jhj0517
commited on
Commit
•
17abb6a
1
Parent(s):
a81c70a
Add docstring
Browse files- modules/sam_inference.py +121 -4
modules/sam_inference.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
2 |
from sam2.build_sam import build_sam2, build_sam2_video_predictor
|
3 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
4 |
-
from typing import Dict, List, Optional
|
5 |
import torch
|
6 |
import os
|
7 |
from datetime import datetime
|
@@ -52,6 +52,13 @@ class SamInference:
|
|
52 |
def load_model(self,
|
53 |
model_type: Optional[str] = None,
|
54 |
load_video_predictor: bool = False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
if model_type is None:
|
56 |
model_type = DEFAULT_MODEL_TYPE
|
57 |
|
@@ -90,6 +97,13 @@ class SamInference:
|
|
90 |
def init_video_inference_state(self,
|
91 |
vid_input: str,
|
92 |
model_type: Optional[str] = None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
if model_type is None:
|
94 |
model_type = self.current_model_type
|
95 |
|
@@ -113,7 +127,19 @@ class SamInference:
|
|
113 |
def generate_mask(self,
|
114 |
image: np.ndarray,
|
115 |
model_type: str,
|
116 |
-
**params):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
if self.model is None or self.current_model_type != model_type:
|
118 |
self.current_model_type = model_type
|
119 |
self.load_model(model_type=model_type)
|
@@ -134,7 +160,23 @@ class SamInference:
|
|
134 |
box: Optional[np.ndarray] = None,
|
135 |
point_coords: Optional[np.ndarray] = None,
|
136 |
point_labels: Optional[np.ndarray] = None,
|
137 |
-
**params):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
if self.model is None or self.current_model_type != model_type:
|
139 |
self.current_model_type = model_type
|
140 |
self.load_model(model_type=model_type)
|
@@ -159,7 +201,24 @@ class SamInference:
|
|
159 |
inference_state: Optional[Dict] = None,
|
160 |
points: Optional[np.ndarray] = None,
|
161 |
labels: Optional[np.ndarray] = None,
|
162 |
-
box: Optional[np.ndarray] = None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
if (self.video_predictor is None or
|
164 |
inference_state is None and self.video_inference_state is None):
|
165 |
logger.exception("Error while predicting frame from video, load video predictor first")
|
@@ -184,6 +243,18 @@ class SamInference:
|
|
184 |
|
185 |
def propagate_in_video(self,
|
186 |
inference_state: Optional[Dict] = None,):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
if inference_state is None and self.video_inference_state is None:
|
188 |
logger.exception("Error while propagating in video, load video predictor first")
|
189 |
|
@@ -219,6 +290,20 @@ class SamInference:
|
|
219 |
pixel_size: Optional[int] = None,
|
220 |
color_hex: Optional[str] = None,
|
221 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
if self.video_predictor is None or self.video_inference_state is None:
|
223 |
logger.exception("Error while adding filter to preview, load video predictor first")
|
224 |
raise f"Error while adding filter to preview"
|
@@ -262,6 +347,22 @@ class SamInference:
|
|
262 |
pixel_size: Optional[int] = None,
|
263 |
color_hex: Optional[str] = None
|
264 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
if self.video_predictor is None or self.video_inference_state is None:
|
266 |
logger.exception("Error while adding filter to preview, load video predictor first")
|
267 |
raise RuntimeError("Error while adding filter to preview")
|
@@ -321,6 +422,21 @@ class SamInference:
|
|
321 |
input_mode: str,
|
322 |
model_type: str,
|
323 |
*params):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
325 |
output_file_name = f"result-{timestamp}.psd"
|
326 |
output_path = os.path.join(self.output_dir, "psd", output_file_name)
|
@@ -378,6 +494,7 @@ class SamInference:
|
|
378 |
def format_to_auto_result(
|
379 |
masks: np.ndarray
|
380 |
):
|
|
|
381 |
place_holder = 0
|
382 |
if len(masks.shape) <= 3:
|
383 |
masks = np.expand_dims(masks, axis=0)
|
|
|
1 |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
2 |
from sam2.build_sam import build_sam2, build_sam2_video_predictor
|
3 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
4 |
+
from typing import Dict, List, Optional, Tuple, Any
|
5 |
import torch
|
6 |
import os
|
7 |
from datetime import datetime
|
|
|
52 |
def load_model(self,
|
53 |
model_type: Optional[str] = None,
|
54 |
load_video_predictor: bool = False):
|
55 |
+
"""
|
56 |
+
Load the model from the model directory. If the model is not found, download it from the URL.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
model_type (str): The model type to load.
|
60 |
+
load_video_predictor (bool): Load the video predictor model.
|
61 |
+
"""
|
62 |
if model_type is None:
|
63 |
model_type = DEFAULT_MODEL_TYPE
|
64 |
|
|
|
97 |
def init_video_inference_state(self,
|
98 |
vid_input: str,
|
99 |
model_type: Optional[str] = None):
|
100 |
+
"""
|
101 |
+
Initialize the video inference state for the video predictor.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
vid_input (str): The video frames directory.
|
105 |
+
model_type (str): The model type to load.
|
106 |
+
"""
|
107 |
if model_type is None:
|
108 |
model_type = self.current_model_type
|
109 |
|
|
|
127 |
def generate_mask(self,
|
128 |
image: np.ndarray,
|
129 |
model_type: str,
|
130 |
+
**params) -> List[Dict[str, Any]]:
|
131 |
+
"""
|
132 |
+
Generate masks with Automatic segmentation. Default hyperparameters are in './configs/default_hparams.yaml.'
|
133 |
+
|
134 |
+
Args:
|
135 |
+
image (np.ndarray): The input image.
|
136 |
+
model_type (str): The model type to load.
|
137 |
+
**params: The hyperparameters for the mask generator.
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
List[Dict[str, Any]]: The auto-generated mask data.
|
141 |
+
"""
|
142 |
+
|
143 |
if self.model is None or self.current_model_type != model_type:
|
144 |
self.current_model_type = model_type
|
145 |
self.load_model(model_type=model_type)
|
|
|
160 |
box: Optional[np.ndarray] = None,
|
161 |
point_coords: Optional[np.ndarray] = None,
|
162 |
point_labels: Optional[np.ndarray] = None,
|
163 |
+
**params) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
164 |
+
"""
|
165 |
+
Predict image with prompt data.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
image (np.ndarray): The input image.
|
169 |
+
model_type (str): The model type to load.
|
170 |
+
box (np.ndarray): The box prompt data.
|
171 |
+
point_coords (np.ndarray): The point coordinates prompt data.
|
172 |
+
point_labels (np.ndarray): The point labels prompt data.
|
173 |
+
**params: The hyperparameters for the mask generator.
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
np.ndarray: The predicted masks output in CxHxW format.
|
177 |
+
np.ndarray: Array of scores for each mask.
|
178 |
+
np.ndarray: Array of logits in CxHxW format.
|
179 |
+
"""
|
180 |
if self.model is None or self.current_model_type != model_type:
|
181 |
self.current_model_type = model_type
|
182 |
self.load_model(model_type=model_type)
|
|
|
201 |
inference_state: Optional[Dict] = None,
|
202 |
points: Optional[np.ndarray] = None,
|
203 |
labels: Optional[np.ndarray] = None,
|
204 |
+
box: Optional[np.ndarray] = None) -> Tuple[int, int, torch.Tensor]:
|
205 |
+
"""
|
206 |
+
Add prediction to the current video inference state. inference state must be initialized before calling this method.
|
207 |
+
|
208 |
+
Args:
|
209 |
+
frame_idx (int): The frame index of the video.
|
210 |
+
obj_id (int): The object id for the frame.
|
211 |
+
inference_state (Dict): The inference state for the video predictor.
|
212 |
+
points (np.ndarray): The point coordinates prompt data.
|
213 |
+
labels (np.ndarray): The point labels prompt data.
|
214 |
+
box (np.ndarray): The box prompt data.
|
215 |
+
|
216 |
+
Returns:
|
217 |
+
int: The frame index of the corresponding prediction.
|
218 |
+
int: The object id of the corresponding prediction.
|
219 |
+
torch.Tensor: The mask logits output in CxHxW format.
|
220 |
+
"""
|
221 |
+
|
222 |
if (self.video_predictor is None or
|
223 |
inference_state is None and self.video_inference_state is None):
|
224 |
logger.exception("Error while predicting frame from video, load video predictor first")
|
|
|
243 |
|
244 |
def propagate_in_video(self,
|
245 |
inference_state: Optional[Dict] = None,):
|
246 |
+
"""
|
247 |
+
Propagate in the video with the tracked predictions for each frame. Currently only supports
|
248 |
+
single frame tracking.
|
249 |
+
|
250 |
+
Args:
|
251 |
+
inference_state (Dict): The inference state for the video predictor. Use self.video_inference_state if None.
|
252 |
+
|
253 |
+
Returns:
|
254 |
+
Dict: The video segments with the image and mask data. It has frame index as each key and each key has
|
255 |
+
"image" and "mask" data. "image" key contains the path of the original image file and "mask" key contains
|
256 |
+
the np.ndarray mask output.
|
257 |
+
"""
|
258 |
if inference_state is None and self.video_inference_state is None:
|
259 |
logger.exception("Error while propagating in video, load video predictor first")
|
260 |
|
|
|
290 |
pixel_size: Optional[int] = None,
|
291 |
color_hex: Optional[str] = None,
|
292 |
):
|
293 |
+
"""
|
294 |
+
Add filter to the preview image with the prompt data. Specially made for gradio app.
|
295 |
+
It adds prediction tracking to the self.video_inference_state and returns the filtered image.
|
296 |
+
|
297 |
+
Args:
|
298 |
+
image_prompt_input_data (Dict): The image prompt data.
|
299 |
+
filter_mode (str): The filter mode to apply. ["Solid Color", "Pixelize"]
|
300 |
+
frame_idx (int): The frame index of the video.
|
301 |
+
pixel_size (int): The pixel size for the pixelize filter.
|
302 |
+
color_hex (str): The color hex code for the solid color filter.
|
303 |
+
|
304 |
+
Returns:
|
305 |
+
np.ndarray: The filtered image output.
|
306 |
+
"""
|
307 |
if self.video_predictor is None or self.video_inference_state is None:
|
308 |
logger.exception("Error while adding filter to preview, load video predictor first")
|
309 |
raise f"Error while adding filter to preview"
|
|
|
347 |
pixel_size: Optional[int] = None,
|
348 |
color_hex: Optional[str] = None
|
349 |
):
|
350 |
+
"""
|
351 |
+
Create a whole filtered video with video_inference_state. Currently only one frame tracking is supported.
|
352 |
+
This needs FFmpeg to run. Returns two output path because of the gradio app.
|
353 |
+
|
354 |
+
Args:
|
355 |
+
image_prompt_input_data (Dict): The image prompt data.
|
356 |
+
filter_mode (str): The filter mode to apply. ["Solid Color", "Pixelize"]
|
357 |
+
frame_idx (int): The frame index of the video.
|
358 |
+
pixel_size (int): The pixel size for the pixelize filter.
|
359 |
+
color_hex (str): The color hex code for the solid color filter.
|
360 |
+
|
361 |
+
Returns:
|
362 |
+
str: The output video path.
|
363 |
+
str: The output video path.
|
364 |
+
"""
|
365 |
+
|
366 |
if self.video_predictor is None or self.video_inference_state is None:
|
367 |
logger.exception("Error while adding filter to preview, load video predictor first")
|
368 |
raise RuntimeError("Error while adding filter to preview")
|
|
|
422 |
input_mode: str,
|
423 |
model_type: str,
|
424 |
*params):
|
425 |
+
"""
|
426 |
+
Divide the layer with the given prompt data and save psd file.
|
427 |
+
|
428 |
+
Args:
|
429 |
+
image_input (np.ndarray): The input image.
|
430 |
+
image_prompt_input_data (Dict): The image prompt data.
|
431 |
+
input_mode (str): The input mode for the image prompt data. ["Automatic", "Box Prompt"]
|
432 |
+
model_type (str): The model type to load.
|
433 |
+
*params: The hyperparameters for the mask generator.
|
434 |
+
|
435 |
+
Returns:
|
436 |
+
List[np.ndarray]: List of images by predicted masks.
|
437 |
+
str: The output path of the psd file.
|
438 |
+
"""
|
439 |
+
|
440 |
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
441 |
output_file_name = f"result-{timestamp}.psd"
|
442 |
output_path = os.path.join(self.output_dir, "psd", output_file_name)
|
|
|
494 |
def format_to_auto_result(
|
495 |
masks: np.ndarray
|
496 |
):
|
497 |
+
"""Format the masks to auto result format for convenience."""
|
498 |
place_holder = 0
|
499 |
if len(masks.shape) <= 3:
|
500 |
masks = np.expand_dims(masks, axis=0)
|