File size: 2,161 Bytes
37aeb5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
from tqdm import tqdm
from PIL import Image
import numpy as np
import torch
from typing import List
from mesh_reconstruction.remesh import calc_vertex_normals
from mesh_reconstruction.opt import MeshOptimizer
from mesh_reconstruction.func import make_star_cameras_orthographic
from mesh_reconstruction.render import NormalsRenderer
from scripts.utils import to_py3d_mesh, init_target
def reconstruct_stage1(pils: List[Image.Image], steps=100, vertices=None, faces=None, start_edge_len=0.15, end_edge_len=0.005, decay=0.995, return_mesh=True, loss_expansion_weight=0.1, gain=0.1):
vertices, faces = vertices.to("cuda"), faces.to("cuda")
assert len(pils) == 4
mv,proj = make_star_cameras_orthographic(4, 1)
renderer = NormalsRenderer(mv,proj,list(pils[0].size))
target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
# 1. no rotate
target_images = target_images[[0, 3, 2, 1]]
# 2. init from coarse mesh
opt = MeshOptimizer(vertices,faces, local_edgelen=False, gain=gain, edge_len_lims=(end_edge_len, start_edge_len))
vertices = opt.vertices
mask = target_images[..., -1] < 0.5
for i in tqdm(range(steps)):
opt.zero_grad()
opt._lr *= decay
normals = calc_vertex_normals(vertices,faces)
images = renderer.render(vertices,normals,faces)
loss_expand = 0.5 * ((vertices+normals).detach() - vertices).pow(2).mean()
t_mask = images[..., -1] > 0.5
loss_target_l2 = (images[t_mask] - target_images[t_mask]).abs().pow(2).mean()
loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean()
loss = loss_target_l2 + loss_alpha_target_mask_l2 + loss_expand * loss_expansion_weight
# out of box
loss_oob = (vertices.abs() > 0.99).float().mean() * 10
loss = loss + loss_oob
loss.backward()
opt.step()
vertices,faces = opt.remesh(poisson=False)
vertices, faces = vertices.detach(), faces.detach()
if return_mesh:
return to_py3d_mesh(vertices, faces)
else:
return vertices, faces
|