Spaces:
Runtime error
Runtime error
geekyrakshit
commited on
Commit
•
192c48a
1
Parent(s):
c8d52e7
added download function for lol dataset
Browse files- .gitignore +3 -0
- enhance_me/commons.py +22 -2
- enhance_me/mirnet/mirnet.py +9 -1
.gitignore
CHANGED
@@ -127,3 +127,6 @@ dmypy.json
|
|
127 |
|
128 |
# Pyre type checker
|
129 |
.pyre/
|
|
|
|
|
|
|
|
127 |
|
128 |
# Pyre type checker
|
129 |
.pyre/
|
130 |
+
|
131 |
+
# Datasets
|
132 |
+
datasets/
|
enhance_me/commons.py
CHANGED
@@ -1,8 +1,11 @@
|
|
1 |
import os
|
2 |
import wandb
|
3 |
-
|
4 |
import matplotlib.pyplot as plt
|
5 |
|
|
|
|
|
|
|
6 |
|
7 |
def read_image(image_path):
|
8 |
image = tf.io.read_file(image_path)
|
@@ -39,5 +42,22 @@ def closest_number(n, m):
|
|
39 |
|
40 |
def init_wandb(project_name, experiment_name, wandb_api_key):
|
41 |
if project_name is not None and experiment_name is not None:
|
42 |
-
os.environ[
|
43 |
wandb.init(project=project_name, name=experiment_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import wandb
|
3 |
+
from glob import glob
|
4 |
import matplotlib.pyplot as plt
|
5 |
|
6 |
+
import tensorflow as tf
|
7 |
+
from tensorflow.keras import utils
|
8 |
+
|
9 |
|
10 |
def read_image(image_path):
|
11 |
image = tf.io.read_file(image_path)
|
|
|
42 |
|
43 |
def init_wandb(project_name, experiment_name, wandb_api_key):
|
44 |
if project_name is not None and experiment_name is not None:
|
45 |
+
os.environ["WANDB_API_KEY"] = wandb_api_key
|
46 |
wandb.init(project=project_name, name=experiment_name)
|
47 |
+
|
48 |
+
|
49 |
+
def download_lol_dataset():
|
50 |
+
utils.get_file(
|
51 |
+
"lol_dataset.zip",
|
52 |
+
"https://github.com/soumik12345/enhance-me/releases/download/v0.1/lol_dataset.zip",
|
53 |
+
cache_dir="./",
|
54 |
+
cache_subdir="./datasets",
|
55 |
+
extract=True,
|
56 |
+
)
|
57 |
+
low_images = sorted(glob("./datasets/lol_dataset/our485/low/*"))
|
58 |
+
enhanced_images = sorted(glob("./datasets/lol_dataset/our485/high/*"))
|
59 |
+
assert len(low_images) == len(enhanced_images)
|
60 |
+
test_low_images = sorted(glob("./datasets/lol_dataset/eval15/low/*"))
|
61 |
+
test_enhanced_images = sorted(glob("./datasets/lol_dataset/eval15/high/*"))
|
62 |
+
assert len(test_low_images) == len(test_enhanced_images)
|
63 |
+
return (low_images, enhanced_images), (test_low_images, test_enhanced_images)
|
enhance_me/mirnet/mirnet.py
CHANGED
@@ -12,7 +12,12 @@ from wandb.keras import WandbCallback
|
|
12 |
from .dataloader import LowLightDataset
|
13 |
from .models import build_mirnet_model
|
14 |
from .losses import CharbonnierLoss
|
15 |
-
from ..commons import
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
class MIRNet:
|
@@ -20,12 +25,15 @@ class MIRNet:
|
|
20 |
self,
|
21 |
experiment_name: str,
|
22 |
image_size: int = 256,
|
|
|
23 |
apply_random_horizontal_flip: bool = True,
|
24 |
apply_random_vertical_flip: bool = True,
|
25 |
apply_random_rotation: bool = True,
|
26 |
wandb_api_key=None,
|
27 |
) -> None:
|
28 |
self.experiment_name = experiment_name
|
|
|
|
|
29 |
self.data_loader = LowLightDataset(
|
30 |
image_size=image_size,
|
31 |
apply_random_horizontal_flip=apply_random_horizontal_flip,
|
|
|
12 |
from .dataloader import LowLightDataset
|
13 |
from .models import build_mirnet_model
|
14 |
from .losses import CharbonnierLoss
|
15 |
+
from ..commons import (
|
16 |
+
peak_signal_noise_ratio,
|
17 |
+
closest_number,
|
18 |
+
init_wandb,
|
19 |
+
download_lol_dataset,
|
20 |
+
)
|
21 |
|
22 |
|
23 |
class MIRNet:
|
|
|
25 |
self,
|
26 |
experiment_name: str,
|
27 |
image_size: int = 256,
|
28 |
+
dataset_label: str = "lol",
|
29 |
apply_random_horizontal_flip: bool = True,
|
30 |
apply_random_vertical_flip: bool = True,
|
31 |
apply_random_rotation: bool = True,
|
32 |
wandb_api_key=None,
|
33 |
) -> None:
|
34 |
self.experiment_name = experiment_name
|
35 |
+
if dataset_label == "lol":
|
36 |
+
download_lol_dataset()
|
37 |
self.data_loader = LowLightDataset(
|
38 |
image_size=image_size,
|
39 |
apply_random_horizontal_flip=apply_random_horizontal_flip,
|