init
Browse files- app.py +19 -5
- ugatit_test.py +6 -6
app.py
CHANGED
@@ -15,7 +15,7 @@ import numpy as np
|
|
15 |
import PIL.Image
|
16 |
|
17 |
from io import BytesIO
|
18 |
-
|
19 |
|
20 |
ORIGINAL_REPO_URL = 'https://github.com/taki0112/UGATIT'
|
21 |
TITLE = 'taki0112/UGATIT'
|
@@ -26,6 +26,9 @@ ARTICLE = """
|
|
26 |
|
27 |
"""
|
28 |
|
|
|
|
|
|
|
29 |
def parse_args() -> argparse.Namespace:
|
30 |
parser = argparse.ArgumentParser()
|
31 |
parser.add_argument('--device', type=str, default='cpu')
|
@@ -41,13 +44,22 @@ def parse_args() -> argparse.Namespace:
|
|
41 |
return parser.parse_args()
|
42 |
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
def run(
|
46 |
-
image
|
|
|
47 |
) -> tuple[PIL.Image.Image]:
|
48 |
-
|
49 |
|
50 |
-
|
|
|
|
|
51 |
|
52 |
|
53 |
def main():
|
@@ -55,7 +67,9 @@ def main():
|
|
55 |
|
56 |
args = parse_args()
|
57 |
|
58 |
-
|
|
|
|
|
59 |
func = functools.update_wrapper(func, run)
|
60 |
|
61 |
|
|
|
15 |
import PIL.Image
|
16 |
|
17 |
from io import BytesIO
|
18 |
+
import ugatit_test
|
19 |
|
20 |
ORIGINAL_REPO_URL = 'https://github.com/taki0112/UGATIT'
|
21 |
TITLE = 'taki0112/UGATIT'
|
|
|
26 |
|
27 |
"""
|
28 |
|
29 |
+
|
30 |
+
MODEL_REPO = 'hylee/UGATIT_model'
|
31 |
+
|
32 |
def parse_args() -> argparse.Namespace:
|
33 |
parser = argparse.ArgumentParser()
|
34 |
parser.add_argument('--device', type=str, default='cpu')
|
|
|
44 |
return parser.parse_args()
|
45 |
|
46 |
|
47 |
+
def load_checkpoint():
|
48 |
+
checkpoint_path = huggingface_hub.hf_hub_download(MODEL_REPO,
|
49 |
+
'UGATIT_selfie2anime_lsgan_4resblock_6dis_1_1_10_10_1000_sn_smoothing/checkpoint',
|
50 |
+
cache_dir='UGATIT_selfie2anime_lsgan_4resblock_6dis_1_1_10_10_1000_sn_smoothing')
|
51 |
+
print(checkpoint_path)
|
52 |
+
return 'UGATIT_selfie2anime_lsgan_4resblock_6dis_1_1_10_10_1000_sn_smoothing'
|
53 |
+
|
54 |
|
55 |
def run(
|
56 |
+
image,
|
57 |
+
checkpoint_dir: str,
|
58 |
) -> tuple[PIL.Image.Image]:
|
|
|
59 |
|
60 |
+
result = ugatit_test.main_test(image.name, checkpoint_dir)
|
61 |
+
|
62 |
+
return PIL.Image.open(result)
|
63 |
|
64 |
|
65 |
def main():
|
|
|
67 |
|
68 |
args = parse_args()
|
69 |
|
70 |
+
checkpoint_dir = load_checkpoint()
|
71 |
+
|
72 |
+
func = functools.partial(run, checkpoint_dir=checkpoint_dir)
|
73 |
func = functools.update_wrapper(func, run)
|
74 |
|
75 |
|
ugatit_test.py
CHANGED
@@ -8,7 +8,7 @@ from ugatit.utils import *
|
|
8 |
|
9 |
class UgatitTest:
|
10 |
|
11 |
-
def __init__(self, sess):
|
12 |
self.light = False
|
13 |
|
14 |
if self.light:
|
@@ -18,7 +18,7 @@ class UgatitTest:
|
|
18 |
|
19 |
self.sess = sess
|
20 |
self.phase = 'test'
|
21 |
-
self.checkpoint_dir =
|
22 |
self.result_dir = 'results'
|
23 |
self.log_dir = 'logs'
|
24 |
self.dataset_name = 'selfie2anime'
|
@@ -57,8 +57,8 @@ class UgatitTest:
|
|
57 |
self.img_size = 256
|
58 |
self.img_ch = 3
|
59 |
|
60 |
-
self.sample_dir = os.path.join('/home/hylee/cartoon/UGATIT/samples', self.model_dir)
|
61 |
-
check_folder(self.sample_dir)
|
62 |
|
63 |
# self.trainA, self.trainB = prepare_data(dataset_name=self.dataset_name, size=self.img_size
|
64 |
self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA'))
|
@@ -350,12 +350,12 @@ class UgatitTest:
|
|
350 |
|
351 |
|
352 |
gan = None
|
353 |
-
def main_test(img_path):
|
354 |
# open session
|
355 |
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
|
356 |
global gan
|
357 |
if gan is None:
|
358 |
-
gan = UgatitTest(sess)
|
359 |
# build graph
|
360 |
gan.build_model()
|
361 |
# show network architecture
|
|
|
8 |
|
9 |
class UgatitTest:
|
10 |
|
11 |
+
def __init__(self, sess, checkpoint_dir):
|
12 |
self.light = False
|
13 |
|
14 |
if self.light:
|
|
|
18 |
|
19 |
self.sess = sess
|
20 |
self.phase = 'test'
|
21 |
+
self.checkpoint_dir = checkpoint_dir
|
22 |
self.result_dir = 'results'
|
23 |
self.log_dir = 'logs'
|
24 |
self.dataset_name = 'selfie2anime'
|
|
|
57 |
self.img_size = 256
|
58 |
self.img_ch = 3
|
59 |
|
60 |
+
#self.sample_dir = os.path.join('/home/hylee/cartoon/UGATIT/samples', self.model_dir)
|
61 |
+
#check_folder(self.sample_dir)
|
62 |
|
63 |
# self.trainA, self.trainB = prepare_data(dataset_name=self.dataset_name, size=self.img_size
|
64 |
self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA'))
|
|
|
350 |
|
351 |
|
352 |
gan = None
|
353 |
+
def main_test(img_path, checkpoint_dir):
|
354 |
# open session
|
355 |
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
|
356 |
global gan
|
357 |
if gan is None:
|
358 |
+
gan = UgatitTest(sess, checkpoint_dir)
|
359 |
# build graph
|
360 |
gan.build_model()
|
361 |
# show network architecture
|