adirik commited on
Commit
1c472db
1 Parent(s): e90f2c5

update app

Browse files
Files changed (4) hide show
  1. app.py +90 -169
  2. find_direction.py +1 -0
  3. generator.py +2 -2
  4. psp_wrapper.py +8 -5
app.py CHANGED
@@ -4,11 +4,27 @@ import dnnlib
4
  import numpy as np
5
  import torch
6
 
7
- from find_direction import find_direction
 
 
 
 
 
 
 
 
 
8
 
9
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
10
- with dnnlib.util.open_url("./pretrained/ffhq.pkl") as f:
11
- G = legacy.load_network_pkl(f)['G_ema'].to(device)
 
 
 
 
 
 
 
12
 
13
 
14
  DESCRIPTION = '''# <a href="https://github.com/catlab-team/stylemc"> StyleMC:</a> Multi-Channel Based Fast Text-Guided Image Generation and Manipulation
@@ -16,171 +32,76 @@ DESCRIPTION = '''# <a href="https://github.com/catlab-team/stylemc"> StyleMC:</a
16
  FOOTER = 'This space is built by <a href = "https://github.com/catlab-team">Catlab Team</a>.'
17
 
18
 
19
- def main():
20
- with gr.Blocks(css='style.css') as demo:
21
- gr.Markdown(DESCRIPTION)
22
-
23
- with gr.Box():
24
- gr.Markdown('''## Step 1 (Finding a global manipulation direction)
25
- - Please enter the target **text prompt** and **identity loss weight** to find global manipulation direction:
26
- - Hit the **Find Direction** button.
27
- ''')
28
- with gr.Row():
29
- with gr.Column():
30
- with gr.Row():
31
- text = gr.Textbox(
32
- label="Enter your prompt",
33
- show_label=False,
34
- max_lines=1,
35
- placeholder="Enter your prompt",
36
- ).style(
37
- container=False,
38
- )
39
- identity_loss_weight = gr.Slider(0.1,
40
- 10,
41
- value=0.5,
42
- step=0.1,
43
- label='Identity Loss Weight',
44
- interactive=True)
45
- btn = gr.Button("Find Direction").style(full_width=False)
46
-
47
- with gr.Box():
48
- gr.Markdown('''## Step 2 (Manipulation)
49
- - Please upload an image for manipulation:
50
- - You can also select the **previous directions** and determine the **manipulation strength**.
51
- - Hit the **Generate** button.
52
- ''')
53
- with gr.Row():
54
- identity_loss_weight = gr.Slider(0.1,
55
- 100,
56
- value=50,
57
- step=0.1,
58
- label='Manipulation Strength',
59
- interactive=True)
60
- with gr.Row():
61
- with gr.Column():
62
- with gr.Row():
63
- input_image = gr.Image(label='Input Image',
64
- type='filepath')
65
- with gr.Row():
66
- generate_button = gr.Button('Generate')
67
- with gr.Column():
68
- with gr.Row():
69
- generated_image = gr.Image(label='Generated Image',
70
- type='numpy',
71
- interactive=False)
72
-
73
-
74
-
75
-
76
- # with gr.Box():
77
- # gr.Markdown('''## Step 2 (Select Style Image)
78
- # - Select **Style Type**.
79
- # - Select **Style Image Index** from the image table below.
80
- # ''')
81
- # with gr.Row():
82
- # with gr.Column():
83
- # style_type = gr.Radio(model.style_types,
84
- # label='Style Type')
85
- # text = get_style_image_markdown_text('cartoon')
86
- # style_image = gr.Markdown(value=text)
87
- # style_index = gr.Slider(0,
88
- # 316,
89
- # value=26,
90
- # step=1,
91
- # label='Style Image Index')
92
-
93
- # with gr.Row():
94
- # example_styles = gr.Dataset(
95
- # components=[style_type, style_index],
96
- # samples=[
97
- # ['cartoon', 26],
98
- # ['caricature', 65],
99
- # ['arcane', 63],
100
- # ['pixar', 80],
101
- # ])
102
-
103
- # with gr.Box():
104
- # gr.Markdown('''## Step 3 (Generate Style Transferred Image)
105
- # - Adjust **Structure Weight** and **Color Weight**.
106
- # - These are weights for the style image, so the larger the value, the closer the resulting image will be to the style image.
107
- # - Hit the **Generate** button.
108
- # ''')
109
- # with gr.Row():
110
- # with gr.Column():
111
- # with gr.Row():
112
- # structure_weight = gr.Slider(0,
113
- # 1,
114
- # value=0.6,
115
- # step=0.1,
116
- # label='Structure Weight')
117
- # with gr.Row():
118
- # color_weight = gr.Slider(0,
119
- # 1,
120
- # value=1,
121
- # step=0.1,
122
- # label='Color Weight')
123
- # with gr.Row():
124
- # structure_only = gr.Checkbox(label='Structure Only')
125
- # with gr.Row():
126
- # generate_button = gr.Button('Generate')
127
-
128
- # with gr.Column():
129
- # result = gr.Image(label='Result')
130
-
131
- # with gr.Row():
132
- # example_weights = gr.Dataset(
133
- # components=[structure_weight, color_weight],
134
- # samples=[
135
- # [0.6, 1.0],
136
- # [0.3, 1.0],
137
- # [0.0, 1.0],
138
- # [1.0, 0.0],
139
- # ])
140
-
141
- gr.Markdown(FOOTER)
142
-
143
- # preprocess_button.click(fn=model.detect_and_align_face,
144
- # inputs=input_image,
145
- # outputs=aligned_face)
146
- # aligned_face.change(fn=model.reconstruct_face,
147
- # inputs=aligned_face,
148
- # outputs=[
149
- # reconstructed_face,
150
- # instyle,
151
- # ])
152
- # style_type.change(fn=update_slider,
153
- # inputs=style_type,
154
- # outputs=style_index)
155
- # style_type.change(fn=update_style_image,
156
- # inputs=style_type,
157
- # outputs=style_image)
158
- # generate_button.click(fn=model.generate,
159
- # inputs=[
160
- # style_type,
161
- # style_index,
162
- # structure_weight,
163
- # color_weight,
164
- # structure_only,
165
- # instyle,
166
- # ],
167
- # outputs=result)
168
- # example_images.click(fn=set_example_image,
169
- # inputs=example_images,
170
- # outputs=example_images.components)
171
- # example_styles.click(fn=set_example_styles,
172
- # inputs=example_styles,
173
- # outputs=example_styles.components)
174
- # example_weights.click(fn=set_example_weights,
175
- # inputs=example_weights,
176
- # outputs=example_weights.components)
177
-
178
- demo.launch(
179
- # enable_queue=args.enable_queue,
180
- # server_port=args.port,
181
- # share=args.share,
182
- )
183
 
