spad / pipeline_spad.py
jadechoghari's picture
update (lots of bugs to fix)
3db57e8 verified
raw
history blame
9.23 kB
import os
import torch
import torch.nn as nn
import numpy as np
import math
from diffusers import DiffusionPipeline
from einops import rearrange, repeat
from itertools import chain
from tqdm import tqdm
from .geometry import get_batch_from_spherical
class SPADPipeline(DiffusionPipeline):
def __init__(self, unet, vae, text_encoder, tokenizer, scheduler):
super().__init__()
self.register_modules(
unet=unet,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler
)
self.cfg_conds = ["txt", "cam", "epi", "plucker"]
self.cfg_scales = [7.5, 1.0, 1.0, 1.0] # Default scales, adjust as needed
self.use_abs_extrinsics = False
self.use_intrinsic = False
self.cc_projection = nn.Sequential(
nn.Linear(4 if not self.use_intrinsic else 8, 1280),
nn.SiLU(),
nn.Linear(1280, 1280),
)
nn.init.zeros_(self.cc_projection[-1].weight)
nn.init.zeros_(self.cc_projection[-1].bias)
def generate_camera_batch(self, elevations, azimuths, use_abs=False):
batch = get_batch_from_spherical(elevations, azimuths)
abs_cams = [torch.tensor([theta, azimuth, 3.5]) for theta, azimuth in zip(elevations, azimuths)]
debug_cams = [[] for _ in range(len(azimuths))]
for i, icam in enumerate(abs_cams):
for j, jcam in enumerate(abs_cams):
if use_abs:
dcam = torch.tensor([icam[0], math.sin(icam[1]), math.cos(icam[1]), icam[2]])
else:
dcam = icam - jcam
dcam = torch.tensor([dcam[0].item(), math.sin(dcam[1].item()), math.cos(dcam[1].item()), dcam[2].item()])
debug_cams[i].append(dcam)
batch["cam"] = torch.stack([torch.stack(dc) for dc in debug_cams])
# Add intrinsics to the batch
focal = 1 / np.tan(0.702769935131073 / 2)
intrinsics = np.diag(np.array([focal, focal, 1])).astype(np.float32)
intrinsics = torch.from_numpy(intrinsics).unsqueeze(0).float().repeat(batch["cam"].shape[0], 1, 1)
batch["render_intrinsics_flat"] = intrinsics[:, [0,1,0,1], [0,1,-1,-1]]
return batch
def get_gaussian_image(self, blob_width=256, blob_height=256, sigma=0.5):
X = np.linspace(-1, 1, blob_width)[None, :]
Y = np.linspace(-1, 1, blob_height)[:, None]
inv_dev = 1 / sigma ** 2
gaussian_blob = np.exp(-0.5 * (X**2) * inv_dev) * np.exp(-0.5 * (Y**2) * inv_dev)
if gaussian_blob.max() > 0:
gaussian_blob = 255.0 * (gaussian_blob - gaussian_blob.min()) / gaussian_blob.max()
gaussian_blob = 255.0 - gaussian_blob
gaussian_blob = (gaussian_blob / 255.0) * 2.0 - 1.0
gaussian_blob = np.expand_dims(gaussian_blob, axis=-1).repeat(3,-1)
gaussian_blob = torch.from_numpy(gaussian_blob)
return gaussian_blob
@torch.no_grad()
def __call__(self, prompt, num_inference_steps=50, guidance_scale=7.5, num_images_per_prompt=1,
elevations=None, azimuths=None, blob_sigma=0.5, **kwargs):
batch_size = len(prompt) if isinstance(prompt, list) else 1
device = self.device
# Generate camera batch
if elevations is None or azimuths is None:
elevations = [45] * 4
azimuths = [0, 90, 180, 270]
n_views = len(elevations)
camera_batch = self.generate_camera_batch(elevations, azimuths, use_abs=self.use_abs_extrinsics)
camera_batch = {k: v[None].repeat_interleave(batch_size, dim=0).to(device) for k, v in camera_batch.items()}
# Prepare gaussian blob initialization
blob = self.get_gaussian_image(sigma=blob_sigma).to(device)
camera_batch["img"] = blob.unsqueeze(0).unsqueeze(0).repeat(batch_size, n_views, 1, 1, 1)
# Encode text
text_input_ids = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt").input_ids.to(device)
text_embeddings = self.text_encoder(text_input_ids)[0]
# Prepare unconditional embeddings for classifier-free guidance
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
# Encode camera data
camera_embeddings = self.cc_projection(camera_batch["cam"])
# Prepare latents
latent_height, latent_width = self.vae.config.sample_size // 8, self.vae.config.sample_size // 8
latents = self.prepare_latents(
batch_size,
self.unet.in_channels,
n_views,
latent_height,
latent_width,
self.unet.dtype,
device,
generator=None,
)
# Prepare epi_constraint_masks (placeholder, replace with actual implementation)
epi_constraint_masks = torch.ones(batch_size, n_views, latent_height, latent_width, n_views, latent_height, latent_width, dtype=torch.bool, device=device)
# Prepare plucker embeddings (placeholder, replace with actual implementation)
plucker_embeds = torch.zeros(batch_size, n_views, 6, latent_height, latent_width, device=device)
latent_height, latent_width = 64, 64 # Fixed to match the required shape [batch_size, 1, 4, 64, 64]
n_objects = 2;
latents = torch.randn(n_objects, n_views, 10, 32, 32, device=device, dtype=self.unet.dtype)
# Set up scheduler
# self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.set_timesteps(10)
# Repeat text_embeddings to match the desired dimensions
text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len, 512]
# Reshape text_embeddings to match [n_objects, n_views, max_seq_len, 512]
text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1)
# Denoising loop
for t in tqdm(self.scheduler.timesteps):
# Expand timesteps to match shape [batch_size, 1, 1]
# timesteps = torch.full((batch_size, 1, 1), t, device=device, dtype=torch.long)
timesteps = torch.full((n_objects, n_views), t, device=device, dtype=torch.long)
# # Repeat text_embeddings to match the desired dimensions
# text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len, 512]
# # Reshape text_embeddings to match [n_objects, n_views, max_seq_len, 512]
# text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1)
# print("old cam shape: ", camera_embeddings.shape)
camera_embeddings = camera_embeddings.repeat(n_objects, 1, 1, 1)
# print("cam emb shape: ", camera_embeddings.shape)
# Prepare context
context = [
# text_embeddings.unsqueeze(1), # [batch_size, 1, max_seq_len, 768]
# camera_embeddings.unsqueeze(1) * 0.0, # [batch_size, 1, 1280] * 0.0
# epi_constraint_masks # Keep this as is for now
text_embeddings, # [n_objects, n_views, max_seq_len, 768]
camera_embeddings # [n_objects, n_views, 1280]
]
# Predict noise residual
noise_pred = self.unet(
latents, # Shape: [batch_size, 1, 4, 64, 64]
timesteps=timesteps, # Shape: [batch_size, 1, 1]
context=context
)
# Perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# Compute previous noisy sample
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
# reduce latents
#EXPERIMENTAL
# If you need to reduce the channels from 10 to 4
latents = latents[:, :, :4, :, :] # Select only the first 4 channels
latents = latents.view(-1, latents.shape[2], latents.shape[3], latents.shape[4])
# Decode latents
images = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
# Post-process images
images = (images / 2 + 0.5).clamp(0, 1)
if images.dim() == 5:
images = images.cpu().permute(0, 1, 3, 4, 2).float().numpy() # For 5D tensors
elif images.dim() == 4:
images = images.cpu().permute(0, 2, 3, 1).float().numpy() # For 4D tensors
else:
raise ValueError(f"Unexpected image dimensions: {images.shape}")
return {"images": images, "nsfw_content_detected": [[False] * n_views for _ in range(batch_size)]}
def prepare_latents(self, batch_size, num_channels, num_views, height, width, dtype, device, generator=None):
shape = (batch_size, num_views, num_channels, height, width)
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
return latents