geekyrakshit commited on
Commit
91e5f9b
1 Parent(s): 4cf5013

added mixed precision

Browse files
README.md CHANGED
@@ -8,4 +8,14 @@ app_file: app.py
8
  pinned: false
9
  ---
10
 
11
- # enhance-me
 
 
 
 
 
 
 
 
 
 
 
8
  pinned: false
9
  ---
10
 
11
+ # Enhance Me
12
+
13
+ A unified platform for image enhancement.
14
+
15
+ ## Usage
16
+
17
+ ### Train using Docker
18
+
19
+ - Build image using `docker build -t enhance-image .`
20
+
21
+ - Run notebook using `docker run -it --gpus all -p 8888:8888 -v $(pwd):/usr/src/enhance-me enhance-image`
enhance_me/zero_dce/zero_dce.py CHANGED
@@ -5,7 +5,7 @@ from datetime import datetime
5
 
6
  import tensorflow as tf
7
  from tensorflow import keras
8
- from tensorflow.keras import optimizers, Model
9
  from wandb.keras import WandbCallback
10
 
11
  from .dce_net import build_dce_net
@@ -20,9 +20,12 @@ from ..commons import download_lol_dataset, init_wandb
20
 
21
 
22
  class ZeroDCE(Model):
23
- def __init__(self, experiment_name=None, wandb_api_key=None, **kwargs):
24
  super(ZeroDCE, self).__init__(**kwargs)
25
  self.experiment_name = experiment_name
 
 
 
26
  if wandb_api_key is not None:
27
  init_wandb("zero-dce", experiment_name, wandb_api_key)
28
  self.using_wandb = True
 
5
 
6
  import tensorflow as tf
7
  from tensorflow import keras
8
+ from tensorflow.keras import optimizers, mixed_precision, Model
9
  from wandb.keras import WandbCallback
10
 
11
  from .dce_net import build_dce_net
 
20
 
21
 
22
  class ZeroDCE(Model):
23
+ def __init__(self, experiment_name=None, wandb_api_key=None, use_mixed_precision: bool = False, **kwargs):
24
  super(ZeroDCE, self).__init__(**kwargs)
25
  self.experiment_name = experiment_name
26
+ if use_mixed_precision:
27
+ policy = mixed_precision.Policy('mixed_float16')
28
+ mixed_precision.set_global_policy(policy)
29
  if wandb_api_key is not None:
30
  init_wandb("zero-dce", experiment_name, wandb_api_key)
31
  self.using_wandb = True
notebooks/enhance_me_train.ipynb CHANGED
@@ -187,15 +187,16 @@
187
  "source": [
188
  "# @title Zero-DCE Train Configs\n",
189
  "\n",
190
- "experiment_name = \"lol_dataset_256\" # @param {type:\"string\"}\n",
191
  "image_size = 256 # @param {type:\"integer\"}\n",
192
  "dataset_label = \"lol\" # @param [\"lol\"]\n",
193
- "apply_resize = False # @param {type:\"boolean\"}\n",
 
194
  "apply_random_horizontal_flip = True # @param {type:\"boolean\"}\n",
195
  "apply_random_vertical_flip = True # @param {type:\"boolean\"}\n",
196
  "apply_random_rotation = True # @param {type:\"boolean\"}\n",
197
  "use_mixed_precision = False # @param {type:\"boolean\"}\n",
198
- "wandb_api_key = \"\" # @param {type:\"string\"}\n",
199
  "val_split = 0.1 # @param {type:\"slider\", min:0.1, max:1.0, step:0.1}\n",
200
  "batch_size = 16 # @param {type:\"integer\"}\n",
201
  "learning_rate = 1e-4 # @param {type:\"number\"}\n",
@@ -211,7 +212,8 @@
211
  "source": [
212
  "zero_dce = ZeroDCE(\n",
213
  " experiment_name=experiment_name,\n",
214
- " wandb_api_key=None if wandb_api_key == \"\" else wandb_api_key\n",
 
215
  ")"
216
  ]
217
  },
@@ -263,6 +265,13 @@
263
  " (18, 18),\n",
264
  " )"
265
  ]
 
 
 
 
 
 
 
266
  }
267
  ],
268
  "metadata": {
 
187
  "source": [
188
  "# @title Zero-DCE Train Configs\n",
189
  "\n",
190
+ "experiment_name = \"lol_dataset_256_resize\" # @param {type:\"string\"}\n",
191
  "image_size = 256 # @param {type:\"integer\"}\n",
192
  "dataset_label = \"lol\" # @param [\"lol\"]\n",
193
+ "use_mixed_precision = False # @param {type:\"boolean\"}\n",
194
+ "apply_resize = True # @param {type:\"boolean\"}\n",
195
  "apply_random_horizontal_flip = True # @param {type:\"boolean\"}\n",
196
  "apply_random_vertical_flip = True # @param {type:\"boolean\"}\n",
197
  "apply_random_rotation = True # @param {type:\"boolean\"}\n",
198
  "use_mixed_precision = False # @param {type:\"boolean\"}\n",
199
+ "wandb_api_key = \"8d7149fe07496df2aaab8e9856a6ed8564e2a644\" # @param {type:\"string\"}\n",
200
  "val_split = 0.1 # @param {type:\"slider\", min:0.1, max:1.0, step:0.1}\n",
201
  "batch_size = 16 # @param {type:\"integer\"}\n",
202
  "learning_rate = 1e-4 # @param {type:\"number\"}\n",
 
212
  "source": [
213
  "zero_dce = ZeroDCE(\n",
214
  " experiment_name=experiment_name,\n",
215
+ " wandb_api_key=None if wandb_api_key == \"\" else wandb_api_key,\n",
216
+ " use_mixed_precision=use_mixed_precision\n",
217
  ")"
218
  ]
219
  },
 
265
  " (18, 18),\n",
266
  " )"
267
  ]
268
+ },
269
+ {
270
+ "cell_type": "code",
271
+ "execution_count": null,
272
+ "metadata": {},
273
+ "outputs": [],
274
+ "source": []
275
  }
276
  ],
277
  "metadata": {