184
 
185
- if __name__ == '__main__':
186
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import numpy as np
5
  import torch
6
 
7
+ import find_direction
8
+ import generator
9
+ import psp_wrapper
10
+
11
+
12
+ psp_encoder_path = "./pretrained/e4e_ffhq_encode.pt"
13
+ landmarks_path = "./pretrained/shape_predictor_68_face_landmarks.dat"
14
+ e4e_embedder = psp_wrapper.psp_encoder(psp_encoder_path, landmarks_path)
15
+ G_ffhq_path = "./pretrained/ffhq.pkl"
16
+ G_metfaces_path = "./pretrained/metfaces.pkl"
17
 
18
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
+
20
+ with dnnlib.util.open_url(G_ffhq_path) as f:
21
+ G_ffhq = legacy.load_network_pkl(f)['G_ema'].to(device)
22
+
23
+ with dnnlib.util.open_url(G_metfaces_path) as f:
24
+ G_metfaces = legacy.load_network_pkl(f)['G_ema'].to(device)
25
+
26
+ G_dict = {"FFHQ": G_ffhq, "MetFaces": G_metfaces}
27
+
28
 
29
 
30
  DESCRIPTION = '''# <a href="https://github.com/catlab-team/stylemc"> StyleMC:</a> Multi-Channel Based Fast Text-Guided Image Generation and Manipulation
 
32
  FOOTER = 'This space is built by <a href = "https://github.com/catlab-team">Catlab Team</a>.'
33
 
34
 
35
+ def add_direction(prompt, stylegan_type, id_loss_w):
36
+ new_dir_name = prompt+" "+stylegan_type+" w_id_loss"+str(id_loss_w)
37
+ if (prompt != None) and (new_dir_name not in direction_list):
38
+ print("adding direction with id:", new_dir_name)
39
+ direction = find_direction.find_direction(G_dict[stylegan_type], prompt)
40
+ print(f"new direction calculated with {stylegan_type} and id loss weight = {id_loss_w}")
41
+ direction_list.append(new_dir_name)
42
+ direction_map[new_dir_name]={"direction":direction, "stylegan_type":stylegan_type}
43
+
44
+ return gr.Radio.update(choices=direction_list, value=None, visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
 
47
+ def generate_output_image(image_path, direction_id, change_power):
48
+ direction = direction_map[direction_id]["direction"]
49
+ G=G_dict["FFHQ"]
50
+
51
+ w = e4e_embedder.get_w(image_path) # numpy array
52
+ s = generator.w_to_s(GIn=G, wsIn=w)
53
+ output_image = generator.generate_from_style(
54
+ GIn=G,
55
+ styles=s,
56
+ styles_direction=direction,
57
+ change_power=change_power,
58
+ outdir='.'
59
+ )
60
+ return output_image
61
+
62
+ with gr.Blocks(css="style.css") as demo:
63
+ gr.Markdown(DESCRIPTION)
64
+
65
+ with gr.Box():
66
+ gr.Markdown('''## Step 1 (Finding a global manipulation direction) - Please enter the target **text prompt** and **identity loss weight** to find global manipulation direction: - Hit the **Find Direction** button.''')
67
+ with gr.Row():
68
+ with gr.Column():
69
+ style_gan_type = gr.Radio(["FFHQ", "MetFaces"], value = "FFHQ", label="StyleGAN Type", interactive=True)
70
+ with gr.Column():
71
+ identity_loss_weight = gr.Slider(
72
+ 0.1, 10, value=0.5, step=0.1,label="Identity Loss Weight",interactive=True
73
+ )
74
+ with gr.Row():
75
+ with gr.Column():
76
+ with gr.Row():
77
+ text = gr.Textbox(
78
+ label="Enter your prompt",
79
+ show_label=False,
80
+ max_lines=1,
81
+ placeholder="Enter your prompt"
82
+ ).style(container=False)
83
+
84
+ find_direction_btn = gr.Button("Find Direction").style(full_width=False)
85
+
86
+ with gr.Box():
87
+ gr.Markdown('''## Step 2 (Manipulation) - Please upload an image for manipulation: - You can also select the **previous directions** and determine the **manipulation strength**. - Hit the **Generate** button.''')
88
+ with gr.Row():
89
+ direction_radio = gr.Radio(direction_list, label="List of Directions")
90
+ with gr.Row():
91
+ manipulation_strength = gr.Slider(
92
+ 0.1, 100, value=25, step=0.1, label="Manipulation Strength",interactive=True
93
+ )
94
+ with gr.Row():
95
+ with gr.Column():
96
+ with gr.Row():
97
+ input_image = gr.Image(label="Input Image", type="filepath")
98
+ with gr.Row():
99
+ generate_btn = gr.Button("Generate")
100
+ with gr.Column():
101
+ with gr.Row():
102
+ generated_image = gr.Image(label="Generated Image",type="pil",interactive=False)
103
+
104
+ find_direction_btn.click(add_direction, inputs=[text, style_gan_type, identity_loss_weight], outputs=direction_radio)
105
+ generate_btn.click(generate_output_image, inputs=[input_image, direction_radio,manipulation_strength], outputs=generated_image)
106
+
107
+ demo.launch(debug=True)
find_direction.py CHANGED
@@ -22,6 +22,7 @@ from torch_utils.ops import upfirdn2d
22
  import id_loss
