use pytorch3d to render, instead of nvdiffrast
Browse files- gradio_app/gradio_3dgen.py +0 -5
- mesh_reconstruction/recon.py +2 -2
- mesh_reconstruction/refine.py +2 -2
- mesh_reconstruction/render.py +118 -0
- scripts/project_mesh.py +14 -18
gradio_app/gradio_3dgen.py
CHANGED
@@ -10,13 +10,8 @@ from scripts.refine_lr_to_sr import run_sr_fast
|
|
10 |
from scripts.utils import save_glb_and_video
|
11 |
from scripts.multiview_inference import geo_reconstruct
|
12 |
|
13 |
-
|
14 |
-
import nvdiffrast.torch as dr
|
15 |
-
dr.RasterizeGLContext(output_db=False)
|
16 |
@spaces.GPU
|
17 |
def generate3dv2(preview_img, input_processing, seed, render_video=True, do_refine=True, expansion_weight=0.1, init_type="std"):
|
18 |
-
dr.RasterizeGLContext(output_db=False) # BUG: cuda_runtime_api.h: No such file or directory
|
19 |
-
|
20 |
if preview_img is None:
|
21 |
raise gr.Error("preview_img is none")
|
22 |
if isinstance(preview_img, str):
|
|
|
10 |
from scripts.utils import save_glb_and_video
|
11 |
from scripts.multiview_inference import geo_reconstruct
|
12 |
|
|
|
|
|
|
|
13 |
@spaces.GPU
|
14 |
def generate3dv2(preview_img, input_processing, seed, render_video=True, do_refine=True, expansion_weight=0.1, init_type="std"):
|
|
|
|
|
15 |
if preview_img is None:
|
16 |
raise gr.Error("preview_img is none")
|
17 |
if isinstance(preview_img, str):
|
mesh_reconstruction/recon.py
CHANGED
@@ -6,14 +6,14 @@ from typing import List
|
|
6 |
from mesh_reconstruction.remesh import calc_vertex_normals
|
7 |
from mesh_reconstruction.opt import MeshOptimizer
|
8 |
from mesh_reconstruction.func import make_star_cameras_orthographic
|
9 |
-
from mesh_reconstruction.render import NormalsRenderer
|
10 |
from scripts.utils import to_py3d_mesh, init_target
|
11 |
|
12 |
def reconstruct_stage1(pils: List[Image.Image], steps=100, vertices=None, faces=None, start_edge_len=0.15, end_edge_len=0.005, decay=0.995, return_mesh=True, loss_expansion_weight=0.1, gain=0.1):
|
13 |
vertices, faces = vertices.to("cuda"), faces.to("cuda")
|
14 |
assert len(pils) == 4
|
15 |
mv,proj = make_star_cameras_orthographic(4, 1)
|
16 |
-
renderer =
|
17 |
|
18 |
target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
|
19 |
# 1. no rotate
|
|
|
6 |
from mesh_reconstruction.remesh import calc_vertex_normals
|
7 |
from mesh_reconstruction.opt import MeshOptimizer
|
8 |
from mesh_reconstruction.func import make_star_cameras_orthographic
|
9 |
+
from mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer
|
10 |
from scripts.utils import to_py3d_mesh, init_target
|
11 |
|
12 |
def reconstruct_stage1(pils: List[Image.Image], steps=100, vertices=None, faces=None, start_edge_len=0.15, end_edge_len=0.005, decay=0.995, return_mesh=True, loss_expansion_weight=0.1, gain=0.1):
|
13 |
vertices, faces = vertices.to("cuda"), faces.to("cuda")
|
14 |
assert len(pils) == 4
|
15 |
mv,proj = make_star_cameras_orthographic(4, 1)
|
16 |
+
renderer = Pytorch3DNormalsRenderer(mv,proj,list(pils[0].size))
|
17 |
|
18 |
target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
|
19 |
# 1. no rotate
|
mesh_reconstruction/refine.py
CHANGED
@@ -5,7 +5,7 @@ from typing import List
|
|
5 |
from mesh_reconstruction.remesh import calc_vertex_normals
|
6 |
from mesh_reconstruction.opt import MeshOptimizer
|
7 |
from mesh_reconstruction.func import make_star_cameras_orthographic
|
8 |
-
from mesh_reconstruction.render import NormalsRenderer
|
9 |
from scripts.project_mesh import multiview_color_projection, get_cameras_list
|
10 |
from scripts.utils import to_py3d_mesh, from_py3d_mesh, init_target
|
11 |
|
@@ -18,7 +18,7 @@ def run_mesh_refine(vertices, faces, pils: List[Image.Image], steps=100, start_e
|
|
18 |
|
19 |
assert len(pils) == 4
|
20 |
mv,proj = make_star_cameras_orthographic(4, 1)
|
21 |
-
renderer =
|
22 |
|
23 |
target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
|
24 |
# 1. no rotate
|
|
|
5 |
from mesh_reconstruction.remesh import calc_vertex_normals
|
6 |
from mesh_reconstruction.opt import MeshOptimizer
|
7 |
from mesh_reconstruction.func import make_star_cameras_orthographic
|
8 |
+
from mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer
|
9 |
from scripts.project_mesh import multiview_color_projection, get_cameras_list
|
10 |
from scripts.utils import to_py3d_mesh, from_py3d_mesh, init_target
|
11 |
|
|
|
18 |
|
19 |
assert len(pils) == 4
|
20 |
mv,proj = make_star_cameras_orthographic(4, 1)
|
21 |
+
renderer = Pytorch3DNormalsRenderer(mv,proj,list(pils[0].size))
|
22 |
|
23 |
target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
|
24 |
# 1. no rotate
|
mesh_reconstruction/render.py
CHANGED
@@ -49,3 +49,121 @@ class NormalsRenderer:
|
|
49 |
col = torch.concat((col,alpha),dim=-1) #C,H,W,4
|
50 |
col = dr.antialias(col, rast_out, vertices_clip, faces) #C,H,W,4
|
51 |
return col #C,H,W,4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
col = torch.concat((col,alpha),dim=-1) #C,H,W,4
|
50 |
col = dr.antialias(col, rast_out, vertices_clip, faces) #C,H,W,4
|
51 |
return col #C,H,W,4
|
52 |
+
|
53 |
+
from pytorch3d.structures import Meshes
|
54 |
+
from pytorch3d.renderer.mesh.shader import ShaderBase
|
55 |
+
from pytorch3d.renderer import (
|
56 |
+
RasterizationSettings,
|
57 |
+
MeshRendererWithFragments,
|
58 |
+
TexturesVertex,
|
59 |
+
MeshRasterizer,
|
60 |
+
BlendParams,
|
61 |
+
FoVOrthographicCameras,
|
62 |
+
look_at_view_transform,
|
63 |
+
hard_rgb_blend,
|
64 |
+
)
|
65 |
+
|
66 |
+
class VertexColorShader(ShaderBase):
|
67 |
+
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
68 |
+
blend_params = kwargs.get("blend_params", self.blend_params)
|
69 |
+
texels = meshes.sample_textures(fragments)
|
70 |
+
return hard_rgb_blend(texels, fragments, blend_params)
|
71 |
+
|
72 |
+
def render_mesh_vertex_color(mesh, cameras, H, W, blur_radius=0.0, faces_per_pixel=1, bkgd=(0., 0., 0.), dtype=torch.float32, device="cuda"):
|
73 |
+
if len(mesh) != len(cameras):
|
74 |
+
if len(cameras) % len(mesh) == 0:
|
75 |
+
mesh = mesh.extend(len(cameras))
|
76 |
+
else:
|
77 |
+
raise NotImplementedError()
|
78 |
+
|
79 |
+
# render requires everything in float16 or float32
|
80 |
+
input_dtype = dtype
|
81 |
+
blend_params = BlendParams(1e-4, 1e-4, bkgd)
|
82 |
+
|
83 |
+
# Define the settings for rasterization and shading
|
84 |
+
raster_settings = RasterizationSettings(
|
85 |
+
image_size=(H, W),
|
86 |
+
blur_radius=blur_radius,
|
87 |
+
faces_per_pixel=faces_per_pixel,
|
88 |
+
clip_barycentric_coords=True,
|
89 |
+
bin_size=None,
|
90 |
+
max_faces_per_bin=500000,
|
91 |
+
)
|
92 |
+
|
93 |
+
# Create a renderer by composing a rasterizer and a shader
|
94 |
+
# We simply render vertex colors through the custom VertexColorShader (no lighting, materials are used)
|
95 |
+
renderer = MeshRendererWithFragments(
|
96 |
+
rasterizer=MeshRasterizer(
|
97 |
+
cameras=cameras,
|
98 |
+
raster_settings=raster_settings
|
99 |
+
),
|
100 |
+
shader=VertexColorShader(
|
101 |
+
device=device,
|
102 |
+
cameras=cameras,
|
103 |
+
blend_params=blend_params
|
104 |
+
)
|
105 |
+
)
|
106 |
+
|
107 |
+
# render RGB and depth, get mask
|
108 |
+
with torch.autocast(dtype=input_dtype, device_type=torch.device(device).type):
|
109 |
+
images, _ = renderer(mesh)
|
110 |
+
return images # BHW4
|
111 |
+
|
112 |
+
class Pytorch3DNormalsRenderer:
|
113 |
+
def __init__(self, cameras, image_size, device):
|
114 |
+
self.cameras = cameras.to(device)
|
115 |
+
self._image_size = image_size
|
116 |
+
self.device = device
|
117 |
+
|
118 |
+
def render(self,
|
119 |
+
vertices: torch.Tensor, #V,3 float
|
120 |
+
normals: torch.Tensor, #V,3 float in [-1, 1]
|
121 |
+
faces: torch.Tensor, #F,3 long
|
122 |
+
) ->torch.Tensor: #C,H,W,4
|
123 |
+
mesh = Meshes(verts=[vertices], faces=[faces], textures=TexturesVertex(verts_features=[(normals + 1) / 2])).to(self.device)
|
124 |
+
return render_mesh_vertex_color(mesh, self.cameras, self._image_size[0], self._image_size[1], device=self.device)
|
125 |
+
|
126 |
+
def get_camera(R, T, focal_length=1 / (2**0.5)):
|
127 |
+
focal_length = 1 / focal_length
|
128 |
+
camera = FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length)
|
129 |
+
return camera
|
130 |
+
|
131 |
+
def make_star_cameras_orthographic_py3d(azim_list, device, focal=2/1.35, dist=1.1):
|
132 |
+
R, T = look_at_view_transform(dist, 0, azim_list)
|
133 |
+
focal_length = 1 / focal
|
134 |
+
return FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length).to(device)
|
135 |
+
|
136 |
+
def save_tensor_to_img(tensor, save_dir):
|
137 |
+
from PIL import Image
|
138 |
+
import numpy as np
|
139 |
+
for idx, img in enumerate(tensor):
|
140 |
+
img = img[..., :3].cpu().numpy()
|
141 |
+
img = (img * 255).astype(np.uint8)
|
142 |
+
img = Image.fromarray(img)
|
143 |
+
img.save(save_dir + f"{idx}.png")
|
144 |
+
|
145 |
+
if __name__ == "__main__":
|
146 |
+
import sys
|
147 |
+
import os
|
148 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
149 |
+
from mesh_reconstruction.func import make_star_cameras_orthographic
|
150 |
+
cameras = make_star_cameras_orthographic_py3d([0, 270, 180, 90], device="cuda", focal=1., dist=4.0)
|
151 |
+
mv,proj = make_star_cameras_orthographic(4, 1)
|
152 |
+
resolution = 1024
|
153 |
+
renderer1 = NormalsRenderer(mv,proj, [resolution,resolution], device="cuda")
|
154 |
+
renderer2 = Pytorch3DNormalsRenderer(cameras, [resolution,resolution], device="cuda")
|
155 |
+
vertices = torch.tensor([[0,0,0],[0,0,1],[0,1,0],[1,0,0]], device="cuda", dtype=torch.float32)
|
156 |
+
normals = torch.tensor([[-1,-1,-1],[1,-1,-1],[-1,-1,1],[-1,1,-1]], device="cuda", dtype=torch.float32)
|
157 |
+
faces = torch.tensor([[0,1,2],[0,1,3],[0,2,3],[1,2,3]], device="cuda", dtype=torch.long)
|
158 |
+
|
159 |
+
import time
|
160 |
+
t0 = time.time()
|
161 |
+
r1 = renderer1.render(vertices, normals, faces)
|
162 |
+
print("time r1:", time.time() - t0)
|
163 |
+
|
164 |
+
t0 = time.time()
|
165 |
+
r2 = renderer2.render(vertices, normals, faces)
|
166 |
+
print("time r2:", time.time() - t0)
|
167 |
+
|
168 |
+
for i in range(4):
|
169 |
+
print((r1[i]-r2[i]).abs().mean(), (r1[i]+r2[i]).abs().mean())
|
scripts/project_mesh.py
CHANGED
@@ -13,17 +13,6 @@ from pytorch3d.renderer import (
|
|
13 |
)
|
14 |
from pytorch3d.renderer import MeshRasterizer
|
15 |
|
16 |
-
def get_camera(world_to_cam, fov_in_degrees=60, focal_length=1 / (2**0.5), cam_type='fov'):
|
17 |
-
# pytorch3d expects transforms as row-vectors, so flip rotation: https://github.com/facebookresearch/pytorch3d/issues/1183
|
18 |
-
R = world_to_cam[:3, :3].t()[None, ...]
|
19 |
-
T = world_to_cam[:3, 3][None, ...]
|
20 |
-
if cam_type == 'fov':
|
21 |
-
camera = FoVPerspectiveCameras(device=world_to_cam.device, R=R, T=T, fov=fov_in_degrees, degrees=True)
|
22 |
-
else:
|
23 |
-
focal_length = 1 / focal_length
|
24 |
-
camera = FoVOrthographicCameras(device=world_to_cam.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length)
|
25 |
-
return camera
|
26 |
-
|
27 |
def render_pix2faces_py3d(meshes, cameras, H=512, W=512, blur_radius=0.0, faces_per_pixel=1):
|
28 |
"""
|
29 |
Renders pix2face of visible faces.
|
@@ -98,11 +87,11 @@ class Pix2FacesRenderer:
|
|
98 |
pix2faces_renderer = None
|
99 |
|
100 |
def get_visible_faces(meshes: Meshes, cameras: CamerasBase, resolution=1024):
|
101 |
-
global pix2faces_renderer
|
102 |
-
if pix2faces_renderer is None:
|
103 |
-
|
104 |
-
|
105 |
-
pix_to_face = pix2faces_renderer.render_pix2faces_nvdiff(meshes, cameras, H=resolution, W=resolution)
|
106 |
|
107 |
unique_faces = torch.unique(pix_to_face.flatten())
|
108 |
unique_faces = unique_faces[unique_faces != -1]
|
@@ -313,12 +302,19 @@ def multiview_color_projection(meshes: Meshes, image_list: List[Image.Image], ca
|
|
313 |
del meshes
|
314 |
return ret_mesh
|
315 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
def get_cameras_list(azim_list, device, focal=2/1.35, dist=1.1):
|
317 |
ret = []
|
318 |
for azim in azim_list:
|
319 |
R, T = look_at_view_transform(dist, 0, azim)
|
320 |
-
|
321 |
-
cameras: OrthographicCameras = get_camera(w2c, focal_length=focal, cam_type='orthogonal').to(device)
|
322 |
ret.append(cameras)
|
323 |
return ret
|
324 |
|
|
|
13 |
)
|
14 |
from pytorch3d.renderer import MeshRasterizer
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
def render_pix2faces_py3d(meshes, cameras, H=512, W=512, blur_radius=0.0, faces_per_pixel=1):
|
17 |
"""
|
18 |
Renders pix2face of visible faces.
|
|
|
87 |
pix2faces_renderer = None
|
88 |
|
89 |
def get_visible_faces(meshes: Meshes, cameras: CamerasBase, resolution=1024):
|
90 |
+
# global pix2faces_renderer
|
91 |
+
# if pix2faces_renderer is None:
|
92 |
+
# pix2faces_renderer = Pix2FacesRenderer()
|
93 |
+
pix_to_face = render_pix2faces_py3d(meshes, cameras, H=resolution, W=resolution)['pix_to_face']
|
94 |
+
# pix_to_face = pix2faces_renderer.render_pix2faces_nvdiff(meshes, cameras, H=resolution, W=resolution)
|
95 |
|
96 |
unique_faces = torch.unique(pix_to_face.flatten())
|
97 |
unique_faces = unique_faces[unique_faces != -1]
|
|
|
302 |
del meshes
|
303 |
return ret_mesh
|
304 |
|
305 |
+
def get_camera(R, T, fov_in_degrees=60, focal_length=1 / (2**0.5), cam_type='fov'):
|
306 |
+
if cam_type == 'fov':
|
307 |
+
camera = FoVPerspectiveCameras(device=R.device, R=R, T=T, fov=fov_in_degrees, degrees=True)
|
308 |
+
else:
|
309 |
+
focal_length = 1 / focal_length
|
310 |
+
camera = FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length)
|
311 |
+
return camera
|
312 |
+
|
313 |
def get_cameras_list(azim_list, device, focal=2/1.35, dist=1.1):
|
314 |
ret = []
|
315 |
for azim in azim_list:
|
316 |
R, T = look_at_view_transform(dist, 0, azim)
|
317 |
+
cameras: OrthographicCameras = get_camera(R, T, focal_length=focal, cam_type='orthogonal').to(device)
|
|
|
318 |
ret.append(cameras)
|
319 |
return ret
|
320 |
|