ECON / lib /common /render.py
Yuliang's picture
upgrade to Gradio 4.14.0
e0ba903
raw
history blame
No virus
12.3 kB
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: [email protected]
import math
import os
import sys
import cv2
import numpy as np
import torch
from PIL import ImageColor
from pytorch3d.renderer import (
AlphaCompositor,
BlendParams,
FoVOrthographicCameras,
MeshRasterizer,
MeshRenderer,
PointsRasterizationSettings,
PointsRasterizer,
PointsRenderer,
RasterizationSettings,
SoftSilhouetteShader,
TexturesVertex,
blending,
look_at_view_transform,
)
from pytorch3d.renderer.mesh import TexturesVertex
from pytorch3d.structures import Meshes
import torch.nn.functional as F
from termcolor import colored
from tqdm import tqdm
import lib.common.render_utils as util
from lib.common.imutils import blend_rgb_norm
from lib.dataset.mesh_util import get_visibility
def image2vid(images, vid_path):
os.makedirs(os.path.dirname(vid_path), exist_ok=True)
w, h = images[0].size
videodims = (w, h)
fourcc = cv2.VideoWriter_fourcc(*"XVID")
video = cv2.VideoWriter(vid_path, fourcc, len(images) / 5.0, videodims)
for image in images:
video.write(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
video.release()
def query_color(verts, faces, image, device, paint_normal=True):
"""query colors from points and image
Args:
verts ([B, 3]): [query verts]
faces ([M, 3]): [query faces]
image ([B, 3, H, W]): [full image]
Returns:
[np.float]: [return colors]
"""
verts = verts.float().to(device)
faces = faces.long().to(device)
(xy, z) = verts.split([2, 1], dim=1)
visibility = get_visibility(xy, z, faces[:, [0, 2, 1]]).flatten()
uv = xy.unsqueeze(0).unsqueeze(2) # [B, N, 2]
uv = uv * torch.tensor([1.0, -1.0]).type_as(uv)
colors = ((
torch.nn.functional.grid_sample(image, uv, align_corners=True)[0, :, :, 0].permute(1, 0) +
1.0
) * 0.5 * 255.0)
if paint_normal:
colors[visibility == 0.0] = ((
Meshes(verts.unsqueeze(0), faces.unsqueeze(0)).verts_normals_padded().squeeze(0) + 1.0
) * 0.5 * 255.0)[visibility == 0.0]
else:
colors[visibility == 0.0] = torch.tensor([0.0, 0.0, 0.0]).to(device)
return colors.detach().cpu()
class cleanShader(torch.nn.Module):
def __init__(self, blend_params=None):
super().__init__()
self.blend_params = blend_params if blend_params is not None else BlendParams()
def forward(self, fragments, meshes, **kwargs):
# get renderer output
blend_params = kwargs.get("blend_params", self.blend_params)
texels = meshes.sample_textures(fragments)
images = blending.softmax_rgb_blend(texels, fragments, blend_params, znear=-256, zfar=256)
return images
class Render:
def __init__(self, size=512, device=torch.device("cuda:0")):
self.device = device
self.size = size
# camera setting
self.dis = 100.0
self.scale = 100.0
self.mesh_y_center = 0.0
# speed control
self.fps = 30
self.step = 3
self.cam_pos = {
"front":
torch.tensor([
(0, self.mesh_y_center, self.dis),
(0, self.mesh_y_center, -self.dis),
]), "frontback":
torch.tensor([
(0, self.mesh_y_center, self.dis),
(0, self.mesh_y_center, -self.dis),
]), "four":
torch.tensor([
(0, self.mesh_y_center, self.dis),
(self.dis, self.mesh_y_center, 0),
(0, self.mesh_y_center, -self.dis),
(-self.dis, self.mesh_y_center, 0),
]), "around":
torch.tensor([(
100.0 * math.cos(np.pi / 180 * angle), self.mesh_y_center,
100.0 * math.sin(np.pi / 180 * angle)
) for angle in range(0, 360, self.step)])
}
self.type = "color"
self.mesh = None
self.deform_mesh = None
self.pcd = None
self.renderer = None
self.meshRas = None
self.uv_rasterizer = util.Pytorch3dRasterizer(self.size)
def get_camera_batch(self, type="four", idx=None):
if idx is None:
idx = np.arange(len(self.cam_pos[type]))
R, T = look_at_view_transform(
eye=self.cam_pos[type][idx],
at=((0, self.mesh_y_center, 0), ),
up=((0, 1, 0), ),
)
cameras = FoVOrthographicCameras(
device=self.device,
R=R,
T=T,
znear=100.0,
zfar=-100.0,
max_y=100.0,
min_y=-100.0,
max_x=100.0,
min_x=-100.0,
scale_xyz=(self.scale * np.ones(3), ) * len(R),
)
return cameras
def init_renderer(self, camera, type="mesh", bg="gray"):
blendparam = BlendParams(1e-4, 1e-8, np.array(ImageColor.getrgb(bg)) / 255.0)
if ("mesh" in type) or ("depth" in type) or ("rgb" in type):
# rasterizer
self.raster_settings_mesh = RasterizationSettings(
image_size=self.size,
blur_radius=np.log(1.0 / 1e-4) * 1e-7,
bin_size=-1,
faces_per_pixel=30,
)
self.meshRas = MeshRasterizer(cameras=camera, raster_settings=self.raster_settings_mesh)
self.renderer = MeshRenderer(
rasterizer=self.meshRas,
shader=cleanShader(blend_params=blendparam),
)
elif type == "mask":
self.raster_settings_silhouette = RasterizationSettings(
image_size=self.size,
blur_radius=np.log(1.0 / 1e-4 - 1.0) * 5e-5,
faces_per_pixel=50,
bin_size=-1,
cull_backfaces=True,
)
self.silhouetteRas = MeshRasterizer(
cameras=camera, raster_settings=self.raster_settings_silhouette
)
self.renderer = MeshRenderer(
rasterizer=self.silhouetteRas, shader=SoftSilhouetteShader()
)
elif type == "pointcloud":
self.raster_settings_pcd = PointsRasterizationSettings(
image_size=self.size, radius=0.006, points_per_pixel=10
)
self.pcdRas = PointsRasterizer(cameras=camera, raster_settings=self.raster_settings_pcd)
self.renderer = PointsRenderer(
rasterizer=self.pcdRas,
compositor=AlphaCompositor(background_color=(0, 0, 0)),
)
def load_meshes(self, verts, faces):
"""load mesh into the pytorch3d renderer
Args:
verts ([N,3] / [B,N,3]): array or tensor
faces ([N,3]/ [B,N,3]): array or tensor
"""
if isinstance(verts, list):
V_lst = []
F_lst = []
for V, F in zip(verts, faces):
if not torch.is_tensor(V):
V_lst.append(torch.tensor(V).float().to(self.device))
F_lst.append(torch.tensor(F).long().to(self.device))
else:
V_lst.append(V.float().to(self.device))
F_lst.append(F.long().to(self.device))
self.meshes = Meshes(V_lst, F_lst).to(self.device)
else:
# array or tensor
if not torch.is_tensor(verts):
verts = torch.tensor(verts)
faces = torch.tensor(faces)
if verts.ndimension() == 2:
verts = verts.float().unsqueeze(0).to(self.device)
faces = faces.long().unsqueeze(0).to(self.device)
if verts.shape[0] != faces.shape[0]:
faces = faces.repeat(len(verts), 1, 1).to(self.device)
self.meshes = Meshes(verts, faces).to(self.device)
# texture only support single mesh
if len(self.meshes) == 1:
self.meshes.textures = TexturesVertex(
verts_features=(self.meshes.verts_normals_padded() + 1.0) * 0.5
)
def get_image(self, cam_type="frontback", type="rgb", bg="gray"):
self.init_renderer(self.get_camera_batch(cam_type), type, bg)
img_lst = []
for mesh_id in range(len(self.meshes)):
current_mesh = self.meshes[mesh_id]
current_mesh.textures = TexturesVertex(
verts_features=(current_mesh.verts_normals_padded() + 1.0) * 0.5
)
if type == "depth":
fragments = self.meshRas(current_mesh.extend(len(self.cam_pos[cam_type])))
images = fragments.zbuf[..., 0]
elif type == "rgb":
images = self.renderer(current_mesh.extend(len(self.cam_pos[cam_type])))
images = (images[:, :, :, :3].permute(0, 3, 1, 2) - 0.5) * 2.0
elif type == "mask":
images = self.renderer(current_mesh.extend(len(self.cam_pos[cam_type])))[:, :, :, 3]
else:
print(f"unknown {type}")
if cam_type == 'frontback':
images[1] = torch.flip(images[1], dims=(-1, ))
# images [N_render, 3, res, res]
img_lst.append(images.unsqueeze(1))
# meshes [N_render, N_mesh, 3, res, res]
meshes = torch.cat(img_lst, dim=1)
return list(meshes)
def get_rendered_video_multi(self, data, save_path):
height, width = data["img_raw"].shape[2:]
width = int(width / (height / 256.0))
height = int(256)
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
video = cv2.VideoWriter(
save_path,
fourcc,
self.fps,
(width * 3, height),
)
pbar = tqdm(range(len(self.meshes)))
print(colored(f"Normal Rendering {os.path.basename(save_path)}...", "blue"))
mesh_renders = [] #[(N_cam, 3, res, res)*N_mesh]
# render all the normals
for mesh_id in pbar:
current_mesh = self.meshes[mesh_id]
current_mesh.textures = TexturesVertex(
verts_features=(current_mesh.verts_normals_padded() + 1.0) * 0.5
)
norm_lst = []
for batch_cams_idx in np.array_split(np.arange(len(self.cam_pos["around"])), 12):
batch_cams = self.get_camera_batch(type='around', idx=batch_cams_idx)
self.init_renderer(batch_cams, "mesh", "gray")
norm_lst.append(
self.renderer(current_mesh.extend(len(batch_cams_idx))
)[..., :3].permute(0, 3, 1, 2)
)
mesh_renders.append(torch.cat(norm_lst).detach().cpu())
# generate video frame by frame
pbar = tqdm(range(len(self.cam_pos["around"])))
print(colored(f"Video Exporting {os.path.basename(save_path)}...", "blue"))
for cam_id in pbar:
img_raw = data["img_raw"]
num_obj = len(mesh_renders) // 2
img_smpl = blend_rgb_norm((torch.stack(mesh_renders)[:num_obj, cam_id] - 0.5) * 2.0,
data)
img_cloth = blend_rgb_norm((torch.stack(mesh_renders)[num_obj:, cam_id] - 0.5) * 2.0,
data)
final_img = torch.cat([img_raw, img_smpl, img_cloth], dim=-1)
final_img_rescale = F.interpolate(
final_img, size=(height, width*3), mode="bilinear", align_corners=False
).squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8)
video.write(final_img_rescale[:, :, ::-1])
video.release()