23
  from copy import deepcopy
24
 
 
25
  def block_forward(self, x, img, ws, shapes, force_fp32=False, fused_modconv=None, **layer_kwargs):
26
  misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
27
  w_iter = iter(ws.unbind(dim=1))
 
22
  import id_loss
23
  from copy import deepcopy
24
 
25
+
26
  def block_forward(self, x, img, ws, shapes, force_fp32=False, fused_modconv=None, **layer_kwargs):
27
  misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
28
  w_iter = iter(ws.unbind(dim=1))
generator.py CHANGED
@@ -31,7 +31,7 @@ from torch import linalg as LA
31
  import torch.nn.functional as F
32
 
33
 
34
- def block_forward(self, x, img, ws, shapes, force_fp32=False, fused_modconv=None, **layer_kwargs):
35
  misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
36
  w_iter = iter(ws.unbind(dim=1))
37
  dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
@@ -74,7 +74,7 @@ def block_forward(self, x, img, ws, shapes, force_fp32=False, fused_modconv=None
74
  return x, img
75
 
76
 
77
- def block_forward_from_style(self, x, img, ws, shapes, force_fp32=False, fused_modconv=None, **layer_kwargs):
78
  misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
79
  w_iter = iter(ws.unbind(dim=1))
80
  dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
 
31
  import torch.nn.functional as F
32
 
33
 
34
+ def block_forward(self, x, img, ws, shapes, force_fp32=True, fused_modconv=None, **layer_kwargs):
35
  misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
36
  w_iter = iter(ws.unbind(dim=1))
37
  dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
 
74
  return x, img
75
 
76
 
77
+ def block_forward_from_style(self, x, img, ws, shapes, force_fp32=True, fused_modconv=None, **layer_kwargs):
78
  misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
79
  w_iter = iter(ws.unbind(dim=1))
80
  dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
psp_wrapper.py CHANGED
@@ -31,15 +31,16 @@ EXPERIMENT_ARGS['transform'] = transforms.Compose([
31
 
32
  class psp_encoder:
33
  def __init__(self, model_path: str, shape_predictor_path: str):
34
- self.ckpt = torch.load(model_path, map_location='cpu')
35
- self.opts = self.ckpt['opts']
36
  # update the training options
37
- self.opts['checkpoint_path'] = model_path
38
  self.opts= Namespace(**self.opts)
39
  self.net = pSp(self.opts)
40
  self.net.eval()
41
  self.net.cuda()
42
  self.shape_predictor = dlib.shape_predictor(shape_predictor_path)
 
43
 
44
  def get_w(self, image_path):
45
  original_image = Image.open(image_path)
@@ -47,8 +48,10 @@ class psp_encoder:
47
  input_image = align_face(filepath=image_path, predictor=self.shape_predictor)
48
  resize_dims = (256, 256)
49
  input_image.resize(resize_dims)
50
- img_transforms = EXPERIMENT_ARGS['transform']
51
  transformed_image = img_transforms(input_image)
 
 
52
  with torch.no_grad():
53
- _, latents = self.net(transformed_image.unsqueeze(0).to("cuda").float(), randomize_noise=False, return_latents=True)
54
  return latents.cpu().numpy()
 
31
 
32
  class psp_encoder:
33
  def __init__(self, model_path: str, shape_predictor_path: str):
34
+ self.ckpt = torch.load(model_path, map_location="cpu")
35
+ self.opts = self.ckpt["opts"]
36
  # update the training options
37
+ self.opts["checkpoint_path"] = model_path
38
  self.opts= Namespace(**self.opts)
39
  self.net = pSp(self.opts)
40
  self.net.eval()
41
  self.net.cuda()
42
  self.shape_predictor = dlib.shape_predictor(shape_predictor_path)
43
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
 
45
  def get_w(self, image_path):
46
  original_image = Image.open(image_path)
 
48
  input_image = align_face(filepath=image_path, predictor=self.shape_predictor)
49
  resize_dims = (256, 256)
50
  input_image.resize(resize_dims)
51
+ img_transforms = EXPERIMENT_ARGS["transform"]
52
  transformed_image = img_transforms(input_image)
53
+
54
+
55
  with torch.no_grad():
56
+ _, latents = self.net(transformed_image.unsqueeze(0).to(device).float(), randomize_noise=False, return_latents=True)
57
  return latents.cpu().numpy()