CorvaeOboro commited on
Commit
589ceac
β€’
1 Parent(s): 00986c2

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -94
app.py CHANGED
@@ -1,100 +1,75 @@
1
- #!/usr/bin/env python
2
-
3
- from __future__ import annotations
4
-
5
- import argparse
6
- import functools
7
- import os
8
- import pickle
9
- import sys
10
-
11
  import gradio as gr
 
12
  import numpy as np
13
  import torch
14
- import torch.nn as nn
15
- from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- sys.path.insert(0, 'gen_ability_icon')
 
 
 
18
 
19
- TITLE = 'gen_ability_icon'
20
- DESCRIPTION = '''creates circular magic ability icons from stylegan2ada model trained on synthetic dataset .
 
 
 
21
  more information here : https://github.com/CorvaeOboro/gen_ability_icon.
22
- '''
23
-
24
-
25
- def parse_args() -> argparse.Namespace:
26
- parser = argparse.ArgumentParser()
27
- parser.add_argument('--device', type=str, default='cpu')
28
- parser.add_argument('--theme', type=str)
29
- parser.add_argument('--live', action='store_true')
30
- parser.add_argument('--share', action='store_true')
31
- parser.add_argument('--port', type=int)
32
- parser.add_argument('--disable-queue',
33
- dest='enable_queue',
34
- action='store_false')
35
- parser.add_argument('--allow-flagging', type=str, default='never')
36
- return parser.parse_args()
37
-
38
-
39
- def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
40
- return torch.from_numpy(np.random.RandomState(seed).randn(
41
- 1, z_dim)).to(device).float()
42
-
43
-
44
- @torch.inference_mode()
45
- def generate_image(seed: int, truncation_psi: float, model: nn.Module,
46
- device: torch.device) -> np.ndarray:
47
- seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
48
-
49
- z = generate_z(model.z_dim, seed, device)
50
- label = torch.zeros([1, model.c_dim], device=device)
51
-
52
- out = model(z, label, truncation_psi=truncation_psi, force_fp32=True)
53
- out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
54
- return out[0].cpu().numpy()
55
-
56
-
57
- def load_model(file_name: str, device: torch.device) -> nn.Module:
58
- path = hf_hub_download(f'CorvaeOboro/gen_ability_icon' , f'{file_name}')
59
- with open(path, 'rb') as f:
60
- model = pickle.load(f)['G_ema']
61
- model.eval()
62
- model.to(device)
63
- with torch.inference_mode():
64
- z = torch.zeros((1, model.z_dim)).to(device)
65
- label = torch.zeros([1, model.c_dim], device=device)
66
- model(z, label, force_fp32=True)
67
- return model
68
-
69
-
70
- def main():
71
- args = parse_args()
72
- device = torch.device(args.device)
73
-
74
- model = load_model('gen_ability_icon_stylegan2ada_20220801.pkl', device)
75
-
76
- func = functools.partial(generate_image, model=model, device=device)
77
- func = functools.update_wrapper(func, generate_image)
78
-
79
- gr.Interface(
80
- func,
81
- [
82
- gr.inputs.Number(default=0, label='Seed'),
83
- gr.inputs.Slider(
84
- 0, 2, step=0.05, default=0.7, label='Truncation psi'),
85
- ],
86
- gr.outputs.Image(type='numpy', label='Output'),
87
- title=TITLE,
88
- description=DESCRIPTION,
89
- theme=args.theme,
90
- allow_flagging=args.allow_flagging,
91
- live=args.live,
92
- ).launch(
93
- enable_queue=args.enable_queue,
94
- server_port=args.port,
95
- share=args.share,
96
- )
97
-
98
-
99
- if __name__ == '__main__':
100
- main()
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
  import numpy as np
4
  import torch
5
+ import pickle
6
+ import types
7
+
8
+ from huggingface_hub import hf_hub_url, cached_download
9
+
10
+ TOKEN = os.environ['TOKEN']
11
+
12
+ with open(cached_download(hf_hub_url('CorvaeOboro/gen_ability_icon', 'gen_ability_icon_stylegan2ada_20220801.pkl'), use_auth_token=TOKEN), 'rb') as f:
13
+ G = pickle.load(f)['G_ema']# torch.nn.Module
14
+
15
+ device = torch.device("cpu")
16
+ if torch.cuda.is_available():
17
+ device = torch.device("cuda")
18
+ G = G.to(device)
19
+ else:
20
+ _old_forward = G.forward
21
+
22
+ def _new_forward(self, *args, **kwargs):
23
+ kwargs["force_fp32"] = True
24
+ return _old_forward(*args, **kwargs)
25
+
26
+ G.forward = types.MethodType(_new_forward, G)
27
+
28
+ _old_synthesis_forward = G.synthesis.forward
29
+
30
+ def _new_synthesis_forward(self, *args, **kwargs):
31
+ kwargs["force_fp32"] = True
32
+ return _old_synthesis_forward(*args, **kwargs)
33
+
34
+ G.synthesis.forward = types.MethodType(_new_synthesis_forward, G.synthesis)
35
+
36
+
37
+ def generate(num_images, interpolate):
38
+ if interpolate:
39
+ z1 = torch.randn([1, G.z_dim])# latent codes
40
+ z2 = torch.randn([1, G.z_dim])# latent codes
41
+ zs = torch.cat([z1 + (z2 - z1) * i / (num_images-1) for i in range(num_images)], 0)
42
+ else:
43
+ zs = torch.randn([num_images, G.z_dim])# latent codes
44
+ with torch.no_grad():
45
+ zs = zs.to(device)
46
+ img = G(zs, None, force_fp32=True, noise_mode='const')
47
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
48
+ return img.cpu().numpy()
49
+
50
+ demo = gr.Blocks()
51
 
52
+ def infer(num_images, interpolate):
53
+ img = generate(round(num_images), interpolate)
54
+ imgs = list(img)
55
+ return imgs
56
 
57
+ with demo:
58
+ gr.Markdown(
59
+ """
60
+ # gen_ability_icon
61
+ creates circular magic ability icons from stylegan2ada model trained on synthetic dataset .
62
  more information here : https://github.com/CorvaeOboro/gen_ability_icon.
63
+ """)
64
+ images_num = gr.inputs.Slider(default=1, label="Num Images", minimum=1, maximum=16, step=1)
65
+ interpolate = gr.inputs.Checkbox(default=False, label="Interpolate")
66
+ submit = gr.Button("Generate")
67
+
68
+
69
+ out = gr.Gallery()
70
+
71
+ submit.click(fn=infer,
72
+ inputs=[images_num, interpolate],
73
+ outputs=out)
74
+
75
+ demo.launch()