jhj0517 commited on
Commit
bbbee26
1 Parent(s): 88d3da9

Update optional return type to `get_frames_from_dir`

Browse files
Files changed (1) hide show
  1. modules/video_utils.py +7 -1
modules/video_utils.py CHANGED
@@ -1,6 +1,8 @@
1
  import subprocess
2
  import os
3
  from typing import List, Optional, Union
 
 
4
 
5
  from modules.logger_util import get_logger
6
  from modules.paths import TEMP_DIR
@@ -37,7 +39,8 @@ def extract_frames(
37
 
38
 
39
  def get_frames_from_dir(vid_dir: str,
40
- available_extensions: Optional[Union[List, str]] = None) -> List:
 
41
  """Get image file paths list from the dir"""
42
  if available_extensions is None:
43
  available_extensions = [".jpg", ".jpeg", ".JPG", ".JPEG"]
@@ -54,6 +57,9 @@ def get_frames_from_dir(vid_dir: str,
54
  frame_names.sort(key=lambda x: int(os.path.splitext(x)[0]))
55
 
56
  frames = [os.path.join(vid_dir, name) for name in frame_names]
 
 
 
57
  return frames
58
 
59
 
 
1
  import subprocess
2
  import os
3
  from typing import List, Optional, Union
4
+ from PIL import Image
5
+ import numpy as np
6
 
7
  from modules.logger_util import get_logger
8
  from modules.paths import TEMP_DIR
 
39
 
40
 
41
  def get_frames_from_dir(vid_dir: str,
42
+ available_extensions: Optional[Union[List, str]] = None,
43
+ as_numpy: bool = False) -> List:
44
  """Get image file paths list from the dir"""
45
  if available_extensions is None:
46
  available_extensions = [".jpg", ".jpeg", ".JPG", ".JPEG"]
 
57
  frame_names.sort(key=lambda x: int(os.path.splitext(x)[0]))
58
 
59
  frames = [os.path.join(vid_dir, name) for name in frame_names]
60
+ if as_numpy:
61
+ frames = [np.array(Image.open(frame)) for frame in frames]
62
+
63
  return frames
64
 
65