diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..3757098f5ad8f04599078346feb0eea63f25ebaf 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.png* filter=lfs diff=lfs merge=lfs -text \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..373c34284d08b190956df9a7ce5c27c8050d7fbe --- /dev/null +++ b/.gitignore @@ -0,0 +1,35 @@ +__pycache__/ +build/ +*.egg-info/ +*.so +venv_*/ +.vs/ +.vscode/ +.idea/ + +tmp_* +data? +data?? +scripts2 + +model_cache + +logs +videos +images +*.mp4 + +vis_data*/ +logs*/ +data*/ +eval_data*/ + + +*.sh +*.out +batchscript* + +pretrained/ +diff-gaussian-rasterization/ + +tmp_data/ \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..6bab82d5b99220e064c3c0bd2ea7d16b4b9bb28b --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "diff-gaussian-rasterization"] + path = diff-gaussian-rasterization + url = https://github.com/ashawkey/diff-gaussian-rasterization \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..a670ca2766b051923d5f8ac7934bad7ac162ab3e --- /dev/null +++ b/app.py @@ -0,0 +1,135 @@ +import gradio as gr +import os +from PIL import Image +import subprocess +from gradio_model4dgs import Model4DGS +import numpy +import hashlib + +os.system('pip install -e ./simple-knn') +os.system('pip install -e ./diff-gaussian-rasterization') + +from huggingface_hub import hf_hub_download +ckpt_path = hf_hub_download(repo_id="ashawkey/LGM", filename="model_fp16_fixrot.safetensors") + +js_func = """ +function refresh() { + const url = new URL(window.location); + + if (url.searchParams.get('__theme') !== 'light') { + url.searchParams.set('__theme', 'light'); + window.location.href = url.href; + } +} +""" + +# check if there is a picture uploaded or selected +def check_img_input(control_image): + if control_image is None: + raise gr.Error("Please select or upload an input image") + +# check if there is a picture uploaded or selected +def check_video_input(image_block: Image.Image): + img_hash = hashlib.sha256(image_block.tobytes()).hexdigest() + if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')): + raise gr.Error("Please generate a video first") + + +def optimize_stage_1(image_block: Image.Image, preprocess_chk: bool, seed_slider: int): + if not os.path.exists('tmp_data'): + os.makedirs('tmp_data') + img_hash = hashlib.sha256(image_block.tobytes()).hexdigest() + if preprocess_chk: + # save image to a designated path + image_block.save(os.path.join('tmp_data', f'{img_hash}.png')) + + # preprocess image + print(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}') + subprocess.run(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}', shell=True) + else: + image_block.save(os.path.join('tmp_data', f'{img_hash}_rgba.png')) + + # stage 1 + subprocess.run(f'export MKL_THREADING_LAYER=GNU;export MKL_SERVICE_FORCE_INTEL=1;python scripts/gen_vid.py --path tmp_data/{img_hash}_rgba.png --seed {seed_slider} --bg white', shell=True) + + # return [os.path.join('logs', 'tmp_rgba_model.ply')] + return os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4') + + +def optimize_stage_2(image_block: Image.Image, seed_slider: int): + img_hash = hashlib.sha256(image_block.tobytes()).hexdigest() + subprocess.run(f'python lgm/infer.py big --resume {ckpt_path} --test_path tmp_data/{img_hash}_rgba.png', shell=True) + # stage 2 + subprocess.run(f'python main_4d.py --config {os.path.join("configs", "4d_demo.yaml")} input={os.path.join("tmp_data", f"{img_hash}_rgba.png")}', shell=True) + # os.rename(os.path.join('logs', f'{img_hash}_rgba_frames'), os.path.join('logs', f'{img_hash}_{seed_slider:03d}_rgba_frames')) + image_dir = os.path.join('logs', f'{img_hash}_rgba_frames') + # return 'vis_data/tmp_rgba.mp4', [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.ply')] + return [image_dir+f'/{t:03d}.ply' for t in range(28)] + + +if __name__ == "__main__": + _TITLE = '''DreamGaussian4D: Generative 4D Gaussian Splatting''' + + _DESCRIPTION = ''' +
+ We present DreamGausssion4D, an efficient 4D generation framework that builds on Gaussian Splatting. + ''' + _IMG_USER_GUIDE = "Please upload an image in the block above (or choose an example above), select a random seed, and click **Generate Video**. After having the video generated, please click **Generate 4D**." + + # load images in 'data' folder as examples + example_folder = os.path.join(os.path.dirname(__file__), 'data') + example_fns = os.listdir(example_folder) + example_fns.sort() + examples_full = [os.path.join(example_folder, x) for x in example_fns if x.endswith('.png')] + + # Compose demo layout & data flow + with gr.Blocks(title=_TITLE, theme=gr.themes.Soft(), js=js_func) as demo: + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown('# ' + _TITLE) + gr.Markdown(_DESCRIPTION) + + # Image-to-3D + with gr.Row(variant='panel'): + with gr.Column(scale=4): + image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image') + + # elevation_slider = gr.Slider(-90, 90, value=0, step=1, label='Estimated elevation angle') + seed_slider = gr.Slider(0, 100000, value=0, step=1, label='Random Seed') + gr.Markdown( + "random seed for video generation.") + + preprocess_chk = gr.Checkbox(True, + label='Preprocess image automatically (remove background and recenter object)') + + gr.Examples( + examples=examples_full, # NOTE: elements must match inputs list! + inputs=[image_block], + outputs=[image_block], + cache_examples=False, + label='Examples (click one of the images below to start)', + examples_per_page=40 + ) + img_run_btn = gr.Button("Generate Video") + fourd_run_btn = gr.Button("Generate 4D") + img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True) + + with gr.Column(scale=5): + obj3d = gr.Video(label="video",height=290) + obj4d = Model4DGS(label="4D Model", height=500, fps=14) + + img_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_1, + inputs=[image_block, + preprocess_chk, + seed_slider], + outputs=[ + obj3d]) + fourd_run_btn.click(check_video_input, inputs=[image_block], queue=False).success(optimize_stage_2, inputs=[image_block, seed_slider], outputs=[obj4d]) + + # demo.queue().launch(share=True) + demo.queue(max_size=10) # <-- Sets up a queue with default parameters + demo.launch(share=True) \ No newline at end of file diff --git a/cam_utils.py b/cam_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..05e730690fcdab48255c73b0f8298ce165149758 --- /dev/null +++ b/cam_utils.py @@ -0,0 +1,146 @@ +import numpy as np +from scipy.spatial.transform import Rotation as R + +import torch + +def dot(x, y): + if isinstance(x, np.ndarray): + return np.sum(x * y, -1, keepdims=True) + else: + return torch.sum(x * y, -1, keepdim=True) + + +def length(x, eps=1e-20): + if isinstance(x, np.ndarray): + return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps)) + else: + return torch.sqrt(torch.clamp(dot(x, x), min=eps)) + + +def safe_normalize(x, eps=1e-20): + return x / length(x, eps) + + +def look_at(campos, target, opengl=True): + # campos: [N, 3], camera/eye position + # target: [N, 3], object to look at + # return: [N, 3, 3], rotation matrix + if not opengl: + # camera forward aligns with -z + forward_vector = safe_normalize(target - campos) + up_vector = np.array([0, 1, 0], dtype=np.float32) + right_vector = safe_normalize(np.cross(forward_vector, up_vector)) + up_vector = safe_normalize(np.cross(right_vector, forward_vector)) + else: + # camera forward aligns with +z + forward_vector = safe_normalize(campos - target) + up_vector = np.array([0, 1, 0], dtype=np.float32) + right_vector = safe_normalize(np.cross(up_vector, forward_vector)) + up_vector = safe_normalize(np.cross(forward_vector, right_vector)) + R = np.stack([right_vector, up_vector, forward_vector], axis=1) + return R + + +# elevation & azimuth to pose (cam2world) matrix +def orbit_camera(elevation, azimuth, radius=1, is_degree=True, target=None, opengl=True): + # radius: scalar + # elevation: scalar, in (-90, 90), from +y to -y is (-90, 90) + # azimuth: scalar, in (-180, 180), from +z to +x is (0, 90) + # return: [4, 4], camera pose matrix + if is_degree: + elevation = np.deg2rad(elevation) + azimuth = np.deg2rad(azimuth) + x = radius * np.cos(elevation) * np.sin(azimuth) + y = - radius * np.sin(elevation) + z = radius * np.cos(elevation) * np.cos(azimuth) + if target is None: + target = np.zeros([3], dtype=np.float32) + campos = np.array([x, y, z]) + target # [3] + T = np.eye(4, dtype=np.float32) + T[:3, :3] = look_at(campos, target, opengl) + T[:3, 3] = campos + return T + + +class OrbitCamera: + def __init__(self, W, H, r=2, fovy=60, near=0.01, far=100): + self.W = W + self.H = H + self.radius = r # camera distance from center + self.fovy = np.deg2rad(fovy) # deg 2 rad + self.near = near + self.far = far + self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point + self.rot = R.from_matrix(np.eye(3)) + self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized! + + @property + def fovx(self): + return 2 * np.arctan(np.tan(self.fovy / 2) * self.W / self.H) + + @property + def campos(self): + return self.pose[:3, 3] + + # pose (c2w) + @property + def pose(self): + # first move camera to radius + res = np.eye(4, dtype=np.float32) + res[2, 3] = self.radius # opengl convention... + # rotate + rot = np.eye(4, dtype=np.float32) + rot[:3, :3] = self.rot.as_matrix() + res = rot @ res + # translate + res[:3, 3] -= self.center + return res + + # view (w2c) + @property + def view(self): + return np.linalg.inv(self.pose) + + # projection (perspective) + @property + def perspective(self): + y = np.tan(self.fovy / 2) + aspect = self.W / self.H + return np.array( + [ + [1 / (y * aspect), 0, 0, 0], + [0, -1 / y, 0, 0], + [ + 0, + 0, + -(self.far + self.near) / (self.far - self.near), + -(2 * self.far * self.near) / (self.far - self.near), + ], + [0, 0, -1, 0], + ], + dtype=np.float32, + ) + + # intrinsics + @property + def intrinsics(self): + focal = self.H / (2 * np.tan(self.fovy / 2)) + return np.array([focal, focal, self.W // 2, self.H // 2], dtype=np.float32) + + @property + def mvp(self): + return self.perspective @ np.linalg.inv(self.pose) # [4, 4] + + def orbit(self, dx, dy): + # rotate along camera up/side axis! + side = self.rot.as_matrix()[:3, 0] + rotvec_x = self.up * np.radians(-0.05 * dx) + rotvec_y = side * np.radians(-0.05 * dy) + self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot + + def scale(self, delta): + self.radius *= 1.1 ** (-delta) + + def pan(self, dx, dy, dz=0): + # pan in camera coordinate system (careful on the sensitivity!) + self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([-dx, -dy, dz]) \ No newline at end of file diff --git a/configs/4d.yaml b/configs/4d.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8bd4838c9bf813981aa2bf47d153352a8b55c463 --- /dev/null +++ b/configs/4d.yaml @@ -0,0 +1,121 @@ +### Input +# input rgba image path (default to None, can be load in GUI too) +input: +# input text prompt (default to None, can be input in GUI too) +prompt: +# input mesh for stage 2 (auto-search from stage 1 output path if None) +mesh: +# estimated elevation angle for input image +elevation: 0 +# reference image resolution +ref_size: 256 +# density thresh for mesh extraction +density_thresh: 0.5 + +### Output +outdir: logs +mesh_format: frames +save_path: '' +save_model: False + +### Training +# guidance loss weights (0 to disable) +mvdream: False +imagedream: False +lambda_sd: 0 +lambda_zero123: 1 +# use stable-zero123 instead of zero123-xl +stable_zero123: True +lambda_svd: 0 +# training batch size per iter +batch_size: 14 +# training iterations for stage 1 +iters: 500 +# training iterations for stage 2 +iters_refine: 50 +# training camera radius +radius: 1.5 +# training camera fovy +fovy: 49.1 # align with zero123 rendering setting (ref: https://github.com/cvlab-columbia/zero123/blob/main/objaverse-rendering/scripts/blender_script.py#L61 +# training camera min elevation +min_ver: -30 +# training camera max elevation +max_ver: 30 +# checkpoint to load for stage 1 (should be a ply file) +load: +# whether allow geom training in stage 2 +train_geo: False +# prob to invert background color during training (0 = always black, 1 = always white) +invert_bg_prob: 0. +n_views: 4 +t_max: 0.5 + + +### GUI +gui: False +force_cuda_rast: False +# GUI resolution +H: 800 +W: 800 + +### Gaussian splatting +optimize_gaussians: True +position_lr_init: 0.001 +position_lr_final: 0.00002 +position_lr_delay_mult: 0.02 +position_lr_max_steps: 500 +feature_lr: 0.01 +opacity_lr: 0.05 +scaling_lr: 0.005 +rotation_lr: 0.005 + +num_pts: 5000 +sh_degree: 0 +percent_dense: 0.1 +density_start_iter: 3000 +density_end_iter: 3000 +densification_interval: 100 +opacity_reset_interval: 700 +densify_grad_threshold: 0.05 + +# deformation field +deformation_lr_init: 0.00064 +deformation_lr_final: 0.00064 +deformation_lr_delay_mult: 0.01 +grid_lr_init: 0.0064 +grid_lr_final: 0.0064 + +### Textured Mesh +geom_lr: 0.0001 +texture_lr: 0.2 + +deformation: + net_width: 64 + timebase_pe: 4 + defor_depth: 1 + posebase_pe: 10 + scale_rotation_pe: 2 + opacity_pe: 2 + timenet_width: 64 + timenet_output: 32 + bounds: 1.6 + plane_tv_weight: 0.0001 + time_smoothness_weight: 0.01 + l1_time_planes: 0.0001 + kplanes_config: + grid_dimensions: 2 + input_coordinate_dim: 4 + output_coordinate_dim: 32 + resolution: [32, 32, 32, 12] + multires: [1] + no_grid: False + no_mlp: False + no_ds: False + no_dr: False + no_do: True + use_res: True + +data_mode: svd +downsample_rate: 1 +# data_mode: c4d +# downsample_rate: 2 \ No newline at end of file diff --git a/configs/4d_c4d.yaml b/configs/4d_c4d.yaml new file mode 100644 index 0000000000000000000000000000000000000000..68e1dd6603879ac84d87dd6662b532cd93e6fe16 --- /dev/null +++ b/configs/4d_c4d.yaml @@ -0,0 +1,119 @@ +### Input +# input rgba image path (default to None, can be load in GUI too) +input: +# input text prompt (default to None, can be input in GUI too) +prompt: +# input mesh for stage 2 (auto-search from stage 1 output path if None) +mesh: +# estimated elevation angle for input image +elevation: 0 +# reference image resolution +ref_size: 256 +# density thresh for mesh extraction +density_thresh: 0.5 + +### Output +outdir: logs +mesh_format: frames +save_path: '' +save_model: False + +### Training +# guidance loss weights (0 to disable) +mvdream: False +imagedream: False +lambda_sd: 0 +lambda_zero123: 1 +# use stable-zero123 instead of zero123-xl +stable_zero123: True +lambda_svd: 0 +# training batch size per iter +batch_size: 32 +# training iterations for stage 1 +iters: 500 +# training iterations for stage 2 +iters_refine: 50 +# training camera radius +radius: 1.5 +# training camera fovy +fovy: 49.1 # align with zero123 rendering setting (ref: https://github.com/cvlab-columbia/zero123/blob/main/objaverse-rendering/scripts/blender_script.py#L61 +# training camera min elevation +min_ver: -30 +# training camera max elevation +max_ver: 30 +# checkpoint to load for stage 1 (should be a ply file) +load: +# whether allow geom training in stage 2 +train_geo: False +# prob to invert background color during training (0 = always black, 1 = always white) +invert_bg_prob: 0. +n_views: 4 +t_max: 0.5 + + +### GUI +gui: False +force_cuda_rast: False +# GUI resolution +H: 800 +W: 800 + +### Gaussian splatting +optimize_gaussians: True +position_lr_init: 0.001 +position_lr_final: 0.00002 +position_lr_delay_mult: 0.02 +position_lr_max_steps: 500 +feature_lr: 0.01 +opacity_lr: 0.05 +scaling_lr: 0.005 +rotation_lr: 0.005 + +num_pts: 5000 +sh_degree: 0 +percent_dense: 0.1 +density_start_iter: 3000 +density_end_iter: 3000 +densification_interval: 100 +opacity_reset_interval: 700 +densify_grad_threshold: 0.05 + +# deformation field +deformation_lr_init: 0.00064 +deformation_lr_final: 0.00064 +deformation_lr_delay_mult: 0.01 +grid_lr_init: 0.0064 +grid_lr_final: 0.0064 + +### Textured Mesh +geom_lr: 0.0001 +texture_lr: 0.2 + +deformation: + net_width: 64 + timebase_pe: 4 + defor_depth: 1 + posebase_pe: 10 + scale_rotation_pe: 2 + opacity_pe: 2 + timenet_width: 64 + timenet_output: 32 + bounds: 1.6 + plane_tv_weight: 0.0001 + time_smoothness_weight: 0.01 + l1_time_planes: 0.0001 + kplanes_config: + grid_dimensions: 2 + input_coordinate_dim: 4 + output_coordinate_dim: 32 + resolution: [32, 32, 32, 32] + multires: [1] + no_grid: False + no_mlp: False + no_ds: False + no_dr: False + no_do: True + use_res: True + +data_mode: c4d +downsample_rate: 1 \ No newline at end of file diff --git a/configs/4d_c4d_low.yaml b/configs/4d_c4d_low.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fd39269e114eaf804243fcc368324ed3d38a7e13 --- /dev/null +++ b/configs/4d_c4d_low.yaml @@ -0,0 +1,119 @@ +### Input +# input rgba image path (default to None, can be load in GUI too) +input: +# input text prompt (default to None, can be input in GUI too) +prompt: +# input mesh for stage 2 (auto-search from stage 1 output path if None) +mesh: +# estimated elevation angle for input image +elevation: 0 +# reference image resolution +ref_size: 256 +# density thresh for mesh extraction +density_thresh: 0.5 + +### Output +outdir: logs +mesh_format: frames +save_path: '' +save_model: False + +### Training +# guidance loss weights (0 to disable) +mvdream: False +imagedream: False +lambda_sd: 0 +lambda_zero123: 1 +# use stable-zero123 instead of zero123-xl +stable_zero123: True +lambda_svd: 0 +# training batch size per iter +batch_size: 8 +# training iterations for stage 1 +iters: 500 +# training iterations for stage 2 +iters_refine: 50 +# training camera radius +radius: 1.5 +# training camera fovy +fovy: 49.1 # align with zero123 rendering setting (ref: https://github.com/cvlab-columbia/zero123/blob/main/objaverse-rendering/scripts/blender_script.py#L61 +# training camera min elevation +min_ver: -30 +# training camera max elevation +max_ver: 30 +# checkpoint to load for stage 1 (should be a ply file) +load: +# whether allow geom training in stage 2 +train_geo: False +# prob to invert background color during training (0 = always black, 1 = always white) +invert_bg_prob: 0. +n_views: 1 +t_max: 0.5 + + +### GUI +gui: False +force_cuda_rast: False +# GUI resolution +H: 800 +W: 800 + +### Gaussian splatting +optimize_gaussians: True +position_lr_init: 0.001 +position_lr_final: 0.00002 +position_lr_delay_mult: 0.02 +position_lr_max_steps: 500 +feature_lr: 0.01 +opacity_lr: 0.05 +scaling_lr: 0.005 +rotation_lr: 0.005 + +num_pts: 5000 +sh_degree: 0 +percent_dense: 0.1 +density_start_iter: 3000 +density_end_iter: 3000 +densification_interval: 100 +opacity_reset_interval: 700 +densify_grad_threshold: 0.05 + +# deformation field +deformation_lr_init: 0.00064 +deformation_lr_final: 0.00064 +deformation_lr_delay_mult: 0.01 +grid_lr_init: 0.0064 +grid_lr_final: 0.0064 + +### Textured Mesh +geom_lr: 0.0001 +texture_lr: 0.2 + +deformation: + net_width: 64 + timebase_pe: 4 + defor_depth: 1 + posebase_pe: 10 + scale_rotation_pe: 2 + opacity_pe: 2 + timenet_width: 64 + timenet_output: 32 + bounds: 1.6 + plane_tv_weight: 0.0001 + time_smoothness_weight: 0.01 + l1_time_planes: 0.0001 + kplanes_config: + grid_dimensions: 2 + input_coordinate_dim: 4 + output_coordinate_dim: 32 + resolution: [32, 32, 32, 12] + multires: [1] + no_grid: False + no_mlp: False + no_ds: False + no_dr: False + no_do: True + use_res: True + +data_mode: c4d +downsample_rate: 4 \ No newline at end of file diff --git a/configs/4d_demo.yaml b/configs/4d_demo.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b1d5274ed1f0b6663f6f98dfc5b47b4cd5c8b79e --- /dev/null +++ b/configs/4d_demo.yaml @@ -0,0 +1,121 @@ +### Input +# input rgba image path (default to None, can be load in GUI too) +input: +# input text prompt (default to None, can be input in GUI too) +prompt: +# input mesh for stage 2 (auto-search from stage 1 output path if None) +mesh: +# estimated elevation angle for input image +elevation: 0 +# reference image resolution +ref_size: 256 +# density thresh for mesh extraction +density_thresh: 0.5 + +### Output +outdir: logs +mesh_format: frames +save_path: '' +save_model: False + +### Training +# guidance loss weights (0 to disable) +mvdream: False +imagedream: False +lambda_sd: 0 +lambda_zero123: 1 +# use stable-zero123 instead of zero123-xl +stable_zero123: True +lambda_svd: 0 +# training batch size per iter +batch_size: 7 +# training iterations for stage 1 +iters: 500 +# training iterations for stage 2 +iters_refine: 50 +# training camera radius +radius: 1.5 +# training camera fovy +fovy: 49.1 # align with zero123 rendering setting (ref: https://github.com/cvlab-columbia/zero123/blob/main/objaverse-rendering/scripts/blender_script.py#L61 +# training camera min elevation +min_ver: -30 +# training camera max elevation +max_ver: 30 +# checkpoint to load for stage 1 (should be a ply file) +load: +# whether allow geom training in stage 2 +train_geo: False +# prob to invert background color during training (0 = always black, 1 = always white) +invert_bg_prob: 0. +n_views: 1 +t_max: 0.5 + + +### GUI +gui: False +force_cuda_rast: False +# GUI resolution +H: 800 +W: 800 + +### Gaussian splatting +optimize_gaussians: True +position_lr_init: 0.001 +position_lr_final: 0.00002 +position_lr_delay_mult: 0.02 +position_lr_max_steps: 500 +feature_lr: 0.01 +opacity_lr: 0.05 +scaling_lr: 0.005 +rotation_lr: 0.005 + +num_pts: 5000 +sh_degree: 0 +percent_dense: 0.1 +density_start_iter: 3000 +density_end_iter: 3000 +densification_interval: 100 +opacity_reset_interval: 700 +densify_grad_threshold: 0.05 + +# deformation field +deformation_lr_init: 0.00064 +deformation_lr_final: 0.00064 +deformation_lr_delay_mult: 0.01 +grid_lr_init: 0.0064 +grid_lr_final: 0.0064 + +### Textured Mesh +geom_lr: 0.0001 +texture_lr: 0.2 + +deformation: + net_width: 64 + timebase_pe: 4 + defor_depth: 1 + posebase_pe: 10 + scale_rotation_pe: 2 + opacity_pe: 2 + timenet_width: 64 + timenet_output: 32 + bounds: 1.6 + plane_tv_weight: 0.0001 + time_smoothness_weight: 0.01 + l1_time_planes: 0.0001 + kplanes_config: + grid_dimensions: 2 + input_coordinate_dim: 4 + output_coordinate_dim: 32 + resolution: [32, 32, 32, 12] + multires: [1] + no_grid: False + no_mlp: False + no_ds: False + no_dr: False + no_do: True + use_res: True + +data_mode: svd +downsample_rate: 2 +# data_mode: c4d +# downsample_rate: 2 \ No newline at end of file diff --git a/configs/4d_low.yaml b/configs/4d_low.yaml new file mode 100644 index 0000000000000000000000000000000000000000..944892d24aa7471d090b80fc87d84cf0ee013b85 --- /dev/null +++ b/configs/4d_low.yaml @@ -0,0 +1,121 @@ +### Input +# input rgba image path (default to None, can be load in GUI too) +input: +# input text prompt (default to None, can be input in GUI too) +prompt: +# input mesh for stage 2 (auto-search from stage 1 output path if None) +mesh: +# estimated elevation angle for input image +elevation: 0 +# reference image resolution +ref_size: 256 +# density thresh for mesh extraction +density_thresh: 0.5 + +### Output +outdir: logs +mesh_format: frames +save_path: '' +save_model: False + +### Training +# guidance loss weights (0 to disable) +mvdream: False +imagedream: False +lambda_sd: 0 +lambda_zero123: 1 +# use stable-zero123 instead of zero123-xl +stable_zero123: True +lambda_svd: 0 +# training batch size per iter +batch_size: 14 +# training iterations for stage 1 +iters: 500 +# training iterations for stage 2 +iters_refine: 50 +# training camera radius +radius: 1.5 +# training camera fovy +fovy: 49.1 # align with zero123 rendering setting (ref: https://github.com/cvlab-columbia/zero123/blob/main/objaverse-rendering/scripts/blender_script.py#L61 +# training camera min elevation +min_ver: -30 +# training camera max elevation +max_ver: 30 +# checkpoint to load for stage 1 (should be a ply file) +load: +# whether allow geom training in stage 2 +train_geo: False +# prob to invert background color during training (0 = always black, 1 = always white) +invert_bg_prob: 0. +n_views: 1 +t_max: 0.5 + + +### GUI +gui: False +force_cuda_rast: False +# GUI resolution +H: 800 +W: 800 + +### Gaussian splatting +optimize_gaussians: True +position_lr_init: 0.001 +position_lr_final: 0.00002 +position_lr_delay_mult: 0.02 +position_lr_max_steps: 500 +feature_lr: 0.01 +opacity_lr: 0.05 +scaling_lr: 0.005 +rotation_lr: 0.005 + +num_pts: 5000 +sh_degree: 0 +percent_dense: 0.1 +density_start_iter: 3000 +density_end_iter: 3000 +densification_interval: 100 +opacity_reset_interval: 700 +densify_grad_threshold: 0.05 + +# deformation field +deformation_lr_init: 0.00064 +deformation_lr_final: 0.00064 +deformation_lr_delay_mult: 0.01 +grid_lr_init: 0.0064 +grid_lr_final: 0.0064 + +### Textured Mesh +geom_lr: 0.0001 +texture_lr: 0.2 + +deformation: + net_width: 64 + timebase_pe: 4 + defor_depth: 1 + posebase_pe: 10 + scale_rotation_pe: 2 + opacity_pe: 2 + timenet_width: 64 + timenet_output: 32 + bounds: 1.6 + plane_tv_weight: 0.0001 + time_smoothness_weight: 0.01 + l1_time_planes: 0.0001 + kplanes_config: + grid_dimensions: 2 + input_coordinate_dim: 4 + output_coordinate_dim: 32 + resolution: [32, 32, 32, 22] + multires: [1] + no_grid: False + no_mlp: False + no_ds: False + no_dr: False + no_do: True + use_res: True + +data_mode: svd +downsample_rate: 1 +# data_mode: c4d +# downsample_rate: 2 \ No newline at end of file diff --git a/configs/dg.yaml b/configs/dg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..edaf1a1a5b7db6070436ebf9c6a4c01b2dc4784b --- /dev/null +++ b/configs/dg.yaml @@ -0,0 +1,85 @@ +### Input +# input rgba image path (default to None, can be load in GUI too) +input: +# input text prompt (default to None, can be input in GUI too) +prompt: +negative_prompt: +# input mesh for stage 2 (auto-search from stage 1 output path if None) +mesh: +# estimated elevation angle for input image +elevation: 0 +# reference image resolution +ref_size: 256 +# density thresh for mesh extraction +density_thresh: 1 + +### Output +outdir: logs +mesh_format: obj +save_path: '' + +### Training +# use mvdream instead of sd 2.1 +mvdream: False +# use imagedream +imagedream: False +# use stable-zero123 instead of zero123-xl +stable_zero123: False +# guidance loss weights (0 to disable) +lambda_sd: 0 +lambda_zero123: 1 +# warmup rgb supervision for image-to-3d +warmup_rgb_loss: True +# training batch size per iter +batch_size: 1 +# training iterations for stage 1 +iters: 500 +# whether to linearly anneal timestep +anneal_timestep: True +# training iterations for stage 2 +iters_refine: 50 +# training camera radius +radius: 2 +# training camera fovy +fovy: 49.1 # align with zero123 rendering setting (ref: https://github.com/cvlab-columbia/zero123/blob/main/objaverse-rendering/scripts/blender_script.py#L61 +# training camera min elevation +min_ver: -30 +# training camera max elevation +max_ver: 30 +# checkpoint to load for stage 1 (should be a ply file) +load: +# whether allow geom training in stage 2 +train_geo: False +# prob to invert background color during training (0 = always black, 1 = always white) +invert_bg_prob: 0.5 + + +### GUI +gui: False +force_cuda_rast: False +# GUI resolution +H: 800 +W: 800 + +### Gaussian splatting +num_pts: 5000 +sh_degree: 0 +position_lr_init: 0.001 +position_lr_final: 0.00002 +position_lr_delay_mult: 0.02 +position_lr_max_steps: 500 +feature_lr: 0.01 +opacity_lr: 0.05 +scaling_lr: 0.005 +rotation_lr: 0.005 +percent_dense: 0.1 +density_start_iter: 100 +density_end_iter: 3000 +densification_interval: 100 +opacity_reset_interval: 700 +densify_grad_threshold: 0.5 + + +### Textured Mesh +geom_lr: 0.0001 +texture_lr: 0.2 \ No newline at end of file diff --git a/configs/dghd.yaml b/configs/dghd.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c4551c56164646736ac3b8101c0226e0f597f351 --- /dev/null +++ b/configs/dghd.yaml @@ -0,0 +1,72 @@ +### Input +# input rgba image path (default to None, can be load in GUI too) +input: +# input text prompt (default to None, can be input in GUI too) +prompt: +# input mesh for stage 2 (auto-search from stage 1 output path if None) +mesh: +# estimated elevation angle for input image +elevation: 0 +# reference image resolution +ref_size: 256 +# density thresh for mesh extraction +density_thresh: 1 + +### Output +outdir: logs +mesh_format: obj +save_path: '' + +### Training +# guidance loss weights (0 to disable) +lambda_sd: 0 +mvdream: False +lambda_zero123: 1 +# use stable-zero123 instead of zero123-xl +stable_zero123: False +# training batch size per iter +batch_size: 16 +# training iterations for stage 1 +iters: 500 +# training iterations for stage 2 +iters_refine: 50 +# training camera radius +radius: 2 +# training camera fovy +fovy: 49.1 # align with zero123 rendering setting (ref: https://github.com/cvlab-columbia/zero123/blob/main/objaverse-rendering/scripts/blender_script.py#L61 +# checkpoint to load for stage 1 (should be a ply file) +load: +# whether allow geom training in stage 2 +train_geo: False +# prob to invert background color during training (0 = always black, 1 = always white) +invert_bg_prob: 0. + + +### GUI +gui: False +force_cuda_rast: False +# GUI resolution +H: 800 +W: 800 + +### Gaussian splatting +num_pts: 5000 +sh_degree: 0 +position_lr_init: 0.001 +position_lr_final: 0.00002 +position_lr_delay_mult: 0.02 +position_lr_max_steps: 500 +feature_lr: 0.01 +opacity_lr: 0.05 +scaling_lr: 0.005 +rotation_lr: 0.005 +percent_dense: 0.1 +density_start_iter: 0 +density_end_iter: 3000 +densification_interval: 100 +opacity_reset_interval: 700 +densify_grad_threshold: 0.05 + +### Textured Mesh +geom_lr: 0.0001 +texture_lr: 0.2 \ No newline at end of file diff --git a/configs/refine.yaml b/configs/refine.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ce19ece9de0324a26b9f3353492e7a460af86a67 --- /dev/null +++ b/configs/refine.yaml @@ -0,0 +1,79 @@ +### Input +# input rgba image path (default to None, can be load in GUI too) +input: +# input text prompt (default to None, can be input in GUI too) +prompt: +# input mesh for stage 2 (auto-search from stage 1 output path if None) +mesh: +# estimated elevation angle for input image +elevation: 0 +# reference image resolution +ref_size: 256 +# density thresh for mesh extraction +density_thresh: 1 + +### Output +outdir: logs +mesh_format: obj +save_path: '' + +### Training +# guidance loss weights (0 to disable) +lambda_sd: 0 +mvdream: False +lambda_zero123: 0 +# use stable-zero123 instead of zero123-xl +stable_zero123: False +lambda_svd: 1 +# training batch size per iter +batch_size: 1 +# training iterations for stage 1 +iters: 500 +# training iterations for stage 2 +iters_refine: 50 +# training camera radius +radius: 1.5 +# training camera fovy +fovy: 49.1 # align with zero123 rendering setting (ref: https://github.com/cvlab-columbia/zero123/blob/main/objaverse-rendering/scripts/blender_script.py#L61 +# checkpoint to load for stage 1 (should be a ply file) +load: +# whether allow geom training in stage 2 +train_geo: False +# prob to invert background color during training (0 = always black, 1 = always white) +invert_bg_prob: 0.5 + + +### GUI +gui: False +force_cuda_rast: False +# GUI resolution +H: 800 +W: 800 + +### Gaussian splatting +num_pts: 5000 +sh_degree: 0 +position_lr_init: 0.001 +position_lr_final: 0.00002 +position_lr_delay_mult: 0.02 +position_lr_max_steps: 500 +feature_lr: 0.01 +opacity_lr: 0.05 +scaling_lr: 0.005 +rotation_lr: 0.005 +percent_dense: 0.1 +density_start_iter: 100 +density_end_iter: 3000 +densification_interval: 100 +opacity_reset_interval: 700 +densify_grad_threshold: 0.5 + +### Textured Mesh +geom_lr: 0.0001 +texture_lr: 0.2 + +static_model: lgm +data_mode: svd +downsample_rate: 2 + +oom_hack: False \ No newline at end of file diff --git a/data/anya_rgba.png b/data/anya_rgba.png new file mode 100644 index 0000000000000000000000000000000000000000..089499e16e410207c890b45bc865627352df967d Binary files /dev/null and b/data/anya_rgba.png differ diff --git a/data/catstatue_rgba.png b/data/catstatue_rgba.png new file mode 100644 index 0000000000000000000000000000000000000000..8d64139587a7bc1e951e1750e862c8b530d5689d Binary files /dev/null and b/data/catstatue_rgba.png differ diff --git a/data/csm_luigi_rgba.png b/data/csm_luigi_rgba.png new file mode 100644 index 0000000000000000000000000000000000000000..162f6bbaa00fe87393aaf0040695b714fdf09493 Binary files /dev/null and b/data/csm_luigi_rgba.png differ diff --git a/data/zelda_rgba.png b/data/zelda_rgba.png new file mode 100644 index 0000000000000000000000000000000000000000..4ee1642f52ae6dd8ec19e83bbd9f778ecacc5bea Binary files /dev/null and b/data/zelda_rgba.png differ diff --git a/gaussian_model_4d.py b/gaussian_model_4d.py new file mode 100644 index 0000000000000000000000000000000000000000..13f952ada5708977dd8aa2853b8117850ec6fedb --- /dev/null +++ b/gaussian_model_4d.py @@ -0,0 +1,773 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import numpy as np +from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation +from torch import nn +import os +from utils.system_utils import mkdir_p +from plyfile import PlyData, PlyElement +from random import randint +from utils.sh_utils import RGB2SH +from simple_knn._C import distCUDA2 +from utils.graphics_utils import BasicPointCloud +from utils.general_utils import strip_symmetric, build_scaling_rotation +from scene.deformation import deform_network +from scene.regulation import compute_plane_smoothness + + +def gaussian_3d_coeff(xyzs, covs): + # xyzs: [N, 3] + # covs: [N, 6] + x, y, z = xyzs[:, 0], xyzs[:, 1], xyzs[:, 2] + a, b, c, d, e, f = covs[:, 0], covs[:, 1], covs[:, 2], covs[:, 3], covs[:, 4], covs[:, 5] + + # eps must be small enough !!! + inv_det = 1 / (a * d * f + 2 * e * c * b - e**2 * a - c**2 * d - b**2 * f + 1e-24) + inv_a = (d * f - e**2) * inv_det + inv_b = (e * c - b * f) * inv_det + inv_c = (e * b - c * d) * inv_det + inv_d = (a * f - c**2) * inv_det + inv_e = (b * c - e * a) * inv_det + inv_f = (a * d - b**2) * inv_det + + power = -0.5 * (x**2 * inv_a + y**2 * inv_d + z**2 * inv_f) - x * y * inv_b - x * z * inv_c - y * z * inv_e + + power[power > 0] = -1e10 # abnormal values... make weights 0 + + return torch.exp(power) + +class GaussianModel: + + def setup_functions(self): + def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): + L = build_scaling_rotation(scaling_modifier * scaling, rotation) + actual_covariance = L @ L.transpose(1, 2) + symm = strip_symmetric(actual_covariance) + return symm + + self.scaling_activation = torch.exp + self.scaling_inverse_activation = torch.log + + self.covariance_activation = build_covariance_from_scaling_rotation + + self.opacity_activation = torch.sigmoid + self.inverse_opacity_activation = inverse_sigmoid + + self.rotation_activation = torch.nn.functional.normalize + + + def __init__(self, sh_degree : int, args): + self.active_sh_degree = 0 + self.max_sh_degree = sh_degree + self._xyz = torch.empty(0) + # self._deformation = torch.empty(0) + self._deformation = deform_network(args) + # self.grid = TriPlaneGrid() + self._features_dc = torch.empty(0) + self._features_rest = torch.empty(0) + self._scaling = torch.empty(0) + self._rotation = torch.empty(0) + self._opacity = torch.empty(0) + self.max_radii2D = torch.empty(0) + self.xyz_gradient_accum = torch.empty(0) + self.denom = torch.empty(0) + self.optimizer = None + self.percent_dense = 0 + self.spatial_lr_scale = 0 + self._deformation_table = torch.empty(0) + self.setup_functions() + + def capture(self): + return ( + self.active_sh_degree, + self._xyz, + self._deformation.state_dict(), + self._deformation_table, + # self.grid, + self._features_dc, + self._features_rest, + self._scaling, + self._rotation, + self._opacity, + self.max_radii2D, + self.xyz_gradient_accum, + self.denom, + self.optimizer.state_dict(), + self.spatial_lr_scale, + ) + + def restore(self, model_args, training_args): + (self.active_sh_degree, + self._xyz, + self._deformation_table, + self._deformation, + # self.grid, + self._features_dc, + self._features_rest, + self._scaling, + self._rotation, + self._opacity, + self.max_radii2D, + xyz_gradient_accum, + denom, + opt_dict, + self.spatial_lr_scale) = model_args + self.training_setup(training_args) + self.xyz_gradient_accum = xyz_gradient_accum + self.denom = denom + self.optimizer.load_state_dict(opt_dict) + + @property + def get_scaling(self): + return self.scaling_activation(self._scaling) + + @property + def get_rotation(self): + return self.rotation_activation(self._rotation) + + @property + def get_xyz(self): + return self._xyz + + @property + def get_features(self): + features_dc = self._features_dc + features_rest = self._features_rest + return torch.cat((features_dc, features_rest), dim=1) + + @property + def get_opacity(self): + return self.opacity_activation(self._opacity) + + + def get_deformed_everything(self, time): + means3D = self.get_xyz + time = torch.tensor(time).to(means3D.device).repeat(means3D.shape[0],1) + time = ((time.float() / self.T) - 0.5) * 2 + + opacity = self._opacity + scales = self._scaling + rotations = self._rotation + + deformation_point = self._deformation_table + means3D_deform, scales_deform, rotations_deform, opacity_deform = self._deformation(means3D[deformation_point], scales[deformation_point], + rotations[deformation_point], opacity[deformation_point], + time[deformation_point]) + + means3D_final = means3D + means3D_deform + rotations_final = rotations + rotations_deform + scales_final = scales + scales_deform + opacity_final = opacity + + return means3D_final, rotations_final, scales_final, opacity_final + + + + @torch.no_grad() + def extract_fields_t(self, resolution=128, num_blocks=16, relax_ratio=1.5, t=0): + # resolution: resolution of field + + block_size = 2 / num_blocks + + assert resolution % block_size == 0 + split_size = resolution // num_blocks + + xyzs, rotation, scale, opacities = self.get_deformed_everything(t) + + scale = self.scaling_activation(scale) + opacities = self.opacity_activation(opacities) + + # pre-filter low opacity gaussians to save computation + mask = (opacities > 0.005).squeeze(1) + + opacities = opacities[mask] + xyzs = xyzs[mask] + stds = scale[mask] + + # normalize to ~ [-1, 1] + mn, mx = xyzs.amin(0), xyzs.amax(0) + self.center = (mn + mx) / 2 + self.scale = 1.8 / (mx - mn).amax().item() + + xyzs = (xyzs - self.center) * self.scale + stds = stds * self.scale + + covs = self.covariance_activation(stds, 1, rotation[mask]) + + # tile + device = opacities.device + occ = torch.zeros([resolution] * 3, dtype=torch.float32, device=device) + + X = torch.linspace(-1, 1, resolution).split(split_size) + Y = torch.linspace(-1, 1, resolution).split(split_size) + Z = torch.linspace(-1, 1, resolution).split(split_size) + + + # loop blocks (assume max size of gaussian is small than relax_ratio * block_size !!!) + for xi, xs in enumerate(X): + for yi, ys in enumerate(Y): + for zi, zs in enumerate(Z): + xx, yy, zz = torch.meshgrid(xs, ys, zs) + # sample points [M, 3] + pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1).to(device) + # in-tile gaussians mask + vmin, vmax = pts.amin(0), pts.amax(0) + vmin -= block_size * relax_ratio + vmax += block_size * relax_ratio + mask = (xyzs < vmax).all(-1) & (xyzs > vmin).all(-1) + # if hit no gaussian, continue to next block + if not mask.any(): + continue + mask_xyzs = xyzs[mask] # [L, 3] + mask_covs = covs[mask] # [L, 6] + mask_opas = opacities[mask].view(1, -1) # [L, 1] --> [1, L] + + # query per point-gaussian pair. + g_pts = pts.unsqueeze(1).repeat(1, mask_covs.shape[0], 1) - mask_xyzs.unsqueeze(0) # [M, L, 3] + g_covs = mask_covs.unsqueeze(0).repeat(pts.shape[0], 1, 1) # [M, L, 6] + + # batch on gaussian to avoid OOM + batch_g = 1024 + val = 0 + for start in range(0, g_covs.shape[1], batch_g): + end = min(start + batch_g, g_covs.shape[1]) + w = gaussian_3d_coeff(g_pts[:, start:end].reshape(-1, 3), g_covs[:, start:end].reshape(-1, 6)).reshape(pts.shape[0], -1) # [M, l] + val += (mask_opas[:, start:end] * w).sum(-1) + + # kiui.lo(val, mask_opas, w) + + occ[xi * split_size: xi * split_size + len(xs), + yi * split_size: yi * split_size + len(ys), + zi * split_size: zi * split_size + len(zs)] = val.reshape(len(xs), len(ys), len(zs)) + return occ + + def extract_mesh_t(self, path, density_thresh=1, t=0, resolution=128, decimate_target=1e5): + from mesh import Mesh + from mesh_utils import decimate_mesh, clean_mesh + + os.makedirs(os.path.dirname(path), exist_ok=True) + + occ = self.extract_fields_t(resolution, t=t).detach().cpu().numpy() + + import mcubes + vertices, triangles = mcubes.marching_cubes(occ, density_thresh) + vertices = vertices / (resolution - 1.0) * 2 - 1 + + # transform back to the original space + vertices = vertices / self.scale + self.center.detach().cpu().numpy() + + vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=0.015) + if decimate_target > 0 and triangles.shape[0] > decimate_target: + vertices, triangles = decimate_mesh(vertices, triangles, decimate_target) + + v = torch.from_numpy(vertices.astype(np.float32)).contiguous().cuda() + f = torch.from_numpy(triangles.astype(np.int32)).contiguous().cuda() + + print( + f"[INFO] marching cubes result: {v.shape} ({v.min().item()}-{v.max().item()}), {f.shape}" + ) + + mesh = Mesh(v=v, f=f, device='cuda') + + return mesh + + def get_covariance(self, scaling_modifier = 1): + return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation) + + def oneupSHdegree(self): + if self.active_sh_degree < self.max_sh_degree: + self.active_sh_degree += 1 + + def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float, time_line: int): + self.spatial_lr_scale = spatial_lr_scale + fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda() + fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda()) + features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() + features[:, :3, 0 ] = fused_color + features[:, 3:, 1:] = 0.0 + + print("Number of points at initialisation : ", fused_point_cloud.shape[0]) + + dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001) + scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3) + rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") + rots[:, 0] = 1 + + opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) + + self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) + self._deformation = self._deformation.to("cuda") + # self.grid = self.grid.to("cuda") + self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True)) + self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True)) + self._scaling = nn.Parameter(scales.requires_grad_(True)) + self._rotation = nn.Parameter(rots.requires_grad_(True)) + self._opacity = nn.Parameter(opacities.requires_grad_(True)) + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0) + + def training_setup(self, training_args): + self.percent_dense = training_args.percent_dense + self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device="cuda") + self.T = training_args.batch_size + + if training_args.optimize_gaussians: + l = [ + {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"}, + {'params': list(self._deformation.get_mlp_parameters()), 'lr': training_args.deformation_lr_init * self.spatial_lr_scale, "name": "deformation"}, + {'params': list(self._deformation.get_grid_parameters()), 'lr': training_args.grid_lr_init * self.spatial_lr_scale, "name": "grid"}, + {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"}, + {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"}, + {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"}, + {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"}, + {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"} + ] + else: + l = [ + {'params': list(self._deformation.get_mlp_parameters()), 'lr': training_args.deformation_lr_init * self.spatial_lr_scale, "name": "deformation"}, + {'params': list(self._deformation.get_grid_parameters()), 'lr': training_args.grid_lr_init * self.spatial_lr_scale, "name": "grid"}, + ] + + self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) + self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale, + lr_final=training_args.position_lr_final*self.spatial_lr_scale, + lr_delay_mult=training_args.position_lr_delay_mult, + max_steps=training_args.position_lr_max_steps) + self.deformation_scheduler_args = get_expon_lr_func(lr_init=training_args.deformation_lr_init*self.spatial_lr_scale, + lr_final=training_args.deformation_lr_final*self.spatial_lr_scale, + lr_delay_mult=training_args.deformation_lr_delay_mult, + max_steps=training_args.position_lr_max_steps) + self.grid_scheduler_args = get_expon_lr_func(lr_init=training_args.grid_lr_init*self.spatial_lr_scale, + lr_final=training_args.grid_lr_final*self.spatial_lr_scale, + lr_delay_mult=training_args.deformation_lr_delay_mult, + max_steps=training_args.position_lr_max_steps) + + def update_learning_rate(self, iteration): + ''' Learning rate scheduling per step ''' + for param_group in self.optimizer.param_groups: + if param_group["name"] == "xyz": + lr = self.xyz_scheduler_args(iteration) + param_group['lr'] = lr + # return lr + if "grid" in param_group["name"]: + lr = self.grid_scheduler_args(iteration) + param_group['lr'] = lr + # return lr + elif param_group["name"] == "deformation": + lr = self.deformation_scheduler_args(iteration) + param_group['lr'] = lr + # return lr + + def construct_list_of_attributes(self): + l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] + # All channels except the 3 DC + for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): + l.append('f_dc_{}'.format(i)) + for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]): + l.append('f_rest_{}'.format(i)) + l.append('opacity') + for i in range(self._scaling.shape[1]): + l.append('scale_{}'.format(i)) + for i in range(self._rotation.shape[1]): + l.append('rot_{}'.format(i)) + return l + def compute_deformation(self,time): + + deform = self._deformation[:,:,:time].sum(dim=-1) + xyz = self._xyz + deform + return xyz + + def load_model(self, path, name): + print("loading model from exists{}".format(path)) + weight_dict = torch.load(os.path.join(path, name+"_deformation.pth"),map_location="cuda") + self._deformation.load_state_dict(weight_dict) + self._deformation = self._deformation.to("cuda") + self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0) + self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device="cuda") + if os.path.exists(os.path.join(path, name+"_deformation_table.pth")): + self._deformation_table = torch.load(os.path.join(path, name+"_deformation_table.pth"),map_location="cuda") + if os.path.exists(os.path.join(path,name+"_deformation_accum.pth")): + self._deformation_accum = torch.load(os.path.join(path, name+"_deformation_accum.pth"),map_location="cuda") + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + + def save_deformation(self, path, name): + torch.save(self._deformation.state_dict(),os.path.join(path, name+"_deformation.pth")) + torch.save(self._deformation_table,os.path.join(path, name+"_deformation_table.pth")) + torch.save(self._deformation_accum,os.path.join(path, name+"_deformation_accum.pth")) + + def save_ply(self, path): + mkdir_p(os.path.dirname(path)) + + xyz = self._xyz.detach().cpu().numpy() + normals = np.zeros_like(xyz) + f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + opacities = self._opacity.detach().cpu().numpy() + scale = self._scaling.detach().cpu().numpy() + rotation = self._rotation.detach().cpu().numpy() + + dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] + + elements = np.empty(xyz.shape[0], dtype=dtype_full) + attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) + elements[:] = list(map(tuple, attributes)) + el = PlyElement.describe(elements, 'vertex') + PlyData([el]).write(path) + + def save_frame_ply(self, path, t): + mkdir_p(os.path.dirname(path)) + + xyzs, rotation, scale, opacities = self.get_deformed_everything(t) + + xyz = xyzs.detach().cpu().numpy() + normals = np.zeros_like(xyz) + f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + opacities = opacities.detach().cpu().numpy() + scale = scale.detach().cpu().numpy() + rotation = rotation.detach().cpu().numpy() + + dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] + + elements = np.empty(xyz.shape[0], dtype=dtype_full) + attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) + elements[:] = list(map(tuple, attributes)) + el = PlyElement.describe(elements, 'vertex') + PlyData([el]).write(path) + # def save_frame_ply(self, path, t): + # mkdir_p(os.path.dirname(path)) + + # xyz = self._xyz.detach().cpu().numpy() + # normals = np.zeros_like(xyz) + # f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + # f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + # opacities = self._opacity.detach().cpu().numpy() + # scale = self._scaling.detach().cpu().numpy() + # rotation = self._rotation.detach().cpu().numpy() + + # dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] + + # elements = np.empty(xyz.shape[0], dtype=dtype_full) + # attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) + # elements[:] = list(map(tuple, attributes)) + # el = PlyElement.describe(elements, 'vertex') + # PlyData([el]).write(path) + + def reset_opacity(self): + opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01)) + optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") + self._opacity = optimizable_tensors["opacity"] + + def load_ply(self, path): + self.spatial_lr_scale = 1 + plydata = PlyData.read(path) + + xyz = np.stack((np.asarray(plydata.elements[0]["x"]), + np.asarray(plydata.elements[0]["y"]), + np.asarray(plydata.elements[0]["z"])), axis=1) + opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] + + features_dc = np.zeros((xyz.shape[0], 3, 1)) + features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) + features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) + features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) + + extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] + extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) + assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3 + features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) + for idx, attr_name in enumerate(extra_f_names): + features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) + # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) + features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) + + scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] + scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) + scales = np.zeros((xyz.shape[0], len(scale_names))) + for idx, attr_name in enumerate(scale_names): + scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] + rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) + rots = np.zeros((xyz.shape[0], len(rot_names))) + for idx, attr_name in enumerate(rot_names): + rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) + self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) + self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) + self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) + self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) + self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) + self.active_sh_degree = self.max_sh_degree + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + self._deformation = self._deformation.to("cuda") + self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0) # everything deformed + + print(self._xyz.shape) + + + def replace_tensor_to_optimizer(self, tensor, name): + optimizable_tensors = {} + for group in self.optimizer.param_groups: + if group["name"] == name: + stored_state = self.optimizer.state.get(group['params'][0], None) + stored_state["exp_avg"] = torch.zeros_like(tensor) + stored_state["exp_avg_sq"] = torch.zeros_like(tensor) + + del self.optimizer.state[group['params'][0]] + group["params"][0] = nn.Parameter(tensor.requires_grad_(True)) + self.optimizer.state[group['params'][0]] = stored_state + + optimizable_tensors[group["name"]] = group["params"][0] + return optimizable_tensors + + def _prune_optimizer(self, mask): + optimizable_tensors = {} + for group in self.optimizer.param_groups: + if len(group["params"]) > 1: + continue + stored_state = self.optimizer.state.get(group['params'][0], None) + if stored_state is not None: + stored_state["exp_avg"] = stored_state["exp_avg"][mask] + stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] + + del self.optimizer.state[group['params'][0]] + group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True))) + self.optimizer.state[group['params'][0]] = stored_state + + optimizable_tensors[group["name"]] = group["params"][0] + else: + group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) + optimizable_tensors[group["name"]] = group["params"][0] + return optimizable_tensors + + def prune_points(self, mask): + valid_points_mask = ~mask + optimizable_tensors = self._prune_optimizer(valid_points_mask) + + self._xyz = optimizable_tensors["xyz"] + self._features_dc = optimizable_tensors["f_dc"] + self._features_rest = optimizable_tensors["f_rest"] + self._opacity = optimizable_tensors["opacity"] + self._scaling = optimizable_tensors["scaling"] + self._rotation = optimizable_tensors["rotation"] + self._deformation_accum = self._deformation_accum[valid_points_mask] + self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] + self._deformation_table = self._deformation_table[valid_points_mask] + self.denom = self.denom[valid_points_mask] + self.max_radii2D = self.max_radii2D[valid_points_mask] + + def cat_tensors_to_optimizer(self, tensors_dict): + optimizable_tensors = {} + for group in self.optimizer.param_groups: + if len(group["params"])>1:continue + assert len(group["params"]) == 1 + extension_tensor = tensors_dict[group["name"]] + stored_state = self.optimizer.state.get(group['params'][0], None) + if stored_state is not None: + + stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0) + stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0) + + del self.optimizer.state[group['params'][0]] + group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) + self.optimizer.state[group['params'][0]] = stored_state + + optimizable_tensors[group["name"]] = group["params"][0] + else: + group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) + optimizable_tensors[group["name"]] = group["params"][0] + + return optimizable_tensors + + def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_deformation_table): + d = {"xyz": new_xyz, + "f_dc": new_features_dc, + "f_rest": new_features_rest, + "opacity": new_opacities, + "scaling" : new_scaling, + "rotation" : new_rotation, + # "deformation": new_deformation + } + + optimizable_tensors = self.cat_tensors_to_optimizer(d) + self._xyz = optimizable_tensors["xyz"] + self._features_dc = optimizable_tensors["f_dc"] + self._features_rest = optimizable_tensors["f_rest"] + self._opacity = optimizable_tensors["opacity"] + self._scaling = optimizable_tensors["scaling"] + self._rotation = optimizable_tensors["rotation"] + # self._deformation = optimizable_tensors["deformation"] + + self._deformation_table = torch.cat([self._deformation_table,new_deformation_table],-1) + self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self._deformation_accum = torch.zeros((self.get_xyz.shape[0], 3), device="cuda") + self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + + def densify_and_split(self, grads, grad_threshold, scene_extent, N=2): + n_init_points = self.get_xyz.shape[0] + # Extract points that satisfy the gradient condition + padded_grad = torch.zeros((n_init_points), device="cuda") + padded_grad[:grads.shape[0]] = grads.squeeze() + selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False) + selected_pts_mask = torch.logical_and(selected_pts_mask, + torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent) + if not selected_pts_mask.any(): + return + stds = self.get_scaling[selected_pts_mask].repeat(N,1) + means =torch.zeros((stds.size(0), 3),device="cuda") + samples = torch.normal(mean=means, std=stds) + rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1) + new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1) + new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N)) + new_rotation = self._rotation[selected_pts_mask].repeat(N,1) + new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1) + new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1) + new_opacity = self._opacity[selected_pts_mask].repeat(N,1) + new_deformation_table = self._deformation_table[selected_pts_mask].repeat(N) + self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation, new_deformation_table) + + prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool))) + self.prune_points(prune_filter) + + def densify_and_clone(self, grads, grad_threshold, scene_extent): + # Extract points that satisfy the gradient condition + selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False) + selected_pts_mask = torch.logical_and(selected_pts_mask, + torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent) + + new_xyz = self._xyz[selected_pts_mask] + # - 0.001 * self._xyz.grad[selected_pts_mask] + new_features_dc = self._features_dc[selected_pts_mask] + new_features_rest = self._features_rest[selected_pts_mask] + new_opacities = self._opacity[selected_pts_mask] + new_scaling = self._scaling[selected_pts_mask] + new_rotation = self._rotation[selected_pts_mask] + new_deformation_table = self._deformation_table[selected_pts_mask] + + self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_deformation_table) + + def prune(self, min_opacity, extent, max_screen_size): + prune_mask = (self.get_opacity < min_opacity).squeeze() + # prune_mask_2 = torch.logical_and(self.get_opacity <= inverse_sigmoid(0.101 , dtype=torch.float, device="cuda"), self.get_opacity >= inverse_sigmoid(0.999 , dtype=torch.float, device="cuda")) + # prune_mask = torch.logical_or(prune_mask, prune_mask_2) + # deformation_sum = abs(self._deformation).sum(dim=-1).mean(dim=-1) + # deformation_mask = (deformation_sum < torch.quantile(deformation_sum, torch.tensor([0.5]).to("cuda"))) + # prune_mask = prune_mask & deformation_mask + if max_screen_size: + big_points_vs = self.max_radii2D > max_screen_size + big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent + prune_mask = torch.logical_or(prune_mask, big_points_vs) + + prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws) + self.prune_points(prune_mask) + + torch.cuda.empty_cache() + def densify(self, max_grad, min_opacity, extent, max_screen_size): + grads = self.xyz_gradient_accum / self.denom + grads[grads.isnan()] = 0.0 + + self.densify_and_clone(grads, max_grad, extent) + self.densify_and_split(grads, max_grad, extent) + def standard_constaint(self): + + means3D = self._xyz.detach() + scales = self._scaling.detach() + rotations = self._rotation.detach() + opacity = self._opacity.detach() + time = torch.tensor(0).to("cuda").repeat(means3D.shape[0],1) + means3D_deform, scales_deform, rotations_deform, _ = self._deformation(means3D, scales, rotations, opacity, time) + position_error = (means3D_deform - means3D)**2 + rotation_error = (rotations_deform - rotations)**2 + scaling_erorr = (scales_deform - scales)**2 + return position_error.mean() + rotation_error.mean() + scaling_erorr.mean() + + + def add_densification_stats(self, viewspace_point_tensor, update_filter): + self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor[update_filter,:2], dim=-1, keepdim=True) + self.denom[update_filter] += 1 + @torch.no_grad() + def update_deformation_table(self,threshold): + # print("origin deformation point nums:",self._deformation_table.sum()) + self._deformation_table = torch.gt(self._deformation_accum.max(dim=-1).values/100,threshold) + def print_deformation_weight_grad(self): + for name, weight in self._deformation.named_parameters(): + if weight.requires_grad: + if weight.grad is None: + + print(name," :",weight.grad) + else: + if weight.grad.mean() != 0: + print(name," :",weight.grad.mean(), weight.grad.min(), weight.grad.max()) + print("-"*50) + def _plane_regulation(self): + multi_res_grids = self._deformation.deformation_net.grid.grids + total = 0 + # model.grids is 6 x [1, rank * F_dim, reso, reso] + for grids in multi_res_grids: + if len(grids) == 3: + time_grids = [] + else: + time_grids = [0,1,3] + for grid_id in time_grids: + total += compute_plane_smoothness(grids[grid_id]) + return total + def _time_regulation(self): + multi_res_grids = self._deformation.deformation_net.grid.grids + total = 0 + # model.grids is 6 x [1, rank * F_dim, reso, reso] + for grids in multi_res_grids: + if len(grids) == 3: + time_grids = [] + else: + time_grids =[2, 4, 5] + for grid_id in time_grids: + total += compute_plane_smoothness(grids[grid_id]) + return total + def _l1_regulation(self): + # model.grids is 6 x [1, rank * F_dim, reso, reso] + multi_res_grids = self._deformation.deformation_net.grid.grids + + total = 0.0 + for grids in multi_res_grids: + if len(grids) == 3: + continue + else: + # These are the spatiotemporal grids + spatiotemporal_grids = [2, 4, 5] + for grid_id in spatiotemporal_grids: + total += torch.abs(1 - grids[grid_id]).mean() + return total + def compute_regulation(self, time_smoothness_weight, l1_time_planes_weight, plane_tv_weight): + return plane_tv_weight * self._plane_regulation() + time_smoothness_weight * self._time_regulation() + l1_time_planes_weight * self._l1_regulation() + + + def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size): + grads = self.xyz_gradient_accum / self.denom + grads[grads.isnan()] = 0.0 + + self.densify_and_clone(grads, max_grad, extent) + self.densify_and_split(grads, max_grad, extent) + + prune_mask = (self.get_opacity < min_opacity).squeeze() + if max_screen_size: + big_points_vs = self.max_radii2D > max_screen_size + big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent + prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws) + self.prune_points(prune_mask) + + torch.cuda.empty_cache() diff --git a/grid_put.py b/grid_put.py new file mode 100644 index 0000000000000000000000000000000000000000..0086cc4efa7527b77b9e583642ca9dfa9ae467fe --- /dev/null +++ b/grid_put.py @@ -0,0 +1,300 @@ +import torch +import torch.nn.functional as F + +def stride_from_shape(shape): + stride = [1] + for x in reversed(shape[1:]): + stride.append(stride[-1] * x) + return list(reversed(stride)) + + +def scatter_add_nd(input, indices, values): + # input: [..., C], D dimension + C channel + # indices: [N, D], long + # values: [N, C] + + D = indices.shape[-1] + C = input.shape[-1] + size = input.shape[:-1] + stride = stride_from_shape(size) + + assert len(size) == D + + input = input.view(-1, C) # [HW, C] + flatten_indices = (indices * torch.tensor(stride, dtype=torch.long, device=indices.device)).sum(-1) # [N] + + input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values) + + return input.view(*size, C) + + +def scatter_add_nd_with_count(input, count, indices, values, weights=None): + # input: [..., C], D dimension + C channel + # count: [..., 1], D dimension + # indices: [N, D], long + # values: [N, C] + + D = indices.shape[-1] + C = input.shape[-1] + size = input.shape[:-1] + stride = stride_from_shape(size) + + assert len(size) == D + + input = input.view(-1, C) # [HW, C] + count = count.view(-1, 1) + + flatten_indices = (indices * torch.tensor(stride, dtype=torch.long, device=indices.device)).sum(-1) # [N] + + if weights is None: + weights = torch.ones_like(values[..., :1]) + + input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values) + count.scatter_add_(0, flatten_indices.unsqueeze(1), weights) + + return input.view(*size, C), count.view(*size, 1) + +def nearest_grid_put_2d(H, W, coords, values, return_count=False): + # coords: [N, 2], float in [-1, 1] + # values: [N, C] + + C = values.shape[-1] + + indices = (coords * 0.5 + 0.5) * torch.tensor( + [H - 1, W - 1], dtype=torch.float32, device=coords.device + ) + indices = indices.round().long() # [N, 2] + + result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C] + count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1] + weights = torch.ones_like(values[..., :1]) # [N, 1] + + result, count = scatter_add_nd_with_count(result, count, indices, values, weights) + + if return_count: + return result, count + + mask = (count.squeeze(-1) > 0) + result[mask] = result[mask] / count[mask].repeat(1, C) + + return result + + +def linear_grid_put_2d(H, W, coords, values, return_count=False): + # coords: [N, 2], float in [-1, 1] + # values: [N, C] + + C = values.shape[-1] + + indices = (coords * 0.5 + 0.5) * torch.tensor( + [H - 1, W - 1], dtype=torch.float32, device=coords.device + ) + indices_00 = indices.floor().long() # [N, 2] + indices_00[:, 0].clamp_(0, H - 2) + indices_00[:, 1].clamp_(0, W - 2) + indices_01 = indices_00 + torch.tensor( + [0, 1], dtype=torch.long, device=indices.device + ) + indices_10 = indices_00 + torch.tensor( + [1, 0], dtype=torch.long, device=indices.device + ) + indices_11 = indices_00 + torch.tensor( + [1, 1], dtype=torch.long, device=indices.device + ) + + h = indices[..., 0] - indices_00[..., 0].float() + w = indices[..., 1] - indices_00[..., 1].float() + w_00 = (1 - h) * (1 - w) + w_01 = (1 - h) * w + w_10 = h * (1 - w) + w_11 = h * w + + result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C] + count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1] + weights = torch.ones_like(values[..., :1]) # [N, 1] + + result, count = scatter_add_nd_with_count(result, count, indices_00, values * w_00.unsqueeze(1), weights* w_00.unsqueeze(1)) + result, count = scatter_add_nd_with_count(result, count, indices_01, values * w_01.unsqueeze(1), weights* w_01.unsqueeze(1)) + result, count = scatter_add_nd_with_count(result, count, indices_10, values * w_10.unsqueeze(1), weights* w_10.unsqueeze(1)) + result, count = scatter_add_nd_with_count(result, count, indices_11, values * w_11.unsqueeze(1), weights* w_11.unsqueeze(1)) + + if return_count: + return result, count + + mask = (count.squeeze(-1) > 0) + result[mask] = result[mask] / count[mask].repeat(1, C) + + return result + +def mipmap_linear_grid_put_2d(H, W, coords, values, min_resolution=32, return_count=False): + # coords: [N, 2], float in [-1, 1] + # values: [N, C] + + C = values.shape[-1] + + result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C] + count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1] + + cur_H, cur_W = H, W + + while min(cur_H, cur_W) > min_resolution: + + # try to fill the holes + mask = (count.squeeze(-1) == 0) + if not mask.any(): + break + + cur_result, cur_count = linear_grid_put_2d(cur_H, cur_W, coords, values, return_count=True) + result[mask] = result[mask] + F.interpolate(cur_result.permute(2,0,1).unsqueeze(0).contiguous(), (H, W), mode='bilinear', align_corners=False).squeeze(0).permute(1,2,0).contiguous()[mask] + count[mask] = count[mask] + F.interpolate(cur_count.view(1, 1, cur_H, cur_W), (H, W), mode='bilinear', align_corners=False).view(H, W, 1)[mask] + cur_H //= 2 + cur_W //= 2 + + if return_count: + return result, count + + mask = (count.squeeze(-1) > 0) + result[mask] = result[mask] / count[mask].repeat(1, C) + + return result + +def nearest_grid_put_3d(H, W, D, coords, values, return_count=False): + # coords: [N, 3], float in [-1, 1] + # values: [N, C] + + C = values.shape[-1] + + indices = (coords * 0.5 + 0.5) * torch.tensor( + [H - 1, W - 1, D - 1], dtype=torch.float32, device=coords.device + ) + indices = indices.round().long() # [N, 2] + + result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, C] + count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, 1] + weights = torch.ones_like(values[..., :1]) # [N, 1] + + result, count = scatter_add_nd_with_count(result, count, indices, values, weights) + + if return_count: + return result, count + + mask = (count.squeeze(-1) > 0) + result[mask] = result[mask] / count[mask].repeat(1, C) + + return result + + +def linear_grid_put_3d(H, W, D, coords, values, return_count=False): + # coords: [N, 3], float in [-1, 1] + # values: [N, C] + + C = values.shape[-1] + + indices = (coords * 0.5 + 0.5) * torch.tensor( + [H - 1, W - 1, D - 1], dtype=torch.float32, device=coords.device + ) + indices_000 = indices.floor().long() # [N, 3] + indices_000[:, 0].clamp_(0, H - 2) + indices_000[:, 1].clamp_(0, W - 2) + indices_000[:, 2].clamp_(0, D - 2) + + indices_001 = indices_000 + torch.tensor([0, 0, 1], dtype=torch.long, device=indices.device) + indices_010 = indices_000 + torch.tensor([0, 1, 0], dtype=torch.long, device=indices.device) + indices_011 = indices_000 + torch.tensor([0, 1, 1], dtype=torch.long, device=indices.device) + indices_100 = indices_000 + torch.tensor([1, 0, 0], dtype=torch.long, device=indices.device) + indices_101 = indices_000 + torch.tensor([1, 0, 1], dtype=torch.long, device=indices.device) + indices_110 = indices_000 + torch.tensor([1, 1, 0], dtype=torch.long, device=indices.device) + indices_111 = indices_000 + torch.tensor([1, 1, 1], dtype=torch.long, device=indices.device) + + h = indices[..., 0] - indices_000[..., 0].float() + w = indices[..., 1] - indices_000[..., 1].float() + d = indices[..., 2] - indices_000[..., 2].float() + + w_000 = (1 - h) * (1 - w) * (1 - d) + w_001 = (1 - h) * w * (1 - d) + w_010 = h * (1 - w) * (1 - d) + w_011 = h * w * (1 - d) + w_100 = (1 - h) * (1 - w) * d + w_101 = (1 - h) * w * d + w_110 = h * (1 - w) * d + w_111 = h * w * d + + result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, D, C] + count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, D, 1] + weights = torch.ones_like(values[..., :1]) # [N, 1] + + result, count = scatter_add_nd_with_count(result, count, indices_000, values * w_000.unsqueeze(1), weights * w_000.unsqueeze(1)) + result, count = scatter_add_nd_with_count(result, count, indices_001, values * w_001.unsqueeze(1), weights * w_001.unsqueeze(1)) + result, count = scatter_add_nd_with_count(result, count, indices_010, values * w_010.unsqueeze(1), weights * w_010.unsqueeze(1)) + result, count = scatter_add_nd_with_count(result, count, indices_011, values * w_011.unsqueeze(1), weights * w_011.unsqueeze(1)) + result, count = scatter_add_nd_with_count(result, count, indices_100, values * w_100.unsqueeze(1), weights * w_100.unsqueeze(1)) + result, count = scatter_add_nd_with_count(result, count, indices_101, values * w_101.unsqueeze(1), weights * w_101.unsqueeze(1)) + result, count = scatter_add_nd_with_count(result, count, indices_110, values * w_110.unsqueeze(1), weights * w_110.unsqueeze(1)) + result, count = scatter_add_nd_with_count(result, count, indices_111, values * w_111.unsqueeze(1), weights * w_111.unsqueeze(1)) + + if return_count: + return result, count + + mask = (count.squeeze(-1) > 0) + result[mask] = result[mask] / count[mask].repeat(1, C) + + return result + +def mipmap_linear_grid_put_3d(H, W, D, coords, values, min_resolution=32, return_count=False): + # coords: [N, 3], float in [-1, 1] + # values: [N, C] + + C = values.shape[-1] + + result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, D, C] + count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, D, 1] + cur_H, cur_W, cur_D = H, W, D + + while min(min(cur_H, cur_W), cur_D) > min_resolution: + + # try to fill the holes + mask = (count.squeeze(-1) == 0) + if not mask.any(): + break + + cur_result, cur_count = linear_grid_put_3d(cur_H, cur_W, cur_D, coords, values, return_count=True) + result[mask] = result[mask] + F.interpolate(cur_result.permute(3,0,1,2).unsqueeze(0).contiguous(), (H, W, D), mode='trilinear', align_corners=False).squeeze(0).permute(1,2,3,0).contiguous()[mask] + count[mask] = count[mask] + F.interpolate(cur_count.view(1, 1, cur_H, cur_W, cur_D), (H, W, D), mode='trilinear', align_corners=False).view(H, W, D, 1)[mask] + cur_H //= 2 + cur_W //= 2 + cur_D //= 2 + + if return_count: + return result, count + + mask = (count.squeeze(-1) > 0) + result[mask] = result[mask] / count[mask].repeat(1, C) + + return result + + +def grid_put(shape, coords, values, mode='linear-mipmap', min_resolution=32, return_raw=False): + # shape: [D], list/tuple + # coords: [N, D], float in [-1, 1] + # values: [N, C] + + D = len(shape) + assert D in [2, 3], f'only support D == 2 or 3, but got D == {D}' + + if mode == 'nearest': + if D == 2: + return nearest_grid_put_2d(*shape, coords, values, return_raw) + else: + return nearest_grid_put_3d(*shape, coords, values, return_raw) + elif mode == 'linear': + if D == 2: + return linear_grid_put_2d(*shape, coords, values, return_raw) + else: + return linear_grid_put_3d(*shape, coords, values, return_raw) + elif mode == 'linear-mipmap': + if D == 2: + return mipmap_linear_grid_put_2d(*shape, coords, values, min_resolution, return_raw) + else: + return mipmap_linear_grid_put_3d(*shape, coords, values, min_resolution, return_raw) + else: + raise NotImplementedError(f"got mode {mode}") \ No newline at end of file diff --git a/gs_renderer_4d.py b/gs_renderer_4d.py new file mode 100644 index 0000000000000000000000000000000000000000..962b3e58c4a9223a9ae979c3d31821a892703c26 --- /dev/null +++ b/gs_renderer_4d.py @@ -0,0 +1,277 @@ +import math +import numpy as np + +import torch + +from diff_gaussian_rasterization import ( + GaussianRasterizationSettings, + GaussianRasterizer, +) + +from sh_utils import eval_sh, SH2RGB, RGB2SH + +from gaussian_model_4d import GaussianModel, BasicPointCloud + +def getProjectionMatrix(znear, zfar, fovX, fovY): + tanHalfFovY = math.tan((fovY / 2)) + tanHalfFovX = math.tan((fovX / 2)) + + P = torch.zeros(4, 4) + + z_sign = 1.0 + + P[0, 0] = 1 / tanHalfFovX + P[1, 1] = 1 / tanHalfFovY + P[3, 2] = z_sign + P[2, 2] = z_sign * zfar / (zfar - znear) + P[2, 3] = -(zfar * znear) / (zfar - znear) + return P + + +class MiniCam: + def __init__(self, c2w, width, height, fovy, fovx, znear, zfar, time=0, gs_convention=True): + # c2w (pose) should be in NeRF convention. + + self.image_width = width + self.image_height = height + self.FoVy = fovy + self.FoVx = fovx + self.znear = znear + self.zfar = zfar + + w2c = np.linalg.inv(c2w) + + if gs_convention: + # rectify... + w2c[1:3, :3] *= -1 + w2c[:3, 3] *= -1 + + self.world_view_transform = torch.tensor(w2c).transpose(0, 1).cuda() + self.projection_matrix = ( + getProjectionMatrix( + znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy + ) + .transpose(0, 1) + .cuda() + ) + self.full_proj_transform = self.world_view_transform @ self.projection_matrix + self.camera_center = -torch.tensor(c2w[:3, 3]).cuda() + + self.time = time + + +class Renderer: + def __init__(self, opt, sh_degree=3, white_background=True, radius=1): + + self.sh_degree = sh_degree + self.white_background = white_background + self.radius = radius + self.opt = opt + self.T = self.opt.batch_size + + self.gaussians = GaussianModel(sh_degree, opt.deformation) + + self.bg_color = torch.tensor( + [1, 1, 1] if white_background else [0, 0, 0], + dtype=torch.float32, + device="cuda", + ) + self.means3D_deform_T = None + self.opacity_deform_T = None + self.scales_deform_T = None + self.rotations_deform_T = None + + + + def initialize(self, input=None, num_pts=5000, radius=0.5): + # load checkpoint + if input is None: + # init from random point cloud + + phis = np.random.random((num_pts,)) * 2 * np.pi + costheta = np.random.random((num_pts,)) * 2 - 1 + thetas = np.arccos(costheta) + mu = np.random.random((num_pts,)) + radius = radius * np.cbrt(mu) + x = radius * np.sin(thetas) * np.cos(phis) + y = radius * np.sin(thetas) * np.sin(phis) + z = radius * np.cos(thetas) + xyz = np.stack((x, y, z), axis=1) + # xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 + + shs = np.random.random((num_pts, 3)) / 255.0 + pcd = BasicPointCloud( + points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)) + ) + # self.gaussians.create_from_pcd(pcd, 10) + self.gaussians.create_from_pcd(pcd, 10, 1) + elif isinstance(input, BasicPointCloud): + # load from a provided pcd + self.gaussians.create_from_pcd(input, 1) + else: + # load from saved ply + self.gaussians.load_ply(input) + + def prepare_render( + self, + ): + means3D = self.gaussians.get_xyz + opacity = self.gaussians._opacity + scales = self.gaussians._scaling + rotations = self.gaussians._rotation + + means3D_T = [] + opacity_T = [] + scales_T = [] + rotations_T = [] + time_T = [] + + for t in range(self.T): + time = torch.tensor(t).to(means3D.device).repeat(means3D.shape[0],1) + time = ((time.float() / self.T) - 0.5) * 2 + + means3D_T.append(means3D) + opacity_T.append(opacity) + scales_T.append(scales) + rotations_T.append(rotations) + time_T.append(time) + + means3D_T = torch.cat(means3D_T) + opacity_T = torch.cat(opacity_T) + scales_T = torch.cat(scales_T) + rotations_T = torch.cat(rotations_T) + time_T = torch.cat(time_T) + + + means3D_deform_T, scales_deform_T, rotations_deform_T, opacity_deform_T = self.gaussians._deformation(means3D_T, scales_T, + rotations_T, opacity_T, + time_T) # time is not none + self.means3D_deform_T = means3D_deform_T.reshape([self.T, means3D_deform_T.shape[0]//self.T, -1]) + self.opacity_deform_T = opacity_deform_T.reshape([self.T, means3D_deform_T.shape[0]//self.T, -1]) + self.scales_deform_T = scales_deform_T.reshape([self.T, means3D_deform_T.shape[0]//self.T, -1]) + self.rotations_deform_T = rotations_deform_T.reshape([self.T, means3D_deform_T.shape[0]//self.T, -1]) + + + def render( + self, + viewpoint_camera, + scaling_modifier=1.0, + bg_color=None, + override_color=None, + compute_cov3D_python=False, + convert_SHs_python=False, + ): + # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means + screenspace_points = ( + torch.zeros_like( + self.gaussians.get_xyz, + dtype=self.gaussians.get_xyz.dtype, + requires_grad=True, + device="cuda", + ) + + 0 + ) + try: + screenspace_points.retain_grad() + except: + pass + + # Set up rasterization configuration + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + raster_settings = GaussianRasterizationSettings( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=self.bg_color if bg_color is None else bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=self.gaussians.active_sh_degree, + campos=viewpoint_camera.camera_center, + prefiltered=False, + debug=False, + ) + + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + means3D = self.gaussians.get_xyz + time = torch.tensor(viewpoint_camera.time).to(means3D.device).repeat(means3D.shape[0],1) + time = ((time.float() / self.T) - 0.5) * 2 + + means2D = screenspace_points + opacity = self.gaussians._opacity + + # If precomputed 3d covariance is provided, use it. If not, then it will be computed from + # scaling / rotation by the rasterizer. + scales = None + rotations = None + cov3D_precomp = None + if compute_cov3D_python: + cov3D_precomp = self.gaussians.get_covariance(scaling_modifier) + else: + scales = self.gaussians._scaling + rotations = self.gaussians._rotation + + means3D_deform, scales_deform, rotations_deform, opacity_deform = self.means3D_deform_T[viewpoint_camera.time], self.scales_deform_T[viewpoint_camera.time], self.rotations_deform_T[viewpoint_camera.time], self.opacity_deform_T[viewpoint_camera.time] + + + means3D_final = means3D + means3D_deform + rotations_final = rotations + rotations_deform + scales_final = scales + scales_deform + opacity_final = opacity + opacity_deform + + + + scales_final = self.gaussians.scaling_activation(scales_final) + rotations_final = self.gaussians.rotation_activation(rotations_final) + opacity = self.gaussians.opacity_activation(opacity) + + + # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors + # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. + shs = None + colors_precomp = None + if colors_precomp is None: + if convert_SHs_python: + shs_view = self.gaussians.get_features.transpose(1, 2).view( + -1, 3, (self.gaussians.max_sh_degree + 1) ** 2 + ) + dir_pp = self.gaussians.get_xyz - viewpoint_camera.camera_center.repeat( + self.gaussians.get_features.shape[0], 1 + ) + dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) + sh2rgb = eval_sh( + self.gaussians.active_sh_degree, shs_view, dir_pp_normalized + ) + colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) + else: + shs = self.gaussians.get_features + else: + colors_precomp = override_color + + rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( + means3D = means3D_final, + means2D = means2D, + shs = shs, + colors_precomp = colors_precomp, + opacities = opacity, + scales = scales_final, + rotations = rotations_final, + cov3D_precomp = cov3D_precomp) + + + rendered_image = rendered_image.clamp(0, 1) + + # Those Gaussians that were frustum culled or had a radius of 0 were not visible. + # They will be excluded from value updates used in the splitting criteria. + return { + "image": rendered_image, + "depth": rendered_depth, + "alpha": rendered_alpha, + "viewspace_points": screenspace_points, + "visibility_filter": radii > 0, + "radii": radii, + } diff --git a/guidance/imagedream_utils.py b/guidance/imagedream_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..54dcb91b8bd4ea218372282157350873fd4e969e --- /dev/null +++ b/guidance/imagedream_utils.py @@ -0,0 +1,326 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms.functional as TF + +from imagedream.camera_utils import get_camera, convert_opengl_to_blender, normalize_camera +from imagedream.model_zoo import build_model +from imagedream.ldm.models.diffusion.ddim import DDIMSampler + +from diffusers import DDIMScheduler + +class ImageDream(nn.Module): + def __init__( + self, + device, + model_name='sd-v2.1-base-4view-ipmv', + ckpt_path=None, + t_range=[0.02, 0.98], + ): + super().__init__() + + self.device = device + self.model_name = model_name + self.ckpt_path = ckpt_path + + self.model = build_model(self.model_name, ckpt_path=self.ckpt_path).eval().to(self.device) + self.model.device = device + for p in self.model.parameters(): + p.requires_grad_(False) + + self.dtype = torch.float32 + + self.num_train_timesteps = 1000 + self.min_step = int(self.num_train_timesteps * t_range[0]) + self.max_step = int(self.num_train_timesteps * t_range[1]) + + self.image_embeddings = {} + self.embeddings = {} + + self.scheduler = DDIMScheduler.from_pretrained( + "stabilityai/stable-diffusion-2-1-base", subfolder="scheduler", torch_dtype=self.dtype + ) + + @torch.no_grad() + def get_image_text_embeds(self, image, prompts, negative_prompts): + + image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False) + image_pil = TF.to_pil_image(image[0]) + image_embeddings = self.model.get_learned_image_conditioning(image_pil).repeat(5,1,1) # [5, 257, 1280] + self.image_embeddings['pos'] = image_embeddings + self.image_embeddings['neg'] = torch.zeros_like(image_embeddings) + + self.image_embeddings['ip_img'] = self.encode_imgs(image) + self.image_embeddings['neg_ip_img'] = torch.zeros_like(self.image_embeddings['ip_img']) + + pos_embeds = self.encode_text(prompts).repeat(5,1,1) + neg_embeds = self.encode_text(negative_prompts).repeat(5,1,1) + self.embeddings['pos'] = pos_embeds + self.embeddings['neg'] = neg_embeds + + return self.image_embeddings['pos'], self.image_embeddings['neg'], self.image_embeddings['ip_img'], self.image_embeddings['neg_ip_img'], self.embeddings['pos'], self.embeddings['neg'] + + def encode_text(self, prompt): + # prompt: [str] + embeddings = self.model.get_learned_conditioning(prompt).to(self.device) + return embeddings + + @torch.no_grad() + def refine(self, pred_rgb, camera, + guidance_scale=5, steps=50, strength=0.8, + ): + + batch_size = pred_rgb.shape[0] + real_batch_size = batch_size // 4 + pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) + latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) + + self.scheduler.set_timesteps(steps) + init_step = int(steps * strength) + latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step]) + + camera = camera[:, [0, 2, 1, 3]] # to blender convention (flip y & z axis) + camera[:, 1] *= -1 + camera = normalize_camera(camera).view(batch_size, 16) + + # extra view + camera = camera.view(real_batch_size, 4, 16) + camera = torch.cat([camera, torch.zeros_like(camera[:, :1])], dim=1) # [rB, 5, 16] + camera = camera.view(real_batch_size * 5, 16) + + camera = camera.repeat(2, 1) + embeddings = torch.cat([self.embeddings['neg'].repeat(real_batch_size, 1, 1), self.embeddings['pos'].repeat(real_batch_size, 1, 1)], dim=0) + image_embeddings = torch.cat([self.image_embeddings['neg'].repeat(real_batch_size, 1, 1), self.image_embeddings['pos'].repeat(real_batch_size, 1, 1)], dim=0) + ip_img_embeddings= torch.cat([self.image_embeddings['neg_ip_img'].repeat(real_batch_size, 1, 1, 1), self.image_embeddings['ip_img'].repeat(real_batch_size, 1, 1, 1)], dim=0) + + context = { + "context": embeddings, + "ip": image_embeddings, + "ip_img": ip_img_embeddings, + "camera": camera, + "num_frames": 4 + 1 + } + + for i, t in enumerate(self.scheduler.timesteps[init_step:]): + + # extra view + + latents = latents.view(real_batch_size, 4, 4, 32, 32) + latents = torch.cat([latents, torch.zeros_like(latents[:, :1])], dim=1).view(-1, 4, 32, 32) + latent_model_input = torch.cat([latents] * 2) + + tt = torch.cat([t.unsqueeze(0).repeat(real_batch_size * 5)] * 2).to(self.device) + + noise_pred = self.model.apply_model(latent_model_input, tt, context) + + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + + # remove extra view + noise_pred_uncond = noise_pred_uncond.reshape(real_batch_size, 5, 4, 32, 32)[:, :-1].reshape(-1, 4, 32, 32) + noise_pred_cond = noise_pred_cond.reshape(real_batch_size, 5, 4, 32, 32)[:, :-1].reshape(-1, 4, 32, 32) + latents = latents.reshape(real_batch_size, 5, 4, 32, 32)[:, :-1].reshape(-1, 4, 32, 32) + + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + imgs = self.decode_latents(latents) # [1, 3, 512, 512] + return imgs + + def train_step( + self, + pred_rgb, # [B, C, H, W] + camera, # [B, 4, 4] + step_ratio=None, + guidance_scale=5, + as_latent=False, + ): + + batch_size = pred_rgb.shape[0] + real_batch_size = batch_size // 4 + pred_rgb = pred_rgb.to(self.dtype) + + if as_latent: + latents = F.interpolate(pred_rgb, (32, 32), mode="bilinear", align_corners=False) * 2 - 1 + else: + # interp to 256x256 to be fed into vae. + pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode="bilinear", align_corners=False) + # encode image into latents with vae, requires grad! + latents = self.encode_imgs(pred_rgb_256) + + if step_ratio is not None: + # dreamtime-like + # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) + t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) + t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) + else: + t = torch.randint(self.min_step, self.max_step + 1, (real_batch_size,), dtype=torch.long, device=self.device).repeat(4) + + camera = camera[:, [0, 2, 1, 3]] # to blender convention (flip y & z axis) + camera[:, 1] *= -1 + camera = normalize_camera(camera).view(batch_size, 16) + + # extra view + camera = camera.view(real_batch_size, 4, 16) + camera = torch.cat([camera, torch.zeros_like(camera[:, :1])], dim=1) # [rB, 5, 16] + camera = camera.view(real_batch_size * 5, 16) + + camera = camera.repeat(2, 1) + embeddings = torch.cat([self.embeddings['neg'].repeat(real_batch_size, 1, 1), self.embeddings['pos'].repeat(real_batch_size, 1, 1)], dim=0) + image_embeddings = torch.cat([self.image_embeddings['neg'].repeat(real_batch_size, 1, 1), self.image_embeddings['pos'].repeat(real_batch_size, 1, 1)], dim=0) + ip_img_embeddings= torch.cat([self.image_embeddings['neg_ip_img'].repeat(real_batch_size, 1, 1, 1), self.image_embeddings['ip_img'].repeat(real_batch_size, 1, 1, 1)], dim=0) + + context = { + "context": embeddings, + "ip": image_embeddings, + "ip_img": ip_img_embeddings, + "camera": camera, + "num_frames": 4 + 1 + } + + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) + latents_noisy = self.model.q_sample(latents, t, noise) # [B=4, 4, 32, 32] + # extra view + t = t.view(real_batch_size, 4) + t = torch.cat([t, t[:, :1]], dim=1).view(-1) + latents_noisy = latents_noisy.view(real_batch_size, 4, 4, 32, 32) + latents_noisy = torch.cat([latents_noisy, torch.zeros_like(latents_noisy[:, :1])], dim=1).view(-1, 4, 32, 32) + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2) + tt = torch.cat([t] * 2) + + # import kiui + # kiui.lo(latent_model_input, t, context['context'], context['camera']) + + noise_pred = self.model.apply_model(latent_model_input, tt, context) + + # perform guidance (high scale from paper!) + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + + # remove extra view + noise_pred_uncond = noise_pred_uncond.reshape(real_batch_size, 5, 4, 32, 32)[:, :-1].reshape(-1, 4, 32, 32) + noise_pred_cond = noise_pred_cond.reshape(real_batch_size, 5, 4, 32, 32)[:, :-1].reshape(-1, 4, 32, 32) + + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + + grad = (noise_pred - noise) + grad = torch.nan_to_num(grad) + + target = (latents - grad).detach() + loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0] + + return loss + + def decode_latents(self, latents): + imgs = self.model.decode_first_stage(latents) + imgs = ((imgs + 1) / 2).clamp(0, 1) + return imgs + + def encode_imgs(self, imgs): + # imgs: [B, 3, 256, 256] + imgs = 2 * imgs - 1 + latents = self.model.get_first_stage_encoding(self.model.encode_first_stage(imgs)) + return latents # [B, 4, 32, 32] + + @torch.no_grad() + def prompt_to_img( + self, + image, + prompts, + negative_prompts="", + height=256, + width=256, + num_inference_steps=50, + guidance_scale=5.0, + latents=None, + elevation=0, + azimuth_start=0, + ): + if isinstance(prompts, str): + prompts = [prompts] + + if isinstance(negative_prompts, str): + negative_prompts = [negative_prompts] + + real_batch_size = len(prompts) + batch_size = len(prompts) * 5 + + # Text embeds -> img latents + sampler = DDIMSampler(self.model) + shape = [4, height // 8, width // 8] + + c_ = {"context": self.encode_text(prompts).repeat(5,1,1)} + uc_ = {"context": self.encode_text(negative_prompts).repeat(5,1,1)} + + # image embeddings + image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False) + image_pil = TF.to_pil_image(image[0]) + image_embeddings = self.model.get_learned_image_conditioning(image_pil).repeat(5,1,1).to(self.device) + c_["ip"] = image_embeddings + uc_["ip"] = torch.zeros_like(image_embeddings) + + ip_img = self.encode_imgs(image) + c_["ip_img"] = ip_img + uc_["ip_img"] = torch.zeros_like(ip_img) + + camera = get_camera(4, elevation=elevation, azimuth_start=azimuth_start, extra_view=True) + camera = camera.repeat(real_batch_size, 1).to(self.device) + + c_["camera"] = uc_["camera"] = camera + c_["num_frames"] = uc_["num_frames"] = 5 + + kiui.lo(image_embeddings, ip_img, camera) + + latents, _ = sampler.sample(S=num_inference_steps, conditioning=c_, + batch_size=batch_size, shape=shape, + verbose=False, + unconditional_guidance_scale=guidance_scale, + unconditional_conditioning=uc_, + eta=0, x_T=None) + + # Img latents -> imgs + imgs = self.decode_latents(latents) # [4, 3, 256, 256] + + kiui.lo(latents, imgs) + + # Img to Numpy + imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() + imgs = (imgs * 255).round().astype("uint8") + + return imgs + + +if __name__ == "__main__": + import argparse + import matplotlib.pyplot as plt + import kiui + + parser = argparse.ArgumentParser() + parser.add_argument("image", type=str) + parser.add_argument("prompt", type=str) + parser.add_argument("--negative", default="", type=str) + parser.add_argument("--steps", type=int, default=30) + opt = parser.parse_args() + + device = torch.device("cuda") + + sd = ImageDream(device) + + image = kiui.read_image(opt.image, mode='tensor') + image = image.permute(2, 0, 1).unsqueeze(0).to(device) + + while True: + imgs = sd.prompt_to_img(image, opt.prompt, opt.negative, num_inference_steps=opt.steps) + + grid = np.concatenate([ + np.concatenate([imgs[0], imgs[1]], axis=1), + np.concatenate([imgs[2], imgs[3]], axis=1), + ], axis=0) + + # visualize image + plt.imshow(grid) + plt.show() \ No newline at end of file diff --git a/guidance/mvdream_utils.py b/guidance/mvdream_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..833a16bd250dafcebf8c34078231187f220ce355 --- /dev/null +++ b/guidance/mvdream_utils.py @@ -0,0 +1,271 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mvdream.camera_utils import get_camera, convert_opengl_to_blender, normalize_camera +from mvdream.model_zoo import build_model +from mvdream.ldm.models.diffusion.ddim import DDIMSampler + +from diffusers import DDIMScheduler + +class MVDream(nn.Module): + def __init__( + self, + device, + model_name='sd-v2.1-base-4view', + ckpt_path=None, + t_range=[0.02, 0.98], + ): + super().__init__() + + self.device = device + self.model_name = model_name + self.ckpt_path = ckpt_path + + self.model = build_model(self.model_name, ckpt_path=self.ckpt_path).eval().to(self.device) + self.model.device = device + for p in self.model.parameters(): + p.requires_grad_(False) + + self.dtype = torch.float32 + + self.num_train_timesteps = 1000 + self.min_step = int(self.num_train_timesteps * t_range[0]) + self.max_step = int(self.num_train_timesteps * t_range[1]) + + self.embeddings = None + + self.scheduler = DDIMScheduler.from_pretrained( + "stabilityai/stable-diffusion-2-1-base", subfolder="scheduler", torch_dtype=self.dtype + ) + + @torch.no_grad() + def get_text_embeds(self, prompts, negative_prompts): + pos_embeds = self.encode_text(prompts).repeat(4,1,1) # [1, 77, 768] + neg_embeds = self.encode_text(negative_prompts).repeat(4,1,1) + self.embeddings = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768] + + def encode_text(self, prompt): + # prompt: [str] + embeddings = self.model.get_learned_conditioning(prompt).to(self.device) + return embeddings + + @torch.no_grad() + def refine(self, pred_rgb, camera, + guidance_scale=100, steps=50, strength=0.8, + ): + + batch_size = pred_rgb.shape[0] + pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) + latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) + # latents = torch.randn((1, 4, 64, 64), device=self.device, dtype=self.dtype) + + self.scheduler.set_timesteps(steps) + init_step = int(steps * strength) + latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step]) + + camera = camera[:, [0, 2, 1, 3]] # to blender convention (flip y & z axis) + camera[:, 1] *= -1 + camera = normalize_camera(camera).view(batch_size, 16) + camera = camera.repeat(2, 1) + context = {"context": self.embeddings, "camera": camera, "num_frames": 4} + + for i, t in enumerate(self.scheduler.timesteps[init_step:]): + + latent_model_input = torch.cat([latents] * 2) + + tt = torch.cat([t.unsqueeze(0).repeat(batch_size)] * 2).to(self.device) + + noise_pred = self.model.apply_model(latent_model_input, tt, context) + + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + imgs = self.decode_latents(latents) # [1, 3, 512, 512] + return imgs + + def train_step( + self, + pred_rgb, # [B, C, H, W], B is multiples of 4 + camera, # [B, 4, 4] + step_ratio=None, + guidance_scale=50, + as_latent=False, + ): + + batch_size = pred_rgb.shape[0] + pred_rgb = pred_rgb.to(self.dtype) + + if as_latent: + latents = F.interpolate(pred_rgb, (32, 32), mode="bilinear", align_corners=False) * 2 - 1 + else: + # interp to 256x256 to be fed into vae. + pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode="bilinear", align_corners=False) + # encode image into latents with vae, requires grad! + latents = self.encode_imgs(pred_rgb_256) + + if step_ratio is not None: + # dreamtime-like + # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) + t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) + t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) + else: + t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) + + # camera = convert_opengl_to_blender(camera) + # flip_yz = torch.tensor([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]).unsqueeze(0) + # camera = torch.matmul(flip_yz.to(camera), camera) + camera = camera[:, [0, 2, 1, 3]] # to blender convention (flip y & z axis) + camera[:, 1] *= -1 + camera = normalize_camera(camera).view(batch_size, 16) + + ############### + # sampler = DDIMSampler(self.model) + # shape = [4, 32, 32] + # c_ = {"context": self.embeddings[4:]} + # uc_ = {"context": self.embeddings[:4]} + + # # print(camera) + + # # camera = get_camera(4, elevation=0, azimuth_start=0) + # # camera = camera.repeat(batch_size // 4, 1).to(self.device) + + # # print(camera) + + # c_["camera"] = uc_["camera"] = camera + # c_["num_frames"] = uc_["num_frames"] = 4 + + # latents_, _ = sampler.sample(S=30, conditioning=c_, + # batch_size=batch_size, shape=shape, + # verbose=False, + # unconditional_guidance_scale=guidance_scale, + # unconditional_conditioning=uc_, + # eta=0, x_T=None) + + # # Img latents -> imgs + # imgs = self.decode_latents(latents_) # [4, 3, 256, 256] + # import kiui + # kiui.vis.plot_image(imgs) + ############### + + camera = camera.repeat(2, 1) + context = {"context": self.embeddings, "camera": camera, "num_frames": 4} + + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) + latents_noisy = self.model.q_sample(latents, t, noise) + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2) + tt = torch.cat([t] * 2) + + # import kiui + # kiui.lo(latent_model_input, t, context['context'], context['camera']) + + noise_pred = self.model.apply_model(latent_model_input, tt, context) + + # perform guidance (high scale from paper!) + noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond) + + grad = (noise_pred - noise) + grad = torch.nan_to_num(grad) + + # seems important to avoid NaN... + # grad = grad.clamp(-1, 1) + + target = (latents - grad).detach() + loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0] + + return loss + + def decode_latents(self, latents): + imgs = self.model.decode_first_stage(latents) + imgs = ((imgs + 1) / 2).clamp(0, 1) + return imgs + + def encode_imgs(self, imgs): + # imgs: [B, 3, 256, 256] + imgs = 2 * imgs - 1 + latents = self.model.get_first_stage_encoding(self.model.encode_first_stage(imgs)) + return latents # [B, 4, 32, 32] + + @torch.no_grad() + def prompt_to_img( + self, + prompts, + negative_prompts="", + height=256, + width=256, + num_inference_steps=50, + guidance_scale=7.5, + latents=None, + elevation=0, + azimuth_start=0, + ): + if isinstance(prompts, str): + prompts = [prompts] + + if isinstance(negative_prompts, str): + negative_prompts = [negative_prompts] + + batch_size = len(prompts) * 4 + + # Text embeds -> img latents + sampler = DDIMSampler(self.model) + shape = [4, height // 8, width // 8] + c_ = {"context": self.encode_text(prompts).repeat(4,1,1)} + uc_ = {"context": self.encode_text(negative_prompts).repeat(4,1,1)} + + camera = get_camera(4, elevation=elevation, azimuth_start=azimuth_start) + camera = camera.repeat(batch_size // 4, 1).to(self.device) + + c_["camera"] = uc_["camera"] = camera + c_["num_frames"] = uc_["num_frames"] = 4 + + latents, _ = sampler.sample(S=num_inference_steps, conditioning=c_, + batch_size=batch_size, shape=shape, + verbose=False, + unconditional_guidance_scale=guidance_scale, + unconditional_conditioning=uc_, + eta=0, x_T=None) + + # Img latents -> imgs + imgs = self.decode_latents(latents) # [4, 3, 256, 256] + + # Img to Numpy + imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() + imgs = (imgs * 255).round().astype("uint8") + + return imgs + + +if __name__ == "__main__": + import argparse + import matplotlib.pyplot as plt + + parser = argparse.ArgumentParser() + parser.add_argument("prompt", type=str) + parser.add_argument("--negative", default="", type=str) + parser.add_argument("--steps", type=int, default=30) + opt = parser.parse_args() + + device = torch.device("cuda") + + sd = MVDream(device) + + while True: + imgs = sd.prompt_to_img(opt.prompt, opt.negative, num_inference_steps=opt.steps) + + grid = np.concatenate([ + np.concatenate([imgs[0], imgs[1]], axis=1), + np.concatenate([imgs[2], imgs[3]], axis=1), + ], axis=0) + + # visualize image + plt.imshow(grid) + plt.show() diff --git a/guidance/sd_utils.py b/guidance/sd_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f930cd05e06006c2b6ed3e08aab3c6efd2ce7fe2 --- /dev/null +++ b/guidance/sd_utils.py @@ -0,0 +1,334 @@ +from transformers import CLIPTextModel, CLIPTokenizer, logging +from diffusers import ( + AutoencoderKL, + UNet2DConditionModel, + PNDMScheduler, + DDIMScheduler, + StableDiffusionPipeline, +) +from diffusers.utils.import_utils import is_xformers_available + +# suppress partial model loading warning +logging.set_verbosity_error() + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def seed_everything(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = True + + +class StableDiffusion(nn.Module): + def __init__( + self, + device, + fp16=True, + vram_O=False, + sd_version="2.1", + hf_key=None, + t_range=[0.02, 0.98], + ): + super().__init__() + + self.device = device + self.sd_version = sd_version + + if hf_key is not None: + print(f"[INFO] using hugging face custom model key: {hf_key}") + model_key = hf_key + elif self.sd_version == "2.1": + model_key = "stabilityai/stable-diffusion-2-1-base" + elif self.sd_version == "2.0": + model_key = "stabilityai/stable-diffusion-2-base" + elif self.sd_version == "1.5": + model_key = "runwayml/stable-diffusion-v1-5" + else: + raise ValueError( + f"Stable-diffusion version {self.sd_version} not supported." + ) + + self.dtype = torch.float16 if fp16 else torch.float32 + + # Create model + pipe = StableDiffusionPipeline.from_pretrained( + model_key, torch_dtype=self.dtype + ) + + if vram_O: + pipe.enable_sequential_cpu_offload() + pipe.enable_vae_slicing() + pipe.unet.to(memory_format=torch.channels_last) + pipe.enable_attention_slicing(1) + # pipe.enable_model_cpu_offload() + else: + pipe.to(device) + + self.vae = pipe.vae + self.tokenizer = pipe.tokenizer + self.text_encoder = pipe.text_encoder + self.unet = pipe.unet + + self.scheduler = DDIMScheduler.from_pretrained( + model_key, subfolder="scheduler", torch_dtype=self.dtype + ) + + del pipe + + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + self.min_step = int(self.num_train_timesteps * t_range[0]) + self.max_step = int(self.num_train_timesteps * t_range[1]) + self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience + + self.embeddings = None + + @torch.no_grad() + def get_text_embeds(self, prompts, negative_prompts): + pos_embeds = self.encode_text(prompts) # [1, 77, 768] + neg_embeds = self.encode_text(negative_prompts) + self.embeddings = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768] + + def encode_text(self, prompt): + # prompt: [str] + inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] + return embeddings + + @torch.no_grad() + def refine(self, pred_rgb, + guidance_scale=100, steps=50, strength=0.8, + ): + + batch_size = pred_rgb.shape[0] + pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) + latents = self.encode_imgs(pred_rgb_512.to(self.dtype)) + # latents = torch.randn((1, 4, 64, 64), device=self.device, dtype=self.dtype) + + self.scheduler.set_timesteps(steps) + init_step = int(steps * strength) + latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step]) + + for i, t in enumerate(self.scheduler.timesteps[init_step:]): + + latent_model_input = torch.cat([latents] * 2) + + noise_pred = self.unet( + latent_model_input, t, encoder_hidden_states=self.embeddings, + ).sample + + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + imgs = self.decode_latents(latents) # [1, 3, 512, 512] + return imgs + + def train_step( + self, + pred_rgb, + step_ratio=None, + guidance_scale=100, + as_latent=False, + ): + + batch_size = pred_rgb.shape[0] + pred_rgb = pred_rgb.to(self.dtype) + + if as_latent: + latents = F.interpolate(pred_rgb, (64, 64), mode="bilinear", align_corners=False) * 2 - 1 + else: + # interp to 512x512 to be fed into vae. + pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False) + # encode image into latents with vae, requires grad! + latents = self.encode_imgs(pred_rgb_512) + + if step_ratio is not None: + # dreamtime-like + # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) + t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) + t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) + else: + t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) + + # w(t), sigma_t^2 + w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1) + + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2) + tt = torch.cat([t] * 2) + + noise_pred = self.unet( + latent_model_input, tt, encoder_hidden_states=self.embeddings.repeat(batch_size, 1, 1) + ).sample + + # perform guidance (high scale from paper!) + noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_pos - noise_pred_uncond + ) + + grad = w * (noise_pred - noise) + grad = torch.nan_to_num(grad) + + # seems important to avoid NaN... + # grad = grad.clamp(-1, 1) + + target = (latents - grad).detach() + loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0] + + return loss + + @torch.no_grad() + def produce_latents( + self, + height=512, + width=512, + num_inference_steps=50, + guidance_scale=7.5, + latents=None, + ): + if latents is None: + latents = torch.randn( + ( + self.embeddings.shape[0] // 2, + self.unet.in_channels, + height // 8, + width // 8, + ), + device=self.device, + ) + + self.scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(self.scheduler.timesteps): + # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. + latent_model_input = torch.cat([latents] * 2) + # predict the noise residual + noise_pred = self.unet( + latent_model_input, t, encoder_hidden_states=self.embeddings + ).sample + + # perform guidance + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_cond - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + return latents + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + imgs = self.vae.decode(latents).sample + imgs = (imgs / 2 + 0.5).clamp(0, 1) + + return imgs + + def encode_imgs(self, imgs): + # imgs: [B, 3, H, W] + + imgs = 2 * imgs - 1 + + posterior = self.vae.encode(imgs).latent_dist + latents = posterior.sample() * self.vae.config.scaling_factor + + return latents + + def prompt_to_img( + self, + prompts, + negative_prompts="", + height=512, + width=512, + num_inference_steps=50, + guidance_scale=7.5, + latents=None, + ): + if isinstance(prompts, str): + prompts = [prompts] + + if isinstance(negative_prompts, str): + negative_prompts = [negative_prompts] + + # Prompts -> text embeds + self.get_text_embeds(prompts, negative_prompts) + + # Text embeds -> img latents + latents = self.produce_latents( + height=height, + width=width, + latents=latents, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + ) # [1, 4, 64, 64] + + # Img latents -> imgs + imgs = self.decode_latents(latents) # [1, 3, 512, 512] + + # Img to Numpy + imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() + imgs = (imgs * 255).round().astype("uint8") + + return imgs + + +if __name__ == "__main__": + import argparse + import matplotlib.pyplot as plt + + parser = argparse.ArgumentParser() + parser.add_argument("prompt", type=str) + parser.add_argument("--negative", default="", type=str) + parser.add_argument( + "--sd_version", + type=str, + default="2.1", + choices=["1.5", "2.0", "2.1"], + help="stable diffusion version", + ) + parser.add_argument( + "--hf_key", + type=str, + default=None, + help="hugging face Stable diffusion model key", + ) + parser.add_argument("--fp16", action="store_true", help="use float16 for training") + parser.add_argument( + "--vram_O", action="store_true", help="optimization for low VRAM usage" + ) + parser.add_argument("-H", type=int, default=512) + parser.add_argument("-W", type=int, default=512) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--steps", type=int, default=50) + opt = parser.parse_args() + + seed_everything(opt.seed) + + device = torch.device("cuda") + + sd = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key) + + imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps) + + # visualize image + plt.imshow(imgs[0]) + plt.show() diff --git a/guidance/svd_utils.py b/guidance/svd_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d15e9a6c66584af7b641282c446e9124e132d1fc --- /dev/null +++ b/guidance/svd_utils.py @@ -0,0 +1,221 @@ +import torch + +from svd import StableVideoDiffusionPipeline +from diffusers import DDIMScheduler + +from PIL import Image +import numpy as np + +import torch.nn as nn +import torch.nn.functional as F + + +class StableVideoDiffusion: + def __init__( + self, + device, + fp16=True, + t_range=[0.02, 0.98], + ): + super().__init__() + + self.guidance_type = [ + 'sds', + 'pixel reconstruction', + 'latent reconstruction' + ][1] + + self.device = device + self.dtype = torch.float16 if fp16 else torch.float32 + + # Create model + pipe = StableVideoDiffusionPipeline.from_pretrained( + "stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16" + ) + pipe.to(device) + + self.pipe = pipe + + self.num_train_timesteps = self.pipe.scheduler.config.num_train_timesteps if self.guidance_type == 'sds' else 25 + self.pipe.scheduler.set_timesteps(self.num_train_timesteps, device=device) # set sigma for euler discrete scheduling + + self.min_step = int(self.num_train_timesteps * t_range[0]) + self.max_step = int(self.num_train_timesteps * t_range[1]) + self.alphas = self.pipe.scheduler.alphas_cumprod.to(self.device) # for convenience + + self.embeddings = None + self.image = None + self.target_cache = None + + @torch.no_grad() + def get_img_embeds(self, image): + self.image = Image.fromarray(np.uint8(image*255)) + + def encode_image(self, image): + image = image * 2 -1 + latents = self.pipe._encode_vae_image(image, self.device, num_videos_per_prompt=1, do_classifier_free_guidance=False) + latents = self.pipe.vae.config.scaling_factor * latents + return latents + + def refine(self, + pred_rgb, + steps=25, strength=0.8, + min_guidance_scale: float = 1.0, + max_guidance_scale: float = 3.0, + ): + # strength = 0.8 + batch_size = pred_rgb.shape[0] + pred_rgb = pred_rgb.to(self.dtype) + + # interp to 512x512 to be fed into vae. + pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False) + # encode image into latents with vae, requires grad! + + # latents = [] + # for i in range(batch_size): + # latent = self.encode_image(pred_rgb_512[i:i+1]) + # latents.append(latent) + # latents = torch.cat(latents, 0) + + latents = self.encode_image(pred_rgb_512) + latents = latents.unsqueeze(0) + + if strength == 0: + init_step = 0 + latents = torch.randn_like(latents) + else: + init_step = int(steps * strength) + latents = self.pipe.scheduler.add_noise(latents, torch.randn_like(latents), self.pipe.scheduler.timesteps[init_step:init_step+1]) + + target = self.pipe( + image=self.image, + height=512, + width=512, + latents=latents, + denoise_beg=init_step, + denoise_end=steps, + output_type='frame', + num_frames=batch_size, + min_guidance_scale=min_guidance_scale, + max_guidance_scale=max_guidance_scale, + num_inference_steps=steps, + decode_chunk_size=1 + ).frames[0] + target = (target + 1) * 0.5 + target = target.permute(1,0,2,3) + return target + + # frames = self.pipe( + # image=self.image, + # height=512, + # width=512, + # latents=latents, + # denoise_beg=init_step, + # denoise_end=steps, + # num_frames=batch_size, + # min_guidance_scale=min_guidance_scale, + # max_guidance_scale=max_guidance_scale, + # num_inference_steps=steps, + # decode_chunk_size=1 + # ).frames[0] + # export_to_gif(frames, f"tmp.gif") + # raise + + def train_step( + self, + pred_rgb, + step_ratio=None, + min_guidance_scale: float = 1.0, + max_guidance_scale: float = 3.0, + ): + + batch_size = pred_rgb.shape[0] + pred_rgb = pred_rgb.to(self.dtype) + + # interp to 512x512 to be fed into vae. + pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False) + # encode image into latents with vae, requires grad! + # latents = self.pipe._encode_image(pred_rgb_512, self.device, num_videos_per_prompt=1, do_classifier_free_guidance=True) + latents = self.encode_image(pred_rgb_512) + latents = latents.unsqueeze(0) + + if step_ratio is not None: + # dreamtime-like + # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) + t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) + t = torch.full((1,), t, dtype=torch.long, device=self.device) + else: + t = torch.randint(self.min_step, self.max_step + 1, (1,), dtype=torch.long, device=self.device) + # print(t) + + w = (1 - self.alphas[t]).view(1, 1, 1, 1) + + + if self.guidance_type == 'sds': + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + t = self.num_train_timesteps - t.item() + # add noise + noise = torch.randn_like(latents) + latents_noisy = self.pipe.scheduler.add_noise(latents, noise, self.pipe.scheduler.timesteps[t:t+1]) # t=0 noise;t=999 clean + noise_pred = self.pipe( + image=self.image, + # image_embeddings=self.embeddings, + height=512, + width=512, + latents=latents_noisy, + output_type='noise', + denoise_beg=t, + denoise_end=t + 1, + min_guidance_scale=min_guidance_scale, + max_guidance_scale=max_guidance_scale, + num_frames=batch_size, + num_inference_steps=self.num_train_timesteps + ).frames[0] + + grad = w * (noise_pred - noise) + grad = torch.nan_to_num(grad) + + target = (latents - grad).detach() + loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[1] + print(loss.item()) + return loss + + elif self.guidance_type == 'pixel reconstruction': + # pixel space reconstruction + if self.target_cache is None: + with torch.no_grad(): + self.target_cache = self.pipe( + image=self.image, + height=512, + width=512, + output_type='frame', + num_frames=batch_size, + num_inference_steps=self.num_train_timesteps, + decode_chunk_size=1 + ).frames[0] + self.target_cache = (self.target_cache + 1) * 0.5 + self.target_cache = self.target_cache.permute(1,0,2,3) + + loss = 0.5 * F.mse_loss(pred_rgb_512.float(), self.target_cache.detach().float(), reduction='sum') / latents.shape[1] + print(loss.item()) + + return loss + + elif self.guidance_type == 'latent reconstruction': + # latent space reconstruction + if self.target_cache is None: + with torch.no_grad(): + self.target_cache = self.pipe( + image=self.image, + height=512, + width=512, + output_type='latent', + num_frames=batch_size, + num_inference_steps=self.num_train_timesteps, + ).frames[0] + + loss = 0.5 * F.mse_loss(latents.float(), self.target_cache.detach().float(), reduction='sum') / latents.shape[1] + print(loss.item()) + + return loss \ No newline at end of file diff --git a/guidance/zero123_utils.py b/guidance/zero123_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a10e90735e892c46c5577eba25691d17256fe4ed --- /dev/null +++ b/guidance/zero123_utils.py @@ -0,0 +1,237 @@ +from diffusers import DDIMScheduler +import torchvision.transforms.functional as TF + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import sys +sys.path.append('./') + +from zero123 import Zero123Pipeline + + +class Zero123(nn.Module): + def __init__(self, device, fp16=True, t_range=[0.02, 0.98], model_key="ashawkey/zero123-xl-diffusers"): + super().__init__() + + self.device = device + self.fp16 = fp16 + self.dtype = torch.float16 if fp16 else torch.float32 + + assert self.fp16, 'Only zero123 fp16 is supported for now.' + + # model_key = "ashawkey/zero123-xl-diffusers" + # model_key = './model_cache/stable_zero123_diffusers' + + self.pipe = Zero123Pipeline.from_pretrained( + model_key, + torch_dtype=self.dtype, + trust_remote_code=True, + ).to(self.device) + + # stable-zero123 has a different camera embedding + self.use_stable_zero123 = 'stable' in model_key + + self.pipe.image_encoder.eval() + self.pipe.vae.eval() + self.pipe.unet.eval() + self.pipe.clip_camera_projection.eval() + + self.vae = self.pipe.vae + self.unet = self.pipe.unet + + self.pipe.set_progress_bar_config(disable=True) + + self.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + + self.min_step = int(self.num_train_timesteps * t_range[0]) + self.max_step = int(self.num_train_timesteps * t_range[1]) + self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience + + self.embeddings = None + + @torch.no_grad() + def get_img_embeds(self, x): + # x: image tensor in [0, 1] + x = F.interpolate(x, (256, 256), mode='bilinear', align_corners=False) + x_pil = [TF.to_pil_image(image) for image in x] + x_clip = self.pipe.feature_extractor(images=x_pil, return_tensors="pt").pixel_values.to(device=self.device, dtype=self.dtype) + c = self.pipe.image_encoder(x_clip).image_embeds + v = self.encode_imgs(x.to(self.dtype)) / self.vae.config.scaling_factor + self.embeddings = [c, v] + return c, v + + def get_cam_embeddings(self, elevation, azimuth, radius, default_elevation=0): + if self.use_stable_zero123: + T = np.stack([np.deg2rad(elevation), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), np.deg2rad([90 + default_elevation] * len(elevation))], axis=-1) + else: + # original zero123 camera embedding + T = np.stack([np.deg2rad(elevation), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1) + T = torch.from_numpy(T).unsqueeze(1).to(dtype=self.dtype, device=self.device) # [8, 1, 4] + return T + + @torch.no_grad() + def refine(self, pred_rgb, elevation, azimuth, radius, + guidance_scale=5, steps=50, strength=0.8, default_elevation=0, + ): + + batch_size = pred_rgb.shape[0] + + self.scheduler.set_timesteps(steps) + + if strength == 0: + init_step = 0 + latents = torch.randn((1, 4, 32, 32), device=self.device, dtype=self.dtype) + else: + init_step = int(steps * strength) + pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) + latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) + latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step]) + + T = self.get_cam_embeddings(elevation, azimuth, radius, default_elevation) + cc_emb = torch.cat([self.embeddings[0].repeat(batch_size, 1, 1), T], dim=-1) + cc_emb = self.pipe.clip_camera_projection(cc_emb) + cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0) + + vae_emb = self.embeddings[1].repeat(batch_size, 1, 1, 1) + vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0) + + for i, t in enumerate(self.scheduler.timesteps[init_step:]): + + x_in = torch.cat([latents] * 2) + t_in = t.view(1).to(self.device) + + noise_pred = self.unet( + torch.cat([x_in, vae_emb], dim=1), + t_in.to(self.unet.dtype), + encoder_hidden_states=cc_emb, + ).sample + + noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + imgs = self.decode_latents(latents) # [1, 3, 256, 256] + return imgs + + def train_step(self, pred_rgb, elevation, azimuth, radius, step_ratio=None, guidance_scale=5, as_latent=False, default_elevation=0): + # pred_rgb: tensor [1, 3, H, W] in [0, 1] + + batch_size = pred_rgb.shape[0] + + if as_latent: + latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1 + else: + pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) + latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) + + if step_ratio is not None: + # dreamtime-like + # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) + # t = self.max_step - (self.max_step - self.min_step) * (step_ratio ** 2) + t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) + t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) + else: + t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) + + w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1) + + with torch.no_grad(): + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + + x_in = torch.cat([latents_noisy] * 2) + t_in = torch.cat([t] * 2) + + T = self.get_cam_embeddings(elevation, azimuth, radius, default_elevation) + cc_emb = torch.cat([self.embeddings[0].unsqueeze(1), T], dim=-1) + cc_emb = self.pipe.clip_camera_projection(cc_emb) + cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0) + + vae_emb = self.embeddings[1] + vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0) + + noise_pred = self.unet( + torch.cat([x_in, vae_emb], dim=1), + t_in.to(self.unet.dtype), + encoder_hidden_states=cc_emb, + ).sample + + noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + + grad = w * (noise_pred - noise) + grad = torch.nan_to_num(grad) + + target = (latents - grad).detach() + loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') + + return loss + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + imgs = self.vae.decode(latents).sample + imgs = (imgs / 2 + 0.5).clamp(0, 1) + + return imgs + + def encode_imgs(self, imgs, mode=False): + # imgs: [B, 3, H, W] + + imgs = 2 * imgs - 1 + + posterior = self.vae.encode(imgs).latent_dist + if mode: + latents = posterior.mode() + else: + latents = posterior.sample() + latents = latents * self.vae.config.scaling_factor + + return latents + + +if __name__ == '__main__': + import cv2 + import argparse + import numpy as np + import matplotlib.pyplot as plt + import kiui + + parser = argparse.ArgumentParser() + + parser.add_argument('input', type=str) + parser.add_argument('--elevation', type=float, default=0, help='delta elevation angle in [-90, 90]') + parser.add_argument('--azimuth', type=float, default=0, help='delta azimuth angle in [-180, 180]') + parser.add_argument('--radius', type=float, default=0, help='delta camera radius multiplier in [-0.5, 0.5]') + parser.add_argument('--stable', action='store_true') + + opt = parser.parse_args() + + device = torch.device('cuda') + + print(f'[INFO] loading image from {opt.input} ...') + image = kiui.read_image(opt.input, mode='tensor') + image = image.permute(2, 0, 1).unsqueeze(0).contiguous().to(device) + image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False) + + print(f'[INFO] loading model ...') + + if opt.stable: + zero123 = Zero123(device, model_key='ashawkey/stable-zero123-diffusers') + else: + zero123 = Zero123(device, model_key='ashawkey/zero123-xl-diffusers') + + print(f'[INFO] running model ...') + zero123.get_img_embeds(image) + + azimuth = opt.azimuth + while True: + outputs = zero123.refine(image, elevation=[opt.elevation], azimuth=[azimuth], radius=[opt.radius], strength=0) + plt.imshow(outputs.float().cpu().numpy().transpose(0, 2, 3, 1)[0]) + plt.show() + azimuth = (azimuth + 10) % 360 diff --git a/lgm/core/attention.py b/lgm/core/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..9dddadfc7591ed3a3844853d7e44996e98de599d --- /dev/null +++ b/lgm/core/attention.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import os +import warnings + +from torch import Tensor +from torch import nn + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Attention)") + else: + warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CrossAttention(nn.Module): + def __init__( + self, + dim: int, + dim_q: int, + dim_k: int, + dim_v: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.to_q = nn.Linear(dim_q, dim, bias=qkv_bias) + self.to_k = nn.Linear(dim_k, dim, bias=qkv_bias) + self.to_v = nn.Linear(dim_v, dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # q: [B, N, Cq] + # k: [B, M, Ck] + # v: [B, M, Cv] + # return: [B, N, C] + + B, N, _ = q.shape + M = k.shape[1] + + q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, N, C/nh] + k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh] + v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh] + + attn = q @ k.transpose(-2, -1) # [B, nh, N, M] + + attn = attn.softmax(dim=-1) # [B, nh, N, M] + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) # [B, nh, N, M] @ [B, nh, M, C/nh] --> [B, nh, N, C/nh] --> [B, N, nh, C/nh] --> [B, N, C] + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffCrossAttention(CrossAttention): + def forward(self, q: Tensor, k: Tensor, v: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, _ = q.shape + M = k.shape[1] + + q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads) # [B, N, nh, C/nh] + k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh] + v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh] + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape(B, N, -1) + + x = self.proj(x) + x = self.proj_drop(x) + return x \ No newline at end of file diff --git a/lgm/core/gs.py b/lgm/core/gs.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a56c664578f1e682bc9ceb8389ecab185b62cf --- /dev/null +++ b/lgm/core/gs.py @@ -0,0 +1,190 @@ +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diff_gaussian_rasterization import ( + GaussianRasterizationSettings, + GaussianRasterizer, +) + +from core.options import Options + +import kiui + +class GaussianRenderer: + def __init__(self, opt: Options): + + self.opt = opt + self.bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda") + + # intrinsics + self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy)) + self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32) + self.proj_matrix[0, 0] = 1 / self.tan_half_fov + self.proj_matrix[1, 1] = 1 / self.tan_half_fov + self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) + self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear) + self.proj_matrix[2, 3] = 1 + + def render(self, gaussians, cam_view, cam_view_proj, cam_pos, bg_color=None, scale_modifier=1): + # gaussians: [B, N, 14] + # cam_view, cam_view_proj: [B, V, 4, 4] + # cam_pos: [B, V, 3] + + device = gaussians.device + B, V = cam_view.shape[:2] + + # loop of loop... + images = [] + alphas = [] + for b in range(B): + + # pos, opacity, scale, rotation, shs + means3D = gaussians[b, :, 0:3].contiguous().float() + opacity = gaussians[b, :, 3:4].contiguous().float() + scales = gaussians[b, :, 4:7].contiguous().float() + rotations = gaussians[b, :, 7:11].contiguous().float() + rgbs = gaussians[b, :, 11:].contiguous().float() # [N, 3] + + for v in range(V): + + # render novel views + view_matrix = cam_view[b, v].float() + view_proj_matrix = cam_view_proj[b, v].float() + campos = cam_pos[b, v].float() + + raster_settings = GaussianRasterizationSettings( + image_height=self.opt.output_size, + image_width=self.opt.output_size, + tanfovx=self.tan_half_fov, + tanfovy=self.tan_half_fov, + bg=self.bg_color if bg_color is None else bg_color, + scale_modifier=scale_modifier, + viewmatrix=view_matrix, + projmatrix=view_proj_matrix, + sh_degree=0, + campos=campos, + prefiltered=False, + debug=False, + ) + + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + # Rasterize visible Gaussians to image, obtain their radii (on screen). + rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( + means3D=means3D, + means2D=torch.zeros_like(means3D, dtype=torch.float32, device=device), + shs=None, + colors_precomp=rgbs, + opacities=opacity, + scales=scales, + rotations=rotations, + cov3D_precomp=None, + ) + + rendered_image = rendered_image.clamp(0, 1) + + images.append(rendered_image) + alphas.append(rendered_alpha) + + images = torch.stack(images, dim=0).view(B, V, 3, self.opt.output_size, self.opt.output_size) + alphas = torch.stack(alphas, dim=0).view(B, V, 1, self.opt.output_size, self.opt.output_size) + + return { + "image": images, # [B, V, 3, H, W] + "alpha": alphas, # [B, V, 1, H, W] + } + + + def save_ply(self, gaussians, path, compatible=True): + # gaussians: [B, N, 14] + # compatible: save pre-activated gaussians as in the original paper + + assert gaussians.shape[0] == 1, 'only support batch size 1' + + from plyfile import PlyData, PlyElement + + means3D = gaussians[0, :, 0:3].contiguous().float() + opacity = gaussians[0, :, 3:4].contiguous().float() + scales = gaussians[0, :, 4:7].contiguous().float() + rotations = gaussians[0, :, 7:11].contiguous().float() + shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float() # [N, 1, 3] + + # prune by opacity + mask = opacity.squeeze(-1) >= 0.005 + means3D = means3D[mask] + opacity = opacity[mask] + scales = scales[mask] + rotations = rotations[mask] + shs = shs[mask] + + # invert activation to make it compatible with the original ply format + if compatible: + opacity = kiui.op.inverse_sigmoid(opacity) + scales = torch.log(scales + 1e-8) + shs = (shs - 0.5) / 0.28209479177387814 + + xyzs = means3D.detach().cpu().numpy() + f_dc = shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + opacities = opacity.detach().cpu().numpy() + scales = scales.detach().cpu().numpy() + rotations = rotations.detach().cpu().numpy() + + l = ['x', 'y', 'z'] + # All channels except the 3 DC + for i in range(f_dc.shape[1]): + l.append('f_dc_{}'.format(i)) + l.append('opacity') + for i in range(scales.shape[1]): + l.append('scale_{}'.format(i)) + for i in range(rotations.shape[1]): + l.append('rot_{}'.format(i)) + + dtype_full = [(attribute, 'f4') for attribute in l] + + elements = np.empty(xyzs.shape[0], dtype=dtype_full) + attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1) + elements[:] = list(map(tuple, attributes)) + el = PlyElement.describe(elements, 'vertex') + + PlyData([el]).write(path) + + def load_ply(self, path, compatible=True): + + from plyfile import PlyData, PlyElement + + plydata = PlyData.read(path) + + xyz = np.stack((np.asarray(plydata.elements[0]["x"]), + np.asarray(plydata.elements[0]["y"]), + np.asarray(plydata.elements[0]["z"])), axis=1) + print("Number of points at loading : ", xyz.shape[0]) + + opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] + + shs = np.zeros((xyz.shape[0], 3)) + shs[:, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) + shs[:, 1] = np.asarray(plydata.elements[0]["f_dc_1"]) + shs[:, 2] = np.asarray(plydata.elements[0]["f_dc_2"]) + + scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] + scales = np.zeros((xyz.shape[0], len(scale_names))) + for idx, attr_name in enumerate(scale_names): + scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot_")] + rots = np.zeros((xyz.shape[0], len(rot_names))) + for idx, attr_name in enumerate(rot_names): + rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + gaussians = np.concatenate([xyz, opacities, scales, rots, shs], axis=1) + gaussians = torch.from_numpy(gaussians).float() # cpu + + if compatible: + gaussians[..., 3:4] = torch.sigmoid(gaussians[..., 3:4]) + gaussians[..., 4:7] = torch.exp(gaussians[..., 4:7]) + gaussians[..., 11:] = 0.28209479177387814 * gaussians[..., 11:] + 0.5 + + return gaussians \ No newline at end of file diff --git a/lgm/core/models.py b/lgm/core/models.py new file mode 100644 index 0000000000000000000000000000000000000000..3c58b33c42661d5c0fe53238e6c6988b8bccf48a --- /dev/null +++ b/lgm/core/models.py @@ -0,0 +1,171 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +import kiui +from kiui.lpips import LPIPS + +from core.unet import UNet +from core.options import Options +from core.gs import GaussianRenderer + + +class LGM(nn.Module): + def __init__( + self, + opt: Options, + ): + super().__init__() + + self.opt = opt + + # unet + self.unet = UNet( + 9, 14, + down_channels=self.opt.down_channels, + down_attention=self.opt.down_attention, + mid_attention=self.opt.mid_attention, + up_channels=self.opt.up_channels, + up_attention=self.opt.up_attention, + ) + + # last conv + self.conv = nn.Conv2d(14, 14, kernel_size=1) # NOTE: maybe remove it if train again + + # Gaussian Renderer + self.gs = GaussianRenderer(opt) + + # activations... + self.pos_act = lambda x: x.clamp(-1, 1) + self.scale_act = lambda x: 0.1 * F.softplus(x) + self.opacity_act = lambda x: torch.sigmoid(x) + self.rot_act = lambda x: F.normalize(x, dim=-1) + self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5 # NOTE: may use sigmoid if train again + + # LPIPS loss + if self.opt.lambda_lpips > 0: + self.lpips_loss = LPIPS(net='vgg') + self.lpips_loss.requires_grad_(False) + + + def state_dict(self, **kwargs): + # remove lpips_loss + state_dict = super().state_dict(**kwargs) + for k in list(state_dict.keys()): + if 'lpips_loss' in k: + del state_dict[k] + return state_dict + + + def prepare_default_rays(self, device, elevation=0): + + from kiui.cam import orbit_camera + from core.utils import get_rays + + cam_poses = np.stack([ + orbit_camera(elevation, 0, radius=self.opt.cam_radius), + orbit_camera(elevation, 90, radius=self.opt.cam_radius), + orbit_camera(elevation, 180, radius=self.opt.cam_radius), + orbit_camera(elevation, 270, radius=self.opt.cam_radius), + ], axis=0) # [4, 4, 4] + cam_poses = torch.from_numpy(cam_poses) + + rays_embeddings = [] + for i in range(cam_poses.shape[0]): + rays_o, rays_d = get_rays(cam_poses[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3] + rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6] + rays_embeddings.append(rays_plucker) + + ## visualize rays for plotting figure + # kiui.vis.plot_image(rays_d * 0.5 + 0.5, save=True) + + rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) # [V, 6, h, w] + + return rays_embeddings + + + def forward_gaussians(self, images): + # images: [B, 4, 9, H, W] + # return: Gaussians: [B, dim_t] + + B, V, C, H, W = images.shape + images = images.view(B*V, C, H, W) + + x = self.unet(images) # [B*4, 14, h, w] + x = self.conv(x) # [B*4, 14, h, w] + + x = x.reshape(B, 4, 14, self.opt.splat_size, self.opt.splat_size) + + ## visualize multi-view gaussian features for plotting figure + # tmp_alpha = self.opacity_act(x[0, :, 3:4]) + # tmp_img_rgb = self.rgb_act(x[0, :, 11:]) * tmp_alpha + (1 - tmp_alpha) + # tmp_img_pos = self.pos_act(x[0, :, 0:3]) * 0.5 + 0.5 + # kiui.vis.plot_image(tmp_img_rgb, save=True) + # kiui.vis.plot_image(tmp_img_pos, save=True) + + x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14) + + pos = self.pos_act(x[..., 0:3]) # [B, N, 3] + opacity = self.opacity_act(x[..., 3:4]) + scale = self.scale_act(x[..., 4:7]) + rotation = self.rot_act(x[..., 7:11]) + rgbs = self.rgb_act(x[..., 11:]) + + gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14] + + return gaussians + + + def forward(self, data, step_ratio=1): + # data: output of the dataloader + # return: loss + + results = {} + loss = 0 + + images = data['input'] # [B, 4, 9, h, W], input features + + # use the first view to predict gaussians + gaussians = self.forward_gaussians(images) # [B, N, 14] + + results['gaussians'] = gaussians + + # always use white bg + bg_color = torch.ones(3, dtype=torch.float32, device=gaussians.device) + + # use the other views for rendering and supervision + results = self.gs.render(gaussians, data['cam_view'], data['cam_view_proj'], data['cam_pos'], bg_color=bg_color) + pred_images = results['image'] # [B, V, C, output_size, output_size] + pred_alphas = results['alpha'] # [B, V, 1, output_size, output_size] + + results['images_pred'] = pred_images + results['alphas_pred'] = pred_alphas + + gt_images = data['images_output'] # [B, V, 3, output_size, output_size], ground-truth novel views + gt_masks = data['masks_output'] # [B, V, 1, output_size, output_size], ground-truth masks + + gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks) + + loss_mse = F.mse_loss(pred_images, gt_images) + F.mse_loss(pred_alphas, gt_masks) + loss = loss + loss_mse + + if self.opt.lambda_lpips > 0: + loss_lpips = self.lpips_loss( + # gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, + # pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, + # downsampled to at most 256 to reduce memory cost + F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False), + F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False), + ).mean() + results['loss_lpips'] = loss_lpips + loss = loss + self.opt.lambda_lpips * loss_lpips + + results['loss'] = loss + + # metric + with torch.no_grad(): + psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2)) + results['psnr'] = psnr + + return results \ No newline at end of file diff --git a/lgm/core/options.py b/lgm/core/options.py new file mode 100644 index 0000000000000000000000000000000000000000..f3b2559fc63487cac918e463a3d6a016d875597e --- /dev/null +++ b/lgm/core/options.py @@ -0,0 +1,120 @@ +import tyro +from dataclasses import dataclass +from typing import Tuple, Literal, Dict, Optional + + +@dataclass +class Options: + ### model + # Unet image input size + input_size: int = 256 + # Unet definition + down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024) + down_attention: Tuple[bool, ...] = (False, False, False, True, True, True) + mid_attention: bool = True + up_channels: Tuple[int, ...] = (1024, 1024, 512, 256) + up_attention: Tuple[bool, ...] = (True, True, True, False) + # Unet output size, dependent on the input_size and U-Net structure! + splat_size: int = 64 + # gaussian render size + output_size: int = 256 + + ### dataset + # data mode (only support s3 now) + data_mode: Literal['s3'] = 's3' + # fovy of the dataset + fovy: float = 49.1 + # camera near plane + znear: float = 0.5 + # camera far plane + zfar: float = 2.5 + # number of all views (input + output) + num_views: int = 12 + # number of views + num_input_views: int = 4 + # camera radius + cam_radius: float = 1.5 # to better use [-1, 1]^3 space + # num workers + num_workers: int = 8 + + ### training + # workspace + workspace: str = './workspace' + # resume + resume: Optional[str] = 'pretrained/model_fp16_fixrot.safetensors' + # batch size (per-GPU) + batch_size: int = 8 + # gradient accumulation + gradient_accumulation_steps: int = 1 + # training epochs + num_epochs: int = 30 + # lpips loss weight + lambda_lpips: float = 1.0 + # gradient clip + gradient_clip: float = 1.0 + # mixed precision + mixed_precision: str = 'bf16' + # learning rate + lr: float = 4e-4 + # augmentation prob for grid distortion + prob_grid_distortion: float = 0.5 + # augmentation prob for camera jitter + prob_cam_jitter: float = 0.5 + + ### testing + # test image path + test_path: Optional[str] = None + + ### misc + # nvdiffrast backend setting + force_cuda_rast: bool = False + # render fancy video with gaussian scaling effect + fancy_video: bool = False + + +# all the default settings +config_defaults: Dict[str, Options] = {} +config_doc: Dict[str, str] = {} + +config_doc['lrm'] = 'the default settings for LGM' +config_defaults['lrm'] = Options() + +config_doc['small'] = 'small model with lower resolution Gaussians' +config_defaults['small'] = Options( + input_size=256, + splat_size=64, + output_size=256, + batch_size=8, + gradient_accumulation_steps=1, + mixed_precision='bf16', +) + +config_doc['big'] = 'big model with higher resolution Gaussians' +config_defaults['big'] = Options( + input_size=256, + up_channels=(1024, 1024, 512, 256, 128), # one more decoder + up_attention=(True, True, True, False, False), + splat_size=128, + output_size=512, # render & supervise Gaussians at a higher resolution. + batch_size=8, + num_views=8, + gradient_accumulation_steps=1, + mixed_precision='bf16', +) + +config_doc['tiny'] = 'tiny model for ablation' +config_defaults['tiny'] = Options( + input_size=256, + down_channels=(32, 64, 128, 256, 512), + down_attention=(False, False, False, False, True), + up_channels=(512, 256, 128), + up_attention=(True, False, False, False), + splat_size=64, + output_size=256, + batch_size=16, + num_views=8, + gradient_accumulation_steps=1, + mixed_precision='bf16', +) + +AllConfigs = tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc) \ No newline at end of file diff --git a/lgm/core/unet.py b/lgm/core/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..1dd38bc627c359b523a4e46cd91fd0add9036d6c --- /dev/null +++ b/lgm/core/unet.py @@ -0,0 +1,319 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np +from typing import Tuple, Literal +from functools import partial + +from core.attention import MemEffAttention + +class MVAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + groups: int = 32, + eps: float = 1e-5, + residual: bool = True, + skip_scale: float = 1, + num_frames: int = 4, # WARN: hardcoded! + ): + super().__init__() + + self.residual = residual + self.skip_scale = skip_scale + self.num_frames = num_frames + + self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim, eps=eps, affine=True) + self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop) + + def forward(self, x): + # x: [B*V, C, H, W] + BV, C, H, W = x.shape + B = BV // self.num_frames # assert BV % self.num_frames == 0 + + res = x + x = self.norm(x) + + x = x.reshape(B, self.num_frames, C, H, W).permute(0, 1, 3, 4, 2).reshape(B, -1, C) + x = self.attn(x) + x = x.reshape(B, self.num_frames, H, W, C).permute(0, 1, 4, 2, 3).reshape(BV, C, H, W) + + if self.residual: + x = (x + res) * self.skip_scale + return x + +class ResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + resample: Literal['default', 'up', 'down'] = 'default', + groups: int = 32, + eps: float = 1e-5, + skip_scale: float = 1, # multiplied to output + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.skip_scale = skip_scale + + self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + self.act = F.silu + + self.resample = None + if resample == 'up': + self.resample = partial(F.interpolate, scale_factor=2.0, mode="nearest") + elif resample == 'down': + self.resample = nn.AvgPool2d(kernel_size=2, stride=2) + + self.shortcut = nn.Identity() + if self.in_channels != self.out_channels: + self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True) + + + def forward(self, x): + res = x + + x = self.norm1(x) + x = self.act(x) + + if self.resample: + res = self.resample(res) + x = self.resample(x) + + x = self.conv1(x) + x = self.norm2(x) + x = self.act(x) + x = self.conv2(x) + + x = (x + self.shortcut(res)) * self.skip_scale + + return x + +class DownBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + downsample: bool = True, + attention: bool = True, + attention_heads: int = 16, + skip_scale: float = 1, + ): + super().__init__() + + nets = [] + attns = [] + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + nets.append(ResnetBlock(in_channels, out_channels, skip_scale=skip_scale)) + if attention: + attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale)) + else: + attns.append(None) + self.nets = nn.ModuleList(nets) + self.attns = nn.ModuleList(attns) + + self.downsample = None + if downsample: + self.downsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1) + + def forward(self, x): + xs = [] + + for attn, net in zip(self.attns, self.nets): + x = net(x) + if attn: + x = attn(x) + xs.append(x) + + if self.downsample: + x = self.downsample(x) + xs.append(x) + + return x, xs + + +class MidBlock(nn.Module): + def __init__( + self, + in_channels: int, + num_layers: int = 1, + attention: bool = True, + attention_heads: int = 16, + skip_scale: float = 1, + ): + super().__init__() + + nets = [] + attns = [] + # first layer + nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale)) + # more layers + for i in range(num_layers): + nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale)) + if attention: + attns.append(MVAttention(in_channels, attention_heads, skip_scale=skip_scale)) + else: + attns.append(None) + self.nets = nn.ModuleList(nets) + self.attns = nn.ModuleList(attns) + + def forward(self, x): + x = self.nets[0](x) + for attn, net in zip(self.attns, self.nets[1:]): + if attn: + x = attn(x) + x = net(x) + return x + + +class UpBlock(nn.Module): + def __init__( + self, + in_channels: int, + prev_out_channels: int, + out_channels: int, + num_layers: int = 1, + upsample: bool = True, + attention: bool = True, + attention_heads: int = 16, + skip_scale: float = 1, + ): + super().__init__() + + nets = [] + attns = [] + for i in range(num_layers): + cin = in_channels if i == 0 else out_channels + cskip = prev_out_channels if (i == num_layers - 1) else out_channels + + nets.append(ResnetBlock(cin + cskip, out_channels, skip_scale=skip_scale)) + if attention: + attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale)) + else: + attns.append(None) + self.nets = nn.ModuleList(nets) + self.attns = nn.ModuleList(attns) + + self.upsample = None + if upsample: + self.upsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x, xs): + + for attn, net in zip(self.attns, self.nets): + res_x = xs[-1] + xs = xs[:-1] + x = torch.cat([x, res_x], dim=1) + x = net(x) + if attn: + x = attn(x) + + if self.upsample: + x = F.interpolate(x, scale_factor=2.0, mode='nearest') + x = self.upsample(x) + + return x + + +# it could be asymmetric! +class UNet(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024), + down_attention: Tuple[bool, ...] = (False, False, False, True, True), + mid_attention: bool = True, + up_channels: Tuple[int, ...] = (1024, 512, 256), + up_attention: Tuple[bool, ...] = (True, True, False), + layers_per_block: int = 2, + skip_scale: float = np.sqrt(0.5), + ): + super().__init__() + + # first + self.conv_in = nn.Conv2d(in_channels, down_channels[0], kernel_size=3, stride=1, padding=1) + + # down + down_blocks = [] + cout = down_channels[0] + for i in range(len(down_channels)): + cin = cout + cout = down_channels[i] + + down_blocks.append(DownBlock( + cin, cout, + num_layers=layers_per_block, + downsample=(i != len(down_channels) - 1), # not final layer + attention=down_attention[i], + skip_scale=skip_scale, + )) + self.down_blocks = nn.ModuleList(down_blocks) + + # mid + self.mid_block = MidBlock(down_channels[-1], attention=mid_attention, skip_scale=skip_scale) + + # up + up_blocks = [] + cout = up_channels[0] + for i in range(len(up_channels)): + cin = cout + cout = up_channels[i] + cskip = down_channels[max(-2 - i, -len(down_channels))] # for assymetric + + up_blocks.append(UpBlock( + cin, cskip, cout, + num_layers=layers_per_block + 1, # one more layer for up + upsample=(i != len(up_channels) - 1), # not final layer + attention=up_attention[i], + skip_scale=skip_scale, + )) + self.up_blocks = nn.ModuleList(up_blocks) + + # last + self.norm_out = nn.GroupNorm(num_channels=up_channels[-1], num_groups=32, eps=1e-5) + self.conv_out = nn.Conv2d(up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1) + + + def forward(self, x): + # x: [B, Cin, H, W] + + # first + x = self.conv_in(x) + + # down + xss = [x] + for block in self.down_blocks: + x, xs = block(x) + xss.extend(xs) + + # mid + x = self.mid_block(x) + + # up + for block in self.up_blocks: + xs = xss[-len(block.nets):] + xss = xss[:-len(block.nets)] + x = block(x, xs) + + # last + x = self.norm_out(x) + x = F.silu(x) + x = self.conv_out(x) # [B, Cout, H', W'] + + return x \ No newline at end of file diff --git a/lgm/core/utils.py b/lgm/core/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a2ef958a0796783ce8735112644be6b37f4e313 --- /dev/null +++ b/lgm/core/utils.py @@ -0,0 +1,108 @@ +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import roma +from kiui.op import safe_normalize + +def get_rays(pose, h, w, fovy, opengl=True): + + x, y = torch.meshgrid( + torch.arange(w, device=pose.device), + torch.arange(h, device=pose.device), + indexing="xy", + ) + x = x.flatten() + y = y.flatten() + + cx = w * 0.5 + cy = h * 0.5 + + focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy)) + + camera_dirs = F.pad( + torch.stack( + [ + (x - cx + 0.5) / focal, + (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0), + ], + dim=-1, + ), + (0, 1), + value=(-1.0 if opengl else 1.0), + ) # [hw, 3] + + rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3] + rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3] + + rays_o = rays_o.view(h, w, 3) + rays_d = safe_normalize(rays_d).view(h, w, 3) + + return rays_o, rays_d + +def orbit_camera_jitter(poses, strength=0.1): + # poses: [B, 4, 4], assume orbit camera in opengl format + # random orbital rotate + + B = poses.shape[0] + rotvec_x = poses[:, :3, 1] * strength * np.pi * (torch.rand(B, 1, device=poses.device) * 2 - 1) + rotvec_y = poses[:, :3, 0] * strength * np.pi / 2 * (torch.rand(B, 1, device=poses.device) * 2 - 1) + + rot = roma.rotvec_to_rotmat(rotvec_x) @ roma.rotvec_to_rotmat(rotvec_y) + R = rot @ poses[:, :3, :3] + T = rot @ poses[:, :3, 3:] + + new_poses = poses.clone() + new_poses[:, :3, :3] = R + new_poses[:, :3, 3:] = T + + return new_poses + +def grid_distortion(images, strength=0.5): + # images: [B, C, H, W] + # num_steps: int, grid resolution for distortion + # strength: float in [0, 1], strength of distortion + + B, C, H, W = images.shape + + num_steps = np.random.randint(8, 17) + grid_steps = torch.linspace(-1, 1, num_steps) + + # have to loop batch... + grids = [] + for b in range(B): + # construct displacement + x_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive + x_steps = (x_steps + strength * (torch.rand_like(x_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb + x_steps = (x_steps * W).long() # [num_steps] + x_steps[0] = 0 + x_steps[-1] = W + xs = [] + for i in range(num_steps - 1): + xs.append(torch.linspace(grid_steps[i], grid_steps[i + 1], x_steps[i + 1] - x_steps[i])) + xs = torch.cat(xs, dim=0) # [W] + + y_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive + y_steps = (y_steps + strength * (torch.rand_like(y_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb + y_steps = (y_steps * H).long() # [num_steps] + y_steps[0] = 0 + y_steps[-1] = H + ys = [] + for i in range(num_steps - 1): + ys.append(torch.linspace(grid_steps[i], grid_steps[i + 1], y_steps[i + 1] - y_steps[i])) + ys = torch.cat(ys, dim=0) # [H] + + # construct grid + grid_x, grid_y = torch.meshgrid(xs, ys, indexing='xy') # [H, W] + grid = torch.stack([grid_x, grid_y], dim=-1) # [H, W, 2] + + grids.append(grid) + + grids = torch.stack(grids, dim=0).to(images.device) # [B, H, W, 2] + + # grid sample + images = F.grid_sample(images, grids, align_corners=False) + + return images diff --git a/lgm/infer.py b/lgm/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..7e6ef106214ba23afaefef022c639e1a7e3d6cba --- /dev/null +++ b/lgm/infer.py @@ -0,0 +1,226 @@ + +import os +import tyro +import glob +import imageio +import numpy as np +import tqdm +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from safetensors.torch import load_file +import rembg + +import kiui +from kiui.op import recenter +from kiui.cam import orbit_camera + +from core.options import AllConfigs, Options +from core.models import LGM +from mvdream.pipeline_mvdream import MVDreamPipeline +import cv2 + +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + +opt = tyro.cli(AllConfigs) + +# model +model = LGM(opt) + +# resume pretrained checkpoint +if opt.resume is not None: + if opt.resume.endswith('safetensors'): + ckpt = load_file(opt.resume, device='cpu') + else: + ckpt = torch.load(opt.resume, map_location='cpu') + model.load_state_dict(ckpt, strict=False) + print(f'[INFO] Loaded checkpoint from {opt.resume}') +else: + print(f'[WARN] model randomly initialized, are you sure?') + +# device +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +model = model.half().to(device) +model.eval() + +rays_embeddings = model.prepare_default_rays(device) + +tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy)) +proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device) +proj_matrix[0, 0] = 1 / tan_half_fov +proj_matrix[1, 1] = 1 / tan_half_fov +proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) +proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear) +proj_matrix[2, 3] = 1 + +# load image dream +pipe = MVDreamPipeline.from_pretrained( + "ashawkey/imagedream-ipmv-diffusers", # remote weights + torch_dtype=torch.float16, + trust_remote_code=True, + # local_files_only=True, +) +pipe = pipe.to(device) + +# load rembg +bg_remover = rembg.new_session() + +# process function +def process(opt: Options, path): + name = os.path.splitext(os.path.basename(path))[0] + if 'CONSISTENT4D' in path: + name = path.split('/')[-2] + print(f'[INFO] Processing {path} --> {name}') + os.makedirs('vis_data', exist_ok=True) + os.makedirs('logs', exist_ok=True) + + input_image = kiui.read_image(path, mode='uint8') + + # bg removal + carved_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4] + mask = carved_image[..., -1] > 0 + + # recenter + image = recenter(carved_image, mask, border_ratio=0.2) + + # generate mv + image = image.astype(np.float32) / 255.0 + + # rgba to rgb white bg + if image.shape[-1] == 4: + image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) + + mv_image = pipe('', image, guidance_scale=5.0, num_inference_steps=30, elevation=0) + mv_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32 + + # generate gaussians + input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256] + input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False) + input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + + input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W] + + with torch.no_grad(): + ############## align azimuth ##################### + with torch.autocast(device_type='cuda', dtype=torch.float16): + # generate gaussians + gaussians = model.forward_gaussians(input_image) + + best_azi = 0 + best_diff = 1e8 + for v, azi in enumerate(np.arange(-180, 180, 1)): + cam_poses = torch.from_numpy(orbit_camera(0, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device) + + cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction + + # cameras needed by gaussian rasterizer + cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] + cam_view_proj = cam_view @ proj_matrix # [V, 4, 4] + cam_pos = - cam_poses[:, :3, 3] # [V, 3] + + # scale = min(azi / 360, 1) + scale = 1 + + + result = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale) + rendered_image = result['image'] + + rendered_image = rendered_image.squeeze(1).permute(0,2,3,1).squeeze(0).contiguous().float().cpu().numpy() + rendered_image = cv2.resize(rendered_image, (image.shape[0], image.shape[1]), interpolation=cv2.INTER_AREA) + + diff = np.mean((rendered_image- image) ** 2) + + if diff < best_diff: + best_diff = diff + best_azi = azi + print("Best aligned azimuth: ", best_azi) + + mv_image = [] + for v, azi in enumerate([0, 90, 180, 270]): + cam_poses = torch.from_numpy(orbit_camera(0, azi + best_azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device) + + cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction + + # cameras needed by gaussian rasterizer + cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] + cam_view_proj = cam_view @ proj_matrix # [V, 4, 4] + cam_pos = - cam_poses[:, :3, 3] # [V, 3] + + # scale = min(azi / 360, 1) + scale = 1 + + + result = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale) + rendered_image = result['image'] + rendered_image = rendered_image.squeeze(1) + rendered_image = F.interpolate(rendered_image, (256, 256)) + rendered_image = rendered_image.permute(0,2,3,1).contiguous().float().cpu().numpy() + mv_image.append(rendered_image) + mv_image = np.concatenate(mv_image, axis=0) + + input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256] + input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False) + input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + + input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W] + + ################################ + + with torch.autocast(device_type='cuda', dtype=torch.float16): + # generate gaussians + gaussians = model.forward_gaussians(input_image) + + # save gaussians + model.gs.save_ply(gaussians, os.path.join('logs', name + '_model.ply')) + + # render 360 video + images = [] + elevation = 0 + + if opt.fancy_video: + + azimuth = np.arange(0, 720, 4, dtype=np.int32) + for azi in tqdm.tqdm(azimuth): + + cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device) + + cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction + + # cameras needed by gaussian rasterizer + cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] + cam_view_proj = cam_view @ proj_matrix # [V, 4, 4] + cam_pos = - cam_poses[:, :3, 3] # [V, 3] + + scale = min(azi / 360, 1) + + image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)['image'] + images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8)) + else: + azimuth = np.arange(0, 360, 2, dtype=np.int32) + for azi in tqdm.tqdm(azimuth): + + cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device) + + cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction + + # cameras needed by gaussian rasterizer + cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] + cam_view_proj = cam_view @ proj_matrix # [V, 4, 4] + cam_pos = - cam_poses[:, :3, 3] # [V, 3] + + image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image'] + images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8)) + + images = np.concatenate(images, axis=0) + imageio.mimwrite(os.path.join('vis_data', name + '_static.mp4'), images, fps=30) + + +assert opt.test_path is not None +if os.path.isdir(opt.test_path): + file_paths = glob.glob(os.path.join(opt.test_path, "*")) +else: + file_paths = [opt.test_path] +for path in file_paths: + process(opt, path) \ No newline at end of file diff --git a/lgm/mvdream/mv_unet.py b/lgm/mvdream/mv_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..7d9ad4def5910394eb64b36f9f76c98e8eaf80ae --- /dev/null +++ b/lgm/mvdream/mv_unet.py @@ -0,0 +1,1005 @@ +import math +import numpy as np +from inspect import isfunction +from typing import Optional, Any, List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +from diffusers.configuration_utils import ConfigMixin +from diffusers.models.modeling_utils import ModelMixin + +# require xformers! +import xformers +import xformers.ops + +from kiui.cam import orbit_camera + +def get_camera( + num_frames, elevation=0, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False, +): + angle_gap = azimuth_span / num_frames + cameras = [] + for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap): + + pose = orbit_camera(elevation, azimuth, radius=1) # [4, 4] + + # opengl to blender + if blender_coord: + pose[2] *= -1 + pose[[1, 2]] = pose[[2, 1]] + + cameras.append(pose.flatten()) + + if extra_view: + cameras.append(np.zeros_like(cameras[0])) + + return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16] + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None] * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + # import pdb; pdb.set_trace() + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def default(val, d): + if val is not None: + return val + return d() if isfunction(d) else d + + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + ip_dim=0, + ip_weight=1, + ): + super().__init__() + + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.ip_dim = ip_dim + self.ip_weight = ip_weight + + if self.ip_dim > 0: + self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None): + q = self.to_q(x) + context = default(context, x) + + if self.ip_dim > 0: + # context: [B, 77 + 16(ip), 1024] + token_len = context.shape[1] + context_ip = context[:, -self.ip_dim :, :] + k_ip = self.to_k_ip(context_ip) + v_ip = self.to_v_ip(context_ip) + context = context[:, : (token_len - self.ip_dim), :] + + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention( + q, k, v, attn_bias=None, op=self.attention_op + ) + + if self.ip_dim > 0: + k_ip, v_ip = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (k_ip, v_ip), + ) + # actually compute the attention, what we cannot get enough of + out_ip = xformers.ops.memory_efficient_attention( + q, k_ip, v_ip, attn_bias=None, op=self.attention_op + ) + out = out + self.ip_weight * out_ip + + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + + +class BasicTransformerBlock3D(nn.Module): + + def __init__( + self, + dim, + n_heads, + d_head, + context_dim, + dropout=0.0, + gated_ff=True, + ip_dim=0, + ip_weight=1, + ): + super().__init__() + + self.attn1 = MemoryEfficientCrossAttention( + query_dim=dim, + context_dim=None, # self-attention + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = MemoryEfficientCrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + # ip only applies to cross-attention + ip_dim=ip_dim, + ip_weight=ip_weight, + ) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + + def forward(self, x, context=None, num_frames=1): + x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous() + x = self.attn1(self.norm1(x), context=None) + x + x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous() + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer3D(nn.Module): + + def __init__( + self, + in_channels, + n_heads, + d_head, + context_dim, # cross attention input dim + depth=1, + dropout=0.0, + ip_dim=0, + ip_weight=1, + ): + super().__init__() + + if not isinstance(context_dim, list): + context_dim = [context_dim] + + self.in_channels = in_channels + + inner_dim = n_heads * d_head + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock3D( + inner_dim, + n_heads, + d_head, + context_dim=context_dim[d], + dropout=dropout, + ip_dim=ip_dim, + ip_weight=ip_weight, + ) + for d in range(depth) + ] + ) + + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + + + def forward(self, x, context=None, num_frames=1): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i], num_frames=num_frames) + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() + + return x + x_in + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head ** -0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q, k, v = map( + lambda t: t.reshape(b, t.shape[1], self.heads, -1) + .transpose(1, 2) + .reshape(b, self.heads, t.shape[1], -1) + .contiguous(), + (q, k, v), + ) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + ): + super().__init__() + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5) + self.proj_in = nn.Linear(embedding_dim, dim) + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim * ff_mult, bias=False), + nn.GELU(), + nn.Linear(dim * ff_mult, dim, bias=False), + ) + ] + ) + ) + + def forward(self, x): + latents = self.latents.repeat(x.size(0), 1, 1) + x = self.proj_in(x) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) + + +class CondSequential(nn.Sequential): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None, num_frames=1): + for layer in self: + if isinstance(layer, ResBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer3D): + x = layer(x, context, num_frames=num_frames) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(nn.Module): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + nn.GroupNorm(32, channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + nn.GroupNorm(32, self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class MultiViewUNetModel(ModelMixin, ConfigMixin): + """ + The full multi-view UNet model with attention, timestep embedding and camera embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + :param camera_dim: dimensionality of camera input. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + transformer_depth=1, + context_dim=None, + n_embed=None, + num_attention_blocks=None, + adm_in_channels=None, + camera_dim=None, + ip_dim=0, # imagedream uses ip_dim > 0 + ip_weight=1.0, + **kwargs, + ): + super().__init__() + assert context_dim is not None + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError( + "provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult" + ) + self.num_res_blocks = num_res_blocks + + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all( + map( + lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], + range(len(num_attention_blocks)), + ) + ) + print( + f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set." + ) + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + self.ip_dim = ip_dim + self.ip_weight = ip_weight + + if self.ip_dim > 0: + self.image_embed = Resampler( + dim=context_dim, + depth=4, + dim_head=64, + heads=12, + num_queries=ip_dim, # num token + embedding_dim=1280, + output_dim=context_dim, + ff_mult=4, + ) + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + nn.Linear(model_channels, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + if camera_dim is not None: + time_embed_dim = model_channels * 4 + self.camera_embed = nn.Sequential( + nn.Linear(camera_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(self.num_classes, time_embed_dim) + elif self.num_classes == "continuous": + # print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + nn.Linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [ + CondSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers: List[Any] = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + if num_attention_blocks is None or nr < num_attention_blocks[level]: + layers.append( + SpatialTransformer3D( + ch, + num_heads, + dim_head, + context_dim=context_dim, + depth=transformer_depth, + ip_dim=self.ip_dim, + ip_weight=self.ip_weight, + ) + ) + self.input_blocks.append(CondSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + CondSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + self.middle_block = CondSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + ), + SpatialTransformer3D( + ch, + num_heads, + dim_head, + context_dim=context_dim, + depth=transformer_depth, + ip_dim=self.ip_dim, + ip_weight=self.ip_weight, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + if num_attention_blocks is None or i < num_attention_blocks[level]: + layers.append( + SpatialTransformer3D( + ch, + num_heads, + dim_head, + context_dim=context_dim, + depth=transformer_depth, + ip_dim=self.ip_dim, + ip_weight=self.ip_weight, + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(CondSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + nn.GroupNorm(32, ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + nn.GroupNorm(32, ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def forward( + self, + x, + timesteps=None, + context=None, + y=None, + camera=None, + num_frames=1, + ip=None, + ip_img=None, + **kwargs, + ): + """ + Apply the model to an input batch. + :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views). + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :param num_frames: a integer indicating number of frames for tensor reshaping. + :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views). + """ + assert ( + x.shape[0] % num_frames == 0 + ), "input batch size must be dividable by num_frames!" + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + + hs = [] + + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) + + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y is not None + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + # Add camera embeddings + if camera is not None: + emb = emb + self.camera_embed(camera) + + # imagedream variant + if self.ip_dim > 0: + x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9] + ip_emb = self.image_embed(ip) + context = torch.cat((context, ip_emb), 1) + + h = x + for module in self.input_blocks: + h = module(h, emb, context, num_frames=num_frames) + hs.append(h) + h = self.middle_block(h, emb, context, num_frames=num_frames) + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context, num_frames=num_frames) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) \ No newline at end of file diff --git a/lgm/mvdream/pipeline_mvdream.py b/lgm/mvdream/pipeline_mvdream.py new file mode 100644 index 0000000000000000000000000000000000000000..3b7b3a72558aa03f77a1626e26d4f87fd830dd90 --- /dev/null +++ b/lgm/mvdream/pipeline_mvdream.py @@ -0,0 +1,559 @@ +import torch +import torch.nn.functional as F +import inspect +import numpy as np +from typing import Callable, List, Optional, Union +from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor +from diffusers import AutoencoderKL, DiffusionPipeline +from diffusers.utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, +) +from diffusers.configuration_utils import FrozenDict +from diffusers.schedulers import DDIMScheduler +from diffusers.utils.torch_utils import randn_tensor + +from mvdream.mv_unet import MultiViewUNetModel, get_camera + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class MVDreamPipeline(DiffusionPipeline): + + _optional_components = ["feature_extractor", "image_encoder"] + + def __init__( + self, + vae: AutoencoderKL, + unet: MultiViewUNetModel, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + scheduler: DDIMScheduler, + # imagedream variant + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModel, + requires_safety_checker: bool = False, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: # type: ignore + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " # type: ignore + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate( + "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False + ) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: # type: ignore + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate( + "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False + ) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + unet=unet, + scheduler=scheduler, + tokenizer=tokenizer, + text_encoder=text_encoder, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError( + "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher" + ) + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError( + "`enable_model_offload` requires `accelerate v0.17.0` or higher." + ) + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook( + cpu_offloaded_model, device, prev_module_hook=hook + ) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance: bool, + negative_prompt=None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError( + f"`prompt` should be either a string or a list of strings, but got {type(prompt)}." + ) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=self.text_encoder.dtype, device=device + ) + + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if image.dtype == np.float32: + image = (image * 255).astype(np.uint8) + + image = self.feature_extractor(image, return_tensors="pt").pixel_values + image = image.to(device=device, dtype=dtype) + + image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + return torch.zeros_like(image_embeds), image_embeds + + def encode_image_latents(self, image, device, num_images_per_prompt): + + dtype = next(self.image_encoder.parameters()).dtype + + image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device) # [1, 3, H, W] + image = 2 * image - 1 + image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False) + image = image.to(dtype=dtype) + + posterior = self.vae.encode(image).latent_dist + latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W] + latents = latents.repeat_interleave(num_images_per_prompt, dim=0) + + return torch.zeros_like(latents), latents + + @torch.no_grad() + def __call__( + self, + prompt: str = "", + image: Optional[np.ndarray] = None, + height: int = 256, + width: int = 256, + elevation: float = 0, + num_inference_steps: int = 50, + guidance_scale: float = 7.0, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "numpy", # pil, numpy, latents + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + num_frames: int = 4, + device=torch.device("cuda:0"), + ): + self.unet = self.unet.to(device=device) + self.vae = self.vae.to(device=device) + self.text_encoder = self.text_encoder.to(device=device) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # imagedream variant + if image is not None: + assert isinstance(image, np.ndarray) and image.dtype == np.float32 + self.image_encoder = self.image_encoder.to(device=device) + image_embeds_neg, image_embeds_pos = self.encode_image(image, device, num_images_per_prompt) + image_latents_neg, image_latents_pos = self.encode_image_latents(image, device, num_images_per_prompt) + + _prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + ) # type: ignore + prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2) + + # Prepare latent variables + actual_num_frames = num_frames if image is None else num_frames + 1 + latents: torch.Tensor = self.prepare_latents( + actual_num_frames * num_images_per_prompt, + 4, + height, + width, + prompt_embeds_pos.dtype, + device, + generator, + None, + ) + + if image is not None: + camera = get_camera(num_frames, elevation=elevation, extra_view=True).to(dtype=latents.dtype, device=device) + else: + camera = get_camera(num_frames, elevation=elevation, extra_view=False).to(dtype=latents.dtype, device=device) + camera = camera.repeat_interleave(num_images_per_prompt, dim=0) + + # Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + multiplier = 2 if do_classifier_free_guidance else 1 + latent_model_input = torch.cat([latents] * multiplier) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + unet_inputs = { + 'x': latent_model_input, + 'timesteps': torch.tensor([t] * actual_num_frames * multiplier, dtype=latent_model_input.dtype, device=device), + 'context': torch.cat([prompt_embeds_neg] * actual_num_frames + [prompt_embeds_pos] * actual_num_frames), + 'num_frames': actual_num_frames, + 'camera': torch.cat([camera] * multiplier), + } + + if image is not None: + unet_inputs['ip'] = torch.cat([image_embeds_neg] * actual_num_frames + [image_embeds_pos] * actual_num_frames) + unet_inputs['ip_img'] = torch.cat([image_latents_neg] + [image_latents_pos]) # no repeat + + # predict the noise residual + noise_pred = self.unet.forward(**unet_inputs) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents: torch.Tensor = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # type: ignore + + # Post-processing + if output_type == "latent": + image = latents + elif output_type == "pil": + image = self.decode_latents(latents) + image = self.numpy_to_pil(image) + else: # numpy + image = self.decode_latents(latents) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + return image \ No newline at end of file diff --git a/main_4d.py b/main_4d.py new file mode 100644 index 0000000000000000000000000000000000000000..92948069d14a617d77b07ed838381a390e3ef34a --- /dev/null +++ b/main_4d.py @@ -0,0 +1,601 @@ +import os +import cv2 +import time +import tqdm +import numpy as np + +import torch +import torch.nn.functional as F + +import rembg + +from cam_utils import orbit_camera, OrbitCamera +from gs_renderer_4d import Renderer, MiniCam + +from grid_put import mipmap_linear_grid_put_2d +import imageio + +import copy + + +class GUI: + def __init__(self, opt): + self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. + self.gui = opt.gui # enable gui + self.W = opt.W + self.H = opt.H + self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy) + + self.mode = "image" + # self.seed = "random" + self.seed = 888 + + self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32) + self.need_update = True # update buffer_image + + # models + self.device = torch.device("cuda") + self.bg_remover = None + + self.guidance_sd = None + self.guidance_zero123 = None + self.guidance_svd = None + + + self.enable_sd = False + self.enable_zero123 = False + self.enable_svd = False + + + # renderer + self.renderer = Renderer(self.opt, sh_degree=self.opt.sh_degree) + self.gaussain_scale_factor = 1 + + # input image + self.input_img = None + self.input_mask = None + self.input_img_torch = None + self.input_mask_torch = None + self.overlay_input_img = False + self.overlay_input_img_ratio = 0.5 + + self.input_img_list = None + self.input_mask_list = None + self.input_img_torch_list = None + self.input_mask_torch_list = None + + # input text + self.prompt = "" + self.negative_prompt = "" + + # training stuff + self.training = False + self.optimizer = None + self.step = 0 + self.train_steps = 1 # steps per rendering loop + + # load input data from cmdline + if self.opt.input is not None: # True + self.load_input(self.opt.input) # load imgs, if has bg, then rm bg; or just load imgs + + # override prompt from cmdline + if self.opt.prompt is not None: # None + self.prompt = self.opt.prompt + + # override if provide a checkpoint + if self.opt.load is not None: # not None + self.renderer.initialize(self.opt.load) + # self.renderer.gaussians.load_model(opt.outdir, opt.save_path) + else: + # initialize gaussians to a blob + self.renderer.initialize(num_pts=self.opt.num_pts) + + self.seed_everything() + + def seed_everything(self): + try: + seed = int(self.seed) + except: + seed = np.random.randint(0, 1000000) + + print(f'Seed: {seed:d}') + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True + + self.last_seed = seed + + def prepare_train(self): + + self.step = 0 + + # setup training + self.renderer.gaussians.training_setup(self.opt) + + # # do not do progressive sh-level + self.renderer.gaussians.active_sh_degree = self.renderer.gaussians.max_sh_degree + self.optimizer = self.renderer.gaussians.optimizer + + # default camera + if self.opt.mvdream or self.opt.imagedream: + # the second view is the front view for mvdream/imagedream. + pose = orbit_camera(self.opt.elevation, 90, self.opt.radius) + else: + pose = orbit_camera(self.opt.elevation, 0, self.opt.radius) + self.fixed_cam = MiniCam( + pose, + self.opt.ref_size, + self.opt.ref_size, + self.cam.fovy, + self.cam.fovx, + self.cam.near, + self.cam.far, + ) + + self.enable_sd = self.opt.lambda_sd > 0 + self.enable_zero123 = self.opt.lambda_zero123 > 0 + self.enable_svd = self.opt.lambda_svd > 0 and self.input_img is not None + + # lazy load guidance model + if self.guidance_sd is None and self.enable_sd: + if self.opt.mvdream: + print(f"[INFO] loading MVDream...") + from guidance.mvdream_utils import MVDream + self.guidance_sd = MVDream(self.device) + print(f"[INFO] loaded MVDream!") + elif self.opt.imagedream: + print(f"[INFO] loading ImageDream...") + from guidance.imagedream_utils import ImageDream + self.guidance_sd = ImageDream(self.device) + print(f"[INFO] loaded ImageDream!") + else: + print(f"[INFO] loading SD...") + from guidance.sd_utils import StableDiffusion + self.guidance_sd = StableDiffusion(self.device) + print(f"[INFO] loaded SD!") + + if self.guidance_zero123 is None and self.enable_zero123: + print(f"[INFO] loading zero123...") + from guidance.zero123_utils import Zero123 + if self.opt.stable_zero123: + self.guidance_zero123 = Zero123(self.device, model_key='ashawkey/stable-zero123-diffusers') + else: + self.guidance_zero123 = Zero123(self.device, model_key='ashawkey/zero123-xl-diffusers') + print(f"[INFO] loaded zero123!") + + if self.guidance_svd is None and self.enable_svd: # False + print(f"[INFO] loading SVD...") + from guidance.svd_utils import StableVideoDiffusion + self.guidance_svd = StableVideoDiffusion(self.device) + print(f"[INFO] loaded SVD!") + + # input image + if self.input_img is not None: + self.input_img_torch = torch.from_numpy(self.input_img).permute(2, 0, 1).unsqueeze(0).to(self.device) + self.input_img_torch = F.interpolate(self.input_img_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False) + + self.input_mask_torch = torch.from_numpy(self.input_mask).permute(2, 0, 1).unsqueeze(0).to(self.device) + self.input_mask_torch = F.interpolate(self.input_mask_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False) + + if self.input_img_list is not None: + self.input_img_torch_list = [torch.from_numpy(input_img).permute(2, 0, 1).unsqueeze(0).to(self.device) for input_img in self.input_img_list] + self.input_img_torch_list = [F.interpolate(input_img_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False) for input_img_torch in self.input_img_torch_list] + + self.input_mask_torch_list = [torch.from_numpy(input_mask).permute(2, 0, 1).unsqueeze(0).to(self.device) for input_mask in self.input_mask_list] + self.input_mask_torch_list = [F.interpolate(input_mask_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False) for input_mask_torch in self.input_mask_torch_list] + # prepare embeddings + with torch.no_grad(): + + if self.enable_sd: + if self.opt.imagedream: + img_pos_list, img_neg_list, ip_pos_list, ip_neg_list, emb_pos_list, emb_neg_list = [], [], [], [], [], [] + for _ in range(self.opt.n_views): + for input_img_torch in self.input_img_torch_list: + img_pos, img_neg, ip_pos, ip_neg, emb_pos, emb_neg = self.guidance_sd.get_image_text_embeds(input_img_torch, [self.prompt], [self.negative_prompt]) + img_pos_list.append(img_pos) + img_neg_list.append(img_neg) + ip_pos_list.append(ip_pos) + ip_neg_list.append(ip_neg) + emb_pos_list.append(emb_pos) + emb_neg_list.append(emb_neg) + self.guidance_sd.image_embeddings['pos'] = torch.cat(img_pos_list, 0) + self.guidance_sd.image_embeddings['neg'] = torch.cat(img_pos_list, 0) + self.guidance_sd.image_embeddings['ip_img'] = torch.cat(ip_pos_list, 0) + self.guidance_sd.image_embeddings['neg_ip_img'] = torch.cat(ip_neg_list, 0) + self.guidance_sd.embeddings['pos'] = torch.cat(emb_pos_list, 0) + self.guidance_sd.embeddings['neg'] = torch.cat(emb_neg_list, 0) + else: + self.guidance_sd.get_text_embeds([self.prompt], [self.negative_prompt]) + + if self.enable_zero123: + c_list, v_list = [], [] + for _ in range(self.opt.n_views): + for input_img_torch in self.input_img_torch_list: + c, v = self.guidance_zero123.get_img_embeds(input_img_torch) + c_list.append(c) + v_list.append(v) + self.guidance_zero123.embeddings = [torch.cat(c_list, 0), torch.cat(v_list, 0)] + + if self.enable_svd: + self.guidance_svd.get_img_embeds(self.input_img) + + def train_step(self): + starter = torch.cuda.Event(enable_timing=True) + ender = torch.cuda.Event(enable_timing=True) + starter.record() + + for _ in range(self.train_steps): # 1 + + self.step += 1 # self.step starts from 0 + step_ratio = min(1, self.step / self.opt.iters) # 1, step / 500 + + # update lr + self.renderer.gaussians.update_learning_rate(self.step) + + loss = 0 + + self.renderer.prepare_render() + + ### known view + if not self.opt.imagedream: + for b_idx in range(self.opt.batch_size): + cur_cam = copy.deepcopy(self.fixed_cam) + cur_cam.time = b_idx + out = self.renderer.render(cur_cam) + + # rgb loss + image = out["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1] + loss = loss + 10000 * step_ratio * F.mse_loss(image, self.input_img_torch_list[b_idx]) / self.opt.batch_size + + # mask loss + mask = out["alpha"].unsqueeze(0) # [1, 1, H, W] in [0, 1] + loss = loss + 1000 * step_ratio * F.mse_loss(mask, self.input_mask_torch_list[b_idx]) / self.opt.batch_size + + ### novel view (manual batch) + render_resolution = 128 if step_ratio < 0.3 else (256 if step_ratio < 0.6 else 512) + # render_resolution = 512 + images = [] + poses = [] + vers, hors, radii = [], [], [] + # avoid too large elevation (> 80 or < -80), and make sure it always cover [-30, 30] + min_ver = max(min(self.opt.min_ver, self.opt.min_ver - self.opt.elevation), -80 - self.opt.elevation) + max_ver = min(max(self.opt.max_ver, self.opt.max_ver - self.opt.elevation), 80 - self.opt.elevation) + + for _ in range(self.opt.n_views): + for b_idx in range(self.opt.batch_size): + + # render random view + ver = np.random.randint(min_ver, max_ver) + hor = np.random.randint(-180, 180) + radius = 0 + + vers.append(ver) + hors.append(hor) + radii.append(radius) + + pose = orbit_camera(self.opt.elevation + ver, hor, self.opt.radius + radius) + poses.append(pose) + + cur_cam = MiniCam(pose, render_resolution, render_resolution, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, time=b_idx) + + bg_color = torch.tensor([1, 1, 1] if np.random.rand() > self.opt.invert_bg_prob else [0, 0, 0], dtype=torch.float32, device="cuda") + out = self.renderer.render(cur_cam, bg_color=bg_color) + + image = out["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1] + images.append(image) + + # enable mvdream training + if self.opt.mvdream or self.opt.imagedream: # False + for view_i in range(1, 4): + pose_i = orbit_camera(self.opt.elevation + ver, hor + 90 * view_i, self.opt.radius + radius) + poses.append(pose_i) + + cur_cam_i = MiniCam(pose_i, render_resolution, render_resolution, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far) + + # bg_color = torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device="cuda") + out_i = self.renderer.render(cur_cam_i, bg_color=bg_color) + + image = out_i["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1] + images.append(image) + + + + images = torch.cat(images, dim=0) + poses = torch.from_numpy(np.stack(poses, axis=0)).to(self.device) + + # guidance loss + if self.enable_sd: + if self.opt.mvdream or self.opt.imagedream: + loss = loss + self.opt.lambda_sd * self.guidance_sd.train_step(images, poses, step_ratio) + else: + loss = loss + self.opt.lambda_sd * self.guidance_sd.train_step(images, step_ratio) + + if self.enable_zero123: + loss = loss + self.opt.lambda_zero123 * self.guidance_zero123.train_step(images, vers, hors, radii, step_ratio) / (self.opt.batch_size * self.opt.n_views) + + if self.enable_svd: + loss = loss + self.opt.lambda_svd * self.guidance_svd.train_step(images, step_ratio) + + # optimize step + loss.backward() + self.optimizer.step() + self.optimizer.zero_grad() + + # densify and prune + if self.step >= self.opt.density_start_iter and self.step <= self.opt.density_end_iter: + viewspace_point_tensor, visibility_filter, radii = out["viewspace_points"], out["visibility_filter"], out["radii"] + self.renderer.gaussians.max_radii2D[visibility_filter] = torch.max(self.renderer.gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) + self.renderer.gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) + + if self.step % self.opt.densification_interval == 0: + # size_threshold = 20 if self.step > self.opt.opacity_reset_interval else None + self.renderer.gaussians.densify_and_prune(self.opt.densify_grad_threshold, min_opacity=0.01, extent=0.5, max_screen_size=1) + + if self.step % self.opt.opacity_reset_interval == 0: + self.renderer.gaussians.reset_opacity() + + ender.record() + torch.cuda.synchronize() + t = starter.elapsed_time(ender) + + self.need_update = True + + + def load_input(self, file): + if self.opt.data_mode == 'c4d': + file_list = [os.path.join(file, f'{x * self.opt.downsample_rate}.png') for x in range(self.opt.batch_size)] + elif self.opt.data_mode == 'svd': + # file_list = [file.replace('.png', f'_frames/{x* self.opt.downsample_rate:03d}_rgba.png') for x in range(self.opt.batch_size)] + # file_list = [x if os.path.exists(x) else (x.replace('_rgba.png', '.png')) for x in file_list] + file_list = [file.replace('.png', f'_frames/{x* self.opt.downsample_rate:03d}.png') for x in range(self.opt.batch_size)] + else: + raise NotImplementedError + self.input_img_list, self.input_mask_list = [], [] + for file in file_list: + # load image + print(f'[INFO] load image from {file}...') + img = cv2.imread(file, cv2.IMREAD_UNCHANGED) + if img.shape[-1] == 3: + if self.bg_remover is None: + self.bg_remover = rembg.new_session() + img = rembg.remove(img, session=self.bg_remover) + # cv2.imwrite(file.replace('.png', '_rgba.png'), img) + img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA) + img = img.astype(np.float32) / 255.0 + input_mask = img[..., 3:] + # white bg + input_img = img[..., :3] * input_mask + (1 - input_mask) + # bgr to rgb + input_img = input_img[..., ::-1].copy() + self.input_img_list.append(input_img) + self.input_mask_list.append(input_mask) + + @torch.no_grad() + def save_model(self, mode='geo', texture_size=1024, interp=1): + os.makedirs(self.opt.outdir, exist_ok=True) + if mode == 'geo': + path = f'logs/{opt.save_path}_mesh_{t:03d}.ply' + mesh = self.renderer.gaussians.extract_mesh_t(path, self.opt.density_thresh, t=t) + mesh.write_ply(path) + + elif mode == 'geo+tex': + from mesh import Mesh, safe_normalize + os.makedirs(os.path.join(self.opt.outdir, self.opt.save_path+'_meshes'), exist_ok=True) + for t in range(self.opt.batch_size): + path = os.path.join(self.opt.outdir, self.opt.save_path+'_meshes', f'{t:03d}.obj') + mesh = self.renderer.gaussians.extract_mesh_t(path, self.opt.density_thresh, t=t) + + # perform texture extraction + print(f"[INFO] unwrap uv...") + h = w = texture_size + mesh.auto_uv() + mesh.auto_normal() + + albedo = torch.zeros((h, w, 3), device=self.device, dtype=torch.float32) + cnt = torch.zeros((h, w, 1), device=self.device, dtype=torch.float32) + + vers = [0] * 8 + [-45] * 8 + [45] * 8 + [-89.9, 89.9] + hors = [0, 45, -45, 90, -90, 135, -135, 180] * 3 + [0, 0] + + render_resolution = 512 + + import nvdiffrast.torch as dr + + if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'): + glctx = dr.RasterizeGLContext() + else: + glctx = dr.RasterizeCudaContext() + + for ver, hor in zip(vers, hors): + # render image + pose = orbit_camera(ver, hor, self.cam.radius) + + cur_cam = MiniCam( + pose, + render_resolution, + render_resolution, + self.cam.fovy, + self.cam.fovx, + self.cam.near, + self.cam.far, + time=t + ) + + cur_out = self.renderer.render(cur_cam) + + rgbs = cur_out["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1] + + # get coordinate in texture image + pose = torch.from_numpy(pose.astype(np.float32)).to(self.device) + proj = torch.from_numpy(self.cam.perspective.astype(np.float32)).to(self.device) + + v_cam = torch.matmul(F.pad(mesh.v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0) + v_clip = v_cam @ proj.T + rast, rast_db = dr.rasterize(glctx, v_clip, mesh.f, (render_resolution, render_resolution)) + + depth, _ = dr.interpolate(-v_cam[..., [2]], rast, mesh.f) # [1, H, W, 1] + depth = depth.squeeze(0) # [H, W, 1] + + alpha = (rast[0, ..., 3:] > 0).float() + + uvs, _ = dr.interpolate(mesh.vt.unsqueeze(0), rast, mesh.ft) # [1, 512, 512, 2] in [0, 1] + + # use normal to produce a back-project mask + normal, _ = dr.interpolate(mesh.vn.unsqueeze(0).contiguous(), rast, mesh.fn) + normal = safe_normalize(normal[0]) + + # rotated normal (where [0, 0, 1] always faces camera) + rot_normal = normal @ pose[:3, :3] + viewcos = rot_normal[..., [2]] + + mask = (alpha > 0) & (viewcos > 0.5) # [H, W, 1] + mask = mask.view(-1) + + uvs = uvs.view(-1, 2).clamp(0, 1)[mask] + rgbs = rgbs.view(3, -1).permute(1, 0)[mask].contiguous() + + # update texture image + cur_albedo, cur_cnt = mipmap_linear_grid_put_2d( + h, w, + uvs[..., [1, 0]] * 2 - 1, + rgbs, + min_resolution=256, + return_count=True, + ) + + mask = cnt.squeeze(-1) < 0.1 + albedo[mask] += cur_albedo[mask] + cnt[mask] += cur_cnt[mask] + + mask = cnt.squeeze(-1) > 0 + albedo[mask] = albedo[mask] / cnt[mask].repeat(1, 3) + + mask = mask.view(h, w) + + albedo = albedo.detach().cpu().numpy() + mask = mask.detach().cpu().numpy() + + # dilate texture + from sklearn.neighbors import NearestNeighbors + from scipy.ndimage import binary_dilation, binary_erosion + + inpaint_region = binary_dilation(mask, iterations=32) + inpaint_region[mask] = 0 + + search_region = mask.copy() + not_search_region = binary_erosion(search_region, iterations=3) + search_region[not_search_region] = 0 + + search_coords = np.stack(np.nonzero(search_region), axis=-1) + inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1) + + knn = NearestNeighbors(n_neighbors=1, algorithm="kd_tree").fit( + search_coords + ) + _, indices = knn.kneighbors(inpaint_coords) + + albedo[tuple(inpaint_coords.T)] = albedo[tuple(search_coords[indices[:, 0]].T)] + + mesh.albedo = torch.from_numpy(albedo).to(self.device) + mesh.write(path) + + + elif mode == 'frames': + os.makedirs(os.path.join(self.opt.outdir, self.opt.save_path+'_frames'), exist_ok=True) + for t in range(self.opt.batch_size * interp): + tt = t / interp + path = os.path.join(self.opt.outdir, self.opt.save_path+'_frames', f'{t:03d}.ply') + self.renderer.gaussians.save_frame_ply(path, tt) + else: + path = os.path.join(self.opt.outdir, self.opt.save_path + '_4d_model.ply') + self.renderer.gaussians.save_ply(path) + self.renderer.gaussians.save_deformation(self.opt.outdir, self.opt.save_path) + + print(f"[INFO] save model to {path}.") + + # no gui mode + def train(self, iters=500, ui=False): + if self.gui: + from visualizer.visergui import ViserViewer + self.viser_gui = ViserViewer(device="cuda", viewer_port=8080) + if iters > 0: + self.prepare_train() + if self.gui: + self.viser_gui.set_renderer(self.renderer, self.fixed_cam) + + for i in tqdm.trange(iters): + self.train_step() + if self.gui: + self.viser_gui.update() + if self.opt.mesh_format == 'frames': + self.save_model(mode='frames', interp=4) + elif self.opt.mesh_format == 'obj': + self.save_model(mode='geo+tex') + + if self.opt.save_model: + self.save_model(mode='model') + + # render eval + image_list =[] + nframes = self.opt.batch_size * 7 + 15 * 7 + hor = 180 + delta_hor = 45 / 15 + delta_time = 1 + for i in range(8): + time = 0 + for j in range(self.opt.batch_size + 15): + pose = orbit_camera(self.opt.elevation, hor-180, self.opt.radius) + cur_cam = MiniCam( + pose, + 512, + 512, + self.cam.fovy, + self.cam.fovx, + self.cam.near, + self.cam.far, + time=time + ) + with torch.no_grad(): + outputs = self.renderer.render(cur_cam) + + out = outputs["image"].cpu().detach().numpy().astype(np.float32) + out = np.transpose(out, (1, 2, 0)) + out = np.uint8(out*255) + image_list.append(out) + + time = (time + delta_time) % self.opt.batch_size + if j >= self.opt.batch_size: + hor = (hor+delta_hor) % 360 + + + imageio.mimwrite(f'vis_data/{opt.save_path}.mp4', image_list, fps=7) + + if self.gui: + while True: + self.viser_gui.update() + +if __name__ == "__main__": + import argparse + from omegaconf import OmegaConf + + parser = argparse.ArgumentParser() + parser.add_argument("--config", required=True, help="path to the yaml config file") + args, extras = parser.parse_known_args() + + # override default config from cli + opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras)) + opt.save_path = os.path.splitext(os.path.basename(opt.input))[0] if opt.save_path == '' else opt.save_path + + + # auto find mesh from stage 1 + opt.load = os.path.join(opt.outdir, opt.save_path + '_model.ply') + + gui = GUI(opt) + + gui.train(opt.iters) + + +# python main_4d.py --config configs/4d_low.yaml input=data/CONSISTENT4D_DATA/in-the-wild/blooming_rose \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b5c573c6571515fe12657c0b357ee33759b7bcf4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,35 @@ +tqdm +rich +ninja +numpy +pandas +scipy +scikit-learn +matplotlib +opencv-python +imageio +imageio-ffmpeg +omegaconf + +torch==2.1.0 --index-url https://download.pytorch.org/whl/cu118 +torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu118 +torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118 +xformer --index-url https://download.pytorch.org/whl/cu118 --no-deps +einops +plyfile +pygltflib +torchvision + +# for stable-diffusion +huggingface_hub +diffusers +accelerate +transformers + +rembg[gpu,cli] + +# gradio demo +gradio +gradio-model4dgs + +-e git+https://github.com/ashawkey/kiuikit.git@main#egg=kiui \ No newline at end of file diff --git a/scene/deformation.py b/scene/deformation.py new file mode 100644 index 0000000000000000000000000000000000000000..972738e328fa9398fe2d7878aac2f59dea19aa4a --- /dev/null +++ b/scene/deformation.py @@ -0,0 +1,241 @@ +import functools +import math +import os +import time +from tkinter import W + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.cpp_extension import load +import torch.nn.init as init +from scene.hexplane import HexPlaneField + + +class Linear_Res(nn.Module): + def __init__(self, W): + super(Linear_Res, self).__init__() + self.main_stream = nn.Linear(W, W) + + def forward(self, x): + x = F.relu(x) + return x + self.main_stream(x) + + +class Head_Res_Net(nn.Module): + def __init__(self, W, H): + super(Head_Res_Net, self).__init__() + self.W = W + self.H = H + + self.feature_out = [Linear_Res(self.W)] + self.feature_out.append(nn.Linear(W, self.H)) + self.feature_out = nn.Sequential(*self.feature_out) + + def initialize_weights(self,): + for m in self.feature_out.modules(): + if isinstance(m, nn.Linear): + init.constant_(m.weight, 0) + if m.bias is not None: + init.constant_(m.bias, 0) + + def forward(self, x): + return self.feature_out(x) + + + +class Deformation(nn.Module): + def __init__(self, D=8, W=256, input_ch=27, input_ch_time=9, skips=[], args=None, use_res=False): + super(Deformation, self).__init__() + self.D = D + self.W = W + self.input_ch = input_ch + self.input_ch_time = input_ch_time + self.skips = skips + + self.no_grid = args.no_grid + self.grid = HexPlaneField(args.bounds, args.kplanes_config, args.multires) + + self.use_res = use_res + if not self.use_res: + self.pos_deform, self.scales_deform, self.rotations_deform, self.opacity_deform = self.create_net() + else: + self.pos_deform, self.scales_deform, self.rotations_deform, self.opacity_deform = self.create_res_net() + self.args = args + + def create_net(self): + + mlp_out_dim = 0 + if self.no_grid: + self.feature_out = [nn.Linear(4,self.W)] + else: + self.feature_out = [nn.Linear(mlp_out_dim + self.grid.feat_dim ,self.W)] + + for i in range(self.D-1): + self.feature_out.append(nn.ReLU()) + self.feature_out.append(nn.Linear(self.W,self.W)) + self.feature_out = nn.Sequential(*self.feature_out) + output_dim = self.W + return \ + nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3)),\ + nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3)),\ + nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 4)), \ + nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 1)) + + def create_res_net(self,): + + mlp_out_dim = 0 + + if self.no_grid: + self.feature_out = [nn.Linear(4,self.W)] + else: + self.feature_out = [nn.Linear(mlp_out_dim + self.grid.feat_dim ,self.W)] + + for i in range(self.D-1): + self.feature_out.append(nn.ReLU()) + self.feature_out.append(nn.Linear(self.W,self.W)) + self.feature_out = nn.Sequential(*self.feature_out) + + output_dim = self.W + return \ + Head_Res_Net(self.W, 3), \ + Head_Res_Net(self.W, 3), \ + Head_Res_Net(self.W, 4), \ + Head_Res_Net(self.W, 1) + + + def query_time(self, rays_pts_emb, scales_emb, rotations_emb, time_emb): + if self.args.no_mlp: + assert not self.no_grid + grid_feature = self.grid(rays_pts_emb[:,:3], time_emb[:,:1]) + h = grid_feature + elif not self.use_res: + if self.no_grid: + h = torch.cat([rays_pts_emb[:,:3],time_emb[:,:1]],-1) + else: + grid_feature = self.grid(rays_pts_emb[:,:3], time_emb[:,:1]) + + h = grid_feature + + h = self.feature_out(h) + else: + if self.no_grid: + h = torch.cat([rays_pts_emb[:,:3],time_emb[:,:1]],-1) + h = self.feature_out(h) + else: + grid_feature = self.grid(rays_pts_emb[:,:3], time_emb[:,:1]) + h = self.feature_out(grid_feature) + return h + + def forward(self, rays_pts_emb, scales_emb=None, rotations_emb=None, opacity = None, time_emb=None): + if time_emb is None: + return self.forward_static(rays_pts_emb[:,:3]) + else: + return self.forward_dynamic(rays_pts_emb, scales_emb, rotations_emb, opacity, time_emb) + + def forward_static(self, rays_pts_emb): + grid_feature = self.grid(rays_pts_emb[:,:3]) + dx = self.static_mlp(grid_feature) + return rays_pts_emb[:, :3] + dx + + def forward_dynamic(self,rays_pts_emb, scales_emb, rotations_emb, opacity_emb, time_emb): + hidden = self.query_time(rays_pts_emb, scales_emb, rotations_emb, time_emb).float() + if self.args.no_mlp: + return hidden[:, :3], hidden[:, 3:6], hidden[:, 6:10], hidden[:, 10:11] + dx = self.pos_deform(hidden) + pts = dx + if self.args.no_ds: + scales = scales_emb[:,:3] + else: + ds = self.scales_deform(hidden) + scales = ds + if self.args.no_dr: + rotations = rotations_emb[:,:4] + else: + dr = self.rotations_deform(hidden) + rotations = dr + if self.args.no_do: + opacity = opacity_emb[:,:1] + else: + do = self.opacity_deform(hidden) + opacity = do + + return pts, scales, rotations, opacity + def get_mlp_parameters(self): + parameter_list = [] + for name, param in self.named_parameters(): + if "grid" not in name: + parameter_list.append(param) + return parameter_list + def get_grid_parameters(self): + return list(self.grid.parameters() ) + + +class deform_network(nn.Module): + def __init__(self, args) : + super(deform_network, self).__init__() + net_width = args.net_width + timebase_pe = args.timebase_pe + defor_depth= args.defor_depth + posbase_pe= args.posebase_pe + scale_rotation_pe = args.scale_rotation_pe + opacity_pe = args.opacity_pe + timenet_width = args.timenet_width + timenet_output = args.timenet_output + times_ch = 2*timebase_pe+1 + self.timenet = nn.Sequential( + nn.Linear(times_ch, timenet_width), nn.ReLU(), + nn.Linear(timenet_width, timenet_output)) + + self.use_res = args.use_res + if self.use_res: + print("Using zero-init and residual") + self.deformation_net = Deformation(W=net_width, D=defor_depth, input_ch=(4+3)+((4+3)*scale_rotation_pe)*2, input_ch_time=timenet_output, args=args, use_res=self.use_res) + self.register_buffer('time_poc', torch.FloatTensor([(2**i) for i in range(timebase_pe)])) + self.register_buffer('pos_poc', torch.FloatTensor([(2**i) for i in range(posbase_pe)])) + self.register_buffer('rotation_scaling_poc', torch.FloatTensor([(2**i) for i in range(scale_rotation_pe)])) + self.register_buffer('opacity_poc', torch.FloatTensor([(2**i) for i in range(opacity_pe)])) + self.apply(initialize_weights) + + if self.use_res: + self.deformation_net.pos_deform.initialize_weights() + self.deformation_net.scales_deform.initialize_weights() + self.deformation_net.rotations_deform.initialize_weights() + self.deformation_net.opacity_deform.initialize_weights() + + + def forward(self, point, scales=None, rotations=None, opacity=None, times_sel=None): + if times_sel is not None: + return self.forward_dynamic(point, scales, rotations, opacity, times_sel) + else: + return self.forward_static(point) + + + def forward_static(self, points): + points = self.deformation_net(points) + return points + def forward_dynamic(self, point, scales=None, rotations=None, opacity=None, times_sel=None): + means3D, scales, rotations, opacity = self.deformation_net( point, + scales, + rotations, + opacity, + times_sel) + return means3D, scales, rotations, opacity + def get_mlp_parameters(self): + return self.deformation_net.get_mlp_parameters() + list(self.timenet.parameters()) + def get_grid_parameters(self): + return self.deformation_net.get_grid_parameters() + + +def initialize_weights(m): + if isinstance(m, nn.Linear): + init.xavier_uniform_(m.weight,gain=1) + if m.bias is not None: + init.xavier_uniform_(m.weight,gain=1) + +def initialize_zeros_weights(m): + if isinstance(m, nn.Linear): + init.constant_(m.weight, 0) + if m.bias is not None: + init.constant_(m.bias, 0) diff --git a/scene/hexplane.py b/scene/hexplane.py new file mode 100644 index 0000000000000000000000000000000000000000..82d44f4bac1ea8b608ab031bcb5f7d0511a5a993 --- /dev/null +++ b/scene/hexplane.py @@ -0,0 +1,182 @@ +import itertools +import logging as log +from typing import Optional, Union, List, Dict, Sequence, Iterable, Collection, Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def get_normalized_directions(directions): + """SH encoding must be in the range [0, 1] + + Args: + directions: batch of directions + """ + return (directions + 1.0) / 2.0 + + +def normalize_aabb(pts, aabb): + return (pts - aabb[0]) * (2.0 / (aabb[1] - aabb[0])) - 1.0 +def grid_sample_wrapper(grid: torch.Tensor, coords: torch.Tensor, align_corners: bool = True) -> torch.Tensor: + grid_dim = coords.shape[-1] + + if grid.dim() == grid_dim + 1: + # no batch dimension present, need to add it + grid = grid.unsqueeze(0) + if coords.dim() == 2: + coords = coords.unsqueeze(0) + + if grid_dim == 2 or grid_dim == 3: + grid_sampler = F.grid_sample + else: + raise NotImplementedError(f"Grid-sample was called with {grid_dim}D data but is only " + f"implemented for 2 and 3D data.") + + coords = coords.view([coords.shape[0]] + [1] * (grid_dim - 1) + list(coords.shape[1:])) + B, feature_dim = grid.shape[:2] + n = coords.shape[-2] + interp = grid_sampler( + grid, # [B, feature_dim, reso, ...] + coords, # [B, 1, ..., n, grid_dim] + align_corners=align_corners, + mode='bilinear', padding_mode='border') + interp = interp.view(B, feature_dim, n).transpose(-1, -2) # [B, n, feature_dim] + interp = interp.squeeze() # [B?, n, feature_dim?] + return interp + +def init_grid_param( + grid_nd: int, + in_dim: int, + out_dim: int, + reso: Sequence[int], + a: float = 0.1, + b: float = 0.5): + assert in_dim == len(reso), "Resolution must have same number of elements as input-dimension" + has_time_planes = in_dim == 4 + assert grid_nd <= in_dim + coo_combs = list(itertools.combinations(range(in_dim), grid_nd)) + grid_coefs = nn.ParameterList() + for ci, coo_comb in enumerate(coo_combs): + new_grid_coef = nn.Parameter(torch.empty( + [1, out_dim] + [reso[cc] for cc in coo_comb[::-1]] + )) + if has_time_planes and 3 in coo_comb: # Initialize time planes to 1 + nn.init.ones_(new_grid_coef) + else: + nn.init.uniform_(new_grid_coef, a=a, b=b) + grid_coefs.append(new_grid_coef) + + return grid_coefs + + +def interpolate_ms_features(pts: torch.Tensor, + ms_grids: Collection[Iterable[nn.Module]], + grid_dimensions: int, + concat_features: bool, + num_levels: Optional[int], + ) -> torch.Tensor: + coo_combs = list(itertools.combinations( + range(pts.shape[-1]), grid_dimensions) + ) + if num_levels is None: + num_levels = len(ms_grids) + multi_scale_interp = [] if concat_features else 0. + grid: nn.ParameterList + for scale_id, grid in enumerate(ms_grids[:num_levels]): + interp_space = 1. + for ci, coo_comb in enumerate(coo_combs): + # interpolate in plane + feature_dim = grid[ci].shape[1] # shape of grid[ci]: 1, out_dim, *reso + interp_out_plane = ( + grid_sample_wrapper(grid[ci], pts[..., coo_comb]) + .view(-1, feature_dim) + ) + # compute product over planes + interp_space = interp_space * interp_out_plane + + # combine over scales + if concat_features: + multi_scale_interp.append(interp_space) + else: + multi_scale_interp = multi_scale_interp + interp_space + + if concat_features: + multi_scale_interp = torch.cat(multi_scale_interp, dim=-1) + return multi_scale_interp + + +class HexPlaneField(nn.Module): + def __init__( + self, + + bounds, + planeconfig, + multires + ) -> None: + super().__init__() + aabb = torch.tensor([[bounds,bounds,bounds], + [-bounds,-bounds,-bounds]]) + self.aabb = nn.Parameter(aabb, requires_grad=False) + self.grid_config = [planeconfig] + self.multiscale_res_multipliers = multires + self.concat_features = True + + # 1. Init planes + self.grids = nn.ModuleList() + self.feat_dim = 0 + for res in self.multiscale_res_multipliers: + # initialize coordinate grid + config = self.grid_config[0].copy() + # Resolution fix: multi-res only on spatial planes + config["resolution"] = [ + r * res for r in config["resolution"][:3] + ] + config["resolution"][3:] + gp = init_grid_param( + grid_nd=config["grid_dimensions"], + in_dim=config["input_coordinate_dim"], + out_dim=config["output_coordinate_dim"], + reso=config["resolution"], + ) + # shape[1] is out-dim - Concatenate over feature len for each scale + if self.concat_features: + self.feat_dim += gp[-1].shape[1] + else: + self.feat_dim = gp[-1].shape[1] + self.grids.append(gp) + # print(f"Initialized model grids: {self.grids}") + print("feature_dim:",self.feat_dim) + + + def set_aabb(self,xyz_max, xyz_min): + aabb = torch.tensor([ + xyz_max, + xyz_min + ]) + self.aabb = nn.Parameter(aabb,requires_grad=True) + print("Voxel Plane: set aabb=",self.aabb) + + def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None): + """Computes and returns the densities.""" + + pts = normalize_aabb(pts, self.aabb) + pts = torch.cat((pts, timestamps), dim=-1) # [n_rays, n_samples, 4] + + pts = pts.reshape(-1, pts.shape[-1]) + features = interpolate_ms_features( + pts, ms_grids=self.grids, # noqa + grid_dimensions=self.grid_config[0]["grid_dimensions"], + concat_features=self.concat_features, num_levels=None) + if len(features) < 1: + features = torch.zeros((0, 1)).to(features.device) + + + return features + + def forward(self, + pts: torch.Tensor, + timestamps: Optional[torch.Tensor] = None): + + features = self.get_density(pts, timestamps) + + return features diff --git a/scene/regulation.py b/scene/regulation.py new file mode 100644 index 0000000000000000000000000000000000000000..80583a32519ca929cd70e1aa360e85443f561139 --- /dev/null +++ b/scene/regulation.py @@ -0,0 +1,176 @@ +import abc +import os +from typing import Sequence + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.optim.lr_scheduler +from torch import nn + + + +def compute_plane_tv(t): + batch_size, c, h, w = t.shape + count_h = batch_size * c * (h - 1) * w + count_w = batch_size * c * h * (w - 1) + h_tv = torch.square(t[..., 1:, :] - t[..., :h-1, :]).sum() + w_tv = torch.square(t[..., :, 1:] - t[..., :, :w-1]).sum() + return 2 * (h_tv / count_h + w_tv / count_w) # This is summing over batch and c instead of avg + + +def compute_plane_smoothness(t): + batch_size, c, h, w = t.shape + # Convolve with a second derivative filter, in the time dimension which is dimension 2 + first_difference = t[..., 1:, :] - t[..., :h-1, :] # [batch, c, h-1, w] + second_difference = first_difference[..., 1:, :] - first_difference[..., :h-2, :] # [batch, c, h-2, w] + # Take the L2 norm of the result + return torch.square(second_difference).mean() + + +class Regularizer(): + def __init__(self, reg_type, initialization): + self.reg_type = reg_type + self.initialization = initialization + self.weight = float(self.initialization) + self.last_reg = None + + def step(self, global_step): + pass + + def report(self, d): + if self.last_reg is not None: + d[self.reg_type].update(self.last_reg.item()) + + def regularize(self, *args, **kwargs) -> torch.Tensor: + out = self._regularize(*args, **kwargs) * self.weight + self.last_reg = out.detach() + return out + + @abc.abstractmethod + def _regularize(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError() + + def __str__(self): + return f"Regularizer({self.reg_type}, weight={self.weight})" + + +class PlaneTV(Regularizer): + def __init__(self, initial_value, what: str = 'field'): + if what not in {'field', 'proposal_network'}: + raise ValueError(f'what must be one of "field" or "proposal_network" ' + f'but {what} was passed.') + name = f'planeTV-{what[:2]}' + super().__init__(name, initial_value) + self.what = what + + def step(self, global_step): + pass + + def _regularize(self, model, **kwargs): + multi_res_grids: Sequence[nn.ParameterList] + if self.what == 'field': + multi_res_grids = model.field.grids + elif self.what == 'proposal_network': + multi_res_grids = [p.grids for p in model.proposal_networks] + else: + raise NotImplementedError(self.what) + total = 0 + # Note: input to compute_plane_tv should be of shape [batch_size, c, h, w] + for grids in multi_res_grids: + if len(grids) == 3: + spatial_grids = [0, 1, 2] + else: + spatial_grids = [0, 1, 3] # These are the spatial grids; the others are spatiotemporal + for grid_id in spatial_grids: + total += compute_plane_tv(grids[grid_id]) + for grid in grids: + # grid: [1, c, h, w] + total += compute_plane_tv(grid) + return total + + +class TimeSmoothness(Regularizer): + def __init__(self, initial_value, what: str = 'field'): + if what not in {'field', 'proposal_network'}: + raise ValueError(f'what must be one of "field" or "proposal_network" ' + f'but {what} was passed.') + name = f'time-smooth-{what[:2]}' + super().__init__(name, initial_value) + self.what = what + + def _regularize(self, model, **kwargs) -> torch.Tensor: + multi_res_grids: Sequence[nn.ParameterList] + if self.what == 'field': + multi_res_grids = model.field.grids + elif self.what == 'proposal_network': + multi_res_grids = [p.grids for p in model.proposal_networks] + else: + raise NotImplementedError(self.what) + total = 0 + # model.grids is 6 x [1, rank * F_dim, reso, reso] + for grids in multi_res_grids: + if len(grids) == 3: + time_grids = [] + else: + time_grids = [2, 4, 5] + for grid_id in time_grids: + total += compute_plane_smoothness(grids[grid_id]) + return torch.as_tensor(total) + + + +class L1ProposalNetwork(Regularizer): + def __init__(self, initial_value): + super().__init__('l1-proposal-network', initial_value) + + def _regularize(self, model, **kwargs) -> torch.Tensor: + grids = [p.grids for p in model.proposal_networks] + total = 0.0 + for pn_grids in grids: + for grid in pn_grids: + total += torch.abs(grid).mean() + return torch.as_tensor(total) + + +class DepthTV(Regularizer): + def __init__(self, initial_value): + super().__init__('tv-depth', initial_value) + + def _regularize(self, model, model_out, **kwargs) -> torch.Tensor: + depth = model_out['depth'] + tv = compute_plane_tv( + depth.reshape(64, 64)[None, None, :, :] + ) + return tv + + +class L1TimePlanes(Regularizer): + def __init__(self, initial_value, what='field'): + if what not in {'field', 'proposal_network'}: + raise ValueError(f'what must be one of "field" or "proposal_network" ' + f'but {what} was passed.') + super().__init__(f'l1-time-{what[:2]}', initial_value) + self.what = what + + def _regularize(self, model, **kwargs) -> torch.Tensor: + # model.grids is 6 x [1, rank * F_dim, reso, reso] + multi_res_grids: Sequence[nn.ParameterList] + if self.what == 'field': + multi_res_grids = model.field.grids + elif self.what == 'proposal_network': + multi_res_grids = [p.grids for p in model.proposal_networks] + else: + raise NotImplementedError(self.what) + + total = 0.0 + for grids in multi_res_grids: + if len(grids) == 3: + continue + else: + # These are the spatiotemporal grids + spatiotemporal_grids = [2, 4, 5] + for grid_id in spatiotemporal_grids: + total += torch.abs(1 - grids[grid_id]).mean() + return torch.as_tensor(total) + diff --git a/scene/utils.py b/scene/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d6edf7e576f140a3c8f230c6d8271d278dc2151b --- /dev/null +++ b/scene/utils.py @@ -0,0 +1,429 @@ +import copy +import json +import math +import os +import pathlib +from typing import Any, Callable, List, Optional, Text, Tuple, Union + +import numpy as np +import scipy.signal +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + + +PRNGKey = Any +Shape = Tuple[int] +Dtype = Any # this could be a real type? +Array = Any +Activation = Callable[[Array], Array] +Initializer = Callable[[PRNGKey, Shape, Dtype], Array] +Normalizer = Callable[[], Callable[[Array], Array]] +PathType = Union[Text, pathlib.PurePosixPath] + +from pathlib import PurePosixPath as GPath + + +def _compute_residual_and_jacobian( + x: np.ndarray, + y: np.ndarray, + xd: np.ndarray, + yd: np.ndarray, + k1: float = 0.0, + k2: float = 0.0, + k3: float = 0.0, + p1: float = 0.0, + p2: float = 0.0, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, + np.ndarray]: + """Auxiliary function of radial_and_tangential_undistort().""" + + r = x * x + y * y + d = 1.0 + r * (k1 + r * (k2 + k3 * r)) + + fx = d * x + 2 * p1 * x * y + p2 * (r + 2 * x * x) - xd + fy = d * y + 2 * p2 * x * y + p1 * (r + 2 * y * y) - yd + + # Compute derivative of d over [x, y] + d_r = (k1 + r * (2.0 * k2 + 3.0 * k3 * r)) + d_x = 2.0 * x * d_r + d_y = 2.0 * y * d_r + + # Compute derivative of fx over x and y. + fx_x = d + d_x * x + 2.0 * p1 * y + 6.0 * p2 * x + fx_y = d_y * x + 2.0 * p1 * x + 2.0 * p2 * y + + # Compute derivative of fy over x and y. + fy_x = d_x * y + 2.0 * p2 * y + 2.0 * p1 * x + fy_y = d + d_y * y + 2.0 * p2 * x + 6.0 * p1 * y + + return fx, fy, fx_x, fx_y, fy_x, fy_y + + +def _radial_and_tangential_undistort( + xd: np.ndarray, + yd: np.ndarray, + k1: float = 0, + k2: float = 0, + k3: float = 0, + p1: float = 0, + p2: float = 0, + eps: float = 1e-9, + max_iterations=10) -> Tuple[np.ndarray, np.ndarray]: + """Computes undistorted (x, y) from (xd, yd).""" + # Initialize from the distorted point. + x = xd.copy() + y = yd.copy() + + for _ in range(max_iterations): + fx, fy, fx_x, fx_y, fy_x, fy_y = _compute_residual_and_jacobian( + x=x, y=y, xd=xd, yd=yd, k1=k1, k2=k2, k3=k3, p1=p1, p2=p2) + denominator = fy_x * fx_y - fx_x * fy_y + x_numerator = fx * fy_y - fy * fx_y + y_numerator = fy * fx_x - fx * fy_x + step_x = np.where( + np.abs(denominator) > eps, x_numerator / denominator, + np.zeros_like(denominator)) + step_y = np.where( + np.abs(denominator) > eps, y_numerator / denominator, + np.zeros_like(denominator)) + + x = x + step_x + y = y + step_y + + return x, y + + +class Camera: + """Class to handle camera geometry.""" + + def __init__(self, + orientation: np.ndarray, + position: np.ndarray, + focal_length: Union[np.ndarray, float], + principal_point: np.ndarray, + image_size: np.ndarray, + skew: Union[np.ndarray, float] = 0.0, + pixel_aspect_ratio: Union[np.ndarray, float] = 1.0, + radial_distortion: Optional[np.ndarray] = None, + tangential_distortion: Optional[np.ndarray] = None, + dtype=np.float32): + """Constructor for camera class.""" + if radial_distortion is None: + radial_distortion = np.array([0.0, 0.0, 0.0], dtype) + if tangential_distortion is None: + tangential_distortion = np.array([0.0, 0.0], dtype) + + self.orientation = np.array(orientation, dtype) + self.position = np.array(position, dtype) + self.focal_length = np.array(focal_length, dtype) + self.principal_point = np.array(principal_point, dtype) + self.skew = np.array(skew, dtype) + self.pixel_aspect_ratio = np.array(pixel_aspect_ratio, dtype) + self.radial_distortion = np.array(radial_distortion, dtype) + self.tangential_distortion = np.array(tangential_distortion, dtype) + self.image_size = np.array(image_size, np.uint32) + self.dtype = dtype + + @classmethod + def from_json(cls, path: PathType): + """Loads a JSON camera into memory.""" + path = GPath(path) + # with path.open('r') as fp: + with open(path, 'r') as fp: + camera_json = json.load(fp) + + # Fix old camera JSON. + if 'tangential' in camera_json: + camera_json['tangential_distortion'] = camera_json['tangential'] + + return cls( + orientation=np.asarray(camera_json['orientation']), + position=np.asarray(camera_json['position']), + focal_length=camera_json['focal_length'], + principal_point=np.asarray(camera_json['principal_point']), + skew=camera_json['skew'], + pixel_aspect_ratio=camera_json['pixel_aspect_ratio'], + radial_distortion=np.asarray(camera_json['radial_distortion']), + tangential_distortion=np.asarray(camera_json['tangential_distortion']), + image_size=np.asarray(camera_json['image_size']), + ) + + def to_json(self): + return { + k: (v.tolist() if hasattr(v, 'tolist') else v) + for k, v in self.get_parameters().items() + } + + def get_parameters(self): + return { + 'orientation': self.orientation, + 'position': self.position, + 'focal_length': self.focal_length, + 'principal_point': self.principal_point, + 'skew': self.skew, + 'pixel_aspect_ratio': self.pixel_aspect_ratio, + 'radial_distortion': self.radial_distortion, + 'tangential_distortion': self.tangential_distortion, + 'image_size': self.image_size, + } + + @property + def scale_factor_x(self): + return self.focal_length + + @property + def scale_factor_y(self): + return self.focal_length * self.pixel_aspect_ratio + + @property + def principal_point_x(self): + return self.principal_point[0] + + @property + def principal_point_y(self): + return self.principal_point[1] + + @property + def has_tangential_distortion(self): + return any(self.tangential_distortion != 0.0) + + @property + def has_radial_distortion(self): + return any(self.radial_distortion != 0.0) + + @property + def image_size_y(self): + return self.image_size[1] + + @property + def image_size_x(self): + return self.image_size[0] + + @property + def image_shape(self): + return self.image_size_y, self.image_size_x + + @property + def optical_axis(self): + return self.orientation[2, :] + + @property + def translation(self): + return -np.matmul(self.orientation, self.position) + + def pixel_to_local_rays(self, pixels: np.ndarray): + """Returns the local ray directions for the provided pixels.""" + y = ((pixels[..., 1] - self.principal_point_y) / self.scale_factor_y) + x = ((pixels[..., 0] - self.principal_point_x - y * self.skew) / + self.scale_factor_x) + + if self.has_radial_distortion or self.has_tangential_distortion: + x, y = _radial_and_tangential_undistort( + x, + y, + k1=self.radial_distortion[0], + k2=self.radial_distortion[1], + k3=self.radial_distortion[2], + p1=self.tangential_distortion[0], + p2=self.tangential_distortion[1]) + + dirs = np.stack([x, y, np.ones_like(x)], axis=-1) + return dirs / np.linalg.norm(dirs, axis=-1, keepdims=True) + + def pixels_to_rays(self, pixels: np.ndarray) -> np.ndarray: + """Returns the rays for the provided pixels. + + Args: + pixels: [A1, ..., An, 2] tensor or np.array containing 2d pixel positions. + + Returns: + An array containing the normalized ray directions in world coordinates. + """ + if pixels.shape[-1] != 2: + raise ValueError('The last dimension of pixels must be 2.') + if pixels.dtype != self.dtype: + raise ValueError(f'pixels dtype ({pixels.dtype!r}) must match camera ' + f'dtype ({self.dtype!r})') + + batch_shape = pixels.shape[:-1] + pixels = np.reshape(pixels, (-1, 2)) + + local_rays_dir = self.pixel_to_local_rays(pixels) + rays_dir = np.matmul(self.orientation.T, local_rays_dir[..., np.newaxis]) + rays_dir = np.squeeze(rays_dir, axis=-1) + + # Normalize rays. + rays_dir /= np.linalg.norm(rays_dir, axis=-1, keepdims=True) + rays_dir = rays_dir.reshape((*batch_shape, 3)) + return rays_dir + + def pixels_to_points(self, pixels: np.ndarray, depth: np.ndarray): + rays_through_pixels = self.pixels_to_rays(pixels) + cosa = np.matmul(rays_through_pixels, self.optical_axis) + points = ( + rays_through_pixels * depth[..., np.newaxis] / cosa[..., np.newaxis] + + self.position) + return points + + def points_to_local_points(self, points: np.ndarray): + translated_points = points - self.position + local_points = (np.matmul(self.orientation, translated_points.T)).T + return local_points + + def project(self, points: np.ndarray): + """Projects a 3D point (x,y,z) to a pixel position (x,y).""" + batch_shape = points.shape[:-1] + points = points.reshape((-1, 3)) + local_points = self.points_to_local_points(points) + + # Get normalized local pixel positions. + x = local_points[..., 0] / local_points[..., 2] + y = local_points[..., 1] / local_points[..., 2] + r2 = x**2 + y**2 + + # Apply radial distortion. + distortion = 1.0 + r2 * ( + self.radial_distortion[0] + r2 * + (self.radial_distortion[1] + self.radial_distortion[2] * r2)) + + # Apply tangential distortion. + x_times_y = x * y + x = ( + x * distortion + 2.0 * self.tangential_distortion[0] * x_times_y + + self.tangential_distortion[1] * (r2 + 2.0 * x**2)) + y = ( + y * distortion + 2.0 * self.tangential_distortion[1] * x_times_y + + self.tangential_distortion[0] * (r2 + 2.0 * y**2)) + + # Map the distorted ray to the image plane and return the depth. + pixel_x = self.focal_length * x + self.skew * y + self.principal_point_x + pixel_y = (self.focal_length * self.pixel_aspect_ratio * y + + self.principal_point_y) + + pixels = np.stack([pixel_x, pixel_y], axis=-1) + return pixels.reshape((*batch_shape, 2)) + + def get_pixel_centers(self): + """Returns the pixel centers.""" + xx, yy = np.meshgrid(np.arange(self.image_size_x, dtype=self.dtype), + np.arange(self.image_size_y, dtype=self.dtype)) + return np.stack([xx, yy], axis=-1) + 0.5 + + def scale(self, scale: float): + """Scales the camera.""" + if scale <= 0: + raise ValueError('scale needs to be positive.') + + new_camera = Camera( + orientation=self.orientation.copy(), + position=self.position.copy(), + focal_length=self.focal_length * scale, + principal_point=self.principal_point.copy() * scale, + skew=self.skew, + pixel_aspect_ratio=self.pixel_aspect_ratio, + radial_distortion=self.radial_distortion.copy(), + tangential_distortion=self.tangential_distortion.copy(), + image_size=np.array((int(round(self.image_size[0] * scale)), + int(round(self.image_size[1] * scale)))), + ) + return new_camera + + def look_at(self, position, look_at, up, eps=1e-6): + """Creates a copy of the camera which looks at a given point. + + Copies the provided vision_sfm camera and returns a new camera that is + positioned at `camera_position` while looking at `look_at_position`. + Camera intrinsics are copied by this method. A common value for the + up_vector is (0, 1, 0). + + Args: + position: A (3,) numpy array representing the position of the camera. + look_at: A (3,) numpy array representing the location the camera + looks at. + up: A (3,) numpy array representing the up direction, whose + projection is parallel to the y-axis of the image plane. + eps: a small number to prevent divides by zero. + + Returns: + A new camera that is copied from the original but is positioned and + looks at the provided coordinates. + + Raises: + ValueError: If the camera position and look at position are very close + to each other or if the up-vector is parallel to the requested optical + axis. + """ + + look_at_camera = self.copy() + optical_axis = look_at - position + norm = np.linalg.norm(optical_axis) + if norm < eps: + raise ValueError('The camera center and look at position are too close.') + optical_axis /= norm + + right_vector = np.cross(optical_axis, up) + norm = np.linalg.norm(right_vector) + if norm < eps: + raise ValueError('The up-vector is parallel to the optical axis.') + right_vector /= norm + + # The three directions here are orthogonal to each other and form a right + # handed coordinate system. + camera_rotation = np.identity(3) + camera_rotation[0, :] = right_vector + camera_rotation[1, :] = np.cross(optical_axis, right_vector) + camera_rotation[2, :] = optical_axis + + look_at_camera.position = position + look_at_camera.orientation = camera_rotation + return look_at_camera + + def crop_image_domain( + self, left: int = 0, right: int = 0, top: int = 0, bottom: int = 0): + """Returns a copy of the camera with adjusted image bounds. + + Args: + left: number of pixels by which to reduce (or augment, if negative) the + image domain at the associated boundary. + right: likewise. + top: likewise. + bottom: likewise. + + The crop parameters may not cause the camera image domain dimensions to + become non-positive. + + Returns: + A camera with adjusted image dimensions. The focal length is unchanged, + and the principal point is updated to preserve the original principal + axis. + """ + + crop_left_top = np.array([left, top]) + crop_right_bottom = np.array([right, bottom]) + new_resolution = self.image_size - crop_left_top - crop_right_bottom + new_principal_point = self.principal_point - crop_left_top + if np.any(new_resolution <= 0): + raise ValueError('Crop would result in non-positive image dimensions.') + + new_camera = self.copy() + new_camera.image_size = np.array([int(new_resolution[0]), + int(new_resolution[1])]) + new_camera.principal_point = np.array([new_principal_point[0], + new_principal_point[1]]) + return new_camera + + def copy(self): + return copy.deepcopy(self) + + +''' Misc +''' +mse2psnr = lambda x : -10. * torch.log10(x) +to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) + + + +''' Checkpoint utils +''' \ No newline at end of file diff --git a/scripts/add_bg_to_gt.py b/scripts/add_bg_to_gt.py new file mode 100644 index 0000000000000000000000000000000000000000..16cf1400ca8d75ea145a1cc9b44aac92c1ddaead --- /dev/null +++ b/scripts/add_bg_to_gt.py @@ -0,0 +1,18 @@ +import os +import cv2 + +os.makedirs('data/CONSISTENT4D_DATA/test_dataset/eval_gt_rgb', exist_ok=True) +file_list = [] +for img_name in ['aurorus', 'crocodile', 'guppie', 'monster', 'pistol', 'skull', 'trump']: + os.makedirs(f'data/CONSISTENT4D_DATA/test_dataset/eval_gt_rgb/{img_name}', exist_ok=True) + for view in range(4): + os.makedirs(f'datdata/CONSISTENT4D_DATAa/test_dataset/eval_gt_rgb/{img_name}/eval_{view}', exist_ok=True) + for t in range(32): + file_list.append(f'data/CONSISTENT4D_DATA/test_dataset/eval_gt/{img_name}/eval_{view}/{t}.png') +for file in file_list: + img = cv2.imread(file, cv2.IMREAD_UNCHANGED) + input_mask = img[..., 3:] + input_mask = input_mask / 255. + input_img = img[..., :3] * input_mask + (1 - input_mask) * 255 + fpath = file.replace('eval_gt', 'eval_gt_rgb') + cv2.imwrite(fpath, input_img) \ No newline at end of file diff --git a/scripts/convert_obj_to_video.py b/scripts/convert_obj_to_video.py new file mode 100644 index 0000000000000000000000000000000000000000..cf26559e918713cae01a98bb43b39366759e3f29 --- /dev/null +++ b/scripts/convert_obj_to_video.py @@ -0,0 +1,20 @@ +import os +import glob +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('--dir', default='logs', type=str, help='Directory where obj files are stored') +parser.add_argument('--out', default='videos', type=str, help='Directory where videos will be saved') +args = parser.parse_args() + +out = args.out +os.makedirs(out, exist_ok=True) + +files = glob.glob(f'{args.dir}/*.obj') +for f in files: + name = os.path.basename(f) + # first stage model, ignore + if name.endswith('_mesh.obj'): + continue + print(f'[INFO] process {name}') + os.system(f"python -m kiui.render {f} --save_video {os.path.join(out, name.replace('.obj', '.mp4'))} ") \ No newline at end of file diff --git a/scripts/gen_vid.py b/scripts/gen_vid.py new file mode 100644 index 0000000000000000000000000000000000000000..b065990136e038cb9dc76bec3098dd50be119e0d --- /dev/null +++ b/scripts/gen_vid.py @@ -0,0 +1,117 @@ +import torch + +from diffusers import StableVideoDiffusionPipeline + +from PIL import Image +import numpy as np + +import cv2 +import rembg + +import argparse +import imageio +import os + +def add_margin(pil_img, top, right, bottom, left, color): + width, height = pil_img.size + new_width = width + right + left + new_height = height + top + bottom + result = Image.new(pil_img.mode, (new_width, new_height), color) + result.paste(pil_img, (left, top)) + return result + +def resize_image(image, output_size=(1024, 576)): + image = image.resize((output_size[1],output_size[1])) + pad_size = (output_size[0]-output_size[1]) //2 + image = add_margin(image, 0, pad_size, 0, pad_size, tuple(np.array(image)[0,0])) + return image + + +def load_image(file, W, H, bg='white'): + # load image + print(f'[INFO] load image from {file}...') + img = cv2.imread(file, cv2.IMREAD_UNCHANGED) + bg_remover = rembg.new_session() + img = rembg.remove(img, session=bg_remover) + img = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) + img = img.astype(np.float32) / 255.0 + input_mask = img[..., 3:] + # white bg + if bg == 'white': + input_img = img[..., :3] * input_mask + (1 - input_mask) + elif bg == 'black': + input_img = img[..., :3] + else: + raise NotImplementedError + # bgr to rgb + input_img = input_img[..., ::-1].copy() + input_img = Image.fromarray(np.uint8(input_img*255)) + return input_img + +def load_image_w_bg(file, W, H): + # load image + print(f'[INFO] load image from {file}...') + img = cv2.imread(file, cv2.IMREAD_UNCHANGED) + img = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) + img = img.astype(np.float32) / 255.0 + input_img = img[..., :3] + # bgr to rgb + input_img = input_img[..., ::-1].copy() + input_img = Image.fromarray(np.uint8(input_img*255)) + return input_img + +def gen_vid(input_path, seed, bg, is_pad): + name = input_path.split('/')[-1].split('.')[0] + input_dir = os.path.dirname(input_path) + pipe = StableVideoDiffusionPipeline.from_pretrained( + "stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16" + ) + # pipe.enable_model_cpu_offload() + pipe.to("cuda") + + if is_pad: + height, width = 576, 1024 + else: + height, width = 512, 512 + + if seed is None: + for bg in ['white', 'black', 'orig']: + if bg == 'orig': + if 'rgba' in name: + continue + image = load_image_w_bg(input_path, width, height) + else: + image = load_image(input_path, width, height, bg) + if is_pad: + image = resize_image(image, output_size=(width, height)) + for seed in range(20): + generator = torch.manual_seed(seed) + frames = pipe(image, height, width, generator=generator).frames[0] + imageio.mimwrite(f"{input_dir}/videos/{name}_{bg}_{seed:03}.mp4", frames, fps=7) + else: + if bg == 'orig': + if 'rgba' in name: + raise ValueError + image = load_image_w_bg(input_path, width, height) + else: + image = load_image(input_path, width, height, bg) + if is_pad: + image = resize_image(image, output_size=(width, height)) + generator = torch.manual_seed(seed) + frames = pipe(image, height, width, generator=generator).frames[0] + + imageio.mimwrite(f"{input_dir}/{name}_generated.mp4", frames, fps=7) + os.makedirs(f"{input_dir}/{name}_frames", exist_ok=True) + for idx, img in enumerate(frames): + if is_pad: + img = img.crop(((width-height) //2, 0, width - (width-height) //2, height)) + img.save(f"{input_dir}/{name}_frames/{idx:03}.png") + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--path", type=str, required=True) + parser.add_argument("--seed", type=int, default=None) + parser.add_argument("--bg", type=str, default='white') + parser.add_argument("--is_pad", type=bool, default=False) + args, extras = parser.parse_known_args() + gen_vid(args.path, args.seed, args.bg, args.is_pad) diff --git a/scripts/process.py b/scripts/process.py new file mode 100644 index 0000000000000000000000000000000000000000..e867db93bf7c7ab76a347f3e45522f70d47b0fed --- /dev/null +++ b/scripts/process.py @@ -0,0 +1,92 @@ +import os +import glob +import sys +import cv2 +import argparse +import numpy as np +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +from PIL import Image +import rembg + +class BLIP2(): + def __init__(self, device='cuda'): + self.device = device + from transformers import AutoProcessor, Blip2ForConditionalGeneration + self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") + self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16).to(device) + + @torch.no_grad() + def __call__(self, image): + image = Image.fromarray(image) + inputs = self.processor(image, return_tensors="pt").to(self.device, torch.float16) + + generated_ids = self.model.generate(**inputs, max_new_tokens=20) + generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() + + return generated_text + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('path', type=str, help="path to image (png, jpeg, etc.)") + parser.add_argument('--model', default='u2net', type=str, help="rembg model, see https://github.com/danielgatis/rembg#models") + parser.add_argument('--size', default=256, type=int, help="output resolution") + parser.add_argument('--border_ratio', default=0.2, type=float, help="output border ratio") + parser.add_argument('--recenter', type=bool, default=True, help="recenter, potentially not helpful for multiview zero123") + opt = parser.parse_args() + + session = rembg.new_session(model_name=opt.model) + + if os.path.isdir(opt.path): + print(f'[INFO] processing directory {opt.path}...') + files = glob.glob(f'{opt.path}/*') + out_dir = opt.path + else: # isfile + files = [opt.path] + out_dir = os.path.dirname(opt.path) + + for file in files: + + out_base = os.path.basename(file).split('.')[0] + out_rgba = os.path.join(out_dir, out_base + '_rgba.png') + + # load image + print(f'[INFO] loading image {file}...') + image = cv2.imread(file, cv2.IMREAD_UNCHANGED) + + # carve background + print(f'[INFO] background removal...') + carved_image = rembg.remove(image, session=session) # [H, W, 4] + mask = carved_image[..., -1] > 0 + + # recenter + if opt.recenter: + print(f'[INFO] recenter...') + final_rgba = np.zeros((opt.size, opt.size, 4), dtype=np.uint8) + + coords = np.nonzero(mask) + x_min, x_max = coords[0].min(), coords[0].max() + y_min, y_max = coords[1].min(), coords[1].max() + h = x_max - x_min + w = y_max - y_min + desired_size = int(opt.size * (1 - opt.border_ratio)) + scale = desired_size / max(h, w) + h2 = int(h * scale) + w2 = int(w * scale) + x2_min = (opt.size - h2) // 2 + x2_max = x2_min + h2 + y2_min = (opt.size - w2) // 2 + y2_max = y2_min + w2 + final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(carved_image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA) + + else: + final_rgba = carved_image + + # write image + cv2.imwrite(out_rgba, final_rgba) \ No newline at end of file diff --git a/scripts/runall.py b/scripts/runall.py new file mode 100644 index 0000000000000000000000000000000000000000..0e40bd34663abf6c3d02ab802bd60681d537cef1 --- /dev/null +++ b/scripts/runall.py @@ -0,0 +1,48 @@ +import os +import glob +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('--dir', default='data', type=str, help='Directory where processed images are stored') +parser.add_argument('--out', default='logs', type=str, help='Directory where obj files will be saved') +parser.add_argument('--video-out', default='videos', type=str, help='Directory where videos will be saved') +parser.add_argument('--gpu', default=0, type=int, help='ID of GPU to use') +parser.add_argument('--elevation', default=0, type=int, help='Elevation angle of view in degrees') +parser.add_argument('--config', default='configs', type=str, help='Path to config directory, which contains image.yaml') +args = parser.parse_args() + +files = glob.glob(f'{args.dir}/*_rgba.png') +configs_dir = args.config + +# check if image.yaml exists +if not os.path.exists(os.path.join(configs_dir, 'image.yaml')): + raise FileNotFoundError( + f'image.yaml not found in {configs_dir} directory. Please check if the directory is correct.' + ) + +# create output directories if not exists +out_dir = args.out +os.makedirs(out_dir, exist_ok=True) +video_dir = args.video_out +os.makedirs(video_dir, exist_ok=True) + + +for file in files: + name = os.path.basename(file).replace("_rgba.png", "") + print(f'======== processing {name} ========') + # first stage + os.system(f'CUDA_VISIBLE_DEVICES={args.gpu} python main.py ' + f'--config {configs_dir}/image.yaml ' + f'input={file} ' + f'save_path={name} elevation={args.elevation}') + # second stage + os.system(f'CUDA_VISIBLE_DEVICES={args.gpu} python main2.py ' + f'--config {configs_dir}/image.yaml ' + f'input={file} ' + f'save_path={name} elevation={args.elevation}') + # export video + mesh_path = os.path.join(out_dir, f'{name}.obj') + os.system(f'python -m kiui.render {mesh_path} ' + f'--save_video {video_dir}/{name}.mp4 ' + f'--wogui ' + f'--elevation {args.elevation}') diff --git a/scripts/runall_mvdream.py b/scripts/runall_mvdream.py new file mode 100644 index 0000000000000000000000000000000000000000..ef30468d46d71a92f8658defee5416aaa59056ea --- /dev/null +++ b/scripts/runall_mvdream.py @@ -0,0 +1,44 @@ +import os +import glob +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('--gpu', default=0, type=int) +args = parser.parse_args() + +prompts = [ + # ('butterfly', 'a beautiful, intricate butterfly'), + # ('boy', 'a nendoroid of a chibi cute boy'), + # ('axe', 'a viking axe, fantasy, blender'), + # ('dog_rocket', 'corgi riding a rocket'), + ('teapot', 'a chinese teapot'), + ('squirrel_guitar', 'a DSLR photo of a squirrel playing guitar'), + # ('house', 'fisherman house, cute, cartoon, blender, stylized'), + # ('ship', 'Higly detailed, majestic royal tall ship, realistic painting'), + ('einstein', 'Albert Einstein with grey suit is riding a bicycle'), + # ('angle', 'a statue of an angle'), + ('lion', 'A 3D model of Simba, the lion cub from The Lion King, standing majestically on Pride Rock, character'), + # ('paris', 'mini Paris, highly detailed 3d model'), + # ('pig_backpack', 'a pig wearing a backpack'), + ('pisa_tower', 'Picture of the Leaning Tower of Pisa, featuring its tilted structure and marble facade'), + # ('robot', 'a human-like full body robot'), + ('coin', 'a golden coin'), + # ('cake', 'a delicious and beautiful cake'), + # ('horse', 'a DSLR photo of a horse'), + # ('cat', 'a photo of a cat'), + ('cat_hat', 'a photo of a cat wearing a wizard hat'), + # ('cat_ball', 'a photo of a cat playing with a red ball'), + # ('nendoroid', 'a nendoroid of a chibi girl'), + +] + +for name, prompt in prompts: + print(f'======== processing {name} ========') + # first stage + os.system(f'CUDA_VISIBLE_DEVICES={args.gpu} python main.py --config configs/text_mv.yaml prompt="{prompt}" save_path={name}') + # second stage + os.system(f'CUDA_VISIBLE_DEVICES={args.gpu} python main2.py --config configs/text_mv.yaml prompt="{prompt}" save_path={name}') + # export video + mesh_path = os.path.join('logs', f'{name}.obj') + os.makedirs('videos', exist_ok=True) + os.system(f'python -m kiui.render {mesh_path} --save_video videos/{name}.mp4 --wogui') \ No newline at end of file diff --git a/scripts/runall_sd.py b/scripts/runall_sd.py new file mode 100644 index 0000000000000000000000000000000000000000..306a5b0d25f2650b778d5edc932c287972dad3b9 --- /dev/null +++ b/scripts/runall_sd.py @@ -0,0 +1,45 @@ +import os +import glob +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('--gpu', default=0, type=int) +args = parser.parse_args() + +prompts = [ + ('strawberry', 'a ripe strawberry'), + ('cactus_pot', 'a small saguaro cactus planted in a clay pot'), + ('hamburger', 'a delicious hamburger'), + ('icecream', 'an icecream'), + ('tulip', 'a blue tulip'), + ('pineapple', 'a ripe pineapple'), + ('goblet', 'a golden goblet'), + # ('squitopus', 'a squirrel-octopus hybrid'), + # ('astronaut', 'Michelangelo style statue of an astronaut'), + # ('teddy_bear', 'a teddy bear'), + # ('corgi_nurse', 'a plush toy of a corgi nurse'), + # ('teapot', 'a blue and white porcelain teapot'), + # ('skull', "a human skull"), + # ('penguin', 'a penguin'), + # ('campfire', 'a campfire'), + # ('donut', 'a donut with pink icing'), + # ('cupcake', 'a birthday cupcake'), + # ('pie', 'shepherds pie'), + # ('cone', 'a traffic cone'), + # ('schoolbus', 'a schoolbus'), + # ('avocado_chair', 'a chair that looks like an avocado'), + # ('glasses', 'a pair of sunglasses') + # ('potion', 'a bottle of green potion'), + # ('chalice', 'a delicate chalice'), +] + +for name, prompt in prompts: + print(f'======== processing {name} ========') + # first stage + os.system(f'CUDA_VISIBLE_DEVICES={args.gpu} python main.py --config configs/text.yaml prompt="{prompt}" save_path={name}') + # second stage + os.system(f'CUDA_VISIBLE_DEVICES={args.gpu} python main2.py --config configs/text.yaml prompt="{prompt}" save_path={name}') + # export video + mesh_path = os.path.join('logs', f'{name}.obj') + os.makedirs('videos', exist_ok=True) + os.system(f'python -m kiui.render {mesh_path} --save_video videos/{name}.mp4 --wogui') \ No newline at end of file diff --git a/sh_utils.py b/sh_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bbca7d192aa3a7edf8c5b2d24dee535eac765785 --- /dev/null +++ b/sh_utils.py @@ -0,0 +1,118 @@ +# Copyright 2021 The PlenOctree Authors. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +import torch + +C0 = 0.28209479177387814 +C1 = 0.4886025119029199 +C2 = [ + 1.0925484305920792, + -1.0925484305920792, + 0.31539156525252005, + -1.0925484305920792, + 0.5462742152960396 +] +C3 = [ + -0.5900435899266435, + 2.890611442640554, + -0.4570457994644658, + 0.3731763325901154, + -0.4570457994644658, + 1.445305721320277, + -0.5900435899266435 +] +C4 = [ + 2.5033429417967046, + -1.7701307697799304, + 0.9461746957575601, + -0.6690465435572892, + 0.10578554691520431, + -0.6690465435572892, + 0.47308734787878004, + -1.7701307697799304, + 0.6258357354491761, +] + + +def eval_sh(deg, sh, dirs): + """ + Evaluate spherical harmonics at unit directions + using hardcoded SH polynomials. + Works with torch/np/jnp. + ... Can be 0 or more batch dimensions. + Args: + deg: int SH deg. Currently, 0-3 supported + sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] + dirs: jnp.ndarray unit directions [..., 3] + Returns: + [..., C] + """ + assert deg <= 4 and deg >= 0 + coeff = (deg + 1) ** 2 + assert sh.shape[-1] >= coeff + + result = C0 * sh[..., 0] + if deg > 0: + x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] + result = (result - + C1 * y * sh[..., 1] + + C1 * z * sh[..., 2] - + C1 * x * sh[..., 3]) + + if deg > 1: + xx, yy, zz = x * x, y * y, z * z + xy, yz, xz = x * y, y * z, x * z + result = (result + + C2[0] * xy * sh[..., 4] + + C2[1] * yz * sh[..., 5] + + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + + C2[3] * xz * sh[..., 7] + + C2[4] * (xx - yy) * sh[..., 8]) + + if deg > 2: + result = (result + + C3[0] * y * (3 * xx - yy) * sh[..., 9] + + C3[1] * xy * z * sh[..., 10] + + C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + + C3[5] * z * (xx - yy) * sh[..., 14] + + C3[6] * x * (xx - 3 * yy) * sh[..., 15]) + + if deg > 3: + result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) + return result + +def RGB2SH(rgb): + return (rgb - 0.5) / C0 + +def SH2RGB(sh): + return sh * C0 + 0.5 \ No newline at end of file diff --git a/utils/camera_utils.py b/utils/camera_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4a23c1e15a3447e843df932abcd05894073c6a8d --- /dev/null +++ b/utils/camera_utils.py @@ -0,0 +1,65 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from scene.cameras import Camera +import numpy as np +from utils.general_utils import PILtoTorch +from utils.graphics_utils import fov2focal + +WARNED = False + +def loadCam(args, id, cam_info, resolution_scale): + + + # resized_image_rgb = PILtoTorch(cam_info.image, resolution) + + # gt_image = resized_image_rgb[:3, ...] + # loaded_mask = None + + # if resized_image_rgb.shape[1] == 4: + # loaded_mask = resized_image_rgb[3:4, ...] + + return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, + FoVx=cam_info.FovX, FoVy=cam_info.FovY, + image=cam_info.image, gt_alpha_mask=None, + image_name=cam_info.image_name, uid=id, data_device=args.data_device, + time = cam_info.time, +) + +def cameraList_from_camInfos(cam_infos, resolution_scale, args): + camera_list = [] + + for id, c in enumerate(cam_infos): + camera_list.append(loadCam(args, id, c, resolution_scale)) + + return camera_list + +def camera_to_JSON(id, camera : Camera): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = camera.R.transpose() + Rt[:3, 3] = camera.T + Rt[3, 3] = 1.0 + + W2C = np.linalg.inv(Rt) + pos = W2C[:3, 3] + rot = W2C[:3, :3] + serializable_array_2d = [x.tolist() for x in rot] + camera_entry = { + 'id' : id, + 'img_name' : camera.image_name, + 'width' : camera.width, + 'height' : camera.height, + 'position': pos.tolist(), + 'rotation': serializable_array_2d, + 'fy' : fov2focal(camera.FovY, camera.height), + 'fx' : fov2focal(camera.FovX, camera.width) + } + return camera_entry diff --git a/utils/general_utils.py b/utils/general_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d4972ca2cc25660e5c543e48ca75deab90fd8b47 --- /dev/null +++ b/utils/general_utils.py @@ -0,0 +1,136 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import sys +from datetime import datetime +import numpy as np +import random + +def inverse_sigmoid(x): + return torch.log(x/(1-x)) + +def PILtoTorch(pil_image, resolution): + if resolution is not None: + resized_image_PIL = pil_image.resize(resolution) + else: + resized_image_PIL = pil_image + resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 + if len(resized_image.shape) == 3: + return resized_image.permute(2, 0, 1) + else: + return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) + +def get_expon_lr_func( + lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 +): + """ + Copied from Plenoxels + + Continuous learning rate decay function. Adapted from JaxNeRF + The returned rate is lr_init when step=0 and lr_final when step=max_steps, and + is log-linearly interpolated elsewhere (equivalent to exponential decay). + If lr_delay_steps>0 then the learning rate will be scaled by some smooth + function of lr_delay_mult, such that the initial learning rate is + lr_init*lr_delay_mult at the beginning of optimization but will be eased back + to the normal learning rate when steps>lr_delay_steps. + :param conf: config subtree 'lr' or similar + :param max_steps: int, the number of steps during optimization. + :return HoF which takes step as input + """ + + def helper(step): + if step < 0 or (lr_init == 0.0 and lr_final == 0.0): + # Disable this parameter + return 0.0 + if lr_delay_steps > 0: + # A kind of reverse cosine decay. + delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( + 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) + ) + else: + delay_rate = 1.0 + t = np.clip(step / max_steps, 0, 1) + log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) + return delay_rate * log_lerp + + return helper + +def strip_lowerdiag(L): + uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") + + uncertainty[:, 0] = L[:, 0, 0] + uncertainty[:, 1] = L[:, 0, 1] + uncertainty[:, 2] = L[:, 0, 2] + uncertainty[:, 3] = L[:, 1, 1] + uncertainty[:, 4] = L[:, 1, 2] + uncertainty[:, 5] = L[:, 2, 2] + return uncertainty + +def strip_symmetric(sym): + return strip_lowerdiag(sym) + +def build_rotation(r): + norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) + + q = r / norm[:, None] + + R = torch.zeros((q.size(0), 3, 3), device='cuda') + + w = q[:, 0] + x = q[:, 1] + y = q[:, 2] + z = q[:, 3] + + R[:, 0, 0] = 1 - 2 * (y*y + z*z) + R[:, 0, 1] = 2 * (x*y - w*z) + R[:, 0, 2] = 2 * (x*z + w*y) + R[:, 1, 0] = 2 * (x*y + w*z) + R[:, 1, 1] = 1 - 2 * (x*x + z*z) + R[:, 1, 2] = 2 * (y*z - w*x) + R[:, 2, 0] = 2 * (x*z - w*y) + R[:, 2, 1] = 2 * (y*z + w*x) + R[:, 2, 2] = 1 - 2 * (x*x + y*y) + return R + +def build_scaling_rotation(s, r): + L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") + R = build_rotation(r) + + L[:,0,0] = s[:,0] + L[:,1,1] = s[:,1] + L[:,2,2] = s[:,2] + + L = R @ L + return L + +def safe_state(silent): + old_f = sys.stdout + class F: + def __init__(self, silent): + self.silent = silent + + def write(self, x): + if not self.silent: + if x.endswith("\n"): + old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) + else: + old_f.write(x) + + def flush(self): + old_f.flush() + + sys.stdout = F(silent) + + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) + torch.cuda.set_device(torch.device("cuda:0")) \ No newline at end of file diff --git a/utils/graphics_utils.py b/utils/graphics_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b4627d837c74fcdffc898fa0c3071cb7b316802b --- /dev/null +++ b/utils/graphics_utils.py @@ -0,0 +1,77 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import math +import numpy as np +from typing import NamedTuple + +class BasicPointCloud(NamedTuple): + points : np.array + colors : np.array + normals : np.array + +def geom_transform_points(points, transf_matrix): + P, _ = points.shape + ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) + points_hom = torch.cat([points, ones], dim=1) + points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) + + denom = points_out[..., 3:] + 0.0000001 + return (points_out[..., :3] / denom).squeeze(dim=0) + +def getWorld2View(R, t): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = R.transpose() + Rt[:3, 3] = t + Rt[3, 3] = 1.0 + return np.float32(Rt) + +def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = R.transpose() + Rt[:3, 3] = t + Rt[3, 3] = 1.0 + + C2W = np.linalg.inv(Rt) + cam_center = C2W[:3, 3] + cam_center = (cam_center + translate) * scale + C2W[:3, 3] = cam_center + Rt = np.linalg.inv(C2W) + return np.float32(Rt) + +def getProjectionMatrix(znear, zfar, fovX, fovY): + tanHalfFovY = math.tan((fovY / 2)) + tanHalfFovX = math.tan((fovX / 2)) + + top = tanHalfFovY * znear + bottom = -top + right = tanHalfFovX * znear + left = -right + + P = torch.zeros(4, 4) + + z_sign = 1.0 + + P[0, 0] = 2.0 * znear / (right - left) + P[1, 1] = 2.0 * znear / (top - bottom) + P[0, 2] = (right + left) / (right - left) + P[1, 2] = (top + bottom) / (top - bottom) + P[3, 2] = z_sign + P[2, 2] = z_sign * zfar / (zfar - znear) + P[2, 3] = -(zfar * znear) / (zfar - znear) + return P + +def fov2focal(fov, pixels): + return pixels / (2 * math.tan(fov / 2)) + +def focal2fov(focal, pixels): + return 2*math.atan(pixels/(2*focal)) \ No newline at end of file diff --git a/utils/image_utils.py b/utils/image_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b150699799529bbab7dd1f57d76224f218fe7a1d --- /dev/null +++ b/utils/image_utils.py @@ -0,0 +1,19 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch + +def mse(img1, img2): + return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) +@torch.no_grad() +def psnr(img1, img2): + mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) + return 20 * torch.log10(1.0 / torch.sqrt(mse)) diff --git a/utils/loss_utils.py b/utils/loss_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6543f32064cc35db5e47bc20d59707e31741b559 --- /dev/null +++ b/utils/loss_utils.py @@ -0,0 +1,67 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import torch.nn.functional as F +from torch.autograd import Variable +from math import exp +import lpips +def lpips_loss(img1, img2, lpips_model): + loss = lpips_model(img1,img2) + return loss.mean() +def l1_loss(network_output, gt): + return torch.abs((network_output - gt)).mean() + +def l2_loss(network_output, gt): + return ((network_output - gt) ** 2).mean() + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + +def ssim(img1, img2, window_size=11, size_average=True): + channel = img1.size(-3) + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + diff --git a/utils/params_utils.py b/utils/params_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6f2ea6446bfa026becc1c74db9872be4a1d7b6b7 --- /dev/null +++ b/utils/params_utils.py @@ -0,0 +1,9 @@ +def merge_hparams(args, config): + params = ["OptimizationParams", "ModelHiddenParams", "ModelParams", "PipelineParams"] + for param in params: + if param in config.keys(): + for key, value in config[param].items(): + if hasattr(args, key): + setattr(args, key, value) + + return args \ No newline at end of file diff --git a/utils/scene_utils.py b/utils/scene_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3224d6392a5397999dd117a964483fe220b71f64 --- /dev/null +++ b/utils/scene_utils.py @@ -0,0 +1,95 @@ +import torch +import os +from PIL import Image, ImageDraw, ImageFont +from matplotlib import pyplot as plt +plt.rcParams['font.sans-serif'] = ['Times New Roman'] + +import numpy as np + +import copy +@torch.no_grad() +def render_training_image(scene, gaussians, viewpoints, render_func, pipe, background, stage, iteration, time_now): + def render(gaussians, viewpoint, path, scaling): + # scaling_copy = gaussians._scaling + render_pkg = render_func(viewpoint, gaussians, pipe, background, stage=stage) + label1 = f"stage:{stage},iter:{iteration}" + times = time_now/60 + if times < 1: + end = "min" + else: + end = "mins" + label2 = "time:%.2f" % times + end + image = render_pkg["render"] + depth = render_pkg["depth"] + image_np = image.permute(1, 2, 0).cpu().numpy() # 转换通道顺序为 (H, W, 3) + depth_np = depth.permute(1, 2, 0).cpu().numpy() + depth_np /= depth_np.max() + depth_np = np.repeat(depth_np, 3, axis=2) + image_np = np.concatenate((image_np, depth_np), axis=1) + image_with_labels = Image.fromarray((np.clip(image_np,0,1) * 255).astype('uint8')) # 转换为8位图像 + # 创建PIL图像对象的副本以绘制标签 + draw1 = ImageDraw.Draw(image_with_labels) + + # 选择字体和字体大小 + font = ImageFont.truetype('./utils/TIMES.TTF', size=40) # 请将路径替换为您选择的字体文件路径 + + # 选择文本颜色 + text_color = (255, 0, 0) # 白色 + + # 选择标签的位置(左上角坐标) + label1_position = (10, 10) + label2_position = (image_with_labels.width - 100 - len(label2) * 10, 10) # 右上角坐标 + + # 在图像上添加标签 + draw1.text(label1_position, label1, fill=text_color, font=font) + draw1.text(label2_position, label2, fill=text_color, font=font) + + image_with_labels.save(path) + render_base_path = os.path.join(scene.model_path, f"{stage}_render") + point_cloud_path = os.path.join(render_base_path,"pointclouds") + image_path = os.path.join(render_base_path,"images") + if not os.path.exists(os.path.join(scene.model_path, f"{stage}_render")): + os.makedirs(render_base_path) + if not os.path.exists(point_cloud_path): + os.makedirs(point_cloud_path) + if not os.path.exists(image_path): + os.makedirs(image_path) + # image:3,800,800 + + # point_save_path = os.path.join(point_cloud_path,f"{iteration}.jpg") + for idx in range(len(viewpoints)): + image_save_path = os.path.join(image_path,f"{iteration}_{idx}.jpg") + render(gaussians,viewpoints[idx],image_save_path,scaling = 1) + # render(gaussians,point_save_path,scaling = 0.1) + # 保存带有标签的图像 + + + + pc_mask = gaussians.get_opacity + pc_mask = pc_mask > 0.1 + xyz = gaussians.get_xyz.detach()[pc_mask.squeeze()].cpu().permute(1,0).numpy() + # visualize_and_save_point_cloud(xyz, viewpoint.R, viewpoint.T, point_save_path) + # 如果需要,您可以将PIL图像转换回PyTorch张量 + # return image + # image_with_labels_tensor = torch.tensor(image_with_labels, dtype=torch.float32).permute(2, 0, 1) / 255.0 +def visualize_and_save_point_cloud(point_cloud, R, T, filename): + # 创建3D散点图 + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + R = R.T + # 应用旋转和平移变换 + T = -R.dot(T) + transformed_point_cloud = np.dot(R, point_cloud) + T.reshape(-1, 1) + # pcd = o3d.geometry.PointCloud() + # pcd.points = o3d.utility.Vector3dVector(transformed_point_cloud.T) # 转置点云数据以匹配Open3D的格式 + # transformed_point_cloud[2,:] = -transformed_point_cloud[2,:] + # 可视化点云 + ax.scatter(transformed_point_cloud[0], transformed_point_cloud[1], transformed_point_cloud[2], c='g', marker='o') + ax.axis("off") + # ax.set_xlabel('X Label') + # ax.set_ylabel('Y Label') + # ax.set_zlabel('Z Label') + + # 保存渲染结果为图片 + plt.savefig(filename) + diff --git a/utils/sh_utils.py b/utils/sh_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bbca7d192aa3a7edf8c5b2d24dee535eac765785 --- /dev/null +++ b/utils/sh_utils.py @@ -0,0 +1,118 @@ +# Copyright 2021 The PlenOctree Authors. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +import torch + +C0 = 0.28209479177387814 +C1 = 0.4886025119029199 +C2 = [ + 1.0925484305920792, + -1.0925484305920792, + 0.31539156525252005, + -1.0925484305920792, + 0.5462742152960396 +] +C3 = [ + -0.5900435899266435, + 2.890611442640554, + -0.4570457994644658, + 0.3731763325901154, + -0.4570457994644658, + 1.445305721320277, + -0.5900435899266435 +] +C4 = [ + 2.5033429417967046, + -1.7701307697799304, + 0.9461746957575601, + -0.6690465435572892, + 0.10578554691520431, + -0.6690465435572892, + 0.47308734787878004, + -1.7701307697799304, + 0.6258357354491761, +] + + +def eval_sh(deg, sh, dirs): + """ + Evaluate spherical harmonics at unit directions + using hardcoded SH polynomials. + Works with torch/np/jnp. + ... Can be 0 or more batch dimensions. + Args: + deg: int SH deg. Currently, 0-3 supported + sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] + dirs: jnp.ndarray unit directions [..., 3] + Returns: + [..., C] + """ + assert deg <= 4 and deg >= 0 + coeff = (deg + 1) ** 2 + assert sh.shape[-1] >= coeff + + result = C0 * sh[..., 0] + if deg > 0: + x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] + result = (result - + C1 * y * sh[..., 1] + + C1 * z * sh[..., 2] - + C1 * x * sh[..., 3]) + + if deg > 1: + xx, yy, zz = x * x, y * y, z * z + xy, yz, xz = x * y, y * z, x * z + result = (result + + C2[0] * xy * sh[..., 4] + + C2[1] * yz * sh[..., 5] + + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + + C2[3] * xz * sh[..., 7] + + C2[4] * (xx - yy) * sh[..., 8]) + + if deg > 2: + result = (result + + C3[0] * y * (3 * xx - yy) * sh[..., 9] + + C3[1] * xy * z * sh[..., 10] + + C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + + C3[5] * z * (xx - yy) * sh[..., 14] + + C3[6] * x * (xx - 3 * yy) * sh[..., 15]) + + if deg > 3: + result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) + return result + +def RGB2SH(rgb): + return (rgb - 0.5) / C0 + +def SH2RGB(sh): + return sh * C0 + 0.5 \ No newline at end of file diff --git a/utils/system_utils.py b/utils/system_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..90ca6d7f77610c967affe313398777cd86920e8e --- /dev/null +++ b/utils/system_utils.py @@ -0,0 +1,28 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from errno import EEXIST +from os import makedirs, path +import os + +def mkdir_p(folder_path): + # Creates a directory. equivalent to using mkdir -p on the command line + try: + makedirs(folder_path) + except OSError as exc: # Python >2.5 + if exc.errno == EEXIST and path.isdir(folder_path): + pass + else: + raise + +def searchForMaxIteration(folder): + saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] + return max(saved_iters) diff --git a/utils/timer.py b/utils/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..c01ff93c1bdc94a07f1c20a07fb0a983dc496e62 --- /dev/null +++ b/utils/timer.py @@ -0,0 +1,24 @@ +import time +class Timer: + def __init__(self): + self.start_time = None + self.elapsed = 0 + self.paused = False + + def start(self): + if self.start_time is None: + self.start_time = time.time() + elif self.paused: + self.start_time = time.time() - self.elapsed + self.paused = False + + def pause(self): + if not self.paused: + self.elapsed = time.time() - self.start_time + self.paused = True + + def get_elapsed_time(self): + if self.paused: + return self.elapsed + else: + return time.time() - self.start_time \ No newline at end of file diff --git a/zero123.py b/zero123.py new file mode 100644 index 0000000000000000000000000000000000000000..158e31ee4f877c11dda9118b382b2e226bf45e3a --- /dev/null +++ b/zero123.py @@ -0,0 +1,666 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL +import torch +import torchvision.transforms.functional as TF +from diffusers.configuration_utils import ConfigMixin, FrozenDict, register_to_config +from diffusers.image_processor import VaeImageProcessor +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.modeling_utils import ModelMixin +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import ( + StableDiffusionSafetyChecker, +) +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import deprecate, is_accelerate_available, logging +from diffusers.utils.torch_utils import randn_tensor +from packaging import version +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CLIPCameraProjection(ModelMixin, ConfigMixin): + """ + A Projection layer for CLIP embedding and camera embedding. + + Parameters: + embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `clip_embed` + additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the + projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings + + additional_embeddings`. + """ + + @register_to_config + def __init__(self, embedding_dim: int = 768, additional_embeddings: int = 4): + super().__init__() + self.embedding_dim = embedding_dim + self.additional_embeddings = additional_embeddings + + self.input_dim = self.embedding_dim + self.additional_embeddings + self.output_dim = self.embedding_dim + + self.proj = torch.nn.Linear(self.input_dim, self.output_dim) + + def forward( + self, + embedding: torch.FloatTensor, + ): + """ + The [`PriorTransformer`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, input_dim)`): + The currently input embeddings. + + Returns: + The output embedding projection (`torch.FloatTensor` of shape `(batch_size, output_dim)`). + """ + proj_embedding = self.proj(embedding) + return proj_embedding + + +class Zero123Pipeline(DiffusionPipeline): + r""" + Pipeline to generate variations from an input image using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + # TODO: feature_extractor is required to encode images (if they are in PIL format), + # we should give a descriptive message if the pipeline doesn't have one. + _optional_components = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + image_encoder: CLIPVisionModelWithProjection, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + clip_camera_projection: CLIPCameraProjection, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warn( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr( + unet.config, "_diffusers_version" + ) and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse( + "0.9.0.dev0" + ) + is_unet_sample_size_less_64 = ( + hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate( + "sample_size<64", "1.0.0", deprecation_message, standard_warn=False + ) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + clip_camera_projection=clip_camera_projection, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [ + self.unet, + self.image_encoder, + self.vae, + self.safety_checker, + ]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_image( + self, + image, + elevation, + azimuth, + distance, + device, + num_images_per_prompt, + do_classifier_free_guidance, + clip_image_embeddings=None, + image_camera_embeddings=None, + ): + dtype = next(self.image_encoder.parameters()).dtype + + if image_camera_embeddings is None: + if image is None: + assert clip_image_embeddings is not None + image_embeddings = clip_image_embeddings.to(device=device, dtype=dtype) + else: + if not isinstance(image, torch.Tensor): + image = self.feature_extractor( + images=image, return_tensors="pt" + ).pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + bs_embed, seq_len, _ = image_embeddings.shape + + if isinstance(elevation, float): + elevation = torch.as_tensor( + [elevation] * bs_embed, dtype=dtype, device=device + ) + if isinstance(azimuth, float): + azimuth = torch.as_tensor( + [azimuth] * bs_embed, dtype=dtype, device=device + ) + if isinstance(distance, float): + distance = torch.as_tensor( + [distance] * bs_embed, dtype=dtype, device=device + ) + + camera_embeddings = torch.stack( + [ + torch.deg2rad(elevation), + torch.sin(torch.deg2rad(azimuth)), + torch.cos(torch.deg2rad(azimuth)), + distance, + ], + dim=-1, + )[:, None, :] + + image_embeddings = torch.cat([image_embeddings, camera_embeddings], dim=-1) + + # project (image, camera) embeddings to the same dimension as clip embeddings + image_embeddings = self.clip_camera_projection(image_embeddings) + else: + image_embeddings = image_camera_embeddings.to(device=device, dtype=dtype) + bs_embed, seq_len, _ = image_embeddings.shape + + # duplicate image embeddings for each generation per prompt, using mps friendly method + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(image_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess( + image, output_type="pil" + ) + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor( + feature_extractor_input, return_tensors="pt" + ).to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, image, height, width, callback_steps): + # TODO: check image size or adjust image size to (height, width) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) + + if (callback_steps is None) or ( + callback_steps is not None + and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_latent_model_input( + self, + latents: torch.FloatTensor, + image: Optional[ + Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor] + ], + num_images_per_prompt: int, + do_classifier_free_guidance: bool, + image_latents: Optional[torch.FloatTensor] = None, + ): + if isinstance(image, PIL.Image.Image): + image_pt = TF.to_tensor(image).unsqueeze(0).to(latents) + elif isinstance(image, list): + image_pt = torch.stack([TF.to_tensor(img) for img in image], dim=0).to( + latents + ) + elif isinstance(image, torch.Tensor): + image_pt = image + else: + image_pt = None + + if image_pt is None: + assert image_latents is not None + image_pt = image_latents.repeat_interleave(num_images_per_prompt, dim=0) + else: + image_pt = image_pt * 2.0 - 1.0 # scale to [-1, 1] + # FIXME: encoded latents should be multiplied with self.vae.config.scaling_factor + # but zero123 was not trained this way + image_pt = self.vae.encode(image_pt).latent_dist.mode() + image_pt = image_pt.repeat_interleave(num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + latent_model_input = torch.cat( + [ + torch.cat([latents, latents], dim=0), + torch.cat([torch.zeros_like(image_pt), image_pt], dim=0), + ], + dim=1, + ) + else: + latent_model_input = torch.cat([latents, image_pt], dim=1) + + return latent_model_input + + @torch.no_grad() + def __call__( + self, + image: Optional[ + Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor] + ] = None, + elevation: Optional[Union[float, torch.FloatTensor]] = None, + azimuth: Optional[Union[float, torch.FloatTensor]] = None, + distance: Optional[Union[float, torch.FloatTensor]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 3.0, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + clip_image_embeddings: Optional[torch.FloatTensor] = None, + image_camera_embeddings: Optional[torch.FloatTensor] = None, + image_latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The image or images to guide the image generation. If you provide a tensor, it needs to comply with the + configuration of + [this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) + `CLIPImageProcessor` + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + # TODO: check input elevation, azimuth, and distance + # TODO: check image, clip_image_embeddings, image_latents + self.check_inputs(image, height, width, callback_steps) + + # 2. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + elif isinstance(image, torch.Tensor): + batch_size = image.shape[0] + else: + assert image_latents is not None + assert ( + clip_image_embeddings is not None or image_camera_embeddings is not None + ) + batch_size = image_latents.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input image + if isinstance(image, PIL.Image.Image) or isinstance(image, list): + pil_image = image + elif isinstance(image, torch.Tensor): + pil_image = [TF.to_pil_image(image[i]) for i in range(image.shape[0])] + else: + pil_image = None + image_embeddings = self._encode_image( + pil_image, + elevation, + azimuth, + distance, + device, + num_images_per_prompt, + do_classifier_free_guidance, + clip_image_embeddings, + image_camera_embeddings, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + # num_channels_latents = self.unet.config.in_channels + num_channels_latents = 4 # FIXME: hard-coded + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = self._get_latent_model_input( + latents, + image, + num_images_per_prompt, + do_classifier_free_guidance, + image_latents, + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=image_embeddings, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs + ).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False + )[0] + image, has_nsfw_concept = self.run_safety_checker( + image, device, image_embeddings.dtype + ) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess( + image, output_type=output_type, do_denormalize=do_denormalize + ) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) \ No newline at end of file