zzzweakman commited on
Commit
0bc2c6f
β€’
1 Parent(s): 5379bd5

fix: retargeting feature leakage

Browse files
app.py CHANGED
@@ -72,7 +72,7 @@ data_examples = [
72
  # Define components first
73
  eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
74
  lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
75
- retargeting_input_image = gr.Image(type="numpy")
76
  output_image = gr.Image(type="numpy")
77
  output_image_paste_back = gr.Image(type="numpy")
78
  output_video = gr.Video()
@@ -144,11 +144,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
144
  examples_per_page=5,
145
  cache_examples=False,
146
  )
147
- gr.Markdown(load_description("assets/gradio_description_retargeting.md"), visible=False)
148
- with gr.Row(visible=False):
149
  eye_retargeting_slider.render()
150
  lip_retargeting_slider.render()
151
- with gr.Row(visible=False):
152
  process_button_retargeting = gr.Button("πŸš— Retargeting", variant="primary")
153
  process_button_reset_retargeting = gr.ClearButton(
154
  [
@@ -160,10 +160,21 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
160
  ],
161
  value="🧹 Clear"
162
  )
163
- with gr.Row(visible=False):
164
  with gr.Column():
165
  with gr.Accordion(open=True, label="Retargeting Input"):
166
  retargeting_input_image.render()
 
 
 
 
 
 
 
 
 
 
 
167
  with gr.Column():
168
  with gr.Accordion(open=True, label="Retargeting Result"):
169
  output_image.render()
@@ -174,7 +185,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
174
  process_button_retargeting.click(
175
  # fn=gradio_pipeline.execute_image,
176
  fn=gpu_wrapped_execute_image,
177
- inputs=[eye_retargeting_slider, lip_retargeting_slider],
178
  outputs=[output_image, output_image_paste_back],
179
  show_progress=True
180
  )
@@ -190,11 +201,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
190
  outputs=[output_video, output_video_concat],
191
  show_progress=True
192
  )
193
- image_input.change(
194
- fn=gradio_pipeline.prepare_retargeting,
195
- inputs=image_input,
196
- outputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image]
197
- )
198
  video_input.upload(
199
  fn=is_square_video,
200
  inputs=video_input,
 
72
  # Define components first
73
  eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
74
  lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
75
+ retargeting_input_image = gr.Image(type="filepath")
76
  output_image = gr.Image(type="numpy")
77
  output_image_paste_back = gr.Image(type="numpy")
78
  output_video = gr.Video()
 
144
  examples_per_page=5,
145
  cache_examples=False,
146
  )
147
+ gr.Markdown(load_description("assets/gradio_description_retargeting.md"), visible=True)
148
+ with gr.Row(visible=True):
149
  eye_retargeting_slider.render()
150
  lip_retargeting_slider.render()
151
+ with gr.Row(visible=True):
152
  process_button_retargeting = gr.Button("πŸš— Retargeting", variant="primary")
153
  process_button_reset_retargeting = gr.ClearButton(
154
  [
 
160
  ],
161
  value="🧹 Clear"
162
  )
163
+ with gr.Row(visible=True):
164
  with gr.Column():
165
  with gr.Accordion(open=True, label="Retargeting Input"):
166
  retargeting_input_image.render()
167
+ gr.Examples(
168
+ examples=[
169
+ [osp.join(example_portrait_dir, "s9.jpg")],
170
+ [osp.join(example_portrait_dir, "s6.jpg")],
171
+ [osp.join(example_portrait_dir, "s10.jpg")],
172
+ [osp.join(example_portrait_dir, "s5.jpg")],
173
+ [osp.join(example_portrait_dir, "s7.jpg")],
174
+ ],
175
+ inputs=[retargeting_input_image],
176
+ cache_examples=False,
177
+ )
178
  with gr.Column():
179
  with gr.Accordion(open=True, label="Retargeting Result"):
180
  output_image.render()
 
185
  process_button_retargeting.click(
186
  # fn=gradio_pipeline.execute_image,
187
  fn=gpu_wrapped_execute_image,
188
+ inputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image, flag_do_crop_input],
189
  outputs=[output_image, output_image_paste_back],
190
  show_progress=True
191
  )
 
201
  outputs=[output_video, output_video_concat],
202
  show_progress=True
203
  )
204
+ # image_input.change(
205
+ # fn=gradio_pipeline.prepare_retargeting,
206
+ # inputs=image_input,
207
+ # outputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image]
208
+ # )
209
  video_input.upload(
210
  fn=is_square_video,
211
  inputs=video_input,
assets/gradio_description_retargeting.md CHANGED
@@ -1 +1 @@
1
- <span style="font-size: 1.2em;">πŸ”₯ To change the target eyes-open and lip-open ratio of the source portrait, please drag the sliders and then click the <strong>πŸš— Retargeting</strong> button. The result would be shown in the middle block. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>
 
1
+ <span style="font-size: 1.2em;">πŸ”₯ To change the eyes and lip open ratio of the source portrait, please drag the sliders and then click the <strong>πŸš— Retargeting</strong> button. The result would be shown in the blocks. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>
src/gradio_pipeline.py CHANGED
@@ -26,16 +26,6 @@ class GradioPipeline(LivePortraitPipeline):
26
  super().__init__(inference_cfg, crop_cfg)
27
  # self.live_portrait_wrapper = self.live_portrait_wrapper
28
  self.args = args
29
- # for single image retargeting
30
- self.start_prepare = False
31
- self.f_s_user = None
32
- self.x_c_s_info_user = None
33
- self.x_s_user = None
34
- self.source_lmk_user = None
35
- self.mask_ori = None
36
- self.img_rgb = None
37
- self.crop_M_c2o = None
38
-
39
 
40
  def execute_video(
41
  self,
@@ -66,30 +56,23 @@ class GradioPipeline(LivePortraitPipeline):
66
  else:
67
  raise gr.Error("The input source portrait or driving video hasn't been prepared yet πŸ’₯!", duration=5)
68
 
69
- def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
70
  """ for single image retargeting
71
  """
 
 
 
 
72
  if input_eye_ratio is None or input_eye_ratio is None:
73
  raise gr.Error("Invalid ratio input πŸ’₯!", duration=5)
74
- elif self.f_s_user is None:
75
- if self.start_prepare:
76
- raise gr.Error(
77
- "The source portrait is under processing πŸ’₯! Please wait for a second.",
78
- duration=5
79
- )
80
- else:
81
- raise gr.Error(
82
- "The source portrait hasn't been prepared yet πŸ’₯! Please scroll to the top of the page to upload.",
83
- duration=5
84
- )
85
  else:
86
- x_s_user = self.x_s_user.to("cuda")
87
- f_s_user = self.f_s_user.to("cuda")
88
  # βˆ†_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
89
- combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user)
90
  eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor)
