jhj0517 commited on
Commit
5929ef8
1 Parent(s): 5c7c82a

Refactor to `clean_files_with_extension()`

Browse files
Files changed (2) hide show
  1. modules/sam_inference.py +4 -4
  2. modules/video_utils.py +5 -16
modules/sam_inference.py CHANGED
@@ -15,7 +15,7 @@ from modules.model_downloader import (
15
  download_sam_model_url
16
  )
17
  from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR, TEMP_OUT_DIR, TEMP_DIR
18
- from modules.constants import BOX_PROMPT_MODE, AUTOMATIC_MODE, COLOR_FILTER, PIXELIZE_FILTER
19
  from modules.mask_utils import (
20
  save_psd_with_masks,
21
  create_mask_combined_images,
@@ -24,7 +24,7 @@ from modules.mask_utils import (
24
  create_solid_color_mask_image
25
  )
26
  from modules.video_utils import (get_frames_from_dir, create_video_from_frames, get_video_info, extract_frames,
27
- extract_sound, clean_temp_dir, clean_image_files)
28
  from modules.utils import save_image
29
  from modules.logger_util import get_logger
30
 
@@ -277,14 +277,14 @@ class SamInference:
277
  logger.error(error_message)
278
  raise gr.Error(error_message, duration=20)
279
 
280
- clean_image_files(TEMP_OUT_DIR)
 
281
 
282
  prompt_frame_image, prompt = image_prompt_input_data["image"], image_prompt_input_data["points"]
283
 
284
  point_labels, point_coords, box = self.handle_prompt_data(prompt)
285
  obj_id = frame_idx
286
 
287
- self.video_predictor.reset_state(self.video_inference_state)
288
  idx, scores, logits = self.add_prediction_to_frame(
289
  frame_idx=frame_idx,
290
  obj_id=obj_id,
 
15
  download_sam_model_url
16
  )
17
  from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR, TEMP_OUT_DIR, TEMP_DIR
18
+ from modules.constants import BOX_PROMPT_MODE, AUTOMATIC_MODE, COLOR_FILTER, PIXELIZE_FILTER, IMAGE_FILE_EXT
19
  from modules.mask_utils import (
20
  save_psd_with_masks,
21
  create_mask_combined_images,
 
24
  create_solid_color_mask_image
25
  )
26
  from modules.video_utils import (get_frames_from_dir, create_video_from_frames, get_video_info, extract_frames,
27
+ extract_sound, clean_temp_dir, clean_files_with_extension)
28
  from modules.utils import save_image
29
  from modules.logger_util import get_logger
30
 
 
277
  logger.error(error_message)
278
  raise gr.Error(error_message, duration=20)
279
 
280
+ clean_files_with_extension(TEMP_OUT_DIR, IMAGE_FILE_EXT)
281
+ self.video_predictor.reset_state(self.video_inference_state)
282
 
283
  prompt_frame_image, prompt = image_prompt_input_data["image"], image_prompt_input_data["points"]
284
 
285
  point_labels, point_coords, box = self.handle_prompt_data(prompt)
286
  obj_id = frame_idx
287
 
 
288
  idx, scores, logits = self.add_prediction_to_frame(
289
  frame_idx=frame_idx,
290
  obj_id=obj_id,
modules/video_utils.py CHANGED
@@ -7,6 +7,7 @@ from dataclasses import dataclass
7
  import re
8
 
9
  from modules.logger_util import get_logger
 
10
  from modules.paths import TEMP_DIR, TEMP_OUT_DIR
11
 
12
  logger = get_logger()
@@ -222,24 +223,12 @@ def clean_temp_dir(temp_dir: Optional[str] = None):
222
  else:
223
  temp_out_dir = os.path.join(temp_dir, "out")
224
 
225
- clean_sound_files(temp_dir)
226
- clean_image_files(temp_dir)
227
- clean_image_files(temp_out_dir)
228
 
229
 
230
- def clean_sound_files(sound_dir: str):
231
- """Removes all sound files from the directory."""
232
- sound_extensions = ['.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a', '.wma']
233
- _clean_files_with_extension(sound_dir, sound_extensions)
234
-
235
-
236
- def clean_image_files(image_dir: str):
237
- """Removes all image files from the dir"""
238
- image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp']
239
- _clean_files_with_extension(image_dir, image_extensions)
240
-
241
-
242
- def _clean_files_with_extension(dir_path: str, extensions: List):
243
  for filename in os.listdir(dir_path):
244
  if filename.lower().endswith(tuple(extensions)):
245
  file_path = os.path.join(dir_path, filename)
 
7
  import re
8
 
9
  from modules.logger_util import get_logger
10
+ from modules.constants import SOUND_FILE_EXT, VIDEO_FILE_EXT, IMAGE_FILE_EXT
11
  from modules.paths import TEMP_DIR, TEMP_OUT_DIR
12
 
13
  logger = get_logger()
 
223
  else:
224
  temp_out_dir = os.path.join(temp_dir, "out")
225
 
226
+ clean_files_with_extension(temp_dir, SOUND_FILE_EXT)
227
+ clean_files_with_extension(temp_dir, IMAGE_FILE_EXT)
228
+ clean_files_with_extension(temp_out_dir, IMAGE_FILE_EXT)
229
 
230
 
231
+ def clean_files_with_extension(dir_path: str, extensions: List):
 
 
 
 
 
 
 
 
 
 
 
 
232
  for filename in os.listdir(dir_path):
233
  if filename.lower().endswith(tuple(extensions)):
234
  file_path = os.path.join(dir_path, filename)