Spaces:
Build error
Build error
implemented video inference
Browse files- README.md +1 -1
- app.py +72 -6
- requirements.txt +2 -1
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: Rome
|
3 |
-
emoji:
|
4 |
colorFrom: purple
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
title: Rome
|
3 |
+
emoji: π
|
4 |
colorFrom: purple
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
app.py
CHANGED
@@ -1,7 +1,16 @@
|
|
1 |
import sys
|
2 |
import torch
|
3 |
-
import gradio as gr
|
4 |
import pickle
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
from easydict import EasyDict as edict
|
7 |
from huggingface_hub import hf_hub_download
|
@@ -11,6 +20,7 @@ sys.path.append('./DECA')
|
|
11 |
|
12 |
from rome.infer import Infer
|
13 |
from rome.src.utils.processing import process_black_shape, tensor2image
|
|
|
14 |
|
15 |
# loading models ---- create model repo
|
16 |
default_modnet_path = hf_hub_download('Pie31415/rome', 'modnet_photographic_portrait_matting.ckpt')
|
@@ -128,8 +138,64 @@ def image_inference(
|
|
128 |
out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2))
|
129 |
return res[..., ::-1]
|
130 |
|
131 |
-
def
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
with gr.Blocks() as demo:
|
135 |
gr.Markdown("# **<p align='center'>ROME: Realistic one-shot mesh-based head avatars</p>**")
|
@@ -151,8 +217,8 @@ with gr.Blocks() as demo:
|
|
151 |
image_button = gr.Button("Predict")
|
152 |
with gr.Tab("Video Inference"):
|
153 |
with gr.Row():
|
154 |
-
|
155 |
-
|
156 |
video_output = gr.Image()
|
157 |
video_button = gr.Button("Predict")
|
158 |
|
@@ -168,6 +234,6 @@ with gr.Blocks() as demo:
|
|
168 |
)
|
169 |
|
170 |
image_button.click(image_inference, inputs=[source_img, driver_img], outputs=image_output)
|
171 |
-
video_button.click(
|
172 |
|
173 |
demo.launch()
|
|
|
1 |
import sys
|
2 |
import torch
|
|
|
3 |
import pickle
|
4 |
+
import cv2
|
5 |
+
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
from collections import defaultdict
|
10 |
+
from glob import glob
|
11 |
+
|
12 |
+
from matplotlib import pyplot as plt
|
13 |
+
from matplotlib import animation
|
14 |
|
15 |
from easydict import EasyDict as edict
|
16 |
from huggingface_hub import hf_hub_download
|
|
|
20 |
|
21 |
from rome.infer import Infer
|
22 |
from rome.src.utils.processing import process_black_shape, tensor2image
|
23 |
+
from rome.src.utils.visuals import mask_errosion
|
24 |
|
25 |
# loading models ---- create model repo
|
26 |
default_modnet_path = hf_hub_download('Pie31415/rome', 'modnet_photographic_portrait_matting.ckpt')
|
|
|
138 |
out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2))
|
139 |
return res[..., ::-1]
|
140 |
|
141 |
+
def extract_frames(driver_vid):
|
142 |
+
image_frames = []
|
143 |
+
vid = cv2.VideoCapture(driver_vid) # path to mp4
|
144 |
+
|
145 |
+
while True:
|
146 |
+
success, img = vid.read()
|
147 |
+
|
148 |
+
if not success: break
|
149 |
+
|
150 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
151 |
+
pil_img = Image.fromarray(img)
|
152 |
+
image_frames.append(pil_img)
|
153 |
+
|
154 |
+
return image_frames
|
155 |
+
|
156 |
+
def video_inference(source_img, driver_vid):
|
157 |
+
image_frames = extract_frames(driver_vid)
|
158 |
+
|
159 |
+
resulted_imgs = defaultdict(list)
|
160 |
+
|
161 |
+
video_folder = 'jenya_driver/'
|
162 |
+
image_frames = sorted(glob(f"{video_folder}/*", recursive=True), key=lambda x: int(x.split('/')[-1][:-4]))
|
163 |
+
|
164 |
+
mask_hard_threshold = 0.5
|
165 |
+
N = len(image_frames)//20
|
166 |
+
for i in range(0, N, 4):
|
167 |
+
new_out = infer.evaluate(source_img, Image.open(image_frames[i]),
|
168 |
+
source_information_for_reuse=out.get('source_information'))
|
169 |
+
|
170 |
+
mask_pred = (new_out['pred_target_unet_mask'].cpu() > mask_hard_threshold).float()
|
171 |
+
mask_pred = mask_errosion(mask_pred[0].float().numpy() * 255)
|
172 |
+
render = new_out['pred_target_img'].cpu() * (mask_pred) + (1 - mask_pred)
|
173 |
+
|
174 |
+
normals = process_black_shape(((new_out['pred_target_normal'][0].cpu() + 1) / 2 * mask_pred + (1 - mask_pred) ) )
|
175 |
+
normals[normals==0.5]=1.
|
176 |
+
|
177 |
+
resulted_imgs['res_normal'].append(tensor2image(normals))
|
178 |
+
resulted_imgs['res_mesh_images'].append(tensor2image(new_out['pred_target_shape_img'][0]))
|
179 |
+
resulted_imgs['res_renders'].append(tensor2image(render[0]))
|
180 |
+
|
181 |
+
video = np.array(resulted_imgs['res_renders'])
|
182 |
+
|
183 |
+
fig = plt.figure()
|
184 |
+
im = plt.imshow(video[0,:,:,::-1])
|
185 |
+
plt.axis('off')
|
186 |
+
plt.close() # this is required to not display the generated image
|
187 |
+
|
188 |
+
def init():
|
189 |
+
im.set_data(video[0,:,:,::-1])
|
190 |
+
|
191 |
+
def animate(i):
|
192 |
+
im.set_data(video[i,:,:,::-1])
|
193 |
+
return im
|
194 |
+
|
195 |
+
anim = animation.FuncAnimation(fig, animate, init_func=init,
|
196 |
+
frames=video.shape[0], interval=30)
|
197 |
+
|
198 |
+
return anim
|
199 |
|
200 |
with gr.Blocks() as demo:
|
201 |
gr.Markdown("# **<p align='center'>ROME: Realistic one-shot mesh-based head avatars</p>**")
|
|
|
217 |
image_button = gr.Button("Predict")
|
218 |
with gr.Tab("Video Inference"):
|
219 |
with gr.Row():
|
220 |
+
source_img2 = gr.Image(type="pil", label="source image", show_label=True)
|
221 |
+
driver_vid = gr.Video(label="driver video")
|
222 |
video_output = gr.Image()
|
223 |
video_button = gr.Button("Predict")
|
224 |
|
|
|
234 |
)
|
235 |
|
236 |
image_button.click(image_inference, inputs=[source_img, driver_img], outputs=image_output)
|
237 |
+
video_button.click(video_inference, inputs=[source_img2, driver_vid], outputs=video_output)
|
238 |
|
239 |
demo.launch()
|
requirements.txt
CHANGED
@@ -8,4 +8,5 @@ matplotlib
|
|
8 |
pillow
|
9 |
https://download.pytorch.org/whl/cu101/torch-1.6.0%2Bcu101-cp38-cp38-linux_x86_64.whl
|
10 |
https://download.pytorch.org/whl/cu101/torchvision-0.7.0%2Bcu101-cp38-cp38-linux_x86_64.whl
|
11 |
-
easydict
|
|
|
|
8 |
pillow
|
9 |
https://download.pytorch.org/whl/cu101/torch-1.6.0%2Bcu101-cp38-cp38-linux_x86_64.whl
|
10 |
https://download.pytorch.org/whl/cu101/torchvision-0.7.0%2Bcu101-cp38-cp38-linux_x86_64.whl
|
11 |
+
easydict
|
12 |
+
opencv
|