geekyrakshit commited on
Commit
056188b
1 Parent(s): cfb00a4

added unpaired dataloader

Browse files
enhance_me/augmentation.py CHANGED
@@ -49,3 +49,38 @@ class AugmentationFactory:
49
  return tf.image.rot90(input_image, condition), tf.image.rot90(
50
  enhanced_image, condition
51
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  return tf.image.rot90(input_image, condition), tf.image.rot90(
50
  enhanced_image, condition
51
  )
52
+
53
+
54
+ class UnpairedAugmentationFactory:
55
+ def __init__(self, image_size) -> None:
56
+ self.image_size = image_size
57
+
58
+ def random_crop(self, image):
59
+ image_shape = tf.shape(image)[:2]
60
+ crop_w = tf.random.uniform(
61
+ shape=(), maxval=image_shape[1] - self.image_size + 1, dtype=tf.int32
62
+ )
63
+ crop_h = tf.random.uniform(
64
+ shape=(), maxval=image_shape[0] - self.image_size + 1, dtype=tf.int32
65
+ )
66
+ return image[
67
+ crop_h : crop_h + self.image_size, crop_w : crop_w + self.image_size
68
+ ]
69
+
70
+ def random_horizontal_flip(self, image):
71
+ return tf.cond(
72
+ tf.random.uniform(shape=(), maxval=1) < 0.5,
73
+ lambda: image,
74
+ lambda: tf.image.flip_left_right(image),
75
+ )
76
+
77
+ def random_vertical_flip(self, image):
78
+ return tf.cond(
79
+ tf.random.uniform(shape=(), maxval=1) < 0.5,
80
+ lambda: image,
81
+ lambda: tf.image.flip_up_down(image),
82
+ )
83
+
84
+ def random_rotate(self, image):
85
+ condition = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
86
+ return tf.image.rot90(image, condition)
enhance_me/zero_dce/__init__.py ADDED
File without changes
enhance_me/zero_dce/dataloader.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from typing import List
3
+
4
+ from ..augmentation import UnpairedAugmentationFactory
5
+
6
+
7
+ class UnpairedLowLightDataset:
8
+ def __init__(
9
+ self,
10
+ image_size: int = 256,
11
+ apply_random_horizontal_flip: bool = True,
12
+ apply_random_vertical_flip: bool = True,
13
+ apply_random_rotation: bool = True,
14
+ ) -> None:
15
+ self.augmentation_factory = UnpairedAugmentationFactory(image_size=image_size)
16
+ self.apply_random_horizontal_flip = apply_random_horizontal_flip
17
+ self.apply_random_vertical_flip = apply_random_vertical_flip
18
+ self.apply_random_rotation = apply_random_rotation
19
+
20
+ def load_data(self, image_path):
21
+ image = tf.io.read_file(image_path)
22
+ image = tf.image.decode_png(image, channels=3)
23
+ image = image / 255.0
24
+ return image
25
+
26
+ def _get_dataset(self, images: List[str], batch_size: int, is_train: bool):
27
+ dataset = tf.data.Dataset.from_tensor_slices((images))
28
+ dataset = dataset.map(self.load_data, num_parallel_calls=tf.data.AUTOTUNE)
29
+ dataset = dataset.map(
30
+ self.augmentation_factory.random_crop, num_parallel_calls=tf.data.AUTOTUNE
31
+ )
32
+ if is_train:
33
+ dataset = (
34
+ dataset.map(
35
+ self.augmentation_factory.random_horizontal_flip,
36
+ num_parallel_calls=tf.data.AUTOTUNE,
37
+ )
38
+ if self.apply_random_horizontal_flip
39
+ else dataset
40
+ )
41
+ dataset = (
42
+ dataset.map(
43
+ self.augmentation_factory.random_vertical_flip,
44
+ num_parallel_calls=tf.data.AUTOTUNE,
45
+ )
46
+ if self.apply_random_vertical_flip
47
+ else dataset
48
+ )
49
+ dataset = (
50
+ dataset.map(
51
+ self.augmentation_factory.random_rotate,
52
+ num_parallel_calls=tf.data.AUTOTUNE,
53
+ )
54
+ if self.apply_random_rotation
55
+ else dataset
56
+ )
57
+ dataset = dataset.batch(batch_size, drop_remainder=True)
58
+ return dataset
59
+
60
+ def get_datasets(
61
+ self,
62
+ images: List[str],
63
+ val_split: float = 0.2,
64
+ batch_size: int = 16,
65
+ ):
66
+ split_index = int(len(images) * (1 - val_split))
67
+ train_images = images[:split_index]
68
+ val_images = images[split_index:]
69
+ print(f"Number of train data points: {len(train_images)}")
70
+ print(f"Number of validation data points: {len(val_images)}")
71
+ train_dataset = self._get_dataset(train_images, batch_size, is_train=True)
72
+ val_dataset = self._get_dataset(val_images, batch_size, is_train=False)
73
+ return train_dataset, val_dataset