Spaces:
Build error
Build error
import sys | |
import torch | |
import gradio as gr | |
import pickle | |
from easydict import EasyDict as edict | |
from huggingface_hub import hf_hub_download | |
sys.path.append("./rome/") | |
sys.path.append('./DECA') | |
from rome.infer import Infer | |
from rome.src.utils.processing import process_black_shape, tensor2image | |
# loading models ---- create model repo | |
default_modnet_path = hf_hub_download('Pie31415/rome', 'modnet_photographic_portrait_matting.ckpt') | |
default_model_path = hf_hub_download('Pie31415/rome', 'rome.pth') | |
# parser configurations | |
args = edict({ | |
"save_dir": ".", | |
"save_render": True, | |
"model_checkpoint": default_model_path, | |
"modnet_path": default_modnet_path, | |
"random_seed": 0, | |
"debug": False, | |
"verbose": False, | |
"model_image_size": 256, | |
"align_source": True, | |
"align_target": False, | |
"align_scale": 1.25, | |
"use_mesh_deformations": False, | |
"subdivide_mesh": False, | |
"renderer_sigma": 1e-08, | |
"renderer_zfar": 100.0, | |
"renderer_type": "soft_mesh", | |
"renderer_texture_type": "texture_uv", | |
"renderer_normalized_alphas": False, | |
"deca_path": "DECA", | |
"rome_data_dir": "rome/data", | |
"autoenc_cat_alphas": False, | |
"autoenc_align_inputs": False, | |
"autoenc_use_warp": False, | |
"autoenc_num_channels": 64, | |
"autoenc_max_channels": 512, | |
"autoenc_num_groups": 4, | |
"autoenc_num_bottleneck_groups": 0, | |
"autoenc_num_blocks": 2, | |
"autoenc_num_layers": 4, | |
"autoenc_block_type": "bottleneck", | |
"neural_texture_channels": 8, | |
"num_harmonic_encoding_funcs": 6, | |
"unet_num_channels": 64, | |
"unet_max_channels": 512, | |
"unet_num_groups": 4, | |
"unet_num_blocks": 1, | |
"unet_num_layers": 2, | |
"unet_block_type": "conv", | |
"unet_skip_connection_type": "cat", | |
"unet_use_normals_cond": True, | |
"unet_use_vertex_cond": False, | |
"unet_use_uvs_cond": False, | |
"unet_pred_mask": False, | |
"use_separate_seg_unet": True, | |
"norm_layer_type": "gn", | |
"activation_type": "relu", | |
"conv_layer_type": "ws_conv", | |
"deform_norm_layer_type": "gn", | |
"deform_activation_type": "relu", | |
"deform_conv_layer_type": "ws_conv", | |
"unet_seg_weight": 0.0, | |
"unet_seg_type": "bce_with_logits", | |
"deform_face_tightness": 0.0001, | |
"use_whole_segmentation": False, | |
"mask_hair_for_neck": False, | |
"use_hair_from_avatar": False, | |
"use_scalp_deforms": True, | |
"use_neck_deforms": True, | |
"use_basis_deformer": False, | |
"use_unet_deformer": True, | |
"pretrained_encoder_basis_path": "", | |
"pretrained_vertex_basis_path": "", | |
"num_basis": 50, | |
"basis_init": "pca", | |
"num_vertex": 5023, | |
"train_basis": True, | |
"path_to_deca": "DECA", | |
"path_to_linear_hair_model": "data/linear_hair.pth", # N/A | |
"path_to_mobile_model": "data/disp_model.pth", # N/A | |
"n_scalp": 60, | |
"use_distill": False, | |
"use_mobile_version": False, | |
"deformer_path": "data/rome.pth", | |
"output_unet_deformer_feats": 32, | |
"use_deca_details": False, | |
"use_flametex": False, | |
"upsample_type": "nearest", | |
"num_frequencies": 6, | |
"deform_face_scale_coef": 0.0, | |
"device": "cpu" | |
}) | |
# download FLAME and DECA pretrained | |
generic_model_path = hf_hub_download('Pie31415/rome', 'generic_model.pkl') | |
deca_model_path = hf_hub_download('Pie31415/rome', 'deca_model.tar') | |
with open(generic_model_path, 'rb') as f: | |
ss = pickle.load(f, encoding='latin1') | |
with open('./DECA/data/generic_model.pkl', 'wb') as out: | |
pickle.dump(ss, out) | |
with open(deca_model_path, "rb") as input: | |
with open('./DECA/data/deca_model.tar', "wb") as out: | |
for line in input: | |
out.write(line) | |
# load ROME inference model | |
infer = Infer(args) | |
def image_inference( | |
source_img: gr.inputs.Image = None, | |
driver_img: gr.inputs.Image = None | |
): | |
out = infer.evaluate(source_img, driver_img, crop_center=False) | |
res = tensor2image(torch.cat([out['source_information']['data_dict']['source_img'][0].cpu(), | |
out['source_information']['data_dict']['target_img'][0].cpu(), | |
out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2)) | |
return res[..., ::-1] | |
def video_inference(): | |
pass | |
with gr.Blocks() as demo: | |
gr.Markdown("# **<p align='center'>ROME: Realistic one-shot mesh-based head avatars</p>**") | |
with gr.Tab("Image Inference"): | |
with gr.Row(): | |
source_img = gr.Image(type="pil", label="source image", show_label=True) | |
driver_img = gr.Image(type="pil", label="driver image", show_label=True) | |
image_output = gr.Image() | |
image_button = gr.Button("Predict") | |
with gr.Tab("Video Inference"): | |
with gr.Row(): | |
source_video = gr.Video(label="source video", ) | |
driver_image_for_vid = gr.Image(type="pil", label="driver image", show_label=True) | |
video_output = gr.Image() | |
video_button = gr.Button("Predict") | |
gr.Examples( | |
examples=[ | |
["./examples/lincoln.jpg", "./examples/taras2.jpg"], | |
["./examples/lincoln.jpg", "./examples/taras1.jpg"] | |
], | |
inputs=[source_img, driver_img], | |
outputs=[image_output], | |
fn=image_inference, | |
cache_examples=True | |
) | |
image_button.click(image_inference, inputs=[source_img, driver_img], outputs=image_output) | |
video_button.click(None, inputs=[source_video, driver_image_for_vid], outputs=video_output) | |
demo.launch() |