hysts HF staff commited on
Commit
cb57f88
β€’
1 Parent(s): 7eb697d
Files changed (2) hide show
  1. README.md +1 -0
  2. app.py +41 -16
README.md CHANGED
@@ -4,6 +4,7 @@ emoji: πŸ“š
4
  colorFrom: pink
5
  colorTo: indigo
6
  sdk: gradio
 
7
  app_file: app.py
8
  pinned: false
9
  ---
 
4
  colorFrom: pink
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 3.0.2
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -2,26 +2,42 @@
2
 
3
  from __future__ import annotations
4
 
 
5
  import functools
6
  import os
7
  import pickle
8
  import sys
9
 
10
- sys.path.insert(0, 'stylegan3')
11
-
12
  import gradio as gr
13
  import numpy as np
14
- import PIL.Image
15
  import torch
 
16
  from huggingface_hub import hf_hub_download
17
 
 
 
 
 
 
 
18
  MODEL_REPO = 'hysts/stylegan3-anime-face-exp002-model'
19
  MODEL_FILE_NAME = '009000.pkl'
 
20
  TOKEN = os.environ['TOKEN']
21
 
22
- DEFAULT_SEED = 3407851645
23
 
24
- TITLE = 'StyleGAN3 Anime Face Generation'
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
  def make_transform(translate: tuple[float, float], angle: float) -> np.ndarray:
@@ -37,13 +53,15 @@ def make_transform(translate: tuple[float, float], angle: float) -> np.ndarray:
37
  return mat
38
 
39
 
40
- def generate_z(seed, device):
41
  return torch.from_numpy(np.random.RandomState(seed).randn(1,
42
  512)).to(device)
43
 
44
 
45
  @torch.inference_mode()
46
- def generate_image(seed, truncation_psi, tx, ty, angle, model, device):
 
 
47
  seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
48
  z = generate_z(seed, device)
49
  c = torch.zeros(0).to(device)
@@ -54,10 +72,10 @@ def generate_image(seed, truncation_psi, tx, ty, angle, model, device):
54
 
55
  out = model(z, c, truncation_psi=truncation_psi)
56
  out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
57
- return PIL.Image.fromarray(out[0].cpu().numpy(), 'RGB')
58
 
59
 
60
- def load_model(device):
61
  path = hf_hub_download(MODEL_REPO, MODEL_FILE_NAME, use_auth_token=TOKEN)
62
  with open(path, 'rb') as f:
63
  model = pickle.load(f)
@@ -71,7 +89,8 @@ def load_model(device):
71
 
72
 
73
  def main():
74
- device = torch.device('cpu')
 
75
 
76
  model = load_model(device)
77
  func = functools.partial(generate_image, model=model, device=device)
@@ -80,19 +99,25 @@ def main():
80
  gr.Interface(
81
  func,
82
  [
83
- gr.inputs.Number(default=DEFAULT_SEED, label='Seed'),
84
  gr.inputs.Slider(
85
  0, 2, step=0.05, default=0.7, label='Truncation psi'),
86
  gr.inputs.Slider(-1, 1, step=0.05, default=0, label='Translate X'),
87
  gr.inputs.Slider(-1, 1, step=0.05, default=0, label='Translate Y'),
88
  gr.inputs.Slider(-180, 180, step=5, default=0, label='Angle'),
89
  ],
90
- gr.outputs.Image(type='pil', label='Output'),
91
  title=TITLE,
92
- enable_queue=True,
93
- allow_screenshot=False,
94
- allow_flagging=False,
95
- ).launch()
 
 
 
 
 
 
96
 
97
 
98
  if __name__ == '__main__':
 
2
 
3
  from __future__ import annotations
4
 
5
+ import argparse
6
  import functools
7
  import os
8
  import pickle
9
  import sys
10
 
 
 
11
  import gradio as gr
12
  import numpy as np
 
13
  import torch
14
+ import torch.nn as nn
15
  from huggingface_hub import hf_hub_download
16
 
17
+ sys.path.insert(0, 'stylegan3')
18
+
19
+ TITLE = 'StyleGAN3 Anime Face Generation'
20
+ DESCRIPTION = 'Expected execution time on Hugging Face Spaces: 20s'
21
+ ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.stylegan3-anime-face-generation-exp002" alt="visitor badge"/></center>'
22
+
23
  MODEL_REPO = 'hysts/stylegan3-anime-face-exp002-model'
24
  MODEL_FILE_NAME = '009000.pkl'
25
+
26
  TOKEN = os.environ['TOKEN']
27
 
 
28
 
29
+ def parse_args() -> argparse.Namespace:
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument('--device', type=str, default='cpu')
32
+ parser.add_argument('--theme', type=str)
33
+ parser.add_argument('--live', action='store_true')
34
+ parser.add_argument('--share', action='store_true')
35
+ parser.add_argument('--port', type=int)
36
+ parser.add_argument('--disable-queue',
37
+ dest='enable_queue',
38
+ action='store_false')
39
+ parser.add_argument('--allow-flagging', type=str, default='never')
40
+ return parser.parse_args()
41
 
42
 
43
  def make_transform(translate: tuple[float, float], angle: float) -> np.ndarray:
 
53
  return mat
54
 
55
 
56
+ def generate_z(seed: int, device: torch.device) -> torch.Tensor:
57
  return torch.from_numpy(np.random.RandomState(seed).randn(1,
58
  512)).to(device)
59
 
60
 
61
  @torch.inference_mode()
62
+ def generate_image(seed: int, truncation_psi: float, tx: float, ty: float,
63
+ angle: float, model: nn.Module,
64
+ device: torch.device) -> np.ndarray:
65
  seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
66
  z = generate_z(seed, device)
67
  c = torch.zeros(0).to(device)
 
72
 
73
  out = model(z, c, truncation_psi=truncation_psi)
74
  out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
75
+ return out[0].cpu().numpy()
76
 
77
 
78
+ def load_model(device: torch.device) -> nn.Module:
79
  path = hf_hub_download(MODEL_REPO, MODEL_FILE_NAME, use_auth_token=TOKEN)
80
  with open(path, 'rb') as f:
81
  model = pickle.load(f)
 
89
 
90
 
91
  def main():
92
+ args = parse_args()
93
+ device = torch.device(args.device)
94
 
95
  model = load_model(device)
96
  func = functools.partial(generate_image, model=model, device=device)
 
99
  gr.Interface(
100
  func,
101
  [
102
+ gr.inputs.Number(default=3407851645, label='Seed'),
103
  gr.inputs.Slider(
104
  0, 2, step=0.05, default=0.7, label='Truncation psi'),
105
  gr.inputs.Slider(-1, 1, step=0.05, default=0, label='Translate X'),
106
  gr.inputs.Slider(-1, 1, step=0.05, default=0, label='Translate Y'),
107
  gr.inputs.Slider(-180, 180, step=5, default=0, label='Angle'),
108
  ],
109
+ gr.outputs.Image(type='numpy', label='Output'),
110
  title=TITLE,
111
+ description=DESCRIPTION,
112
+ article=ARTICLE,
113
+ theme=args.theme,
114
+ allow_flagging='never',
115
+ live=args.live,
116
+ ).launch(
117
+ enable_queue=args.enable_queue,
118
+ server_port=args.port,
119
+ share=args.share,
120
+ )
121
 
122
 
123
  if __name__ == '__main__':