#!/usr/bin/env python # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and import argparse import os import einops import numpy as np import torch import torch.utils.checkpoint from accelerate.utils import ProjectConfiguration, set_seed from PIL import Image from torchvision import transforms from tqdm.auto import tqdm import torchvision import json import cv2 from skimage.io import imsave import matplotlib.pyplot as plt # read .exr files for RTMV dataset os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a Zero123 training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, default="lambdalabs/sd-image-variations-diffusers", required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--revision", type=str, default=None, required=False, help=( "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" " float32 precision." ), ) parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, default=256, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( "--T_in", type=int, default=1, help="Number of input views" ) parser.add_argument( "--T_out", type=int, default=1, help="Number of output views" ) parser.add_argument( "--guidance_scale", type=float, default=3.0, help="unconditional guidance scale, if guidance_scale>1.0, do_classifier_free_guidance" ) parser.add_argument( "--data_dir", type=str, default=".", help=( "The input data dir. Should contain the .png files (or other data files) for the task." ), ) parser.add_argument( "--data_type", type=str, default="GSO25", help=( "The input data type. Chosen from GSO25, GSO3D, GSO100, RTMV, NeRF, Franka, MVDream, Text2Img" ), ) parser.add_argument( "--cape_type", type=str, default="6DoF", help=( "The camera pose encoding CaPE type. Chosen from 4DoF, 6DoF" ), ) parser.add_argument( "--output_dir", type=str, default="logs_eval", help=( "The output directory where the model predictions and checkpoints will be written." ), ) parser.add_argument( "--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument( "--enable_xformers_memory_efficient_attention", default=True, help="Whether or not to use xformers." ) if input_args is not None: args = parser.parse_args(input_args) else: args = parser.parse_args() if args.resolution % 8 != 0: raise ValueError( "`--resolution` must be divisible by 8 for consistently sized encoded images." ) return args # create angles in archimedean spiral with T_out number import math def get_archimedean_spiral(sphere_radius, num_steps=250): # x-z plane, around upper y ''' https://en.wikipedia.org/wiki/Spiral, section "Spherical spiral". c = a / pi ''' a = 40 r = sphere_radius translations = [] angles = [] # i = a / 2 i = 0.01 while i < a: theta = i / a * math.pi x = r * math.sin(theta) * math.cos(-i) z = r * math.sin(-theta + math.pi) * math.sin(-i) y = r * - math.cos(theta) # translations.append((x, y, z)) # origin translations.append((x, z, -y)) angles.append([np.rad2deg(-i), np.rad2deg(theta)]) # i += a / (2 * num_steps) i += a / (1 * num_steps) return np.array(translations), np.stack(angles) # 36 views around the circle, with elevation degree def get_circle_traj(sphere_radius, elevation=0, num_steps=36): translations = [] angles = [] elevation = np.deg2rad(elevation) for i in range(num_steps): theta = i / num_steps * 2 * math.pi x = sphere_radius * math.sin(theta) * math.cos(elevation) z = sphere_radius * math.sin(-theta+math.pi) * math.sin(-elevation) y = sphere_radius * -math.cos(theta) translations.append((x, z, -y)) angles.append([np.rad2deg(-elevation), np.rad2deg(theta)]) return np.array(translations), np.stack(angles) def look_at(origin, target, up): forward = (target - origin) forward = forward / np.linalg.norm(forward) right = np.cross(up, forward) right = right / np.linalg.norm(right) new_up = np.cross(forward, right) rotation_matrix = np.column_stack((right, new_up, -forward, target)) matrix = np.row_stack((rotation_matrix, [0, 0, 0, 1])) return matrix # from carvekit.api.high import HiInterface # def create_carvekit_interface(): # # Check doc strings for more information # interface = HiInterface(object_type="object", # Can be "object" or "hairs-like". # batch_size_seg=5, # batch_size_matting=1, # device='cuda' if torch.cuda.is_available() else 'cpu', # seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net # matting_mask_size=2048, # trimap_prob_threshold=231, # trimap_dilation=30, # trimap_erosion_iters=5, # fp16=False) # # return interface import rembg def create_rembg_interface(): rembg_session = rembg.new_session() return rembg_session def main(args): if args.seed is not None: set_seed(args.seed) CaPE_TYPE = args.cape_type if CaPE_TYPE == "6DoF": import sys sys.path.insert(0, "./6DoF/") # use the customized diffusers modules from diffusers import DDIMScheduler from dataset import get_pose from CN_encoder import CN_encoder from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline elif CaPE_TYPE == "4DoF": import sys sys.path.insert(0, "./4DoF/") # use the customized diffusers modules from diffusers import DDIMScheduler from dataset import get_pose from CN_encoder import CN_encoder from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline else: raise ValueError("CaPE_TYPE must be chosen from 4DoF, 6DoF") # from dataset import get_pose # from CN_encoder import CN_encoder # from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline DATA_DIR = args.data_dir DATA_TYPE = args.data_type if DATA_TYPE == "GSO25": T_in_DATA_TYPE = "render_mvs_25" # same condition for GSO T_out_DATA_TYPE = "render_mvs_25" # for 2D metrics T_out = 25 elif DATA_TYPE == "GSO25_6dof": T_in_DATA_TYPE = "render_6dof_25" # same condition for GSO T_out_DATA_TYPE = "render_6dof_25" # for 2D metrics T_out = 25 elif DATA_TYPE == "GSO3D": T_in_DATA_TYPE = "render_mvs_25" # same condition for GSO T_out_DATA_TYPE = "render_sync_36_single" # for 3D metrics T_out = 36 elif DATA_TYPE == "GSO100": T_in_DATA_TYPE = "render_mvs_25" # same condition for GSO T_out_DATA_TYPE = "render_spiral_100" # for 360 gif T_out = 100 elif DATA_TYPE == "NeRF": T_out = 200 elif DATA_TYPE == "RTMV": T_out = 20 elif DATA_TYPE == "Franka": T_out = 100 # do a 360 gif elif DATA_TYPE == "MVDream": T_out = 100 # do a 360 gif elif DATA_TYPE == "Text2Img": T_out = 100 # do a 360 gif elif DATA_TYPE == "dust3r": # carvekit = create_carvekit_interface() rembg_session = create_rembg_interface() T_out = 50 # do a 360 gif # get the number of .png files in the folder obj_names = [f for f in os.listdir(DATA_DIR+"/user_object") if f.endswith('.png')] args.T_in = len(obj_names) else: raise NotImplementedError T_in = args.T_in OUTPUT_DIR= f"logs_{CaPE_TYPE}/{DATA_TYPE}/N{T_in}M{T_out}" os.makedirs(OUTPUT_DIR, exist_ok=True) # get all folders in DATA_DIR if DATA_TYPE == "Text2Img": # get all rgba_png in DATA_DIR obj_names = [f for f in os.listdir(DATA_DIR) if f.endswith('rgba.png')] else: obj_names = [f for f in os.listdir(DATA_DIR) if os.path.isdir(os.path.join(DATA_DIR, f))] weight_dtype = torch.float16 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") h, w = args.resolution, args.resolution bg_color = [1., 1., 1., 1.] radius = 2.2 #1.5 #1.8 # Objaverse training radius [1.5, 2.2] # radius_4dof = np.pi * (np.log(radius) - np.log(1.5)) / (np.log(2.2)-np.log(1.5)) # Init Dataset image_transforms = torchvision.transforms.Compose( [ torchvision.transforms.Resize((args.resolution, args.resolution)), # 256, 256 transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ] ) # Init pipeline scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision) image_encoder = CN_encoder.from_pretrained(args.pretrained_model_name_or_path, subfolder="image_encoder", revision=args.revision) pipeline = Zero1to3StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, revision=args.revision, scheduler=scheduler, image_encoder=None, safety_checker=None, feature_extractor=None, torch_dtype=weight_dtype, ) pipeline.image_encoder = image_encoder pipeline = pipeline.to(device) pipeline.set_progress_bar_config(disable=False) if args.enable_xformers_memory_efficient_attention: pipeline.enable_xformers_memory_efficient_attention() # enable vae slicing pipeline.enable_vae_slicing() if args.seed is None: generator = None else: generator = torch.Generator(device=device).manual_seed(args.seed) for obj_name in tqdm(obj_names): print(f"Processing {obj_name}") if DATA_TYPE == "NeRF": if os.path.exists(os.path.join(args.output_dir, obj_name, "output.gif")): continue # load train info with open(os.path.join(DATA_DIR, obj_name, "transforms_train.json"), "r") as f: train_info = json.load(f)["frames"] # load test info with open(os.path.join(DATA_DIR, obj_name, "transforms_test.json"), "r") as f: test_info = json.load(f)["frames"] # find the radius [min_t, max_t] of the object, we later scale it to training radius [1.5, 2.2] max_t = 0 min_t = 100 for i in range(len(train_info)): pose = np.array(train_info[i]["transform_matrix"]).reshape(4, 4) translation = pose[:3, -1] radii = np.linalg.norm(translation) if max_t < radii: max_t = radii if min_t > radii: min_t = radii info_dir = os.path.join("metrics/NeRF_idx", obj_name) assert os.path.exists(info_dir) # use fixed train index train_index = np.load(os.path.join(info_dir, f"train_N{T_in}M20_random.npy")) test_index = np.arange(len(test_info)) # use all test views elif DATA_TYPE == "Franka": angles_in = np.load(os.path.join(DATA_DIR, obj_name, "angles.npy")) # azimuth, elevation in radians assert T_in <= len(angles_in) total_index = np.arange(0, len(angles_in)) # num of input views # random shuffle total_index np.random.shuffle(total_index) train_index = total_index[:T_in] xyzs, angles_out = get_archimedean_spiral(radius, T_out) origin = np.array([0, 0, 0]) up = np.array([0, 0, 1]) test_index = np.arange(len(angles_out)) # use all 100 test views elif DATA_TYPE == "MVDream": # 4 input views front right back left angles_in = [] for polar in [90]: # 1 for azimu in np.arange(0, 360, 90): # 4 angles_in.append(np.array([azimu, polar])) assert T_in == len(angles_in) xyzs, angles_out = get_archimedean_spiral(radius, T_out) origin = np.array([0, 0, 0]) up = np.array([0, 0, 1]) train_index = np.arange(T_in) test_index = np.arange(T_out) elif DATA_TYPE == "Text2Img": # 1 input view angles_in = [] angles_in.append(np.array([0, 90])) assert T_in == len(angles_in) xyzs, angles_out = get_archimedean_spiral(radius, T_out) origin = np.array([0, 0, 0]) up = np.array([0, 0, 1]) train_index = np.arange(T_in) test_index = np.arange(T_out) elif DATA_TYPE == "dust3r": # TODO full archimedean spiral traj # xyzs, angles_out = get_archimedean_spiral(radius, T_out) # TODO only top circle traj xyzs, angles_out = get_archimedean_spiral(1.5, 100) xyzs = xyzs[:T_out] angles_out = angles_out[:T_out] # # TODO circle traj # xyzs, angles_out = get_circle_traj(radius, elevation=30, num_steps=T_out) origin = np.array([0, 0, 0]) up = np.array([0, 0, 1]) train_index = np.arange(T_in) test_index = np.arange(T_out) # get the max_t radii = np.load(os.path.join(DATA_DIR, obj_name, "radii.npy")) max_t = np.max(radii) min_t = np.min(radii) else: train_index = np.arange(T_in) test_index = np.arange(T_out) # prepare input img + pose, output pose input_image = [] pose_in = [] pose_out = [] gt_image = [] for T_in_index in train_index: if DATA_TYPE == "RTMV": img_path = os.path.join(DATA_DIR, obj_name, '%05d.exr' % T_in_index) input_im = cv2.imread(img_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) img = cv2.cvtColor(input_im, cv2.COLOR_BGR2RGB, input_im) img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)).convert("RGB") input_image.append(image_transforms(img)) # load input pose pose_path = os.path.join(DATA_DIR, obj_name, '%05d.json' % T_in_index) with open(pose_path, "r") as f: pose_dict = json.load(f) input_RT = np.array(pose_dict["camera_data"]["cam2world"]).T input_RT = np.linalg.inv(input_RT)[:3] pose_in.append(get_pose(np.concatenate([input_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0))) else: if DATA_TYPE == "NeRF": img_path = os.path.join(DATA_DIR, obj_name, train_info[T_in_index]["file_path"] + ".png") pose = np.array(train_info[T_in_index]["transform_matrix"]) if CaPE_TYPE == "6DoF": # blender to opencv pose[1:3, :] *= -1 pose = np.linalg.inv(pose) # scale radius to [1.5, 2.2] pose[:3, 3] *= 1. / max_t * radius elif CaPE_TYPE == "4DoF": pose = np.linalg.inv(pose) pose_in.append(torch.from_numpy(get_pose(pose))) elif DATA_TYPE == "Franka": img_path = os.path.join(DATA_DIR, obj_name, "images_rgba", f"frame{T_in_index:06d}.png") azimuth, elevation = np.rad2deg(angles_in[T_in_index]) print("input angles index", T_in_index, "azimuth", azimuth, "elevation", elevation) if CaPE_TYPE == "4DoF": pose_in.append(torch.from_numpy([np.deg2rad(90. - elevation), np.deg2rad(azimuth - 180), 0., 0.])) elif CaPE_TYPE == "6DoF": neg_i = np.deg2rad(azimuth - 180) neg_theta = np.deg2rad(90. - elevation) xyz = np.array([np.sin(neg_theta) * np.cos(neg_i), np.sin(-neg_theta + np.pi) * np.sin(neg_i), np.cos(neg_theta)]) * radius pose = look_at(origin, xyz, up) pose = np.linalg.inv(pose) pose[2, :] *= -1 pose_in.append(torch.from_numpy(get_pose(pose))) elif DATA_TYPE == "MVDream" or DATA_TYPE == "Text2Img": if DATA_TYPE == "MVDream": img_path = os.path.join(DATA_DIR, obj_name, f"{T_in_index}_rgba.png") elif DATA_TYPE == "Text2Img": img_path = os.path.join(DATA_DIR, obj_name) azimuth, polar = angles_in[T_in_index] if CaPE_TYPE == "4DoF": pose_in.append(torch.tensor([np.deg2rad(polar), np.deg2rad(azimuth), 0., 0.])) elif CaPE_TYPE == "6DoF": neg_theta = np.deg2rad(polar) neg_i = np.deg2rad(azimuth) xyz = np.array([np.sin(neg_theta) * np.cos(neg_i), np.sin(-neg_theta + np.pi) * np.sin(neg_i), np.cos(neg_theta)]) * radius pose = look_at(origin, xyz, up) pose = np.linalg.inv(pose) pose[2, :] *= -1 pose_in.append(torch.from_numpy(get_pose(pose))) elif DATA_TYPE == "dust3r": # TODO get the object coordinate, now one of the camera is the center img_path = os.path.join(DATA_DIR, obj_name, "%03d.png" % T_in_index) pose = get_pose(np.linalg.inv(np.load(os.path.join(DATA_DIR, obj_name, "%03d.npy" % T_in_index)))) pose[1:3, :] *= -1 # scale radius to [1.5, 2.2] pose[:3, 3] *= 1. / max_t * radius pose_in.append(torch.from_numpy(pose)) else: # GSO img_path = os.path.join(DATA_DIR, obj_name, T_in_DATA_TYPE, "model/%03d.png" % T_in_index) pose_path = os.path.join(DATA_DIR, obj_name, T_in_DATA_TYPE, "model/%03d.npy" % T_in_index) if T_in_DATA_TYPE == "render_mvs_25" or T_in_DATA_TYPE == "render_6dof_25": # blender coordinate pose_in.append(get_pose(np.concatenate([np.load(pose_path)[:3, :], np.array([[0, 0, 0, 1]])], axis=0))) else: # opencv coordinate pose = get_pose(np.concatenate([np.load(pose_path)[:3, :], np.array([[0, 0, 0, 1]])], axis=0)) pose[1:3, :] *= -1 # pose out 36 is in opencv coordinate, pose in 25 is in blender coordinate pose_in.append(torch.from_numpy(pose)) # pose_in.append(get_pose(np.concatenate([np.load(pose_path)[:3, :], np.array([[0, 0, 0, 1]])], axis=0))) # load image img = plt.imread(img_path) if (img.shape[-1] == 3 or (img[:,:,-1] == 1).all()) and DATA_TYPE == "dust3r": img_pil = Image.fromarray(np.uint8(img * 255.)).convert("RGB") # to PIL image ## use carvekit # image_without_background = carvekit([img_pil])[0] # image_without_background = np.array(image_without_background) # est_seg = image_without_background > 127 # foreground = est_seg[:, :, -1].astype(np.bool_) # img = np.concatenate([img[:,:,:3], foreground[:, :, np.newaxis]], axis=-1) # use rembg image = rembg.remove(img_pil, session=rembg_session) foreground = np.array(image)[:,:,-1] > 127 img = np.concatenate([img[:,:,:3], foreground[:, :, np.newaxis]], axis=-1) img[img[:, :, -1] == 0.] = bg_color img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)).convert("RGB") input_image.append(image_transforms(img)) for T_out_index in test_index: if DATA_TYPE == "RTMV": img_path = os.path.join(DATA_DIR, obj_name, '%05d.exr' % T_out_index) gt_im = cv2.imread(img_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) img = cv2.cvtColor(gt_im, cv2.COLOR_BGR2RGB, gt_im) img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)).convert("RGB") gt_image.append(image_transforms(img)) # load pose pose_path = os.path.join(DATA_DIR, obj_name, '%05d.json' % T_out_index) with open(pose_path, "r") as f: pose_dict = json.load(f) output_RT = np.array(pose_dict["camera_data"]["cam2world"]).T output_RT = np.linalg.inv(output_RT)[:3] pose_out.append(get_pose(np.concatenate([output_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0))) else: if DATA_TYPE == "NeRF": img_path = os.path.join(DATA_DIR, obj_name, test_info[T_out_index]["file_path"] + ".png") pose = np.array(test_info[T_out_index]["transform_matrix"]) if CaPE_TYPE == "6DoF": # blender to opencv pose[1:3, :] *= -1 pose = np.linalg.inv(pose) # scale radius to [1.5, 2.2] pose[:3, 3] *= 1. / max_t * radius elif CaPE_TYPE == "4DoF": pose = np.linalg.inv(pose) pose_out.append(torch.from_numpy(get_pose(pose))) elif DATA_TYPE == "Franka": img_path = None azimuth, polar = angles_out[T_out_index] if CaPE_TYPE == "4DoF": pose_out.append(torch.from_numpy([np.deg2rad(polar), np.deg2rad(azimuth), 0., 0.])) elif CaPE_TYPE == "6DoF": pose = look_at(origin, xyzs[T_out_index], up) neg_theta = np.deg2rad(polar) neg_i = np.deg2rad(azimuth) xyz = np.array([np.sin(neg_theta) * np.cos(neg_i), np.sin(-neg_theta + np.pi) * np.sin(neg_i), np.cos(neg_theta)]) * radius assert np.allclose(xyzs[T_out_index], xyz) pose = np.linalg.inv(pose) pose[2, :] *= -1 pose_out.append(torch.from_numpy(get_pose(pose))) elif DATA_TYPE == "MVDream" or DATA_TYPE == "Text2Img" or DATA_TYPE == "dust3r": img_path = None azimuth, polar = angles_out[T_out_index] if CaPE_TYPE == "4DoF": pose_out.append(torch.tensor([np.deg2rad(polar), np.deg2rad(azimuth), 0., 0.])) elif CaPE_TYPE == "6DoF": pose = look_at(origin, xyzs[T_out_index], up) pose = np.linalg.inv(pose) pose[2, :] *= -1 pose_out.append(torch.from_numpy(get_pose(pose))) else: # GSO img_path = os.path.join(DATA_DIR, obj_name, T_out_DATA_TYPE, "model/%03d.png" % T_out_index) pose_path = os.path.join(DATA_DIR, obj_name, T_out_DATA_TYPE, "model/%03d.npy" % T_out_index) if T_out_DATA_TYPE == "render_mvs_25" or T_out_DATA_TYPE == "render_6dof_25": # blender coordinate pose_out.append(get_pose(np.concatenate([np.load(pose_path)[:3, :], np.array([[0, 0, 0, 1]])], axis=0))) else: # opencv coordinate pose = get_pose(np.concatenate([np.load(pose_path)[:3, :], np.array([[0, 0, 0, 1]])], axis=0)) pose[1:3, :] *= -1 # pose out 36 is in opencv coordinate, pose in 25 is in blender coordinate pose_out.append(torch.from_numpy(pose)) # load image if img_path is not None: # sometimes don't have GT target view image img = plt.imread(img_path) img[img[:, :, -1] == 0.] = bg_color img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)).convert("RGB") gt_image.append(image_transforms(img)) # [B, T, C, H, W] input_image = torch.stack(input_image, dim=0).to(device).to(weight_dtype).unsqueeze(0) if len(gt_image)>0: gt_image = torch.stack(gt_image, dim=0).to(device).to(weight_dtype).unsqueeze(0) # [B, T, 4] pose_in = np.stack(pose_in) pose_out = np.stack(pose_out) if CaPE_TYPE == "6DoF": pose_in_inv = np.linalg.inv(pose_in).transpose([0, 2, 1]) pose_out_inv = np.linalg.inv(pose_out).transpose([0, 2, 1]) pose_in_inv = torch.from_numpy(pose_in_inv).to(device).to(weight_dtype).unsqueeze(0) pose_out_inv = torch.from_numpy(pose_out_inv).to(device).to(weight_dtype).unsqueeze(0) pose_in = torch.from_numpy(pose_in).to(device).to(weight_dtype).unsqueeze(0) pose_out = torch.from_numpy(pose_out).to(device).to(weight_dtype).unsqueeze(0) input_image = einops.rearrange(input_image, "b t c h w -> (b t) c h w") if len(gt_image)>0: gt_image = einops.rearrange(gt_image, "b t c h w -> (b t) c h w") assert T_in == input_image.shape[0] assert T_in == pose_in.shape[1] assert T_out == pose_out.shape[1] # run inference if CaPE_TYPE == "6DoF": with torch.autocast("cuda"): image = pipeline(input_imgs=input_image, prompt_imgs=input_image, poses=[[pose_out, pose_out_inv], [pose_in, pose_in_inv]], height=h, width=w, T_in=T_in, T_out=T_out, guidance_scale=args.guidance_scale, num_inference_steps=50, generator=generator, output_type="numpy").images elif CaPE_TYPE == "4DoF": with torch.autocast("cuda"): image = pipeline(input_imgs=input_image, prompt_imgs=input_image, poses=[pose_out, pose_in], height=h, width=w, T_in=T_in, T_out=T_out, guidance_scale=args.guidance_scale, num_inference_steps=50, generator=generator, output_type="numpy").images # save results output_dir = os.path.join(OUTPUT_DIR, obj_name) os.makedirs(output_dir, exist_ok=True) # save input image for visualization imsave(os.path.join(output_dir, 'input.png'), ((np.concatenate(input_image.permute(0, 2, 3, 1).cpu().numpy(), 1) + 1) / 2 * 255).astype(np.uint8)) # save output image if T_out >= 30: # save to N imgs for i in range(T_out): imsave(os.path.join(output_dir, f'{i}.png'), (image[i] * 255).astype(np.uint8)) # make a gif frames = [Image.fromarray((image[i] * 255).astype(np.uint8)) for i in range(T_out)] frame_one = frames[0] frame_one.save(os.path.join(output_dir, "output.gif"), format="GIF", append_images=frames, save_all=True, duration=50, loop=1) else: imsave(os.path.join(output_dir, '0.png'), (np.concatenate(image, 1) * 255).astype(np.uint8)) # save gt for visualization if len(gt_image)>0: imsave(os.path.join(output_dir, 'gt.png'), ((np.concatenate(gt_image.permute(0, 2, 3, 1).cpu().numpy(), 1) + 1) / 2 * 255).astype(np.uint8)) if __name__ == "__main__": args = parse_args() main(args)