hysts HF staff commited on
Commit
009c38e
1 Parent(s): 4fce96a
Files changed (5) hide show
  1. .gitmodules +3 -0
  2. DualStyleGAN +1 -0
  3. app.py +351 -0
  4. packages.txt +2 -0
  5. requirements.txt +7 -0
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "DualStyleGAN"]
2
+ path = DualStyleGAN
3
+ url = https://github.com/williamyang1991/DualStyleGAN
DualStyleGAN ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 64285b179d0929e301a97c2f2c438546ff49e20d
app.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import os
7
+ import sys
8
+ from typing import Callable
9
+
10
+ import dlib
11
+ import gradio as gr
12
+ import huggingface_hub
13
+ import numpy as np
14
+ import PIL.Image
15
+ import torch
16
+ import torch.nn as nn
17
+ import torchvision.transforms as T
18
+
19
+ if os.environ.get('SYSTEM') == 'spaces':
20
+ os.system("sed -i '10,17d' DualStyleGAN/model/stylegan/op/fused_act.py")
21
+ os.system("sed -i '10,17d' DualStyleGAN/model/stylegan/op/upfirdn2d.py")
22
+
23
+ sys.path.insert(0, 'DualStyleGAN')
24
+
25
+ from model.dualstylegan import DualStyleGAN
26
+ from model.encoder.align_all_parallel import align_face
27
+ from model.encoder.psp import pSp
28
+
29
+ STYLE_IMAGE_PATHS = {
30
+ 'cartoon':
31
+ 'https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/cartoon_overview.jpg',
32
+ 'caricature':
33
+ 'https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/caricature_overview.jpg',
34
+ 'anime':
35
+ 'https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/anime_overview.jpg',
36
+ 'arcane':
37
+ 'https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/Reconstruction_arcane_overview.jpg',
38
+ 'comic':
39
+ 'https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/Reconstruction_comic_overview.jpg',
40
+ 'pixar':
41
+ 'https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/Reconstruction_pixar_overview.jpg',
42
+ 'slamdunk':
43
+ 'https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/Reconstruction_slamdunk_overview.jpg',
44
+ }
45
+
46
+ TOKEN = os.environ['TOKEN']
47
+ MODEL_REPO = 'hysts/DualStyleGAN'
48
+
49
+
50
+ def parse_args() -> argparse.Namespace:
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument('--device', type=str, default='cpu')
53
+ parser.add_argument('--theme', type=str)
54
+ parser.add_argument('--live', action='store_true')
55
+ parser.add_argument('--share', action='store_true')
56
+ parser.add_argument('--port', type=int)
57
+ parser.add_argument('--disable-queue',
58
+ dest='enable_queue',
59
+ action='store_false')
60
+ parser.add_argument('--allow-flagging', type=str, default='never')
61
+ return parser.parse_args()
62
+
63
+
64
+ class App:
65
+
66
+ def __init__(self, device: torch.device):
67
+ self.device = device
68
+ self.face_detector = self._create_dlib_landmark_model()
69
+ self.encoder = self._load_encoder()
70
+ self.transform = self._create_transform()
71
+
72
+ self.style_types = [
73
+ 'cartoon',
74
+ 'caricature',
75
+ 'anime',
76
+ 'arcane',
77
+ 'comic',
78
+ 'pixar',
79
+ 'slamdunk',
80
+ ]
81
+ self.generator_dict = {
82
+ style_type: self._load_generator(style_type)
83
+ for style_type in self.style_types
84
+ }
85
+ self.exstyle_dict = {
86
+ style_type: self._load_exstylecode(style_type)
87
+ for style_type in self.style_types
88
+ }
89
+
90
+ @staticmethod
91
+ def _create_dlib_landmark_model():
92
+ path = huggingface_hub.hf_hub_download(
93
+ 'hysts/dlib_face_landmark_model',
94
+ 'shape_predictor_68_face_landmarks.dat',
95
+ use_auth_token=TOKEN)
96
+ return dlib.shape_predictor(path)
97
+
98
+ def _load_encoder(self) -> nn.Module:
99
+ ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO,
100
+ 'models/encoder.pt',
101
+ use_auth_token=TOKEN)
102
+ ckpt = torch.load(ckpt_path, map_location='cpu')
103
+ opts = ckpt['opts']
104
+ opts['device'] = self.device.type
105
+ opts['checkpoint_path'] = ckpt_path
106
+ opts = argparse.Namespace(**opts)
107
+ model = pSp(opts)
108
+ model.to(self.device)
109
+ model.eval()
110
+ return model
111
+
112
+ @staticmethod
113
+ def _create_transform() -> Callable:
114
+ transform = T.Compose([
115
+ T.Resize(256),
116
+ T.CenterCrop(256),
117
+ T.ToTensor(),
118
+ T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
119
+ ])
120
+ return transform
121
+
122
+ def _load_generator(self, style_type: str) -> nn.Module:
123
+ model = DualStyleGAN(1024, 512, 8, 2, res_index=6)
124
+ ckpt_path = huggingface_hub.hf_hub_download(
125
+ MODEL_REPO,
126
+ f'models/{style_type}/generator.pt',
127
+ use_auth_token=TOKEN)
128
+ ckpt = torch.load(ckpt_path, map_location='cpu')
129
+ model.load_state_dict(ckpt['g_ema'])
130
+ model.to(self.device)
131
+ model.eval()
132
+ return model
133
+
134
+ @staticmethod
135
+ def _load_exstylecode(style_type: str) -> dict[str, np.ndarray]:
136
+ if style_type in ['cartoon', 'caricature', 'anime']:
137
+ filename = 'refined_exstyle_code.npy'
138
+ else:
139
+ filename = 'exstyle_code.npy'
140
+ path = huggingface_hub.hf_hub_download(
141
+ MODEL_REPO,
142
+ f'models/{style_type}/{filename}',
143
+ use_auth_token=TOKEN)
144
+ exstyles = np.load(path, allow_pickle=True).item()
145
+ return exstyles
146
+
147
+ def detect_and_align_face(self, image) -> np.ndarray:
148
+ image = align_face(filepath=image.name, predictor=self.face_detector)
149
+ return image
150
+
151
+ @staticmethod
152
+ def denormalize(tensor: torch.Tensor) -> torch.Tensor:
153
+ return torch.clamp((tensor + 1) / 2 * 255, 0, 255).to(torch.uint8)
154
+
155
+ def postprocess(self, tensor: torch.Tensor) -> np.ndarray:
156
+ tensor = self.denormalize(tensor)
157
+ return tensor.cpu().numpy().transpose(1, 2, 0)
158
+
159
+ @torch.inference_mode()
160
+ def reconstruct_face(self,
161
+ image: np.ndarray) -> tuple[np.ndarray, torch.Tensor]:
162
+ image = PIL.Image.fromarray(image)
163
+ input_data = self.transform(image).unsqueeze(0).to(self.device)
164
+ img_rec, instyle = self.encoder(input_data,
165
+ randomize_noise=False,
166
+ return_latents=True,
167
+ z_plus_latent=True,
168
+ return_z_plus_latent=True,
169
+ resize=False)
170
+ img_rec = torch.clamp(img_rec.detach(), -1, 1)
171
+ img_rec = self.postprocess(img_rec[0])
172
+ return img_rec, instyle
173
+
174
+ @torch.inference_mode()
175
+ def generate(self, style_type: str, style_id: int, structure_weight: float,
176
+ color_weight: float, structure_only: bool,
177
+ instyle: torch.Tensor) -> np.ndarray:
178
+ generator = self.generator_dict[style_type]
179
+ exstyles = self.exstyle_dict[style_type]
180
+
181
+ style_id = int(style_id)
182
+ stylename = list(exstyles.keys())[style_id]
183
+
184
+ latent = torch.tensor(exstyles[stylename]).to(self.device)
185
+ if structure_only:
186
+ latent[0, 7:18] = instyle[0, 7:18]
187
+ exstyle = generator.generator.style(
188
+ latent.reshape(latent.shape[0] * latent.shape[1],
189
+ latent.shape[2])).reshape(latent.shape)
190
+
191
+ img_gen, _ = generator([instyle],
192
+ exstyle,
193
+ z_plus_latent=True,
194
+ truncation=0.7,
195
+ truncation_latent=0,
196
+ use_res=True,
197
+ interp_weights=[structure_weight] * 7 +
198
+ [color_weight] * 11)
199
+ img_gen = torch.clamp(img_gen.detach(), -1, 1)
200
+ img_gen = self.postprocess(img_gen[0])
201
+ return img_gen
202
+
203
+
204
+ def update_slider(choice: str):
205
+ max_vals = {
206
+ 'cartoon': 316,
207
+ 'caricature': 198,
208
+ 'anime': 173,
209
+ 'arcane': 99,
210
+ 'comic': 100,
211
+ 'pixar': 121,
212
+ 'slamdunk': 119,
213
+ }
214
+ return gr.Slider.update(maximum=max_vals[choice] + 1, value=26)
215
+
216
+
217
+ def update_style_image(choice: str):
218
+ style_image_path = STYLE_IMAGE_PATHS[choice]
219
+ text = f'<center><img src="{style_image_path}" alt="style image" width="800" height="400"></center>'
220
+ return gr.Markdown.update(value=text)
221
+
222
+
223
+ def main():
224
+ args = parse_args()
225
+ app = App(device=torch.device(args.device))
226
+
227
+ with gr.Blocks(theme=args.theme) as demo:
228
+ gr.Markdown(
229
+ '''<center><h1>Portrait Style Transfer with DualStyleGAN</h1></center>
230
+
231
+ This is an unofficial demo app for https://github.com/williamyang1991/DualStyleGAN.
232
+
233
+ <center><img src="https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/overview.jpg" alt="overview" width="800" height="400"></center>
234
+
235
+ Related App: https://huggingface.co/spaces/hysts/DualStyleGAN
236
+ ''')
237
+
238
+ with gr.Box():
239
+ gr.Markdown('''## Step 1
240
+
241
+ - Drop an image containing a near-frontal face to the **Input Image**.
242
+ - If there are multiple faces in the image, hit the Edit button in the upper right corner and crop the input image beforehand.
243
+ - Hit the **Detect & Align** button.
244
+ - Hit the **Reconstruct Face** button.
245
+ - The final result will be based on this **Reconstructed Face**. So, if the reconstructed image is not satisfactory, you may want to change the input image.
246
+ ''')
247
+ with gr.Row():
248
+ with gr.Column():
249
+ with gr.Row():
250
+ input_image = gr.Image(label='Input Image',
251
+ type='file')
252
+ with gr.Row():
253
+ detect_button = gr.Button('Detect & Align Face')
254
+ with gr.Column():
255
+ with gr.Row():
256
+ face_image = gr.Image(label='Aligned Face',
257
+ type='numpy')
258
+ with gr.Row():
259
+ reconstruct_button = gr.Button('Reconstruct Face')
260
+ with gr.Column():
261
+ reconstructed_face = gr.Image(label='Reconstructed Face',
262
+ type='numpy')
263
+ instyle = gr.Variable()
264
+
265
+ with gr.Box():
266
+ gr.Markdown('''## Step 2
267
+
268
+ - Select **Style Type**.
269
+ - Select **Style Image Index** from the image table below.
270
+ ''')
271
+ with gr.Row():
272
+ with gr.Column():
273
+ with gr.Column():
274
+ style_type = gr.Radio(app.style_types,
275
+ label='Style Type')
276
+ with gr.Column():
277
+ style_index = gr.Slider(0,
278
+ 317,
279
+ value=26,
280
+ step=1,
281
+ label='Style Image Index',
282
+ interactive=True)
283
+ style_image_path = STYLE_IMAGE_PATHS['cartoon']
284
+ text = f'<center><img src="{style_image_path}" alt="style image" width="800" height="400"></center>'
285
+ style_image = gr.Markdown(value=text)
286
+
287
+ with gr.Box():
288
+ gr.Markdown('''## Step 3
289
+
290
+ - Adjust **Structure Weight** and **Color Weight**.
291
+ - These are weights for the style image, so the larger the value, the closer the resulting image will be to the style image.
292
+ - Hit the **Generate** button.
293
+ ''')
294
+ with gr.Row():
295
+ with gr.Column():
296
+ with gr.Row():
297
+ structure_weight = gr.Slider(0,
298
+ 1,
299
+ value=0.6,
300
+ step=0.1,
301
+ label='Structure Weight')
302
+ with gr.Row():
303
+ color_weight = gr.Slider(0,
304
+ 1,
305
+ value=1,
306
+ step=0.1,
307
+ label='Color Weight')
308
+ with gr.Row():
309
+ structure_only = gr.Checkbox(label='Structure Only')
310
+ with gr.Row():
311
+ generate_button = gr.Button('Generate')
312
+
313
+ with gr.Column():
314
+ output_image = gr.Image(label='Output Image')
315
+
316
+ gr.Markdown(
317
+ '<center><img src="https://visitor-badge.glitch.me/badge?page_id=gradio-blocks.dualstylegan" alt="visitor badge"/></center>'
318
+ )
319
+
320
+ detect_button.click(fn=app.detect_and_align_face,
321
+ inputs=input_image,
322
+ outputs=face_image)
323
+ reconstruct_button.click(fn=app.reconstruct_face,
324
+ inputs=face_image,
325
+ outputs=[reconstructed_face, instyle])
326
+ style_type.change(fn=update_slider,
327
+ inputs=style_type,
328
+ outputs=style_index)
329
+ style_type.change(fn=update_style_image,
330
+ inputs=style_type,
331
+ outputs=style_image)
332
+ generate_button.click(fn=app.generate,
333
+ inputs=[
334
+ style_type,
335
+ style_index,
336
+ structure_weight,
337
+ color_weight,
338
+ structure_only,
339
+ instyle,
340
+ ],
341
+ outputs=output_image)
342
+
343
+ demo.launch(
344
+ enable_queue=args.enable_queue,
345
+ server_port=args.port,
346
+ share=args.share,
347
+ )
348
+
349
+
350
+ if __name__ == '__main__':
351
+ main()
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ cmake
2
+ ninja-build
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ dlib==19.23.0
2
+ numpy==1.22.3
3
+ opencv-python-headless==4.5.5.62
4
+ Pillow==9.0.1
5
+ scipy==1.8.0
6
+ torch==1.11.0
7
+ torchvision==0.12.0