hysts commited on
Commit
7a39e7e
1 Parent(s): 2243d0c
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
1
+ *.jpg filter=lfs diff=lfs merge=lfs -text
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "stylegan3"]
2
+ path = stylegan3
3
+ url = https://github.com/NVlabs/stylegan3
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import pickle
9
+ import sys
10
+
11
+ sys.path.insert(0, 'stylegan3')
12
+
13
+ import gradio as gr
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ from huggingface_hub import hf_hub_download
18
+
19
+ ORIGINAL_REPO_URL = 'https://github.com/self-distilled-stylegan/self-distilled-internet-photos'
20
+ TITLE = 'Self-Distilled StyleGAN'
21
+ DESCRIPTION = f'This is a demo for models provided in {ORIGINAL_REPO_URL}.'
22
+ SAMPLE_IMAGE_DIR = 'https://huggingface.co/spaces/hysts/Self-Distilled-StyleGAN/resolve/main/samples'
23
+ ARTICLE = f'''## Generated images
24
+ - truncation: 0.7
25
+ ### Dogs
26
+ - size: 1024x1024
27
+ - seed: 0-99
28
+ ![Dogs]({SAMPLE_IMAGE_DIR}/dogs.jpg)
29
+ ### Elephants
30
+ - size: 512x512
31
+ - seed: 0-99
32
+ ![Elephants]({SAMPLE_IMAGE_DIR}/elephants.jpg)
33
+ ### Horses
34
+ - size: 256x256
35
+ - seed: 0-99
36
+ ![Horses]({SAMPLE_IMAGE_DIR}/horses.jpg)
37
+ ### Bicycles
38
+ - size: 256x256
39
+ - seed: 0-99
40
+ ![Bicycles]({SAMPLE_IMAGE_DIR}/bicycles.jpg)
41
+ ### Lions
42
+ - size: 512x512
43
+ - seed: 0-99
44
+ ![Lions]({SAMPLE_IMAGE_DIR}/lions.jpg)
45
+ ### Giraffes
46
+ - size: 512x512
47
+ - seed: 0-99
48
+ ![Giraffes]({SAMPLE_IMAGE_DIR}/giraffes.jpg)
49
+ ### Parrots
50
+ - size: 512x512
51
+ - seed: 0-99
52
+ ![Parrots]({SAMPLE_IMAGE_DIR}/parrots.jpg)
53
+ '''
54
+
55
+ TOKEN = os.environ['TOKEN']
56
+
57
+
58
+ def parse_args() -> argparse.Namespace:
59
+ parser = argparse.ArgumentParser()
60
+ parser.add_argument('--device', type=str, default='cpu')
61
+ parser.add_argument('--theme', type=str)
62
+ parser.add_argument('--live', action='store_true')
63
+ parser.add_argument('--share', action='store_true')
64
+ parser.add_argument('--port', type=int)
65
+ parser.add_argument('--disable-queue',
66
+ dest='enable_queue',
67
+ action='store_false')
68
+ parser.add_argument('--allow-flagging', type=str, default='never')
69
+ parser.add_argument('--allow-screenshot', action='store_true')
70
+ return parser.parse_args()
71
+
72
+
73
+ def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
74
+ return torch.from_numpy(np.random.RandomState(seed).randn(
75
+ 1, z_dim)).to(device).float()
76
+
77
+
78
+ @torch.inference_mode()
79
+ def generate_image(model_name: str, seed: int, truncation_psi: float,
80
+ model_dict: dict[str, nn.Module],
81
+ device: torch.device) -> np.ndarray:
82
+ model = model_dict[model_name]
83
+ seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
84
+
85
+ z = generate_z(model.z_dim, seed, device)
86
+ label = torch.zeros([1, model.c_dim], device=device)
87
+
88
+ out = model(z, label, truncation_psi=truncation_psi)
89
+ out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
90
+ return out[0].cpu().numpy()
91
+
92
+
93
+ def load_model(model_name: str, device: torch.device) -> nn.Module:
94
+ path = hf_hub_download('hysts/Self-Distilled-StyleGAN',
95
+ f'models/{model_name}_pytorch.pkl',
96
+ use_auth_token=TOKEN)
97
+ with open(path, 'rb') as f:
98
+ model = pickle.load(f)['G_ema']
99
+ model.eval()
100
+ model.to(device)
101
+ with torch.inference_mode():
102
+ z = torch.zeros((1, model.z_dim)).to(device)
103
+ label = torch.zeros([1, model.c_dim], device=device)
104
+ model(z, label)
105
+ return model
106
+
107
+
108
+ def main():
109
+ gr.close_all()
110
+
111
+ args = parse_args()
112
+ device = torch.device(args.device)
113
+
114
+ model_names = [
115
+ 'dogs_1024',
116
+ 'elephants_512',
117
+ 'horses_256',
118
+ 'bicycles_256',
119
+ 'lions_512',
120
+ 'giraffes_512',
121
+ 'parrots_512',
122
+ ]
123
+
124
+ model_dict = {name: load_model(name, device) for name in model_names}
125
+
126
+ func = functools.partial(generate_image,
127
+ model_dict=model_dict,
128
+ device=device)
129
+ func = functools.update_wrapper(func, generate_image)
130
+
131
+ gr.Interface(
132
+ func,
133
+ [
134
+ gr.inputs.Radio(
135
+ model_names, type='value', default='dogs_1024', label='Model'),
136
+ gr.inputs.Number(default=0, label='Seed'),
137
+ gr.inputs.Slider(
138
+ 0, 2, step=0.05, default=0.7, label='Truncation psi'),
139
+ ],
140
+ gr.outputs.Image(type='numpy', label='Output'),
141
+ title=TITLE,
142
+ description=DESCRIPTION,
143
+ article=ARTICLE,
144
+ theme=args.theme,
145
+ allow_screenshot=args.allow_screenshot,
146
+ allow_flagging=args.allow_flagging,
147
+ live=args.live,
148
+ ).launch(
149
+ enable_queue=args.enable_queue,
150
+ server_port=args.port,
151
+ share=args.share,
152
+ )
153
+
154
+
155
+ if __name__ == '__main__':
156
+ main()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy==1.22.3
2
+ Pillow==9.0.1
3
+ scipy==1.8.0
4
+ torch==1.11.0
5
+ torchvision==0.12.0
samples/bicycles.jpg ADDED

