ahnobari commited on
Commit
ef48a87
1 Parent(s): 41b552c

postprocess fix

Browse files
Files changed (1) hide show
  1. bikefusion/data_utils.py +7 -2
bikefusion/data_utils.py CHANGED
@@ -29,7 +29,7 @@ def pad_to_square(image, target_size=(128, 128)):
29
 
30
  def preprocess(images):
31
  # Convert arrays to tensors
32
- images = torch.from_numpy(images).float()
33
 
34
  # Apply padding to each image in the dataset
35
  images = torch.stack([pad_to_square(img) for img in images])
@@ -52,7 +52,12 @@ def un_pad(image, target_size=(80, 128)):
52
 
53
  def postprocess(images):
54
  # Convert tensors to arrays
55
- images = images.detach().cpu().numpy()
 
 
 
 
 
56
 
57
  # Unpad each image in the dataset
58
  images = np.stack([un_pad(img) for img in images])
 
29
 
30
  def preprocess(images):
31
  # Convert arrays to tensors
32
+ images = torch.tensor(images).float()
33
 
34
  # Apply padding to each image in the dataset
35
  images = torch.stack([pad_to_square(img) for img in images])
 
52
 
53
  def postprocess(images):
54
  # Convert tensors to arrays
55
+ if isinstance(images, torch.Tensor):
56
+ images = images.detach().cpu().numpy()
57
+ elif isinstance(images, np.ndarray):
58
+ pass
59
+ else:
60
+ raise ValueError("images must be either a torch.Tensor or a np.ndarray")
61
 
62
  # Unpad each image in the dataset
63
  images = np.stack([un_pad(img) for img in images])