import argparse import BboxTools as bbt import gradio as gr import numpy as np from PIL import Image from pytorch3d.renderer import RasterizationSettings, PerspectiveCameras, MeshRasterizer, MeshRenderer, HardPhongShader, BlendParams, camera_position_from_spherical_angles, look_at_rotation, PointLights from pytorch3d.renderer import TexturesVertex as Textures from pytorch3d.structures import Meshes import torch mesh_paths = { "Aeroplane": "CAD_selected/aeroplane.off", "Bicycle": "CAD_selected/bicycle.off", "Boat": "CAD_selected/boat.off", "Bottle": "CAD_selected/bottle.off", "Bus": "CAD_selected/bus.off", "Car": "CAD_selected/car.off", "Chair": "CAD_selected/chair.off", "Diningtable": "CAD_selected/diningtable.off", "Motorbike": "CAD_selected/motorbike.off", "Sofa": "CAD_selected/sofa.off", "Train": "CAD_selected/train.off", "Tvmonitor": "CAD_selected/tvmonitor.off", } def parse_args(): parser = argparse.ArgumentParser(description='Render off') parser.add_argument('--azimuth', type=float) parser.add_argument('--elevation', type=float) parser.add_argument('--theta', type=float) parser.add_argument('--dist', type=float) parser.add_argument('--category', type=str) parser.add_argument('--unit', type=str) parser.add_argument('--img_id', type=int) return parser.parse_args() def rotation_theta(theta, device_=None): # cos -sin 0 # sin cos 0 # 0 0 1 if type(theta) == float: if device_ is None: device_ = 'cpu' theta = torch.ones((1, 1, 1)).to(device_) * theta else: if device_ is None: device_ = theta.device theta = theta.view(-1, 1, 1) mul_ = torch.Tensor([[1, 0, 0, 0, 1, 0, 0, 0, 0], [0, -1, 0, 1, 0, 0, 0, 0, 0]]).view(1, 2, 9).to(device_) bia_ = torch.Tensor([0] * 8 + [1]).view(1, 1, 9).to(device_) # [n, 1, 2] cos_sin = torch.cat((torch.cos(theta), torch.sin(theta)), dim=2).to(device_) # [n, 1, 2] @ [1, 2, 9] + [1, 1, 9] => [n, 1, 9] => [n, 3, 3] trans = torch.matmul(cos_sin, mul_) + bia_ trans = trans.view(-1, 3, 3) return trans def campos_to_R_T(campos, theta, device='cpu', at=((0, 0, 0),), up=((0, 1, 0), )): R = look_at_rotation(campos, at=at, device=device, up=up) # (n, 3, 3) R = torch.bmm(R, rotation_theta(theta, device_=device)) T = -torch.bmm(R.transpose(1, 2), campos.unsqueeze(2))[:, :, 0] # (1, 3) return R, T def load_off(off_file_name, to_torch=False): file_handle = open(off_file_name) file_list = file_handle.readlines() n_points = int(file_list[1].split(' ')[0]) all_strings = ''.join(file_list[2:2 + n_points]) array_ = np.fromstring(all_strings, dtype=np.float32, sep='\n') all_strings = ''.join(file_list[2 + n_points:]) array_int = np.fromstring(all_strings, dtype=np.int32, sep='\n') array_ = array_.reshape((-1, 3)) if not to_torch: return array_, array_int.reshape((-1, 4))[:, 1::] else: return torch.from_numpy(array_), torch.from_numpy(array_int.reshape((-1, 4))[:, 1::]) def pre_process_mesh_pascal(verts): verts = torch.cat((verts[:, 0:1], verts[:, 2:3], -verts[:, 1:2]), dim=1) return verts def render(azimuth, elevation, theta, dist, category, unit, img_id): azimuth = float(azimuth) elevation = float(elevation) theta = float(theta) dist = float(dist) h, w = 256, 256 render_image_size = max(h, w) crop_size = (256, 256) device = 'cpu' cameras = PerspectiveCameras(focal_length=12.0, device=device) raster_settings = RasterizationSettings( image_size=render_image_size, blur_radius=0.0, faces_per_pixel=1, bin_size=0 ) raster_settings1 = RasterizationSettings( image_size=render_image_size // 8, blur_radius=0.0, faces_per_pixel=1, bin_size=0 ) rasterizer = MeshRasterizer( cameras=cameras, raster_settings=raster_settings1 ) lights = PointLights(device=device, location=((2.0, 2.0, -2.0),)) phong_renderer = MeshRenderer( rasterizer=MeshRasterizer( cameras=cameras, raster_settings=raster_settings ), shader=HardPhongShader(device=device, lights=lights, cameras=cameras) ) x3d, xface = load_off(mesh_paths[category]) x3d = x3d * 1.0 verts = torch.from_numpy(x3d).to(device) verts = pre_process_mesh_pascal(verts) faces = torch.from_numpy(xface).to(device) verts_rgb = torch.ones_like(verts)[None] # verts_rgb = torch.ones_like(verts)[None] * torch.Tensor(color).view(1, 1, 3).to(verts.device) textures = Textures(verts_rgb.to(device)) meshes = Meshes(verts=[verts], faces=[faces], textures=textures) # meshes = Meshes(verts=[verts], faces=[faces]) C = camera_position_from_spherical_angles(dist, elevation, azimuth, degrees=(unit=='Degree'), device=device) R, T = campos_to_R_T(C, theta, device=device) image = phong_renderer(meshes_world=meshes.clone(), R=R, T=T) image = image[:, ..., :3] box_ = bbt.box_by_shape(crop_size, (render_image_size // 2,) * 2) bbox = box_.bbox image = image[:, bbox[0][0]:bbox[0][1], bbox[1][0]:bbox[1][1], :] image = torch.squeeze(image).detach().cpu().numpy() image = np.array((image / image.max()) * 255).astype(np.uint8) cx, cy = (128, 128) dx = int(-cx + w/2) dy = int(-cy + h/2) image_pad = np.pad(image, ((abs(dy), abs(dy)), (abs(dx), abs(dx)), (0, 0)), mode='edge') image = image_pad[dy+abs(dy):dy+abs(dy)+image.shape[0], dx+abs(dx):dx+abs(dx)+image.shape[1]] Image.fromarray(image).save(f'{img_id:05d}.png') if __name__ == '__main__': args = parse_args() render(args.azimuth, args.elevation, args.theta, args.dist, args.category, args.unit, args.img_id)