Pie31415 commited on
Commit
ab66a38
1 Parent(s): 9941f21

updated app

Browse files
Files changed (1) hide show
  1. app.py +7 -11
app.py CHANGED
@@ -13,9 +13,7 @@ sys.path.append("./rome/")
13
  sys.path.append('./DECA')
14
 
15
  # loading models ---- create model repo
16
-
17
- default_modnet_path = hf_hub_download(
18
- 'Pie31415/rome', 'modnet_photographic_portrait_matting.ckpt')
19
  default_model_path = hf_hub_download('Pie31415/rome', 'rome.pth')
20
 
21
  # parser configurations
@@ -126,25 +124,23 @@ def image_inference(
126
  ):
127
  out = infer.evaluate(source_img, driver_img, crop_center=False)
128
  res = tensor2image(torch.cat([out['source_information']['data_dict']['source_img'][0].cpu(),
129
- out['source_information']['data_dict']['target_img'][0].cpu(
130
- ),
131
- out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2))
132
  return res[..., ::-1]
133
 
134
- def folder_inference():
135
  pass
136
 
137
  with gr.Blocks() as demo:
138
  with gr.Tab("Image Inference"):
139
- image_input = [gr.Image(type="pil"), gr.Image(type="pil")]
140
  image_output = gr.Image()
141
  image_button = gr.Button("Predict")
142
- with gr.Tab("Inference Over Folder"):
143
- pass
144
  with gr.Tab("Video Inference"):
 
145
  pass
146
 
147
- image_button.click(image_inference, inputs=image_input, outputs=image_output)
148
  title = "ROME: Realistic one-shot mesh-based head avatars"
149
  examples = gr.Examples(["examples/lincoln.jpg", "examples/tars2.jpg"])
150
 
 
13
  sys.path.append('./DECA')
14
 
15
  # loading models ---- create model repo
16
+ default_modnet_path = hf_hub_download('Pie31415/rome', 'modnet_photographic_portrait_matting.ckpt')
 
 
17
  default_model_path = hf_hub_download('Pie31415/rome', 'rome.pth')
18
 
19
  # parser configurations
 
124
  ):
125
  out = infer.evaluate(source_img, driver_img, crop_center=False)
126
  res = tensor2image(torch.cat([out['source_information']['data_dict']['source_img'][0].cpu(),
127
+ out['source_information']['data_dict']['target_img'][0].cpu(),
128
+ out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2))
 
129
  return res[..., ::-1]
130
 
131
+ def video_inference():
132
  pass
133
 
134
  with gr.Blocks() as demo:
135
  with gr.Tab("Image Inference"):
136
+ image_inputs = [gr.Image(type="pil"), gr.Image(type="pil")]
137
  image_output = gr.Image()
138
  image_button = gr.Button("Predict")
 
 
139
  with gr.Tab("Video Inference"):
140
+ video_inputs = [gr.Video(), gr.Image()]
141
  pass
142
 
143
+ image_button.click(image_inference, inputs=image_inputs, outputs=image_output)
144
  title = "ROME: Realistic one-shot mesh-based head avatars"
145
  examples = gr.Examples(["examples/lincoln.jpg", "examples/tars2.jpg"])
146