91
  # βˆ†_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
92
- combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], self.source_lmk_user)
93
  lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
94
  num_kp = x_s_user.shape[1]
95
  # default: use x_s
@@ -97,21 +80,20 @@ class GradioPipeline(LivePortraitPipeline):
97
  # D(W(f_s; x_s, xβ€²_d))
98
  out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
99
  out = self.live_portrait_wrapper.parse_output(out['out'])[0]
100
- out_to_ori_blend = paste_back(out, self.crop_M_c2o, self.img_rgb, self.mask_ori)
101
  # gr.Info("Run successfully!", duration=2)
102
  return out, out_to_ori_blend
103
 
104
 
105
- def prepare_retargeting(self, input_image_path, flag_do_crop = True):
106
  """ for single image retargeting
107
  """
108
- if input_image_path is not None:
109
  # gr.Info("Upload successfully!", duration=2)
110
- self.start_prepare = True
111
  inference_cfg = self.live_portrait_wrapper.cfg
112
  ######## process source portrait ########
113
- img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
114
- log(f"Load source image from {input_image_path}.")
115
  crop_info = self.cropper.crop_single_image(img_rgb)
116
  if flag_do_crop:
117
  I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
@@ -120,23 +102,13 @@ class GradioPipeline(LivePortraitPipeline):
120
  x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
121
  R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
122
  ############################################
123
-
124
- # record global info for next time use
125
- self.f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
126
- self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
127
- self.x_s_info_user = x_s_info
128
- self.source_lmk_user = crop_info['lmk_crop']
129
- self.img_rgb = img_rgb
130
- self.crop_M_c2o = crop_info['M_c2o']
131
- self.mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
132
- # update slider
133
- eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None])
134
- eye_close_ratio = float(eye_close_ratio.squeeze(0).mean())
135
- lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None])
136
- lip_close_ratio = float(lip_close_ratio.squeeze(0).mean())
137
- # for vis
138
- self.I_s_vis = self.live_portrait_wrapper.parse_output(I_s)[0]
139
- return eye_close_ratio, lip_close_ratio, self.I_s_vis
140
  else:
141
  # when press the clear button, go here
142
- return 0.8, 0.8, self.I_s_vis
 
 
26
  super().__init__(inference_cfg, crop_cfg)
27
  # self.live_portrait_wrapper = self.live_portrait_wrapper
28
  self.args = args
 
 
 
 
 
 
 
 
 
 
29
 
30
  def execute_video(
31
  self,
 
56
  else:
57
  raise gr.Error("The input source portrait or driving video hasn't been prepared yet πŸ’₯!", duration=5)
58
 
59
+ def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_image, flag_do_crop = True):
60
  """ for single image retargeting
61
  """
62
+ # disposable feature
63
+ f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
64
+ self.prepare_retargeting(input_image, flag_do_crop)
65
+
66
  if input_eye_ratio is None or input_eye_ratio is None:
67
  raise gr.Error("Invalid ratio input πŸ’₯!", duration=5)
 
 
 
 
 
 
 
 
 
 
 
68
  else:
69
+ x_s_user = x_s_user.to("cuda")
70
+ f_s_user = f_s_user.to("cuda")
71
  # βˆ†_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
72
+ combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], source_lmk_user)
73
  eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor)
74
  # βˆ†_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
75
+ combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user)
76
  lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
77
  num_kp = x_s_user.shape[1]
78
  # default: use x_s
 
80
  # D(W(f_s; x_s, xβ€²_d))
81
  out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
82
  out = self.live_portrait_wrapper.parse_output(out['out'])[0]
83
+ out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori)
84
  # gr.Info("Run successfully!", duration=2)
85
  return out, out_to_ori_blend
86
 
87
 
88
+ def prepare_retargeting(self, input_image, flag_do_crop = True):
89
  """ for single image retargeting
90
  """
91
+ if input_image is not None:
92
  # gr.Info("Upload successfully!", duration=2)
 
93
  inference_cfg = self.live_portrait_wrapper.cfg
94
  ######## process source portrait ########
95
+ img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16)
96
+ log(f"Load source image from {input_image}.")
97
  crop_info = self.cropper.crop_single_image(img_rgb)
98
  if flag_do_crop:
99
  I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
 
102
  x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
103
  R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
104
  ############################################
105
+ f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
106
+ x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
107
+ source_lmk_user = crop_info['lmk_crop']
108
+ crop_M_c2o = crop_info['M_c2o']
109
+ mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
110
+ return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
 
 
 
 
 
 
 
 
 
 
 
111
  else:
112
  # when press the clear button, go here
113
+ raise gr.Error("The retargeting input hasn't been prepared yet πŸ’₯!", duration=5)
114
+