Git LFS Details

  • SHA256: 781759a7093044308a92f52f8899b46e6bfb890af75bb0e26a7fc49ef06f8eca
  • Pointer size: 132 Bytes
  • Size of remote file: 3.06 MB
samples/dogs.jpg ADDED

Git LFS Details

  • SHA256: 59c447950969db71a7027390d832545b28ec5788d76e16236c971b9b3e31406b
  • Pointer size: 133 Bytes
  • Size of remote file: 32.5 MB
samples/elephants.jpg ADDED

Git LFS Details

  • SHA256: 106b68d11a1c9d3d9b9f51ef80674bf351d1fd78291e8690d2ab4ed259986493
  • Pointer size: 133 Bytes
  • Size of remote file: 12.1 MB
samples/giraffes.jpg ADDED

Git LFS Details

  • SHA256: 3d91a32b61056874a698af5749e3d002e20bff608055b6104413081d845bedb4
  • Pointer size: 133 Bytes
  • Size of remote file: 10.6 MB
samples/horses.jpg ADDED

Git LFS Details

  • SHA256: 3bc6d05771a64f9852fe9f8f2da2c61a274da34295e61c1df9c1aef07a8bfc6a
  • Pointer size: 132 Bytes
  • Size of remote file: 3.15 MB
samples/lions.jpg ADDED

Git LFS Details

  • SHA256: 4216e153da49fbff81ef41484f48f1c68c6c1d455cba0a1eed8458aa64dacccc
  • Pointer size: 133 Bytes
  • Size of remote file: 11.3 MB
samples/parrots.jpg ADDED

Git LFS Details

  • SHA256: c3c7fb15868b09eb0b3844778d2e99ccd648725e50c2c4d3f62a6d4b1cb1367d
  • Pointer size: 132 Bytes
  • Size of remote file: 6.72 MB
stylegan3 ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit a5a69f58294509598714d1e88c9646c3d7c6ec94