Pie31415 commited on
Commit
a4fd448
1 Parent(s): ae44faf

updated app

Browse files
Files changed (1) hide show
  1. app.py +45 -40
app.py CHANGED
@@ -1,29 +1,25 @@
1
- import os, sys
 
 
 
 
 
 
 
2
  import torch
3
- import argparse
4
-
5
- import numpy as np
6
  import torch
7
- import matplotlib.pyplot as plt
8
- from PIL import Image
9
-
10
- print(torch.__version__)
11
- print(torch.version.cuda)
12
 
13
  sys.path.append("./rome/")
14
  sys.path.append('./DECA')
15
 
16
- from rome.src.utils import args as args_utils
17
- from rome.src.utils.processing import process_black_shape, tensor2image
18
 
19
  # loading models ---- create model repo
20
- from huggingface_hub import hf_hub_download
21
 
22
- default_modnet_path = hf_hub_download('Pie31415/rome','modnet_photographic_portrait_matting.ckpt')
23
- default_model_path = hf_hub_download('Pie31415/rome','rome.pth')
 
24
 
25
  # parser configurations
26
- from easydict import EasyDict as edict
27
 
28
  args = edict({
29
  "save_dir": ".",
@@ -93,8 +89,8 @@ args = edict({
93
  "num_vertex": 5023,
94
  "train_basis": True,
95
  "path_to_deca": "DECA",
96
- "path_to_linear_hair_model": "data/linear_hair.pth", # N/A
97
- "path_to_mobile_model": "data/disp_model.pth", # N/A
98
  "n_scalp": 60,
99
  "use_distill": False,
100
  "use_mobile_version": False,
@@ -109,40 +105,49 @@ args = edict({
109
  })
110
 
111
  # download FLAME and DECA pretrained
112
- generic_model_path = hf_hub_download('Pie31415/rome','generic_model.pkl')
113
- deca_model_path = hf_hub_download('Pie31415/rome','deca_model.tar')
114
-
115
- import pickle
116
 
117
  with open(generic_model_path, 'rb') as f:
118
- ss = pickle.load(f, encoding='latin1')
119
 
120
- with open('./DECA/data/generic_model.pkl', 'wb') as out:
121
- pickle.dump(ss, out)
122
 
123
  with open(deca_model_path, "rb") as input:
124
- with open('./DECA/data/deca_model.tar', "wb") as out:
125
- for line in input:
126
- out.write(line)
127
 
128
  # load ROME inference model
129
- from rome.infer import Infer
130
  infer = Infer(args)
131
 
132
- def predict(source_img, driver_img):
 
 
 
133
  out = infer.evaluate(source_img, driver_img, crop_center=False)
134
  res = tensor2image(torch.cat([out['source_information']['data_dict']['source_img'][0].cpu(),
135
- out['source_information']['data_dict']['target_img'][0].cpu(),
136
- out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2))
 
137
  return res[..., ::-1]
138
 
139
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- gr.Interface(
142
- fn=predict,
143
- inputs=[
144
- gr.Image(type="pil"),
145
- gr.Image(type="pil")
146
- ],
147
- outputs=gr.Image(),
148
- examples=[]).launch()
 
1
+ import gradio as gr
2
+ from rome.infer import Infer
3
+ import pickle
4
+ from easydict import EasyDict as edict
5
+ from huggingface_hub import hf_hub_download
6
+ from rome.src.utils.processing import process_black_shape, tensor2image
7
+ from rome.src.utils import args as args_utils
8
+ import sys
9
  import torch
 
 
 
10
  import torch
 
 
 
 
 
11
 
12
  sys.path.append("./rome/")
13
  sys.path.append('./DECA')
14
 
 
 
15
 
16
  # loading models ---- create model repo
 
17
 
18
+ default_modnet_path = hf_hub_download(
19
+ 'Pie31415/rome', 'modnet_photographic_portrait_matting.ckpt')
20
+ default_model_path = hf_hub_download('Pie31415/rome', 'rome.pth')
21
 
22
  # parser configurations
 
23
 
24
  args = edict({
25
  "save_dir": ".",
 
89
  "num_vertex": 5023,
90
  "train_basis": True,
91
  "path_to_deca": "DECA",
92
+ "path_to_linear_hair_model": "data/linear_hair.pth", # N/A
93
+ "path_to_mobile_model": "data/disp_model.pth", # N/A
94
  "n_scalp": 60,
95
  "use_distill": False,
96
  "use_mobile_version": False,
 
105
  })
106
 
107
  # download FLAME and DECA pretrained
108
+ generic_model_path = hf_hub_download('Pie31415/rome', 'generic_model.pkl')
109
+ deca_model_path = hf_hub_download('Pie31415/rome', 'deca_model.tar')
 
 
110
 
111
  with open(generic_model_path, 'rb') as f:
112
+ ss = pickle.load(f, encoding='latin1')
113
 
114
+ with open('./DECA/data/generic_model.pkl', 'wb') as out:
115
+ pickle.dump(ss, out)
116
 
117
  with open(deca_model_path, "rb") as input:
118
+ with open('./DECA/data/deca_model.tar', "wb") as out:
119
+ for line in input:
120
+ out.write(line)
121
 
122
  # load ROME inference model
 
123
  infer = Infer(args)
124
 
125
+ def image_inference(
126
+ source_img: gr.inputs.Image = None,
127
+ driver_img: gr.inputs.Image = None
128
+ ):
129
  out = infer.evaluate(source_img, driver_img, crop_center=False)
130
  res = tensor2image(torch.cat([out['source_information']['data_dict']['source_img'][0].cpu(),
131
+ out['source_information']['data_dict']['target_img'][0].cpu(
132
+ ),
133
+ out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2))
134
  return res[..., ::-1]
135
 
136
+ def folder_inference():
137
+ pass
138
+
139
+ with gr.Blocks() as demo:
140
+ with gr.Tab("Image Inference"):
141
+ image_input = [gr.Image(type="pil"), gr.Image(type="pil")]
142
+ image_output = gr.Image()
143
+ image_button = gr.Button("Predict")
144
+ with gr.Tab("Inference Over Folder"):
145
+ pass
146
+ with gr.Tab("Video Inference"):
147
+ pass
148
+
149
+ image_button.click(image_inference, inputs=image_input, outputs=image_output)
150
+ title = "ROME: Realistic one-shot mesh-based head avatars"
151
+ examples = gr.Examples(["examples/lincoln.jpg", "examples/tars2.jpg"])
152
 
153
+ demo.launch()