hylee commited on
Commit
1b9c487
1 Parent(s): c0c7d9b
Files changed (2) hide show
  1. app.py +19 -5
  2. ugatit_test.py +6 -6
app.py CHANGED
@@ -15,7 +15,7 @@ import numpy as np
15
  import PIL.Image
16
 
17
  from io import BytesIO
18
-
19
 
20
  ORIGINAL_REPO_URL = 'https://github.com/taki0112/UGATIT'
21
  TITLE = 'taki0112/UGATIT'
@@ -26,6 +26,9 @@ ARTICLE = """
26
 
27
  """
28
 
 
 
 
29
  def parse_args() -> argparse.Namespace:
30
  parser = argparse.ArgumentParser()
31
  parser.add_argument('--device', type=str, default='cpu')
@@ -41,13 +44,22 @@ def parse_args() -> argparse.Namespace:
41
  return parser.parse_args()
42
 
43
 
 
 
 
 
 
 
 
44
 
45
  def run(
46
- image
 
47
  ) -> tuple[PIL.Image.Image]:
48
-
49
 
50
- return PIL.Image.open(image.name)
 
 
51
 
52
 
53
  def main():
@@ -55,7 +67,9 @@ def main():
55
 
56
  args = parse_args()
57
 
58
- func = functools.partial(run)
 
 
59
  func = functools.update_wrapper(func, run)
60
 
61
 
 
15
  import PIL.Image
16
 
17
  from io import BytesIO
18
+ import ugatit_test
19
 
20
  ORIGINAL_REPO_URL = 'https://github.com/taki0112/UGATIT'
21
  TITLE = 'taki0112/UGATIT'
 
26
 
27
  """
28
 
29
+
30
+ MODEL_REPO = 'hylee/UGATIT_model'
31
+
32
  def parse_args() -> argparse.Namespace:
33
  parser = argparse.ArgumentParser()
34
  parser.add_argument('--device', type=str, default='cpu')
 
44
  return parser.parse_args()
45
 
46
 
47
+ def load_checkpoint():
48
+ checkpoint_path = huggingface_hub.hf_hub_download(MODEL_REPO,
49
+ 'UGATIT_selfie2anime_lsgan_4resblock_6dis_1_1_10_10_1000_sn_smoothing/checkpoint',
50
+ cache_dir='UGATIT_selfie2anime_lsgan_4resblock_6dis_1_1_10_10_1000_sn_smoothing')
51
+ print(checkpoint_path)
52
+ return 'UGATIT_selfie2anime_lsgan_4resblock_6dis_1_1_10_10_1000_sn_smoothing'
53
+
54
 
55
  def run(
56
+ image,
57
+ checkpoint_dir: str,
58
  ) -> tuple[PIL.Image.Image]:
 
59
 
60
+ result = ugatit_test.main_test(image.name, checkpoint_dir)
61
+
62
+ return PIL.Image.open(result)
63
 
64
 
65
  def main():
 
67
 
68
  args = parse_args()
69
 
70
+ checkpoint_dir = load_checkpoint()
71
+
72
+ func = functools.partial(run, checkpoint_dir=checkpoint_dir)
73
  func = functools.update_wrapper(func, run)
74
 
75
 
ugatit_test.py CHANGED
@@ -8,7 +8,7 @@ from ugatit.utils import *
8
 
9
  class UgatitTest:
10
 
11
- def __init__(self, sess):
12
  self.light = False
13
 
14
  if self.light:
@@ -18,7 +18,7 @@ class UgatitTest:
18
 
19
  self.sess = sess
20
  self.phase = 'test'
21
- self.checkpoint_dir = '/home/hylee/cartoon/UGATIT/checkpoint'
22
  self.result_dir = 'results'
23
  self.log_dir = 'logs'
24
  self.dataset_name = 'selfie2anime'
@@ -57,8 +57,8 @@ class UgatitTest:
57
  self.img_size = 256
58
  self.img_ch = 3
59
 
60
- self.sample_dir = os.path.join('/home/hylee/cartoon/UGATIT/samples', self.model_dir)
61
- check_folder(self.sample_dir)
62
 
63
  # self.trainA, self.trainB = prepare_data(dataset_name=self.dataset_name, size=self.img_size
64
  self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA'))
@@ -350,12 +350,12 @@ class UgatitTest:
350
 
351
 
352
  gan = None
353
- def main_test(img_path):
354
  # open session
355
  sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
356
  global gan
357
  if gan is None:
358
- gan = UgatitTest(sess)
359
  # build graph
360
  gan.build_model()
361
  # show network architecture
 
8
 
9
  class UgatitTest:
10
 
11
+ def __init__(self, sess, checkpoint_dir):
12
  self.light = False
13
 
14
  if self.light:
 
18
 
19
  self.sess = sess
20
  self.phase = 'test'
21
+ self.checkpoint_dir = checkpoint_dir
22
  self.result_dir = 'results'
23
  self.log_dir = 'logs'
24
  self.dataset_name = 'selfie2anime'
 
57
  self.img_size = 256
58
  self.img_ch = 3
59
 
60
+ #self.sample_dir = os.path.join('/home/hylee/cartoon/UGATIT/samples', self.model_dir)
61
+ #check_folder(self.sample_dir)
62
 
63
  # self.trainA, self.trainB = prepare_data(dataset_name=self.dataset_name, size=self.img_size
64
  self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA'))
 
350
 
351
 
352
  gan = None
353
+ def main_test(img_path, checkpoint_dir):
354
  # open session
355
  sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
356
  global gan
357
  if gan is None:
358
+ gan = UgatitTest(sess, checkpoint_dir)
359
  # build graph
360
  gan.build_model()
361
  # show network architecture