File size: 9,228 Bytes
3db57e8 e53e4fc 3db57e8 e53e4fc 3db57e8 e53e4fc 3db57e8 e53e4fc 3db57e8 e53e4fc 3db57e8 e53e4fc 3db57e8 e53e4fc 3db57e8 e53e4fc 3db57e8 e53e4fc 3db57e8 e53e4fc 3db57e8 e53e4fc 3db57e8 e53e4fc 3db57e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import os
import torch
import torch.nn as nn
import numpy as np
import math
from diffusers import DiffusionPipeline
from einops import rearrange, repeat
from itertools import chain
from tqdm import tqdm
from .geometry import get_batch_from_spherical
class SPADPipeline(DiffusionPipeline):
def __init__(self, unet, vae, text_encoder, tokenizer, scheduler):
super().__init__()
self.register_modules(
unet=unet,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler
)
self.cfg_conds = ["txt", "cam", "epi", "plucker"]
self.cfg_scales = [7.5, 1.0, 1.0, 1.0] # Default scales, adjust as needed
self.use_abs_extrinsics = False
self.use_intrinsic = False
self.cc_projection = nn.Sequential(
nn.Linear(4 if not self.use_intrinsic else 8, 1280),
nn.SiLU(),
nn.Linear(1280, 1280),
)
nn.init.zeros_(self.cc_projection[-1].weight)
nn.init.zeros_(self.cc_projection[-1].bias)
def generate_camera_batch(self, elevations, azimuths, use_abs=False):
batch = get_batch_from_spherical(elevations, azimuths)
abs_cams = [torch.tensor([theta, azimuth, 3.5]) for theta, azimuth in zip(elevations, azimuths)]
debug_cams = [[] for _ in range(len(azimuths))]
for i, icam in enumerate(abs_cams):
for j, jcam in enumerate(abs_cams):
if use_abs:
dcam = torch.tensor([icam[0], math.sin(icam[1]), math.cos(icam[1]), icam[2]])
else:
dcam = icam - jcam
dcam = torch.tensor([dcam[0].item(), math.sin(dcam[1].item()), math.cos(dcam[1].item()), dcam[2].item()])
debug_cams[i].append(dcam)
batch["cam"] = torch.stack([torch.stack(dc) for dc in debug_cams])
# Add intrinsics to the batch
focal = 1 / np.tan(0.702769935131073 / 2)
intrinsics = np.diag(np.array([focal, focal, 1])).astype(np.float32)
intrinsics = torch.from_numpy(intrinsics).unsqueeze(0).float().repeat(batch["cam"].shape[0], 1, 1)
batch["render_intrinsics_flat"] = intrinsics[:, [0,1,0,1], [0,1,-1,-1]]
return batch
def get_gaussian_image(self, blob_width=256, blob_height=256, sigma=0.5):
X = np.linspace(-1, 1, blob_width)[None, :]
Y = np.linspace(-1, 1, blob_height)[:, None]
inv_dev = 1 / sigma ** 2
gaussian_blob = np.exp(-0.5 * (X**2) * inv_dev) * np.exp(-0.5 * (Y**2) * inv_dev)
if gaussian_blob.max() > 0:
gaussian_blob = 255.0 * (gaussian_blob - gaussian_blob.min()) / gaussian_blob.max()
gaussian_blob = 255.0 - gaussian_blob
gaussian_blob = (gaussian_blob / 255.0) * 2.0 - 1.0
gaussian_blob = np.expand_dims(gaussian_blob, axis=-1).repeat(3,-1)
gaussian_blob = torch.from_numpy(gaussian_blob)
return gaussian_blob
@torch.no_grad()
def __call__(self, prompt, num_inference_steps=50, guidance_scale=7.5, num_images_per_prompt=1,
elevations=None, azimuths=None, blob_sigma=0.5, **kwargs):
batch_size = len(prompt) if isinstance(prompt, list) else 1
device = self.device
# Generate camera batch
if elevations is None or azimuths is None:
elevations = [45] * 4
azimuths = [0, 90, 180, 270]
n_views = len(elevations)
camera_batch = self.generate_camera_batch(elevations, azimuths, use_abs=self.use_abs_extrinsics)
camera_batch = {k: v[None].repeat_interleave(batch_size, dim=0).to(device) for k, v in camera_batch.items()}
# Prepare gaussian blob initialization
blob = self.get_gaussian_image(sigma=blob_sigma).to(device)
camera_batch["img"] = blob.unsqueeze(0).unsqueeze(0).repeat(batch_size, n_views, 1, 1, 1)
# Encode text
text_input_ids = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt").input_ids.to(device)
text_embeddings = self.text_encoder(text_input_ids)[0]
# Prepare unconditional embeddings for classifier-free guidance
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
# Encode camera data
camera_embeddings = self.cc_projection(camera_batch["cam"])
# Prepare latents
latent_height, latent_width = self.vae.config.sample_size // 8, self.vae.config.sample_size // 8
latents = self.prepare_latents(
batch_size,
self.unet.in_channels,
n_views,
latent_height,
latent_width,
self.unet.dtype,
device,
generator=None,
)
# Prepare epi_constraint_masks (placeholder, replace with actual implementation)
epi_constraint_masks = torch.ones(batch_size, n_views, latent_height, latent_width, n_views, latent_height, latent_width, dtype=torch.bool, device=device)
# Prepare plucker embeddings (placeholder, replace with actual implementation)
plucker_embeds = torch.zeros(batch_size, n_views, 6, latent_height, latent_width, device=device)
latent_height, latent_width = 64, 64 # Fixed to match the required shape [batch_size, 1, 4, 64, 64]
n_objects = 2;
latents = torch.randn(n_objects, n_views, 10, 32, 32, device=device, dtype=self.unet.dtype)
# Set up scheduler
# self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.set_timesteps(10)
# Repeat text_embeddings to match the desired dimensions
text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len, 512]
# Reshape text_embeddings to match [n_objects, n_views, max_seq_len, 512]
text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1)
# Denoising loop
for t in tqdm(self.scheduler.timesteps):
# Expand timesteps to match shape [batch_size, 1, 1]
# timesteps = torch.full((batch_size, 1, 1), t, device=device, dtype=torch.long)
timesteps = torch.full((n_objects, n_views), t, device=device, dtype=torch.long)
# # Repeat text_embeddings to match the desired dimensions
# text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len, 512]
# # Reshape text_embeddings to match [n_objects, n_views, max_seq_len, 512]
# text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1)
# print("old cam shape: ", camera_embeddings.shape)
camera_embeddings = camera_embeddings.repeat(n_objects, 1, 1, 1)
# print("cam emb shape: ", camera_embeddings.shape)
# Prepare context
context = [
# text_embeddings.unsqueeze(1), # [batch_size, 1, max_seq_len, 768]
# camera_embeddings.unsqueeze(1) * 0.0, # [batch_size, 1, 1280] * 0.0
# epi_constraint_masks # Keep this as is for now
text_embeddings, # [n_objects, n_views, max_seq_len, 768]
camera_embeddings # [n_objects, n_views, 1280]
]
# Predict noise residual
noise_pred = self.unet(
latents, # Shape: [batch_size, 1, 4, 64, 64]
timesteps=timesteps, # Shape: [batch_size, 1, 1]
context=context
)
# Perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# Compute previous noisy sample
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
# reduce latents
#EXPERIMENTAL
# If you need to reduce the channels from 10 to 4
latents = latents[:, :, :4, :, :] # Select only the first 4 channels
latents = latents.view(-1, latents.shape[2], latents.shape[3], latents.shape[4])
# Decode latents
images = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
# Post-process images
images = (images / 2 + 0.5).clamp(0, 1)
if images.dim() == 5:
images = images.cpu().permute(0, 1, 3, 4, 2).float().numpy() # For 5D tensors
elif images.dim() == 4:
images = images.cpu().permute(0, 2, 3, 1).float().numpy() # For 4D tensors
else:
raise ValueError(f"Unexpected image dimensions: {images.shape}")
return {"images": images, "nsfw_content_detected": [[False] * n_views for _ in range(batch_size)]}
def prepare_latents(self, batch_size, num_channels, num_views, height, width, dtype, device, generator=None):
shape = (batch_size, num_views, num_channels, height, width)
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
return latents |