Spaces:
Paused
Paused
import math | |
import os | |
from dataclasses import dataclass, field | |
from typing import List, Union | |
import numpy as np | |
import PIL.Image | |
import torch | |
import torch.nn.functional as F | |
import trimesh | |
from einops import rearrange | |
from huggingface_hub import hf_hub_download | |
from omegaconf import OmegaConf | |
from PIL import Image | |
from .models.isosurface import MarchingCubeHelper | |
from .utils import ( | |
BaseModule, | |
ImagePreprocessor, | |
find_class, | |
get_spherical_cameras, | |
scale_tensor, | |
) | |
class TSR(BaseModule): | |
class Config(BaseModule.Config): | |
cond_image_size: int | |
image_tokenizer_cls: str | |
image_tokenizer: dict | |
tokenizer_cls: str | |
tokenizer: dict | |
backbone_cls: str | |
backbone: dict | |
post_processor_cls: str | |
post_processor: dict | |
decoder_cls: str | |
decoder: dict | |
renderer_cls: str | |
renderer: dict | |
cfg: Config | |
def from_pretrained( | |
cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str, token=None | |
): | |
if os.path.isdir(pretrained_model_name_or_path): | |
config_path = os.path.join(pretrained_model_name_or_path, config_name) | |
weight_path = os.path.join(pretrained_model_name_or_path, weight_name) | |
else: | |
config_path = hf_hub_download( | |
repo_id=pretrained_model_name_or_path, filename=config_name, token=token | |
) | |
weight_path = hf_hub_download( | |
repo_id=pretrained_model_name_or_path, filename=weight_name, token=token | |
) | |
cfg = OmegaConf.load(config_path) | |
OmegaConf.resolve(cfg) | |
model = cls(cfg) | |
ckpt = torch.load(weight_path, map_location="cpu") | |
model.load_state_dict(ckpt) | |
return model | |
def configure(self): | |
self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)( | |
self.cfg.image_tokenizer | |
) | |
self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer) | |
self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone) | |
self.post_processor = find_class(self.cfg.post_processor_cls)( | |
self.cfg.post_processor | |
) | |
self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder) | |
self.renderer = find_class(self.cfg.renderer_cls)(self.cfg.renderer) | |
self.image_processor = ImagePreprocessor() | |
self.isosurface_helper = None | |
def forward( | |
self, | |
image: Union[ | |
PIL.Image.Image, | |
np.ndarray, | |
torch.FloatTensor, | |
List[PIL.Image.Image], | |
List[np.ndarray], | |
List[torch.FloatTensor], | |
], | |
device: str, | |
) -> torch.FloatTensor: | |
rgb_cond = self.image_processor(image, self.cfg.cond_image_size)[:, None].to( | |
device | |
) | |
batch_size = rgb_cond.shape[0] | |
input_image_tokens: torch.Tensor = self.image_tokenizer( | |
rearrange(rgb_cond, "B Nv H W C -> B Nv C H W", Nv=1), | |
) | |
input_image_tokens = rearrange( | |
input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=1 | |
) | |
tokens: torch.Tensor = self.tokenizer(batch_size) | |
tokens = self.backbone( | |
tokens, | |
encoder_hidden_states=input_image_tokens, | |
) | |
scene_codes = self.post_processor(self.tokenizer.detokenize(tokens)) | |
return scene_codes | |
def render( | |
self, | |
scene_codes, | |
n_views: int, | |
elevation_deg: float = 0.0, | |
camera_distance: float = 1.9, | |
fovy_deg: float = 40.0, | |
height: int = 256, | |
width: int = 256, | |
return_type: str = "pil", | |
): | |
rays_o, rays_d = get_spherical_cameras( | |
n_views, elevation_deg, camera_distance, fovy_deg, height, width | |
) | |
rays_o, rays_d = rays_o.to(scene_codes.device), rays_d.to(scene_codes.device) | |
def process_output(image: torch.FloatTensor): | |
if return_type == "pt": | |
return image | |
elif return_type == "np": | |
return image.detach().cpu().numpy() | |
elif return_type == "pil": | |
return Image.fromarray( | |
(image.detach().cpu().numpy() * 255.0).astype(np.uint8) | |
) | |
else: | |
raise NotImplementedError | |
images = [] | |
for scene_code in scene_codes: | |
images_ = [] | |
for i in range(n_views): | |
with torch.no_grad(): | |
image = self.renderer( | |
self.decoder, scene_code, rays_o[i], rays_d[i] | |
) | |
images_.append(process_output(image)) | |
images.append(images_) | |
return images | |
def set_marching_cubes_resolution(self, resolution: int): | |
if ( | |
self.isosurface_helper is not None | |
and self.isosurface_helper.resolution == resolution | |
): | |
return | |
self.isosurface_helper = MarchingCubeHelper(resolution) | |
def extract_mesh(self, scene_codes, resolution: int = 256, threshold: float = 25.0): | |
self.set_marching_cubes_resolution(resolution) | |
meshes = [] | |
for scene_code in scene_codes: | |
with torch.no_grad(): | |
density = self.renderer.query_triplane( | |
self.decoder, | |
scale_tensor( | |
self.isosurface_helper.grid_vertices.to(scene_codes.device), | |
self.isosurface_helper.points_range, | |
(-self.renderer.cfg.radius, self.renderer.cfg.radius), | |
), | |
scene_code, | |
)["density_act"] | |
v_pos, t_pos_idx = self.isosurface_helper(-(density - threshold)) | |
v_pos = scale_tensor( | |
v_pos, | |
self.isosurface_helper.points_range, | |
(-self.renderer.cfg.radius, self.renderer.cfg.radius), | |
) | |
with torch.no_grad(): | |
color = self.renderer.query_triplane( | |
self.decoder, | |
v_pos, | |
scene_code, | |
)["color"] | |
mesh = trimesh.Trimesh( | |
vertices=v_pos.cpu().numpy(), | |
faces=t_pos_idx.cpu().numpy(), | |
vertex_colors=color.cpu().numpy(), | |
) | |
meshes.append(mesh) | |
return meshes | |