Spaces:
Sleeping
Sleeping
jhj0517
commited on
Commit
•
5929ef8
1
Parent(s):
5c7c82a
Refactor to `clean_files_with_extension()`
Browse files- modules/sam_inference.py +4 -4
- modules/video_utils.py +5 -16
modules/sam_inference.py
CHANGED
@@ -15,7 +15,7 @@ from modules.model_downloader import (
|
|
15 |
download_sam_model_url
|
16 |
)
|
17 |
from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR, TEMP_OUT_DIR, TEMP_DIR
|
18 |
-
from modules.constants import BOX_PROMPT_MODE, AUTOMATIC_MODE, COLOR_FILTER, PIXELIZE_FILTER
|
19 |
from modules.mask_utils import (
|
20 |
save_psd_with_masks,
|
21 |
create_mask_combined_images,
|
@@ -24,7 +24,7 @@ from modules.mask_utils import (
|
|
24 |
create_solid_color_mask_image
|
25 |
)
|
26 |
from modules.video_utils import (get_frames_from_dir, create_video_from_frames, get_video_info, extract_frames,
|
27 |
-
extract_sound, clean_temp_dir,
|
28 |
from modules.utils import save_image
|
29 |
from modules.logger_util import get_logger
|
30 |
|
@@ -277,14 +277,14 @@ class SamInference:
|
|
277 |
logger.error(error_message)
|
278 |
raise gr.Error(error_message, duration=20)
|
279 |
|
280 |
-
|
|
|
281 |
|
282 |
prompt_frame_image, prompt = image_prompt_input_data["image"], image_prompt_input_data["points"]
|
283 |
|
284 |
point_labels, point_coords, box = self.handle_prompt_data(prompt)
|
285 |
obj_id = frame_idx
|
286 |
|
287 |
-
self.video_predictor.reset_state(self.video_inference_state)
|
288 |
idx, scores, logits = self.add_prediction_to_frame(
|
289 |
frame_idx=frame_idx,
|
290 |
obj_id=obj_id,
|
|
|
15 |
download_sam_model_url
|
16 |
)
|
17 |
from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR, TEMP_OUT_DIR, TEMP_DIR
|
18 |
+
from modules.constants import BOX_PROMPT_MODE, AUTOMATIC_MODE, COLOR_FILTER, PIXELIZE_FILTER, IMAGE_FILE_EXT
|
19 |
from modules.mask_utils import (
|
20 |
save_psd_with_masks,
|
21 |
create_mask_combined_images,
|
|
|
24 |
create_solid_color_mask_image
|
25 |
)
|
26 |
from modules.video_utils import (get_frames_from_dir, create_video_from_frames, get_video_info, extract_frames,
|
27 |
+
extract_sound, clean_temp_dir, clean_files_with_extension)
|
28 |
from modules.utils import save_image
|
29 |
from modules.logger_util import get_logger
|
30 |
|
|
|
277 |
logger.error(error_message)
|
278 |
raise gr.Error(error_message, duration=20)
|
279 |
|
280 |
+
clean_files_with_extension(TEMP_OUT_DIR, IMAGE_FILE_EXT)
|
281 |
+
self.video_predictor.reset_state(self.video_inference_state)
|
282 |
|
283 |
prompt_frame_image, prompt = image_prompt_input_data["image"], image_prompt_input_data["points"]
|
284 |
|
285 |
point_labels, point_coords, box = self.handle_prompt_data(prompt)
|
286 |
obj_id = frame_idx
|
287 |
|
|
|
288 |
idx, scores, logits = self.add_prediction_to_frame(
|
289 |
frame_idx=frame_idx,
|
290 |
obj_id=obj_id,
|
modules/video_utils.py
CHANGED
@@ -7,6 +7,7 @@ from dataclasses import dataclass
|
|
7 |
import re
|
8 |
|
9 |
from modules.logger_util import get_logger
|
|
|
10 |
from modules.paths import TEMP_DIR, TEMP_OUT_DIR
|
11 |
|
12 |
logger = get_logger()
|
@@ -222,24 +223,12 @@ def clean_temp_dir(temp_dir: Optional[str] = None):
|
|
222 |
else:
|
223 |
temp_out_dir = os.path.join(temp_dir, "out")
|
224 |
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
|
229 |
|
230 |
-
def
|
231 |
-
"""Removes all sound files from the directory."""
|
232 |
-
sound_extensions = ['.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a', '.wma']
|
233 |
-
_clean_files_with_extension(sound_dir, sound_extensions)
|
234 |
-
|
235 |
-
|
236 |
-
def clean_image_files(image_dir: str):
|
237 |
-
"""Removes all image files from the dir"""
|
238 |
-
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp']
|
239 |
-
_clean_files_with_extension(image_dir, image_extensions)
|
240 |
-
|
241 |
-
|
242 |
-
def _clean_files_with_extension(dir_path: str, extensions: List):
|
243 |
for filename in os.listdir(dir_path):
|
244 |
if filename.lower().endswith(tuple(extensions)):
|
245 |
file_path = os.path.join(dir_path, filename)
|
|
|
7 |
import re
|
8 |
|
9 |
from modules.logger_util import get_logger
|
10 |
+
from modules.constants import SOUND_FILE_EXT, VIDEO_FILE_EXT, IMAGE_FILE_EXT
|
11 |
from modules.paths import TEMP_DIR, TEMP_OUT_DIR
|
12 |
|
13 |
logger = get_logger()
|
|
|
223 |
else:
|
224 |
temp_out_dir = os.path.join(temp_dir, "out")
|
225 |
|
226 |
+
clean_files_with_extension(temp_dir, SOUND_FILE_EXT)
|
227 |
+
clean_files_with_extension(temp_dir, IMAGE_FILE_EXT)
|
228 |
+
clean_files_with_extension(temp_out_dir, IMAGE_FILE_EXT)
|
229 |
|
230 |
|
231 |
+
def clean_files_with_extension(dir_path: str, extensions: List):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
for filename in os.listdir(dir_path):
|
233 |
if filename.lower().endswith(tuple(extensions)):
|
234 |
file_path = os.path.join(dir_path, filename)
|