Spaces:
Running
Running
Update
Browse files
app.py
CHANGED
@@ -7,7 +7,6 @@ import functools
|
|
7 |
import os
|
8 |
import pathlib
|
9 |
import sys
|
10 |
-
import tarfile
|
11 |
from typing import Callable
|
12 |
|
13 |
if os.environ.get('SYSTEM') == 'spaces':
|
@@ -29,6 +28,24 @@ from model.encoder.align_all_parallel import align_face
|
|
29 |
from model.encoder.psp import pSp
|
30 |
from util import load_image, visualize
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
TOKEN = os.environ['TOKEN']
|
33 |
|
34 |
MODEL_REPO = 'hysts/DualStyleGAN'
|
@@ -49,17 +66,6 @@ def parse_args() -> argparse.Namespace:
|
|
49 |
return parser.parse_args()
|
50 |
|
51 |
|
52 |
-
def download_cartoon_images() -> None:
|
53 |
-
image_dir = pathlib.Path('cartoon')
|
54 |
-
if not image_dir.exists():
|
55 |
-
path = huggingface_hub.hf_hub_download('hysts/DualStyleGAN-Cartoon',
|
56 |
-
'cartoon.tar.gz',
|
57 |
-
repo_type='dataset',
|
58 |
-
use_auth_token=TOKEN)
|
59 |
-
with tarfile.open(path) as f:
|
60 |
-
f.extractall()
|
61 |
-
|
62 |
-
|
63 |
def load_encoder(device: torch.device) -> nn.Module:
|
64 |
ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO,
|
65 |
'models/encoder.pt',
|
@@ -188,13 +194,7 @@ def run(
|
|
188 |
img_gen1 = postprocess(img_gen[1])
|
189 |
img_gen2 = postprocess(img_gen2[0])
|
190 |
|
191 |
-
|
192 |
-
style_image_dir = pathlib.Path(style_type)
|
193 |
-
style_image = PIL.Image.open(style_image_dir / stylename)
|
194 |
-
except Exception:
|
195 |
-
style_image = None
|
196 |
-
|
197 |
-
return image, style_image, img_rec, img_gen0, img_gen1, img_gen2
|
198 |
|
199 |
|
200 |
def main():
|
@@ -221,7 +221,6 @@ def main():
|
|
221 |
for style_type in style_types
|
222 |
}
|
223 |
|
224 |
-
download_cartoon_images()
|
225 |
dlib_landmark_model = create_dlib_landmark_model()
|
226 |
encoder = load_encoder(device)
|
227 |
transform = create_transform()
|
@@ -235,14 +234,6 @@ def main():
|
|
235 |
device=device)
|
236 |
func = functools.update_wrapper(func, run)
|
237 |
|
238 |
-
repo_url = 'https://github.com/williamyang1991/DualStyleGAN'
|
239 |
-
title = 'williamyang1991/DualStyleGAN'
|
240 |
-
description = f"""A demo for {repo_url}
|
241 |
-
|
242 |
-
You can select style images for cartoon from the table below.
|
243 |
-
"""
|
244 |
-
article = '![cartoon style images](https://user-images.githubusercontent.com/18130694/159848447-96fa5194-32ec-42f0-945a-3b1958bf6e5e.jpg)'
|
245 |
-
|
246 |
image_paths = sorted(pathlib.Path('images').glob('*'))
|
247 |
examples = [[path.as_posix(), 'cartoon', 26] for path in image_paths]
|
248 |
|
@@ -260,7 +251,6 @@ def main():
|
|
260 |
],
|
261 |
[
|
262 |
gr.outputs.Image(type='pil', label='Aligned Face'),
|
263 |
-
gr.outputs.Image(type='pil', label='Selected Style Image'),
|
264 |
gr.outputs.Image(type='pil', label='Reconstructed'),
|
265 |
gr.outputs.Image(type='pil', label='Result 1'),
|
266 |
gr.outputs.Image(type='pil', label='Result 2'),
|
@@ -268,9 +258,9 @@ def main():
|
|
268 |
],
|
269 |
examples=examples,
|
270 |
theme=args.theme,
|
271 |
-
title=
|
272 |
-
description=
|
273 |
-
article=
|
274 |
allow_screenshot=args.allow_screenshot,
|
275 |
allow_flagging=args.allow_flagging,
|
276 |
live=args.live,
|
|
|
7 |
import os
|
8 |
import pathlib
|
9 |
import sys
|
|
|
10 |
from typing import Callable
|
11 |
|
12 |
if os.environ.get('SYSTEM') == 'spaces':
|
|
|
28 |
from model.encoder.psp import pSp
|
29 |
from util import load_image, visualize
|
30 |
|
31 |
+
ORIGINAL_REPO_URL = 'https://github.com/williamyang1991/DualStyleGAN'
|
32 |
+
TITLE = 'williamyang1991/DualStyleGAN'
|
33 |
+
DESCRIPTION = f"""A demo for {ORIGINAL_REPO_URL}
|
34 |
+
|
35 |
+
You can select style images for cartoon from the table below.
|
36 |
+
|
37 |
+
The style image index should be in the following range:
|
38 |
+
|
39 |
+
- cartoon: 0-316
|
40 |
+
- caricature: 0-198
|
41 |
+
- anime: 0-173
|
42 |
+
- arcane: 0-99
|
43 |
+
- comic: 0-100
|
44 |
+
- pixar: 0-121
|
45 |
+
- slamdunk: 0-119
|
46 |
+
"""
|
47 |
+
ARTICLE = '![cartoon style images](https://user-images.githubusercontent.com/18130694/159848447-96fa5194-32ec-42f0-945a-3b1958bf6e5e.jpg)'
|
48 |
+
|
49 |
TOKEN = os.environ['TOKEN']
|
50 |
|
51 |
MODEL_REPO = 'hysts/DualStyleGAN'
|
|
|
66 |
return parser.parse_args()
|
67 |
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
def load_encoder(device: torch.device) -> nn.Module:
|
70 |
ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO,
|
71 |
'models/encoder.pt',
|
|
|
194 |
img_gen1 = postprocess(img_gen[1])
|
195 |
img_gen2 = postprocess(img_gen2[0])
|
196 |
|
197 |
+
return image, img_rec, img_gen0, img_gen1, img_gen2
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
|
200 |
def main():
|
|
|
221 |
for style_type in style_types
|
222 |
}
|
223 |
|
|
|
224 |
dlib_landmark_model = create_dlib_landmark_model()
|
225 |
encoder = load_encoder(device)
|
226 |
transform = create_transform()
|
|
|
234 |
device=device)
|
235 |
func = functools.update_wrapper(func, run)
|
236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
image_paths = sorted(pathlib.Path('images').glob('*'))
|
238 |
examples = [[path.as_posix(), 'cartoon', 26] for path in image_paths]
|
239 |
|
|
|
251 |
],
|
252 |
[
|
253 |
gr.outputs.Image(type='pil', label='Aligned Face'),
|
|
|
254 |
gr.outputs.Image(type='pil', label='Reconstructed'),
|
255 |
gr.outputs.Image(type='pil', label='Result 1'),
|
256 |
gr.outputs.Image(type='pil', label='Result 2'),
|
|
|
258 |
],
|
259 |
examples=examples,
|
260 |
theme=args.theme,
|
261 |
+
title=TITLE,
|
262 |
+
description=DESCRIPTION,
|
263 |
+
article=ARTICLE,
|
264 |
allow_screenshot=args.allow_screenshot,
|
265 |
allow_flagging=args.allow_flagging,
|
266 |
live=args.live,
|